1 //===- ComplexDeinterleavingPass.cpp --------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Identification: 10 // This step is responsible for finding the patterns that can be lowered to 11 // complex instructions, and building a graph to represent the complex 12 // structures. Starting from the "Converging Shuffle" (a shuffle that 13 // reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the 14 // operands are evaluated and identified as "Composite Nodes" (collections of 15 // instructions that can potentially be lowered to a single complex 16 // instruction). This is performed by checking the real and imaginary components 17 // and tracking the data flow for each component while following the operand 18 // pairs. Validity of each node is expected to be done upon creation, and any 19 // validation errors should halt traversal and prevent further graph 20 // construction. 21 // Instead of relying on Shuffle operations, vector interleaving and 22 // deinterleaving can be represented by vector.interleave2 and 23 // vector.deinterleave2 intrinsics. Scalable vectors can be represented only by 24 // these intrinsics, whereas, fixed-width vectors are recognized for both 25 // shufflevector instruction and intrinsics. 26 // 27 // Replacement: 28 // This step traverses the graph built up by identification, delegating to the 29 // target to validate and generate the correct intrinsics, and plumbs them 30 // together connecting each end of the new intrinsics graph to the existing 31 // use-def chain. This step is assumed to finish successfully, as all 32 // information is expected to be correct by this point. 33 // 34 // 35 // Internal data structure: 36 // ComplexDeinterleavingGraph: 37 // Keeps references to all the valid CompositeNodes formed as part of the 38 // transformation, and every Instruction contained within said nodes. It also 39 // holds onto a reference to the root Instruction, and the root node that should 40 // replace it. 41 // 42 // ComplexDeinterleavingCompositeNode: 43 // A CompositeNode represents a single transformation point; each node should 44 // transform into a single complex instruction (ignoring vector splitting, which 45 // would generate more instructions per node). They are identified in a 46 // depth-first manner, traversing and identifying the operands of each 47 // instruction in the order they appear in the IR. 48 // Each node maintains a reference to its Real and Imaginary instructions, 49 // as well as any additional instructions that make up the identified operation 50 // (Internal instructions should only have uses within their containing node). 51 // A Node also contains the rotation and operation type that it represents. 52 // Operands contains pointers to other CompositeNodes, acting as the edges in 53 // the graph. ReplacementValue is the transformed Value* that has been emitted 54 // to the IR. 55 // 56 // Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and 57 // ReplacementValue fields of that Node are relevant, where the ReplacementValue 58 // should be pre-populated. 59 // 60 //===----------------------------------------------------------------------===// 61 62 #include "llvm/CodeGen/ComplexDeinterleavingPass.h" 63 #include "llvm/ADT/Statistic.h" 64 #include "llvm/Analysis/TargetLibraryInfo.h" 65 #include "llvm/Analysis/TargetTransformInfo.h" 66 #include "llvm/CodeGen/TargetLowering.h" 67 #include "llvm/CodeGen/TargetPassConfig.h" 68 #include "llvm/CodeGen/TargetSubtargetInfo.h" 69 #include "llvm/IR/IRBuilder.h" 70 #include "llvm/IR/PatternMatch.h" 71 #include "llvm/InitializePasses.h" 72 #include "llvm/Target/TargetMachine.h" 73 #include "llvm/Transforms/Utils/Local.h" 74 #include <algorithm> 75 76 using namespace llvm; 77 using namespace PatternMatch; 78 79 #define DEBUG_TYPE "complex-deinterleaving" 80 81 STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed"); 82 83 static cl::opt<bool> ComplexDeinterleavingEnabled( 84 "enable-complex-deinterleaving", 85 cl::desc("Enable generation of complex instructions"), cl::init(true), 86 cl::Hidden); 87 88 /// Checks the given mask, and determines whether said mask is interleaving. 89 /// 90 /// To be interleaving, a mask must alternate between `i` and `i + (Length / 91 /// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a 92 /// 4x vector interleaving mask would be <0, 2, 1, 3>). 93 static bool isInterleavingMask(ArrayRef<int> Mask); 94 95 /// Checks the given mask, and determines whether said mask is deinterleaving. 96 /// 97 /// To be deinterleaving, a mask must increment in steps of 2, and either start 98 /// with 0 or 1. 99 /// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or 100 /// <1, 3, 5, 7>). 101 static bool isDeinterleavingMask(ArrayRef<int> Mask); 102 103 /// Returns true if the operation is a negation of V, and it works for both 104 /// integers and floats. 105 static bool isNeg(Value *V); 106 107 /// Returns the operand for negation operation. 108 static Value *getNegOperand(Value *V); 109 110 namespace { 111 112 class ComplexDeinterleavingLegacyPass : public FunctionPass { 113 public: 114 static char ID; 115 116 ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr) 117 : FunctionPass(ID), TM(TM) { 118 initializeComplexDeinterleavingLegacyPassPass( 119 *PassRegistry::getPassRegistry()); 120 } 121 122 StringRef getPassName() const override { 123 return "Complex Deinterleaving Pass"; 124 } 125 126 bool runOnFunction(Function &F) override; 127 void getAnalysisUsage(AnalysisUsage &AU) const override { 128 AU.addRequired<TargetLibraryInfoWrapperPass>(); 129 AU.setPreservesCFG(); 130 } 131 132 private: 133 const TargetMachine *TM; 134 }; 135 136 class ComplexDeinterleavingGraph; 137 struct ComplexDeinterleavingCompositeNode { 138 139 ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op, 140 Value *R, Value *I) 141 : Operation(Op), Real(R), Imag(I) {} 142 143 private: 144 friend class ComplexDeinterleavingGraph; 145 using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>; 146 using RawNodePtr = ComplexDeinterleavingCompositeNode *; 147 148 public: 149 ComplexDeinterleavingOperation Operation; 150 Value *Real; 151 Value *Imag; 152 153 // This two members are required exclusively for generating 154 // ComplexDeinterleavingOperation::Symmetric operations. 155 unsigned Opcode; 156 std::optional<FastMathFlags> Flags; 157 158 ComplexDeinterleavingRotation Rotation = 159 ComplexDeinterleavingRotation::Rotation_0; 160 SmallVector<RawNodePtr> Operands; 161 Value *ReplacementNode = nullptr; 162 163 void addOperand(NodePtr Node) { Operands.push_back(Node.get()); } 164 165 void dump() { dump(dbgs()); } 166 void dump(raw_ostream &OS) { 167 auto PrintValue = [&](Value *V) { 168 if (V) { 169 OS << "\""; 170 V->print(OS, true); 171 OS << "\"\n"; 172 } else 173 OS << "nullptr\n"; 174 }; 175 auto PrintNodeRef = [&](RawNodePtr Ptr) { 176 if (Ptr) 177 OS << Ptr << "\n"; 178 else 179 OS << "nullptr\n"; 180 }; 181 182 OS << "- CompositeNode: " << this << "\n"; 183 OS << " Real: "; 184 PrintValue(Real); 185 OS << " Imag: "; 186 PrintValue(Imag); 187 OS << " ReplacementNode: "; 188 PrintValue(ReplacementNode); 189 OS << " Operation: " << (int)Operation << "\n"; 190 OS << " Rotation: " << ((int)Rotation * 90) << "\n"; 191 OS << " Operands: \n"; 192 for (const auto &Op : Operands) { 193 OS << " - "; 194 PrintNodeRef(Op); 195 } 196 } 197 }; 198 199 class ComplexDeinterleavingGraph { 200 public: 201 struct Product { 202 Value *Multiplier; 203 Value *Multiplicand; 204 bool IsPositive; 205 }; 206 207 using Addend = std::pair<Value *, bool>; 208 using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr; 209 using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr; 210 211 // Helper struct for holding info about potential partial multiplication 212 // candidates 213 struct PartialMulCandidate { 214 Value *Common; 215 NodePtr Node; 216 unsigned RealIdx; 217 unsigned ImagIdx; 218 bool IsNodeInverted; 219 }; 220 221 explicit ComplexDeinterleavingGraph(const TargetLowering *TL, 222 const TargetLibraryInfo *TLI) 223 : TL(TL), TLI(TLI) {} 224 225 private: 226 const TargetLowering *TL = nullptr; 227 const TargetLibraryInfo *TLI = nullptr; 228 SmallVector<NodePtr> CompositeNodes; 229 DenseMap<std::pair<Value *, Value *>, NodePtr> CachedResult; 230 231 SmallPtrSet<Instruction *, 16> FinalInstructions; 232 233 /// Root instructions are instructions from which complex computation starts 234 std::map<Instruction *, NodePtr> RootToNode; 235 236 /// Topologically sorted root instructions 237 SmallVector<Instruction *, 1> OrderedRoots; 238 239 /// When examining a basic block for complex deinterleaving, if it is a simple 240 /// one-block loop, then the only incoming block is 'Incoming' and the 241 /// 'BackEdge' block is the block itself." 242 BasicBlock *BackEdge = nullptr; 243 BasicBlock *Incoming = nullptr; 244 245 /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction 246 /// %OutsideUser as it is shown in the IR: 247 /// 248 /// vector.body: 249 /// %PHInode = phi <vector type> [ zeroinitializer, %entry ], 250 /// [ %ReductionOp, %vector.body ] 251 /// ... 252 /// %ReductionOp = fadd i64 ... 253 /// ... 254 /// br i1 %condition, label %vector.body, %middle.block 255 /// 256 /// middle.block: 257 /// %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp) 258 /// 259 /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding 260 /// `llvm.vector.reduce.fadd` when unroll factor isn't one. 261 std::map<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo; 262 263 /// In the process of detecting a reduction, we consider a pair of 264 /// %ReductionOP, which we refer to as real and imag (or vice versa), and 265 /// traverse the use-tree to detect complex operations. As this is a reduction 266 /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds 267 /// to the %ReductionOPs that we suspect to be complex. 268 /// RealPHI and ImagPHI are used by the identifyPHINode method. 269 PHINode *RealPHI = nullptr; 270 PHINode *ImagPHI = nullptr; 271 272 /// Set this flag to true if RealPHI and ImagPHI were reached during reduction 273 /// detection. 274 bool PHIsFound = false; 275 276 /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode. 277 /// The new PHINode corresponds to a vector of deinterleaved complex numbers. 278 /// This mapping is populated during 279 /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then 280 /// used in the ComplexDeinterleavingOperation::ReductionOperation node 281 /// replacement process. 282 std::map<PHINode *, PHINode *> OldToNewPHI; 283 284 NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation, 285 Value *R, Value *I) { 286 assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI && 287 Operation != ComplexDeinterleavingOperation::ReductionOperation) || 288 (R && I)) && 289 "Reduction related nodes must have Real and Imaginary parts"); 290 return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R, 291 I); 292 } 293 294 NodePtr submitCompositeNode(NodePtr Node) { 295 CompositeNodes.push_back(Node); 296 if (Node->Real && Node->Imag) 297 CachedResult[{Node->Real, Node->Imag}] = Node; 298 return Node; 299 } 300 301 /// Identifies a complex partial multiply pattern and its rotation, based on 302 /// the following patterns 303 /// 304 /// 0: r: cr + ar * br 305 /// i: ci + ar * bi 306 /// 90: r: cr - ai * bi 307 /// i: ci + ai * br 308 /// 180: r: cr - ar * br 309 /// i: ci - ar * bi 310 /// 270: r: cr + ai * bi 311 /// i: ci - ai * br 312 NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag); 313 314 /// Identify the other branch of a Partial Mul, taking the CommonOperandI that 315 /// is partially known from identifyPartialMul, filling in the other half of 316 /// the complex pair. 317 NodePtr 318 identifyNodeWithImplicitAdd(Instruction *I, Instruction *J, 319 std::pair<Value *, Value *> &CommonOperandI); 320 321 /// Identifies a complex add pattern and its rotation, based on the following 322 /// patterns. 323 /// 324 /// 90: r: ar - bi 325 /// i: ai + br 326 /// 270: r: ar + bi 327 /// i: ai - br 328 NodePtr identifyAdd(Instruction *Real, Instruction *Imag); 329 NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag); 330 331 NodePtr identifyNode(Value *R, Value *I); 332 333 /// Determine if a sum of complex numbers can be formed from \p RealAddends 334 /// and \p ImagAddens. If \p Accumulator is not null, add the result to it. 335 /// Return nullptr if it is not possible to construct a complex number. 336 /// \p Flags are needed to generate symmetric Add and Sub operations. 337 NodePtr identifyAdditions(std::list<Addend> &RealAddends, 338 std::list<Addend> &ImagAddends, 339 std::optional<FastMathFlags> Flags, 340 NodePtr Accumulator); 341 342 /// Extract one addend that have both real and imaginary parts positive. 343 NodePtr extractPositiveAddend(std::list<Addend> &RealAddends, 344 std::list<Addend> &ImagAddends); 345 346 /// Determine if sum of multiplications of complex numbers can be formed from 347 /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result 348 /// to it. Return nullptr if it is not possible to construct a complex number. 349 NodePtr identifyMultiplications(std::vector<Product> &RealMuls, 350 std::vector<Product> &ImagMuls, 351 NodePtr Accumulator); 352 353 /// Go through pairs of multiplication (one Real and one Imag) and find all 354 /// possible candidates for partial multiplication and put them into \p 355 /// Candidates. Returns true if all Product has pair with common operand 356 bool collectPartialMuls(const std::vector<Product> &RealMuls, 357 const std::vector<Product> &ImagMuls, 358 std::vector<PartialMulCandidate> &Candidates); 359 360 /// If the code is compiled with -Ofast or expressions have `reassoc` flag, 361 /// the order of complex computation operations may be significantly altered, 362 /// and the real and imaginary parts may not be executed in parallel. This 363 /// function takes this into consideration and employs a more general approach 364 /// to identify complex computations. Initially, it gathers all the addends 365 /// and multiplicands and then constructs a complex expression from them. 366 NodePtr identifyReassocNodes(Instruction *I, Instruction *J); 367 368 NodePtr identifyRoot(Instruction *I); 369 370 /// Identifies the Deinterleave operation applied to a vector containing 371 /// complex numbers. There are two ways to represent the Deinterleave 372 /// operation: 373 /// * Using two shufflevectors with even indices for /pReal instruction and 374 /// odd indices for /pImag instructions (only for fixed-width vectors) 375 /// * Using two extractvalue instructions applied to `vector.deinterleave2` 376 /// intrinsic (for both fixed and scalable vectors) 377 NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag); 378 379 /// identifying the operation that represents a complex number repeated in a 380 /// Splat vector. There are two possible types of splats: ConstantExpr with 381 /// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an 382 /// initialization mask with all values set to zero. 383 NodePtr identifySplat(Value *Real, Value *Imag); 384 385 NodePtr identifyPHINode(Instruction *Real, Instruction *Imag); 386 387 /// Identifies SelectInsts in a loop that has reduction with predication masks 388 /// and/or predicated tail folding 389 NodePtr identifySelectNode(Instruction *Real, Instruction *Imag); 390 391 Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node); 392 393 /// Complete IR modifications after producing new reduction operation: 394 /// * Populate the PHINode generated for 395 /// ComplexDeinterleavingOperation::ReductionPHI 396 /// * Deinterleave the final value outside of the loop and repurpose original 397 /// reduction users 398 void processReductionOperation(Value *OperationReplacement, RawNodePtr Node); 399 400 public: 401 void dump() { dump(dbgs()); } 402 void dump(raw_ostream &OS) { 403 for (const auto &Node : CompositeNodes) 404 Node->dump(OS); 405 } 406 407 /// Returns false if the deinterleaving operation should be cancelled for the 408 /// current graph. 409 bool identifyNodes(Instruction *RootI); 410 411 /// In case \pB is one-block loop, this function seeks potential reductions 412 /// and populates ReductionInfo. Returns true if any reductions were 413 /// identified. 414 bool collectPotentialReductions(BasicBlock *B); 415 416 void identifyReductionNodes(); 417 418 /// Check that every instruction, from the roots to the leaves, has internal 419 /// uses. 420 bool checkNodes(); 421 422 /// Perform the actual replacement of the underlying instruction graph. 423 void replaceNodes(); 424 }; 425 426 class ComplexDeinterleaving { 427 public: 428 ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli) 429 : TL(tl), TLI(tli) {} 430 bool runOnFunction(Function &F); 431 432 private: 433 bool evaluateBasicBlock(BasicBlock *B); 434 435 const TargetLowering *TL = nullptr; 436 const TargetLibraryInfo *TLI = nullptr; 437 }; 438 439 } // namespace 440 441 char ComplexDeinterleavingLegacyPass::ID = 0; 442 443 INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE, 444 "Complex Deinterleaving", false, false) 445 INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE, 446 "Complex Deinterleaving", false, false) 447 448 PreservedAnalyses ComplexDeinterleavingPass::run(Function &F, 449 FunctionAnalysisManager &AM) { 450 const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering(); 451 auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F); 452 if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F)) 453 return PreservedAnalyses::all(); 454 455 PreservedAnalyses PA; 456 PA.preserve<FunctionAnalysisManagerModuleProxy>(); 457 return PA; 458 } 459 460 FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) { 461 return new ComplexDeinterleavingLegacyPass(TM); 462 } 463 464 bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) { 465 const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering(); 466 auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); 467 return ComplexDeinterleaving(TL, &TLI).runOnFunction(F); 468 } 469 470 bool ComplexDeinterleaving::runOnFunction(Function &F) { 471 if (!ComplexDeinterleavingEnabled) { 472 LLVM_DEBUG( 473 dbgs() << "Complex deinterleaving has been explicitly disabled.\n"); 474 return false; 475 } 476 477 if (!TL->isComplexDeinterleavingSupported()) { 478 LLVM_DEBUG( 479 dbgs() << "Complex deinterleaving has been disabled, target does " 480 "not support lowering of complex number operations.\n"); 481 return false; 482 } 483 484 bool Changed = false; 485 for (auto &B : F) 486 Changed |= evaluateBasicBlock(&B); 487 488 return Changed; 489 } 490 491 static bool isInterleavingMask(ArrayRef<int> Mask) { 492 // If the size is not even, it's not an interleaving mask 493 if ((Mask.size() & 1)) 494 return false; 495 496 int HalfNumElements = Mask.size() / 2; 497 for (int Idx = 0; Idx < HalfNumElements; ++Idx) { 498 int MaskIdx = Idx * 2; 499 if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements)) 500 return false; 501 } 502 503 return true; 504 } 505 506 static bool isDeinterleavingMask(ArrayRef<int> Mask) { 507 int Offset = Mask[0]; 508 int HalfNumElements = Mask.size() / 2; 509 510 for (int Idx = 1; Idx < HalfNumElements; ++Idx) { 511 if (Mask[Idx] != (Idx * 2) + Offset) 512 return false; 513 } 514 515 return true; 516 } 517 518 bool isNeg(Value *V) { 519 return match(V, m_FNeg(m_Value())) || match(V, m_Neg(m_Value())); 520 } 521 522 Value *getNegOperand(Value *V) { 523 assert(isNeg(V)); 524 auto *I = cast<Instruction>(V); 525 if (I->getOpcode() == Instruction::FNeg) 526 return I->getOperand(0); 527 528 return I->getOperand(1); 529 } 530 531 bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) { 532 ComplexDeinterleavingGraph Graph(TL, TLI); 533 if (Graph.collectPotentialReductions(B)) 534 Graph.identifyReductionNodes(); 535 536 for (auto &I : *B) 537 Graph.identifyNodes(&I); 538 539 if (Graph.checkNodes()) { 540 Graph.replaceNodes(); 541 return true; 542 } 543 544 return false; 545 } 546 547 ComplexDeinterleavingGraph::NodePtr 548 ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd( 549 Instruction *Real, Instruction *Imag, 550 std::pair<Value *, Value *> &PartialMatch) { 551 LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag 552 << "\n"); 553 554 if (!Real->hasOneUse() || !Imag->hasOneUse()) { 555 LLVM_DEBUG(dbgs() << " - Mul operand has multiple uses.\n"); 556 return nullptr; 557 } 558 559 if ((Real->getOpcode() != Instruction::FMul && 560 Real->getOpcode() != Instruction::Mul) || 561 (Imag->getOpcode() != Instruction::FMul && 562 Imag->getOpcode() != Instruction::Mul)) { 563 LLVM_DEBUG( 564 dbgs() << " - Real or imaginary instruction is not fmul or mul\n"); 565 return nullptr; 566 } 567 568 Value *R0 = Real->getOperand(0); 569 Value *R1 = Real->getOperand(1); 570 Value *I0 = Imag->getOperand(0); 571 Value *I1 = Imag->getOperand(1); 572 573 // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the 574 // rotations and use the operand. 575 unsigned Negs = 0; 576 Value *Op; 577 if (match(R0, m_Neg(m_Value(Op)))) { 578 Negs |= 1; 579 R0 = Op; 580 } else if (match(R1, m_Neg(m_Value(Op)))) { 581 Negs |= 1; 582 R1 = Op; 583 } 584 585 if (isNeg(I0)) { 586 Negs |= 2; 587 Negs ^= 1; 588 I0 = Op; 589 } else if (match(I1, m_Neg(m_Value(Op)))) { 590 Negs |= 2; 591 Negs ^= 1; 592 I1 = Op; 593 } 594 595 ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs; 596 597 Value *CommonOperand; 598 Value *UncommonRealOp; 599 Value *UncommonImagOp; 600 601 if (R0 == I0 || R0 == I1) { 602 CommonOperand = R0; 603 UncommonRealOp = R1; 604 } else if (R1 == I0 || R1 == I1) { 605 CommonOperand = R1; 606 UncommonRealOp = R0; 607 } else { 608 LLVM_DEBUG(dbgs() << " - No equal operand\n"); 609 return nullptr; 610 } 611 612 UncommonImagOp = (CommonOperand == I0) ? I1 : I0; 613 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 || 614 Rotation == ComplexDeinterleavingRotation::Rotation_270) 615 std::swap(UncommonRealOp, UncommonImagOp); 616 617 // Between identifyPartialMul and here we need to have found a complete valid 618 // pair from the CommonOperand of each part. 619 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 || 620 Rotation == ComplexDeinterleavingRotation::Rotation_180) 621 PartialMatch.first = CommonOperand; 622 else 623 PartialMatch.second = CommonOperand; 624 625 if (!PartialMatch.first || !PartialMatch.second) { 626 LLVM_DEBUG(dbgs() << " - Incomplete partial match\n"); 627 return nullptr; 628 } 629 630 NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second); 631 if (!CommonNode) { 632 LLVM_DEBUG(dbgs() << " - No CommonNode identified\n"); 633 return nullptr; 634 } 635 636 NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp); 637 if (!UncommonNode) { 638 LLVM_DEBUG(dbgs() << " - No UncommonNode identified\n"); 639 return nullptr; 640 } 641 642 NodePtr Node = prepareCompositeNode( 643 ComplexDeinterleavingOperation::CMulPartial, Real, Imag); 644 Node->Rotation = Rotation; 645 Node->addOperand(CommonNode); 646 Node->addOperand(UncommonNode); 647 return submitCompositeNode(Node); 648 } 649 650 ComplexDeinterleavingGraph::NodePtr 651 ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real, 652 Instruction *Imag) { 653 LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag 654 << "\n"); 655 // Determine rotation 656 auto IsAdd = [](unsigned Op) { 657 return Op == Instruction::FAdd || Op == Instruction::Add; 658 }; 659 auto IsSub = [](unsigned Op) { 660 return Op == Instruction::FSub || Op == Instruction::Sub; 661 }; 662 ComplexDeinterleavingRotation Rotation; 663 if (IsAdd(Real->getOpcode()) && IsAdd(Imag->getOpcode())) 664 Rotation = ComplexDeinterleavingRotation::Rotation_0; 665 else if (IsSub(Real->getOpcode()) && IsAdd(Imag->getOpcode())) 666 Rotation = ComplexDeinterleavingRotation::Rotation_90; 667 else if (IsSub(Real->getOpcode()) && IsSub(Imag->getOpcode())) 668 Rotation = ComplexDeinterleavingRotation::Rotation_180; 669 else if (IsAdd(Real->getOpcode()) && IsSub(Imag->getOpcode())) 670 Rotation = ComplexDeinterleavingRotation::Rotation_270; 671 else { 672 LLVM_DEBUG(dbgs() << " - Unhandled rotation.\n"); 673 return nullptr; 674 } 675 676 if (isa<FPMathOperator>(Real) && 677 (!Real->getFastMathFlags().allowContract() || 678 !Imag->getFastMathFlags().allowContract())) { 679 LLVM_DEBUG(dbgs() << " - Contract is missing from the FastMath flags.\n"); 680 return nullptr; 681 } 682 683 Value *CR = Real->getOperand(0); 684 Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1)); 685 if (!RealMulI) 686 return nullptr; 687 Value *CI = Imag->getOperand(0); 688 Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1)); 689 if (!ImagMulI) 690 return nullptr; 691 692 if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) { 693 LLVM_DEBUG(dbgs() << " - Mul instruction has multiple uses\n"); 694 return nullptr; 695 } 696 697 Value *R0 = RealMulI->getOperand(0); 698 Value *R1 = RealMulI->getOperand(1); 699 Value *I0 = ImagMulI->getOperand(0); 700 Value *I1 = ImagMulI->getOperand(1); 701 702 Value *CommonOperand; 703 Value *UncommonRealOp; 704 Value *UncommonImagOp; 705 706 if (R0 == I0 || R0 == I1) { 707 CommonOperand = R0; 708 UncommonRealOp = R1; 709 } else if (R1 == I0 || R1 == I1) { 710 CommonOperand = R1; 711 UncommonRealOp = R0; 712 } else { 713 LLVM_DEBUG(dbgs() << " - No equal operand\n"); 714 return nullptr; 715 } 716 717 UncommonImagOp = (CommonOperand == I0) ? I1 : I0; 718 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 || 719 Rotation == ComplexDeinterleavingRotation::Rotation_270) 720 std::swap(UncommonRealOp, UncommonImagOp); 721 722 std::pair<Value *, Value *> PartialMatch( 723 (Rotation == ComplexDeinterleavingRotation::Rotation_0 || 724 Rotation == ComplexDeinterleavingRotation::Rotation_180) 725 ? CommonOperand 726 : nullptr, 727 (Rotation == ComplexDeinterleavingRotation::Rotation_90 || 728 Rotation == ComplexDeinterleavingRotation::Rotation_270) 729 ? CommonOperand 730 : nullptr); 731 732 auto *CRInst = dyn_cast<Instruction>(CR); 733 auto *CIInst = dyn_cast<Instruction>(CI); 734 735 if (!CRInst || !CIInst) { 736 LLVM_DEBUG(dbgs() << " - Common operands are not instructions.\n"); 737 return nullptr; 738 } 739 740 NodePtr CNode = identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch); 741 if (!CNode) { 742 LLVM_DEBUG(dbgs() << " - No cnode identified\n"); 743 return nullptr; 744 } 745 746 NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp); 747 if (!UncommonRes) { 748 LLVM_DEBUG(dbgs() << " - No UncommonRes identified\n"); 749 return nullptr; 750 } 751 752 assert(PartialMatch.first && PartialMatch.second); 753 NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second); 754 if (!CommonRes) { 755 LLVM_DEBUG(dbgs() << " - No CommonRes identified\n"); 756 return nullptr; 757 } 758 759 NodePtr Node = prepareCompositeNode( 760 ComplexDeinterleavingOperation::CMulPartial, Real, Imag); 761 Node->Rotation = Rotation; 762 Node->addOperand(CommonRes); 763 Node->addOperand(UncommonRes); 764 Node->addOperand(CNode); 765 return submitCompositeNode(Node); 766 } 767 768 ComplexDeinterleavingGraph::NodePtr 769 ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) { 770 LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n"); 771 772 // Determine rotation 773 ComplexDeinterleavingRotation Rotation; 774 if ((Real->getOpcode() == Instruction::FSub && 775 Imag->getOpcode() == Instruction::FAdd) || 776 (Real->getOpcode() == Instruction::Sub && 777 Imag->getOpcode() == Instruction::Add)) 778 Rotation = ComplexDeinterleavingRotation::Rotation_90; 779 else if ((Real->getOpcode() == Instruction::FAdd && 780 Imag->getOpcode() == Instruction::FSub) || 781 (Real->getOpcode() == Instruction::Add && 782 Imag->getOpcode() == Instruction::Sub)) 783 Rotation = ComplexDeinterleavingRotation::Rotation_270; 784 else { 785 LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n"); 786 return nullptr; 787 } 788 789 auto *AR = dyn_cast<Instruction>(Real->getOperand(0)); 790 auto *BI = dyn_cast<Instruction>(Real->getOperand(1)); 791 auto *AI = dyn_cast<Instruction>(Imag->getOperand(0)); 792 auto *BR = dyn_cast<Instruction>(Imag->getOperand(1)); 793 794 if (!AR || !AI || !BR || !BI) { 795 LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n"); 796 return nullptr; 797 } 798 799 NodePtr ResA = identifyNode(AR, AI); 800 if (!ResA) { 801 LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n"); 802 return nullptr; 803 } 804 NodePtr ResB = identifyNode(BR, BI); 805 if (!ResB) { 806 LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n"); 807 return nullptr; 808 } 809 810 NodePtr Node = 811 prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag); 812 Node->Rotation = Rotation; 813 Node->addOperand(ResA); 814 Node->addOperand(ResB); 815 return submitCompositeNode(Node); 816 } 817 818 static bool isInstructionPairAdd(Instruction *A, Instruction *B) { 819 unsigned OpcA = A->getOpcode(); 820 unsigned OpcB = B->getOpcode(); 821 822 return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) || 823 (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) || 824 (OpcA == Instruction::Sub && OpcB == Instruction::Add) || 825 (OpcA == Instruction::Add && OpcB == Instruction::Sub); 826 } 827 828 static bool isInstructionPairMul(Instruction *A, Instruction *B) { 829 auto Pattern = 830 m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value())); 831 832 return match(A, Pattern) && match(B, Pattern); 833 } 834 835 static bool isInstructionPotentiallySymmetric(Instruction *I) { 836 switch (I->getOpcode()) { 837 case Instruction::FAdd: 838 case Instruction::FSub: 839 case Instruction::FMul: 840 case Instruction::FNeg: 841 case Instruction::Add: 842 case Instruction::Sub: 843 case Instruction::Mul: 844 return true; 845 default: 846 return false; 847 } 848 } 849 850 ComplexDeinterleavingGraph::NodePtr 851 ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real, 852 Instruction *Imag) { 853 if (Real->getOpcode() != Imag->getOpcode()) 854 return nullptr; 855 856 if (!isInstructionPotentiallySymmetric(Real) || 857 !isInstructionPotentiallySymmetric(Imag)) 858 return nullptr; 859 860 auto *R0 = Real->getOperand(0); 861 auto *I0 = Imag->getOperand(0); 862 863 NodePtr Op0 = identifyNode(R0, I0); 864 NodePtr Op1 = nullptr; 865 if (Op0 == nullptr) 866 return nullptr; 867 868 if (Real->isBinaryOp()) { 869 auto *R1 = Real->getOperand(1); 870 auto *I1 = Imag->getOperand(1); 871 Op1 = identifyNode(R1, I1); 872 if (Op1 == nullptr) 873 return nullptr; 874 } 875 876 if (isa<FPMathOperator>(Real) && 877 Real->getFastMathFlags() != Imag->getFastMathFlags()) 878 return nullptr; 879 880 auto Node = prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric, 881 Real, Imag); 882 Node->Opcode = Real->getOpcode(); 883 if (isa<FPMathOperator>(Real)) 884 Node->Flags = Real->getFastMathFlags(); 885 886 Node->addOperand(Op0); 887 if (Real->isBinaryOp()) 888 Node->addOperand(Op1); 889 890 return submitCompositeNode(Node); 891 } 892 893 ComplexDeinterleavingGraph::NodePtr 894 ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) { 895 LLVM_DEBUG(dbgs() << "identifyNode on " << *R << " / " << *I << "\n"); 896 assert(R->getType() == I->getType() && 897 "Real and imaginary parts should not have different types"); 898 899 auto It = CachedResult.find({R, I}); 900 if (It != CachedResult.end()) { 901 LLVM_DEBUG(dbgs() << " - Folding to existing node\n"); 902 return It->second; 903 } 904 905 if (NodePtr CN = identifySplat(R, I)) 906 return CN; 907 908 auto *Real = dyn_cast<Instruction>(R); 909 auto *Imag = dyn_cast<Instruction>(I); 910 if (!Real || !Imag) 911 return nullptr; 912 913 if (NodePtr CN = identifyDeinterleave(Real, Imag)) 914 return CN; 915 916 if (NodePtr CN = identifyPHINode(Real, Imag)) 917 return CN; 918 919 if (NodePtr CN = identifySelectNode(Real, Imag)) 920 return CN; 921 922 auto *VTy = cast<VectorType>(Real->getType()); 923 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); 924 925 bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported( 926 ComplexDeinterleavingOperation::CMulPartial, NewVTy); 927 bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported( 928 ComplexDeinterleavingOperation::CAdd, NewVTy); 929 930 if (HasCMulSupport && isInstructionPairMul(Real, Imag)) { 931 if (NodePtr CN = identifyPartialMul(Real, Imag)) 932 return CN; 933 } 934 935 if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) { 936 if (NodePtr CN = identifyAdd(Real, Imag)) 937 return CN; 938 } 939 940 if (HasCMulSupport && HasCAddSupport) { 941 if (NodePtr CN = identifyReassocNodes(Real, Imag)) 942 return CN; 943 } 944 945 if (NodePtr CN = identifySymmetricOperation(Real, Imag)) 946 return CN; 947 948 LLVM_DEBUG(dbgs() << " - Not recognised as a valid pattern.\n"); 949 CachedResult[{R, I}] = nullptr; 950 return nullptr; 951 } 952 953 ComplexDeinterleavingGraph::NodePtr 954 ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real, 955 Instruction *Imag) { 956 auto IsOperationSupported = [](unsigned Opcode) -> bool { 957 return Opcode == Instruction::FAdd || Opcode == Instruction::FSub || 958 Opcode == Instruction::FNeg || Opcode == Instruction::Add || 959 Opcode == Instruction::Sub; 960 }; 961 962 if (!IsOperationSupported(Real->getOpcode()) || 963 !IsOperationSupported(Imag->getOpcode())) 964 return nullptr; 965 966 std::optional<FastMathFlags> Flags; 967 if (isa<FPMathOperator>(Real)) { 968 if (Real->getFastMathFlags() != Imag->getFastMathFlags()) { 969 LLVM_DEBUG(dbgs() << "The flags in Real and Imaginary instructions are " 970 "not identical\n"); 971 return nullptr; 972 } 973 974 Flags = Real->getFastMathFlags(); 975 if (!Flags->allowReassoc()) { 976 LLVM_DEBUG( 977 dbgs() 978 << "the 'Reassoc' attribute is missing in the FastMath flags\n"); 979 return nullptr; 980 } 981 } 982 983 // Collect multiplications and addend instructions from the given instruction 984 // while traversing it operands. Additionally, verify that all instructions 985 // have the same fast math flags. 986 auto Collect = [&Flags](Instruction *Insn, std::vector<Product> &Muls, 987 std::list<Addend> &Addends) -> bool { 988 SmallVector<PointerIntPair<Value *, 1, bool>> Worklist = {{Insn, true}}; 989 SmallPtrSet<Value *, 8> Visited; 990 while (!Worklist.empty()) { 991 auto [V, IsPositive] = Worklist.back(); 992 Worklist.pop_back(); 993 if (!Visited.insert(V).second) 994 continue; 995 996 Instruction *I = dyn_cast<Instruction>(V); 997 if (!I) { 998 Addends.emplace_back(V, IsPositive); 999 continue; 1000 } 1001 1002 // If an instruction has more than one user, it indicates that it either 1003 // has an external user, which will be later checked by the checkNodes 1004 // function, or it is a subexpression utilized by multiple expressions. In 1005 // the latter case, we will attempt to separately identify the complex 1006 // operation from here in order to create a shared 1007 // ComplexDeinterleavingCompositeNode. 1008 if (I != Insn && I->getNumUses() > 1) { 1009 LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n"); 1010 Addends.emplace_back(I, IsPositive); 1011 continue; 1012 } 1013 switch (I->getOpcode()) { 1014 case Instruction::FAdd: 1015 case Instruction::Add: 1016 Worklist.emplace_back(I->getOperand(1), IsPositive); 1017 Worklist.emplace_back(I->getOperand(0), IsPositive); 1018 break; 1019 case Instruction::FSub: 1020 Worklist.emplace_back(I->getOperand(1), !IsPositive); 1021 Worklist.emplace_back(I->getOperand(0), IsPositive); 1022 break; 1023 case Instruction::Sub: 1024 if (isNeg(I)) { 1025 Worklist.emplace_back(getNegOperand(I), !IsPositive); 1026 } else { 1027 Worklist.emplace_back(I->getOperand(1), !IsPositive); 1028 Worklist.emplace_back(I->getOperand(0), IsPositive); 1029 } 1030 break; 1031 case Instruction::FMul: 1032 case Instruction::Mul: { 1033 Value *A, *B; 1034 if (isNeg(I->getOperand(0))) { 1035 A = getNegOperand(I->getOperand(0)); 1036 IsPositive = !IsPositive; 1037 } else { 1038 A = I->getOperand(0); 1039 } 1040 1041 if (isNeg(I->getOperand(1))) { 1042 B = getNegOperand(I->getOperand(1)); 1043 IsPositive = !IsPositive; 1044 } else { 1045 B = I->getOperand(1); 1046 } 1047 Muls.push_back(Product{A, B, IsPositive}); 1048 break; 1049 } 1050 case Instruction::FNeg: 1051 Worklist.emplace_back(I->getOperand(0), !IsPositive); 1052 break; 1053 default: 1054 Addends.emplace_back(I, IsPositive); 1055 continue; 1056 } 1057 1058 if (Flags && I->getFastMathFlags() != *Flags) { 1059 LLVM_DEBUG(dbgs() << "The instruction's fast math flags are " 1060 "inconsistent with the root instructions' flags: " 1061 << *I << "\n"); 1062 return false; 1063 } 1064 } 1065 return true; 1066 }; 1067 1068 std::vector<Product> RealMuls, ImagMuls; 1069 std::list<Addend> RealAddends, ImagAddends; 1070 if (!Collect(Real, RealMuls, RealAddends) || 1071 !Collect(Imag, ImagMuls, ImagAddends)) 1072 return nullptr; 1073 1074 if (RealAddends.size() != ImagAddends.size()) 1075 return nullptr; 1076 1077 NodePtr FinalNode; 1078 if (!RealMuls.empty() || !ImagMuls.empty()) { 1079 // If there are multiplicands, extract positive addend and use it as an 1080 // accumulator 1081 FinalNode = extractPositiveAddend(RealAddends, ImagAddends); 1082 FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode); 1083 if (!FinalNode) 1084 return nullptr; 1085 } 1086 1087 // Identify and process remaining additions 1088 if (!RealAddends.empty() || !ImagAddends.empty()) { 1089 FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode); 1090 if (!FinalNode) 1091 return nullptr; 1092 } 1093 assert(FinalNode && "FinalNode can not be nullptr here"); 1094 // Set the Real and Imag fields of the final node and submit it 1095 FinalNode->Real = Real; 1096 FinalNode->Imag = Imag; 1097 submitCompositeNode(FinalNode); 1098 return FinalNode; 1099 } 1100 1101 bool ComplexDeinterleavingGraph::collectPartialMuls( 1102 const std::vector<Product> &RealMuls, const std::vector<Product> &ImagMuls, 1103 std::vector<PartialMulCandidate> &PartialMulCandidates) { 1104 // Helper function to extract a common operand from two products 1105 auto FindCommonInstruction = [](const Product &Real, 1106 const Product &Imag) -> Value * { 1107 if (Real.Multiplicand == Imag.Multiplicand || 1108 Real.Multiplicand == Imag.Multiplier) 1109 return Real.Multiplicand; 1110 1111 if (Real.Multiplier == Imag.Multiplicand || 1112 Real.Multiplier == Imag.Multiplier) 1113 return Real.Multiplier; 1114 1115 return nullptr; 1116 }; 1117 1118 // Iterating over real and imaginary multiplications to find common operands 1119 // If a common operand is found, a partial multiplication candidate is created 1120 // and added to the candidates vector The function returns false if no common 1121 // operands are found for any product 1122 for (unsigned i = 0; i < RealMuls.size(); ++i) { 1123 bool FoundCommon = false; 1124 for (unsigned j = 0; j < ImagMuls.size(); ++j) { 1125 auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]); 1126 if (!Common) 1127 continue; 1128 1129 auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier 1130 : RealMuls[i].Multiplicand; 1131 auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier 1132 : ImagMuls[j].Multiplicand; 1133 1134 auto Node = identifyNode(A, B); 1135 if (Node) { 1136 FoundCommon = true; 1137 PartialMulCandidates.push_back({Common, Node, i, j, false}); 1138 } 1139 1140 Node = identifyNode(B, A); 1141 if (Node) { 1142 FoundCommon = true; 1143 PartialMulCandidates.push_back({Common, Node, i, j, true}); 1144 } 1145 } 1146 if (!FoundCommon) 1147 return false; 1148 } 1149 return true; 1150 } 1151 1152 ComplexDeinterleavingGraph::NodePtr 1153 ComplexDeinterleavingGraph::identifyMultiplications( 1154 std::vector<Product> &RealMuls, std::vector<Product> &ImagMuls, 1155 NodePtr Accumulator = nullptr) { 1156 if (RealMuls.size() != ImagMuls.size()) 1157 return nullptr; 1158 1159 std::vector<PartialMulCandidate> Info; 1160 if (!collectPartialMuls(RealMuls, ImagMuls, Info)) 1161 return nullptr; 1162 1163 // Map to store common instruction to node pointers 1164 std::map<Value *, NodePtr> CommonToNode; 1165 std::vector<bool> Processed(Info.size(), false); 1166 for (unsigned I = 0; I < Info.size(); ++I) { 1167 if (Processed[I]) 1168 continue; 1169 1170 PartialMulCandidate &InfoA = Info[I]; 1171 for (unsigned J = I + 1; J < Info.size(); ++J) { 1172 if (Processed[J]) 1173 continue; 1174 1175 PartialMulCandidate &InfoB = Info[J]; 1176 auto *InfoReal = &InfoA; 1177 auto *InfoImag = &InfoB; 1178 1179 auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common); 1180 if (!NodeFromCommon) { 1181 std::swap(InfoReal, InfoImag); 1182 NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common); 1183 } 1184 if (!NodeFromCommon) 1185 continue; 1186 1187 CommonToNode[InfoReal->Common] = NodeFromCommon; 1188 CommonToNode[InfoImag->Common] = NodeFromCommon; 1189 Processed[I] = true; 1190 Processed[J] = true; 1191 } 1192 } 1193 1194 std::vector<bool> ProcessedReal(RealMuls.size(), false); 1195 std::vector<bool> ProcessedImag(ImagMuls.size(), false); 1196 NodePtr Result = Accumulator; 1197 for (auto &PMI : Info) { 1198 if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx]) 1199 continue; 1200 1201 auto It = CommonToNode.find(PMI.Common); 1202 // TODO: Process independent complex multiplications. Cases like this: 1203 // A.real() * B where both A and B are complex numbers. 1204 if (It == CommonToNode.end()) { 1205 LLVM_DEBUG({ 1206 dbgs() << "Unprocessed independent partial multiplication:\n"; 1207 for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]}) 1208 dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier 1209 << " multiplied by " << *Mul->Multiplicand << "\n"; 1210 }); 1211 return nullptr; 1212 } 1213 1214 auto &RealMul = RealMuls[PMI.RealIdx]; 1215 auto &ImagMul = ImagMuls[PMI.ImagIdx]; 1216 1217 auto NodeA = It->second; 1218 auto NodeB = PMI.Node; 1219 auto IsMultiplicandReal = PMI.Common == NodeA->Real; 1220 // The following table illustrates the relationship between multiplications 1221 // and rotations. If we consider the multiplication (X + iY) * (U + iV), we 1222 // can see: 1223 // 1224 // Rotation | Real | Imag | 1225 // ---------+--------+--------+ 1226 // 0 | x * u | x * v | 1227 // 90 | -y * v | y * u | 1228 // 180 | -x * u | -x * v | 1229 // 270 | y * v | -y * u | 1230 // 1231 // Check if the candidate can indeed be represented by partial 1232 // multiplication 1233 // TODO: Add support for multiplication by complex one 1234 if ((IsMultiplicandReal && PMI.IsNodeInverted) || 1235 (!IsMultiplicandReal && !PMI.IsNodeInverted)) 1236 continue; 1237 1238 // Determine the rotation based on the multiplications 1239 ComplexDeinterleavingRotation Rotation; 1240 if (IsMultiplicandReal) { 1241 // Detect 0 and 180 degrees rotation 1242 if (RealMul.IsPositive && ImagMul.IsPositive) 1243 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_0; 1244 else if (!RealMul.IsPositive && !ImagMul.IsPositive) 1245 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_180; 1246 else 1247 continue; 1248 1249 } else { 1250 // Detect 90 and 270 degrees rotation 1251 if (!RealMul.IsPositive && ImagMul.IsPositive) 1252 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_90; 1253 else if (RealMul.IsPositive && !ImagMul.IsPositive) 1254 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_270; 1255 else 1256 continue; 1257 } 1258 1259 LLVM_DEBUG({ 1260 dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n"; 1261 dbgs().indent(4) << "X: " << *NodeA->Real << "\n"; 1262 dbgs().indent(4) << "Y: " << *NodeA->Imag << "\n"; 1263 dbgs().indent(4) << "U: " << *NodeB->Real << "\n"; 1264 dbgs().indent(4) << "V: " << *NodeB->Imag << "\n"; 1265 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n"; 1266 }); 1267 1268 NodePtr NodeMul = prepareCompositeNode( 1269 ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr); 1270 NodeMul->Rotation = Rotation; 1271 NodeMul->addOperand(NodeA); 1272 NodeMul->addOperand(NodeB); 1273 if (Result) 1274 NodeMul->addOperand(Result); 1275 submitCompositeNode(NodeMul); 1276 Result = NodeMul; 1277 ProcessedReal[PMI.RealIdx] = true; 1278 ProcessedImag[PMI.ImagIdx] = true; 1279 } 1280 1281 // Ensure all products have been processed, if not return nullptr. 1282 if (!all_of(ProcessedReal, [](bool V) { return V; }) || 1283 !all_of(ProcessedImag, [](bool V) { return V; })) { 1284 1285 // Dump debug information about which partial multiplications are not 1286 // processed. 1287 LLVM_DEBUG({ 1288 dbgs() << "Unprocessed products (Real):\n"; 1289 for (size_t i = 0; i < ProcessedReal.size(); ++i) { 1290 if (!ProcessedReal[i]) 1291 dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-") 1292 << *RealMuls[i].Multiplier << " multiplied by " 1293 << *RealMuls[i].Multiplicand << "\n"; 1294 } 1295 dbgs() << "Unprocessed products (Imag):\n"; 1296 for (size_t i = 0; i < ProcessedImag.size(); ++i) { 1297 if (!ProcessedImag[i]) 1298 dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-") 1299 << *ImagMuls[i].Multiplier << " multiplied by " 1300 << *ImagMuls[i].Multiplicand << "\n"; 1301 } 1302 }); 1303 return nullptr; 1304 } 1305 1306 return Result; 1307 } 1308 1309 ComplexDeinterleavingGraph::NodePtr 1310 ComplexDeinterleavingGraph::identifyAdditions( 1311 std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends, 1312 std::optional<FastMathFlags> Flags, NodePtr Accumulator = nullptr) { 1313 if (RealAddends.size() != ImagAddends.size()) 1314 return nullptr; 1315 1316 NodePtr Result; 1317 // If we have accumulator use it as first addend 1318 if (Accumulator) 1319 Result = Accumulator; 1320 // Otherwise find an element with both positive real and imaginary parts. 1321 else 1322 Result = extractPositiveAddend(RealAddends, ImagAddends); 1323 1324 if (!Result) 1325 return nullptr; 1326 1327 while (!RealAddends.empty()) { 1328 auto ItR = RealAddends.begin(); 1329 auto [R, IsPositiveR] = *ItR; 1330 1331 bool FoundImag = false; 1332 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) { 1333 auto [I, IsPositiveI] = *ItI; 1334 ComplexDeinterleavingRotation Rotation; 1335 if (IsPositiveR && IsPositiveI) 1336 Rotation = ComplexDeinterleavingRotation::Rotation_0; 1337 else if (!IsPositiveR && IsPositiveI) 1338 Rotation = ComplexDeinterleavingRotation::Rotation_90; 1339 else if (!IsPositiveR && !IsPositiveI) 1340 Rotation = ComplexDeinterleavingRotation::Rotation_180; 1341 else 1342 Rotation = ComplexDeinterleavingRotation::Rotation_270; 1343 1344 NodePtr AddNode; 1345 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 || 1346 Rotation == ComplexDeinterleavingRotation::Rotation_180) { 1347 AddNode = identifyNode(R, I); 1348 } else { 1349 AddNode = identifyNode(I, R); 1350 } 1351 if (AddNode) { 1352 LLVM_DEBUG({ 1353 dbgs() << "Identified addition:\n"; 1354 dbgs().indent(4) << "X: " << *R << "\n"; 1355 dbgs().indent(4) << "Y: " << *I << "\n"; 1356 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n"; 1357 }); 1358 1359 NodePtr TmpNode; 1360 if (Rotation == llvm::ComplexDeinterleavingRotation::Rotation_0) { 1361 TmpNode = prepareCompositeNode( 1362 ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr); 1363 if (Flags) { 1364 TmpNode->Opcode = Instruction::FAdd; 1365 TmpNode->Flags = *Flags; 1366 } else { 1367 TmpNode->Opcode = Instruction::Add; 1368 } 1369 } else if (Rotation == 1370 llvm::ComplexDeinterleavingRotation::Rotation_180) { 1371 TmpNode = prepareCompositeNode( 1372 ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr); 1373 if (Flags) { 1374 TmpNode->Opcode = Instruction::FSub; 1375 TmpNode->Flags = *Flags; 1376 } else { 1377 TmpNode->Opcode = Instruction::Sub; 1378 } 1379 } else { 1380 TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, 1381 nullptr, nullptr); 1382 TmpNode->Rotation = Rotation; 1383 } 1384 1385 TmpNode->addOperand(Result); 1386 TmpNode->addOperand(AddNode); 1387 submitCompositeNode(TmpNode); 1388 Result = TmpNode; 1389 RealAddends.erase(ItR); 1390 ImagAddends.erase(ItI); 1391 FoundImag = true; 1392 break; 1393 } 1394 } 1395 if (!FoundImag) 1396 return nullptr; 1397 } 1398 return Result; 1399 } 1400 1401 ComplexDeinterleavingGraph::NodePtr 1402 ComplexDeinterleavingGraph::extractPositiveAddend( 1403 std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends) { 1404 for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) { 1405 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) { 1406 auto [R, IsPositiveR] = *ItR; 1407 auto [I, IsPositiveI] = *ItI; 1408 if (IsPositiveR && IsPositiveI) { 1409 auto Result = identifyNode(R, I); 1410 if (Result) { 1411 RealAddends.erase(ItR); 1412 ImagAddends.erase(ItI); 1413 return Result; 1414 } 1415 } 1416 } 1417 } 1418 return nullptr; 1419 } 1420 1421 bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) { 1422 // This potential root instruction might already have been recognized as 1423 // reduction. Because RootToNode maps both Real and Imaginary parts to 1424 // CompositeNode we should choose only one either Real or Imag instruction to 1425 // use as an anchor for generating complex instruction. 1426 auto It = RootToNode.find(RootI); 1427 if (It != RootToNode.end()) { 1428 auto RootNode = It->second; 1429 assert(RootNode->Operation == 1430 ComplexDeinterleavingOperation::ReductionOperation); 1431 // Find out which part, Real or Imag, comes later, and only if we come to 1432 // the latest part, add it to OrderedRoots. 1433 auto *R = cast<Instruction>(RootNode->Real); 1434 auto *I = cast<Instruction>(RootNode->Imag); 1435 auto *ReplacementAnchor = R->comesBefore(I) ? I : R; 1436 if (ReplacementAnchor != RootI) 1437 return false; 1438 OrderedRoots.push_back(RootI); 1439 return true; 1440 } 1441 1442 auto RootNode = identifyRoot(RootI); 1443 if (!RootNode) 1444 return false; 1445 1446 LLVM_DEBUG({ 1447 Function *F = RootI->getFunction(); 1448 BasicBlock *B = RootI->getParent(); 1449 dbgs() << "Complex deinterleaving graph for " << F->getName() 1450 << "::" << B->getName() << ".\n"; 1451 dump(dbgs()); 1452 dbgs() << "\n"; 1453 }); 1454 RootToNode[RootI] = RootNode; 1455 OrderedRoots.push_back(RootI); 1456 return true; 1457 } 1458 1459 bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) { 1460 bool FoundPotentialReduction = false; 1461 1462 auto *Br = dyn_cast<BranchInst>(B->getTerminator()); 1463 if (!Br || Br->getNumSuccessors() != 2) 1464 return false; 1465 1466 // Identify simple one-block loop 1467 if (Br->getSuccessor(0) != B && Br->getSuccessor(1) != B) 1468 return false; 1469 1470 SmallVector<PHINode *> PHIs; 1471 for (auto &PHI : B->phis()) { 1472 if (PHI.getNumIncomingValues() != 2) 1473 continue; 1474 1475 if (!PHI.getType()->isVectorTy()) 1476 continue; 1477 1478 auto *ReductionOp = dyn_cast<Instruction>(PHI.getIncomingValueForBlock(B)); 1479 if (!ReductionOp) 1480 continue; 1481 1482 // Check if final instruction is reduced outside of current block 1483 Instruction *FinalReduction = nullptr; 1484 auto NumUsers = 0u; 1485 for (auto *U : ReductionOp->users()) { 1486 ++NumUsers; 1487 if (U == &PHI) 1488 continue; 1489 FinalReduction = dyn_cast<Instruction>(U); 1490 } 1491 1492 if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B || 1493 isa<PHINode>(FinalReduction)) 1494 continue; 1495 1496 ReductionInfo[ReductionOp] = {&PHI, FinalReduction}; 1497 BackEdge = B; 1498 auto BackEdgeIdx = PHI.getBasicBlockIndex(B); 1499 auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0; 1500 Incoming = PHI.getIncomingBlock(IncomingIdx); 1501 FoundPotentialReduction = true; 1502 1503 // If the initial value of PHINode is an Instruction, consider it a leaf 1504 // value of a complex deinterleaving graph. 1505 if (auto *InitPHI = 1506 dyn_cast<Instruction>(PHI.getIncomingValueForBlock(Incoming))) 1507 FinalInstructions.insert(InitPHI); 1508 } 1509 return FoundPotentialReduction; 1510 } 1511 1512 void ComplexDeinterleavingGraph::identifyReductionNodes() { 1513 SmallVector<bool> Processed(ReductionInfo.size(), false); 1514 SmallVector<Instruction *> OperationInstruction; 1515 for (auto &P : ReductionInfo) 1516 OperationInstruction.push_back(P.first); 1517 1518 // Identify a complex computation by evaluating two reduction operations that 1519 // potentially could be involved 1520 for (size_t i = 0; i < OperationInstruction.size(); ++i) { 1521 if (Processed[i]) 1522 continue; 1523 for (size_t j = i + 1; j < OperationInstruction.size(); ++j) { 1524 if (Processed[j]) 1525 continue; 1526 1527 auto *Real = OperationInstruction[i]; 1528 auto *Imag = OperationInstruction[j]; 1529 if (Real->getType() != Imag->getType()) 1530 continue; 1531 1532 RealPHI = ReductionInfo[Real].first; 1533 ImagPHI = ReductionInfo[Imag].first; 1534 PHIsFound = false; 1535 auto Node = identifyNode(Real, Imag); 1536 if (!Node) { 1537 std::swap(Real, Imag); 1538 std::swap(RealPHI, ImagPHI); 1539 Node = identifyNode(Real, Imag); 1540 } 1541 1542 // If a node is identified and reduction PHINode is used in the chain of 1543 // operations, mark its operation instructions as used to prevent 1544 // re-identification and attach the node to the real part 1545 if (Node && PHIsFound) { 1546 LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: " 1547 << *Real << " / " << *Imag << "\n"); 1548 Processed[i] = true; 1549 Processed[j] = true; 1550 auto RootNode = prepareCompositeNode( 1551 ComplexDeinterleavingOperation::ReductionOperation, Real, Imag); 1552 RootNode->addOperand(Node); 1553 RootToNode[Real] = RootNode; 1554 RootToNode[Imag] = RootNode; 1555 submitCompositeNode(RootNode); 1556 break; 1557 } 1558 } 1559 } 1560 1561 RealPHI = nullptr; 1562 ImagPHI = nullptr; 1563 } 1564 1565 bool ComplexDeinterleavingGraph::checkNodes() { 1566 // Collect all instructions from roots to leaves 1567 SmallPtrSet<Instruction *, 16> AllInstructions; 1568 SmallVector<Instruction *, 8> Worklist; 1569 for (auto &Pair : RootToNode) 1570 Worklist.push_back(Pair.first); 1571 1572 // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG 1573 // chains 1574 while (!Worklist.empty()) { 1575 auto *I = Worklist.back(); 1576 Worklist.pop_back(); 1577 1578 if (!AllInstructions.insert(I).second) 1579 continue; 1580 1581 for (Value *Op : I->operands()) { 1582 if (auto *OpI = dyn_cast<Instruction>(Op)) { 1583 if (!FinalInstructions.count(I)) 1584 Worklist.emplace_back(OpI); 1585 } 1586 } 1587 } 1588 1589 // Find instructions that have users outside of chain 1590 SmallVector<Instruction *, 2> OuterInstructions; 1591 for (auto *I : AllInstructions) { 1592 // Skip root nodes 1593 if (RootToNode.count(I)) 1594 continue; 1595 1596 for (User *U : I->users()) { 1597 if (AllInstructions.count(cast<Instruction>(U))) 1598 continue; 1599 1600 // Found an instruction that is not used by XCMLA/XCADD chain 1601 Worklist.emplace_back(I); 1602 break; 1603 } 1604 } 1605 1606 // If any instructions are found to be used outside, find and remove roots 1607 // that somehow connect to those instructions. 1608 SmallPtrSet<Instruction *, 16> Visited; 1609 while (!Worklist.empty()) { 1610 auto *I = Worklist.back(); 1611 Worklist.pop_back(); 1612 if (!Visited.insert(I).second) 1613 continue; 1614 1615 // Found an impacted root node. Removing it from the nodes to be 1616 // deinterleaved 1617 if (RootToNode.count(I)) { 1618 LLVM_DEBUG(dbgs() << "Instruction " << *I 1619 << " could be deinterleaved but its chain of complex " 1620 "operations have an outside user\n"); 1621 RootToNode.erase(I); 1622 } 1623 1624 if (!AllInstructions.count(I) || FinalInstructions.count(I)) 1625 continue; 1626 1627 for (User *U : I->users()) 1628 Worklist.emplace_back(cast<Instruction>(U)); 1629 1630 for (Value *Op : I->operands()) { 1631 if (auto *OpI = dyn_cast<Instruction>(Op)) 1632 Worklist.emplace_back(OpI); 1633 } 1634 } 1635 return !RootToNode.empty(); 1636 } 1637 1638 ComplexDeinterleavingGraph::NodePtr 1639 ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) { 1640 if (auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) { 1641 if (Intrinsic->getIntrinsicID() != 1642 Intrinsic::experimental_vector_interleave2) 1643 return nullptr; 1644 1645 auto *Real = dyn_cast<Instruction>(Intrinsic->getOperand(0)); 1646 auto *Imag = dyn_cast<Instruction>(Intrinsic->getOperand(1)); 1647 if (!Real || !Imag) 1648 return nullptr; 1649 1650 return identifyNode(Real, Imag); 1651 } 1652 1653 auto *SVI = dyn_cast<ShuffleVectorInst>(RootI); 1654 if (!SVI) 1655 return nullptr; 1656 1657 // Look for a shufflevector that takes separate vectors of the real and 1658 // imaginary components and recombines them into a single vector. 1659 if (!isInterleavingMask(SVI->getShuffleMask())) 1660 return nullptr; 1661 1662 Instruction *Real; 1663 Instruction *Imag; 1664 if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag)))) 1665 return nullptr; 1666 1667 return identifyNode(Real, Imag); 1668 } 1669 1670 ComplexDeinterleavingGraph::NodePtr 1671 ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real, 1672 Instruction *Imag) { 1673 Instruction *I = nullptr; 1674 Value *FinalValue = nullptr; 1675 if (match(Real, m_ExtractValue<0>(m_Instruction(I))) && 1676 match(Imag, m_ExtractValue<1>(m_Specific(I))) && 1677 match(I, m_Intrinsic<Intrinsic::experimental_vector_deinterleave2>( 1678 m_Value(FinalValue)))) { 1679 NodePtr PlaceholderNode = prepareCompositeNode( 1680 llvm::ComplexDeinterleavingOperation::Deinterleave, Real, Imag); 1681 PlaceholderNode->ReplacementNode = FinalValue; 1682 FinalInstructions.insert(Real); 1683 FinalInstructions.insert(Imag); 1684 return submitCompositeNode(PlaceholderNode); 1685 } 1686 1687 auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real); 1688 auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag); 1689 if (!RealShuffle || !ImagShuffle) { 1690 if (RealShuffle || ImagShuffle) 1691 LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n"); 1692 return nullptr; 1693 } 1694 1695 Value *RealOp1 = RealShuffle->getOperand(1); 1696 if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) { 1697 LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n"); 1698 return nullptr; 1699 } 1700 Value *ImagOp1 = ImagShuffle->getOperand(1); 1701 if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) { 1702 LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n"); 1703 return nullptr; 1704 } 1705 1706 Value *RealOp0 = RealShuffle->getOperand(0); 1707 Value *ImagOp0 = ImagShuffle->getOperand(0); 1708 1709 if (RealOp0 != ImagOp0) { 1710 LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n"); 1711 return nullptr; 1712 } 1713 1714 ArrayRef<int> RealMask = RealShuffle->getShuffleMask(); 1715 ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask(); 1716 if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) { 1717 LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n"); 1718 return nullptr; 1719 } 1720 1721 if (RealMask[0] != 0 || ImagMask[0] != 1) { 1722 LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n"); 1723 return nullptr; 1724 } 1725 1726 // Type checking, the shuffle type should be a vector type of the same 1727 // scalar type, but half the size 1728 auto CheckType = [&](ShuffleVectorInst *Shuffle) { 1729 Value *Op = Shuffle->getOperand(0); 1730 auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType()); 1731 auto *OpTy = cast<FixedVectorType>(Op->getType()); 1732 1733 if (OpTy->getScalarType() != ShuffleTy->getScalarType()) 1734 return false; 1735 if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements()) 1736 return false; 1737 1738 return true; 1739 }; 1740 1741 auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool { 1742 if (!CheckType(Shuffle)) 1743 return false; 1744 1745 ArrayRef<int> Mask = Shuffle->getShuffleMask(); 1746 int Last = *Mask.rbegin(); 1747 1748 Value *Op = Shuffle->getOperand(0); 1749 auto *OpTy = cast<FixedVectorType>(Op->getType()); 1750 int NumElements = OpTy->getNumElements(); 1751 1752 // Ensure that the deinterleaving shuffle only pulls from the first 1753 // shuffle operand. 1754 return Last < NumElements; 1755 }; 1756 1757 if (RealShuffle->getType() != ImagShuffle->getType()) { 1758 LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n"); 1759 return nullptr; 1760 } 1761 if (!CheckDeinterleavingShuffle(RealShuffle)) { 1762 LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n"); 1763 return nullptr; 1764 } 1765 if (!CheckDeinterleavingShuffle(ImagShuffle)) { 1766 LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n"); 1767 return nullptr; 1768 } 1769 1770 NodePtr PlaceholderNode = 1771 prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Deinterleave, 1772 RealShuffle, ImagShuffle); 1773 PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0); 1774 FinalInstructions.insert(RealShuffle); 1775 FinalInstructions.insert(ImagShuffle); 1776 return submitCompositeNode(PlaceholderNode); 1777 } 1778 1779 ComplexDeinterleavingGraph::NodePtr 1780 ComplexDeinterleavingGraph::identifySplat(Value *R, Value *I) { 1781 auto IsSplat = [](Value *V) -> bool { 1782 // Fixed-width vector with constants 1783 if (isa<ConstantDataVector>(V)) 1784 return true; 1785 1786 VectorType *VTy; 1787 ArrayRef<int> Mask; 1788 // Splats are represented differently depending on whether the repeated 1789 // value is a constant or an Instruction 1790 if (auto *Const = dyn_cast<ConstantExpr>(V)) { 1791 if (Const->getOpcode() != Instruction::ShuffleVector) 1792 return false; 1793 VTy = cast<VectorType>(Const->getType()); 1794 Mask = Const->getShuffleMask(); 1795 } else if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V)) { 1796 VTy = Shuf->getType(); 1797 Mask = Shuf->getShuffleMask(); 1798 } else { 1799 return false; 1800 } 1801 1802 // When the data type is <1 x Type>, it's not possible to differentiate 1803 // between the ComplexDeinterleaving::Deinterleave and 1804 // ComplexDeinterleaving::Splat operations. 1805 if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1) 1806 return false; 1807 1808 return all_equal(Mask) && Mask[0] == 0; 1809 }; 1810 1811 if (!IsSplat(R) || !IsSplat(I)) 1812 return nullptr; 1813 1814 auto *Real = dyn_cast<Instruction>(R); 1815 auto *Imag = dyn_cast<Instruction>(I); 1816 if ((!Real && Imag) || (Real && !Imag)) 1817 return nullptr; 1818 1819 if (Real && Imag) { 1820 // Non-constant splats should be in the same basic block 1821 if (Real->getParent() != Imag->getParent()) 1822 return nullptr; 1823 1824 FinalInstructions.insert(Real); 1825 FinalInstructions.insert(Imag); 1826 } 1827 NodePtr PlaceholderNode = 1828 prepareCompositeNode(ComplexDeinterleavingOperation::Splat, R, I); 1829 return submitCompositeNode(PlaceholderNode); 1830 } 1831 1832 ComplexDeinterleavingGraph::NodePtr 1833 ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real, 1834 Instruction *Imag) { 1835 if (Real != RealPHI || Imag != ImagPHI) 1836 return nullptr; 1837 1838 PHIsFound = true; 1839 NodePtr PlaceholderNode = prepareCompositeNode( 1840 ComplexDeinterleavingOperation::ReductionPHI, Real, Imag); 1841 return submitCompositeNode(PlaceholderNode); 1842 } 1843 1844 ComplexDeinterleavingGraph::NodePtr 1845 ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real, 1846 Instruction *Imag) { 1847 auto *SelectReal = dyn_cast<SelectInst>(Real); 1848 auto *SelectImag = dyn_cast<SelectInst>(Imag); 1849 if (!SelectReal || !SelectImag) 1850 return nullptr; 1851 1852 Instruction *MaskA, *MaskB; 1853 Instruction *AR, *AI, *RA, *BI; 1854 if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR), 1855 m_Instruction(RA))) || 1856 !match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI), 1857 m_Instruction(BI)))) 1858 return nullptr; 1859 1860 if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB)) 1861 return nullptr; 1862 1863 if (!MaskA->getType()->isVectorTy()) 1864 return nullptr; 1865 1866 auto NodeA = identifyNode(AR, AI); 1867 if (!NodeA) 1868 return nullptr; 1869 1870 auto NodeB = identifyNode(RA, BI); 1871 if (!NodeB) 1872 return nullptr; 1873 1874 NodePtr PlaceholderNode = prepareCompositeNode( 1875 ComplexDeinterleavingOperation::ReductionSelect, Real, Imag); 1876 PlaceholderNode->addOperand(NodeA); 1877 PlaceholderNode->addOperand(NodeB); 1878 FinalInstructions.insert(MaskA); 1879 FinalInstructions.insert(MaskB); 1880 return submitCompositeNode(PlaceholderNode); 1881 } 1882 1883 static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode, 1884 std::optional<FastMathFlags> Flags, 1885 Value *InputA, Value *InputB) { 1886 Value *I; 1887 switch (Opcode) { 1888 case Instruction::FNeg: 1889 I = B.CreateFNeg(InputA); 1890 break; 1891 case Instruction::FAdd: 1892 I = B.CreateFAdd(InputA, InputB); 1893 break; 1894 case Instruction::Add: 1895 I = B.CreateAdd(InputA, InputB); 1896 break; 1897 case Instruction::FSub: 1898 I = B.CreateFSub(InputA, InputB); 1899 break; 1900 case Instruction::Sub: 1901 I = B.CreateSub(InputA, InputB); 1902 break; 1903 case Instruction::FMul: 1904 I = B.CreateFMul(InputA, InputB); 1905 break; 1906 case Instruction::Mul: 1907 I = B.CreateMul(InputA, InputB); 1908 break; 1909 default: 1910 llvm_unreachable("Incorrect symmetric opcode"); 1911 } 1912 if (Flags) 1913 cast<Instruction>(I)->setFastMathFlags(*Flags); 1914 return I; 1915 } 1916 1917 Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder, 1918 RawNodePtr Node) { 1919 if (Node->ReplacementNode) 1920 return Node->ReplacementNode; 1921 1922 auto ReplaceOperandIfExist = [&](RawNodePtr &Node, unsigned Idx) -> Value * { 1923 return Node->Operands.size() > Idx 1924 ? replaceNode(Builder, Node->Operands[Idx]) 1925 : nullptr; 1926 }; 1927 1928 Value *ReplacementNode; 1929 switch (Node->Operation) { 1930 case ComplexDeinterleavingOperation::CAdd: 1931 case ComplexDeinterleavingOperation::CMulPartial: 1932 case ComplexDeinterleavingOperation::Symmetric: { 1933 Value *Input0 = ReplaceOperandIfExist(Node, 0); 1934 Value *Input1 = ReplaceOperandIfExist(Node, 1); 1935 Value *Accumulator = ReplaceOperandIfExist(Node, 2); 1936 assert(!Input1 || (Input0->getType() == Input1->getType() && 1937 "Node inputs need to be of the same type")); 1938 assert(!Accumulator || 1939 (Input0->getType() == Accumulator->getType() && 1940 "Accumulator and input need to be of the same type")); 1941 if (Node->Operation == ComplexDeinterleavingOperation::Symmetric) 1942 ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, Node->Flags, 1943 Input0, Input1); 1944 else 1945 ReplacementNode = TL->createComplexDeinterleavingIR( 1946 Builder, Node->Operation, Node->Rotation, Input0, Input1, 1947 Accumulator); 1948 break; 1949 } 1950 case ComplexDeinterleavingOperation::Deinterleave: 1951 llvm_unreachable("Deinterleave node should already have ReplacementNode"); 1952 break; 1953 case ComplexDeinterleavingOperation::Splat: { 1954 auto *NewTy = VectorType::getDoubleElementsVectorType( 1955 cast<VectorType>(Node->Real->getType())); 1956 auto *R = dyn_cast<Instruction>(Node->Real); 1957 auto *I = dyn_cast<Instruction>(Node->Imag); 1958 if (R && I) { 1959 // Splats that are not constant are interleaved where they are located 1960 Instruction *InsertPoint = (I->comesBefore(R) ? R : I)->getNextNode(); 1961 IRBuilder<> IRB(InsertPoint); 1962 ReplacementNode = 1963 IRB.CreateIntrinsic(Intrinsic::experimental_vector_interleave2, NewTy, 1964 {Node->Real, Node->Imag}); 1965 } else { 1966 ReplacementNode = 1967 Builder.CreateIntrinsic(Intrinsic::experimental_vector_interleave2, 1968 NewTy, {Node->Real, Node->Imag}); 1969 } 1970 break; 1971 } 1972 case ComplexDeinterleavingOperation::ReductionPHI: { 1973 // If Operation is ReductionPHI, a new empty PHINode is created. 1974 // It is filled later when the ReductionOperation is processed. 1975 auto *VTy = cast<VectorType>(Node->Real->getType()); 1976 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); 1977 auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHI()); 1978 OldToNewPHI[dyn_cast<PHINode>(Node->Real)] = NewPHI; 1979 ReplacementNode = NewPHI; 1980 break; 1981 } 1982 case ComplexDeinterleavingOperation::ReductionOperation: 1983 ReplacementNode = replaceNode(Builder, Node->Operands[0]); 1984 processReductionOperation(ReplacementNode, Node); 1985 break; 1986 case ComplexDeinterleavingOperation::ReductionSelect: { 1987 auto *MaskReal = cast<Instruction>(Node->Real)->getOperand(0); 1988 auto *MaskImag = cast<Instruction>(Node->Imag)->getOperand(0); 1989 auto *A = replaceNode(Builder, Node->Operands[0]); 1990 auto *B = replaceNode(Builder, Node->Operands[1]); 1991 auto *NewMaskTy = VectorType::getDoubleElementsVectorType( 1992 cast<VectorType>(MaskReal->getType())); 1993 auto *NewMask = 1994 Builder.CreateIntrinsic(Intrinsic::experimental_vector_interleave2, 1995 NewMaskTy, {MaskReal, MaskImag}); 1996 ReplacementNode = Builder.CreateSelect(NewMask, A, B); 1997 break; 1998 } 1999 } 2000 2001 assert(ReplacementNode && "Target failed to create Intrinsic call."); 2002 NumComplexTransformations += 1; 2003 Node->ReplacementNode = ReplacementNode; 2004 return ReplacementNode; 2005 } 2006 2007 void ComplexDeinterleavingGraph::processReductionOperation( 2008 Value *OperationReplacement, RawNodePtr Node) { 2009 auto *Real = cast<Instruction>(Node->Real); 2010 auto *Imag = cast<Instruction>(Node->Imag); 2011 auto *OldPHIReal = ReductionInfo[Real].first; 2012 auto *OldPHIImag = ReductionInfo[Imag].first; 2013 auto *NewPHI = OldToNewPHI[OldPHIReal]; 2014 2015 auto *VTy = cast<VectorType>(Real->getType()); 2016 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); 2017 2018 // We have to interleave initial origin values coming from IncomingBlock 2019 Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming); 2020 Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming); 2021 2022 IRBuilder<> Builder(Incoming->getTerminator()); 2023 auto *NewInit = Builder.CreateIntrinsic( 2024 Intrinsic::experimental_vector_interleave2, NewVTy, {InitReal, InitImag}); 2025 2026 NewPHI->addIncoming(NewInit, Incoming); 2027 NewPHI->addIncoming(OperationReplacement, BackEdge); 2028 2029 // Deinterleave complex vector outside of loop so that it can be finally 2030 // reduced 2031 auto *FinalReductionReal = ReductionInfo[Real].second; 2032 auto *FinalReductionImag = ReductionInfo[Imag].second; 2033 2034 Builder.SetInsertPoint( 2035 &*FinalReductionReal->getParent()->getFirstInsertionPt()); 2036 auto *Deinterleave = Builder.CreateIntrinsic( 2037 Intrinsic::experimental_vector_deinterleave2, 2038 OperationReplacement->getType(), OperationReplacement); 2039 2040 auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0); 2041 FinalReductionReal->replaceUsesOfWith(Real, NewReal); 2042 2043 Builder.SetInsertPoint(FinalReductionImag); 2044 auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1); 2045 FinalReductionImag->replaceUsesOfWith(Imag, NewImag); 2046 } 2047 2048 void ComplexDeinterleavingGraph::replaceNodes() { 2049 SmallVector<Instruction *, 16> DeadInstrRoots; 2050 for (auto *RootInstruction : OrderedRoots) { 2051 // Check if this potential root went through check process and we can 2052 // deinterleave it 2053 if (!RootToNode.count(RootInstruction)) 2054 continue; 2055 2056 IRBuilder<> Builder(RootInstruction); 2057 auto RootNode = RootToNode[RootInstruction]; 2058 Value *R = replaceNode(Builder, RootNode.get()); 2059 2060 if (RootNode->Operation == 2061 ComplexDeinterleavingOperation::ReductionOperation) { 2062 auto *RootReal = cast<Instruction>(RootNode->Real); 2063 auto *RootImag = cast<Instruction>(RootNode->Imag); 2064 ReductionInfo[RootReal].first->removeIncomingValue(BackEdge); 2065 ReductionInfo[RootImag].first->removeIncomingValue(BackEdge); 2066 DeadInstrRoots.push_back(cast<Instruction>(RootReal)); 2067 DeadInstrRoots.push_back(cast<Instruction>(RootImag)); 2068 } else { 2069 assert(R && "Unable to find replacement for RootInstruction"); 2070 DeadInstrRoots.push_back(RootInstruction); 2071 RootInstruction->replaceAllUsesWith(R); 2072 } 2073 } 2074 2075 for (auto *I : DeadInstrRoots) 2076 RecursivelyDeleteTriviallyDeadInstructions(I, TLI); 2077 } 2078