1 //===- ComplexDeinterleavingPass.cpp --------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Identification: 10 // This step is responsible for finding the patterns that can be lowered to 11 // complex instructions, and building a graph to represent the complex 12 // structures. Starting from the "Converging Shuffle" (a shuffle that 13 // reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the 14 // operands are evaluated and identified as "Composite Nodes" (collections of 15 // instructions that can potentially be lowered to a single complex 16 // instruction). This is performed by checking the real and imaginary components 17 // and tracking the data flow for each component while following the operand 18 // pairs. Validity of each node is expected to be done upon creation, and any 19 // validation errors should halt traversal and prevent further graph 20 // construction. 21 // Instead of relying on Shuffle operations, vector interleaving and 22 // deinterleaving can be represented by vector.interleave2 and 23 // vector.deinterleave2 intrinsics. Scalable vectors can be represented only by 24 // these intrinsics, whereas, fixed-width vectors are recognized for both 25 // shufflevector instruction and intrinsics. 26 // 27 // Replacement: 28 // This step traverses the graph built up by identification, delegating to the 29 // target to validate and generate the correct intrinsics, and plumbs them 30 // together connecting each end of the new intrinsics graph to the existing 31 // use-def chain. This step is assumed to finish successfully, as all 32 // information is expected to be correct by this point. 33 // 34 // 35 // Internal data structure: 36 // ComplexDeinterleavingGraph: 37 // Keeps references to all the valid CompositeNodes formed as part of the 38 // transformation, and every Instruction contained within said nodes. It also 39 // holds onto a reference to the root Instruction, and the root node that should 40 // replace it. 41 // 42 // ComplexDeinterleavingCompositeNode: 43 // A CompositeNode represents a single transformation point; each node should 44 // transform into a single complex instruction (ignoring vector splitting, which 45 // would generate more instructions per node). They are identified in a 46 // depth-first manner, traversing and identifying the operands of each 47 // instruction in the order they appear in the IR. 48 // Each node maintains a reference to its Real and Imaginary instructions, 49 // as well as any additional instructions that make up the identified operation 50 // (Internal instructions should only have uses within their containing node). 51 // A Node also contains the rotation and operation type that it represents. 52 // Operands contains pointers to other CompositeNodes, acting as the edges in 53 // the graph. ReplacementValue is the transformed Value* that has been emitted 54 // to the IR. 55 // 56 // Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and 57 // ReplacementValue fields of that Node are relevant, where the ReplacementValue 58 // should be pre-populated. 59 // 60 //===----------------------------------------------------------------------===// 61 62 #include "llvm/CodeGen/ComplexDeinterleavingPass.h" 63 #include "llvm/ADT/Statistic.h" 64 #include "llvm/Analysis/TargetLibraryInfo.h" 65 #include "llvm/Analysis/TargetTransformInfo.h" 66 #include "llvm/CodeGen/TargetLowering.h" 67 #include "llvm/CodeGen/TargetPassConfig.h" 68 #include "llvm/CodeGen/TargetSubtargetInfo.h" 69 #include "llvm/IR/IRBuilder.h" 70 #include "llvm/IR/PatternMatch.h" 71 #include "llvm/InitializePasses.h" 72 #include "llvm/Target/TargetMachine.h" 73 #include "llvm/Transforms/Utils/Local.h" 74 #include <algorithm> 75 76 using namespace llvm; 77 using namespace PatternMatch; 78 79 #define DEBUG_TYPE "complex-deinterleaving" 80 81 STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed"); 82 83 static cl::opt<bool> ComplexDeinterleavingEnabled( 84 "enable-complex-deinterleaving", 85 cl::desc("Enable generation of complex instructions"), cl::init(true), 86 cl::Hidden); 87 88 /// Checks the given mask, and determines whether said mask is interleaving. 89 /// 90 /// To be interleaving, a mask must alternate between `i` and `i + (Length / 91 /// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a 92 /// 4x vector interleaving mask would be <0, 2, 1, 3>). 93 static bool isInterleavingMask(ArrayRef<int> Mask); 94 95 /// Checks the given mask, and determines whether said mask is deinterleaving. 96 /// 97 /// To be deinterleaving, a mask must increment in steps of 2, and either start 98 /// with 0 or 1. 99 /// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or 100 /// <1, 3, 5, 7>). 101 static bool isDeinterleavingMask(ArrayRef<int> Mask); 102 103 /// Returns true if the operation is a negation of V, and it works for both 104 /// integers and floats. 105 static bool isNeg(Value *V); 106 107 /// Returns the operand for negation operation. 108 static Value *getNegOperand(Value *V); 109 110 namespace { 111 112 class ComplexDeinterleavingLegacyPass : public FunctionPass { 113 public: 114 static char ID; 115 116 ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr) 117 : FunctionPass(ID), TM(TM) { 118 initializeComplexDeinterleavingLegacyPassPass( 119 *PassRegistry::getPassRegistry()); 120 } 121 122 StringRef getPassName() const override { 123 return "Complex Deinterleaving Pass"; 124 } 125 126 bool runOnFunction(Function &F) override; 127 void getAnalysisUsage(AnalysisUsage &AU) const override { 128 AU.addRequired<TargetLibraryInfoWrapperPass>(); 129 AU.setPreservesCFG(); 130 } 131 132 private: 133 const TargetMachine *TM; 134 }; 135 136 class ComplexDeinterleavingGraph; 137 struct ComplexDeinterleavingCompositeNode { 138 139 ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op, 140 Value *R, Value *I) 141 : Operation(Op), Real(R), Imag(I) {} 142 143 private: 144 friend class ComplexDeinterleavingGraph; 145 using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>; 146 using RawNodePtr = ComplexDeinterleavingCompositeNode *; 147 148 public: 149 ComplexDeinterleavingOperation Operation; 150 Value *Real; 151 Value *Imag; 152 153 // This two members are required exclusively for generating 154 // ComplexDeinterleavingOperation::Symmetric operations. 155 unsigned Opcode; 156 std::optional<FastMathFlags> Flags; 157 158 ComplexDeinterleavingRotation Rotation = 159 ComplexDeinterleavingRotation::Rotation_0; 160 SmallVector<RawNodePtr> Operands; 161 Value *ReplacementNode = nullptr; 162 163 void addOperand(NodePtr Node) { Operands.push_back(Node.get()); } 164 165 void dump() { dump(dbgs()); } 166 void dump(raw_ostream &OS) { 167 auto PrintValue = [&](Value *V) { 168 if (V) { 169 OS << "\""; 170 V->print(OS, true); 171 OS << "\"\n"; 172 } else 173 OS << "nullptr\n"; 174 }; 175 auto PrintNodeRef = [&](RawNodePtr Ptr) { 176 if (Ptr) 177 OS << Ptr << "\n"; 178 else 179 OS << "nullptr\n"; 180 }; 181 182 OS << "- CompositeNode: " << this << "\n"; 183 OS << " Real: "; 184 PrintValue(Real); 185 OS << " Imag: "; 186 PrintValue(Imag); 187 OS << " ReplacementNode: "; 188 PrintValue(ReplacementNode); 189 OS << " Operation: " << (int)Operation << "\n"; 190 OS << " Rotation: " << ((int)Rotation * 90) << "\n"; 191 OS << " Operands: \n"; 192 for (const auto &Op : Operands) { 193 OS << " - "; 194 PrintNodeRef(Op); 195 } 196 } 197 }; 198 199 class ComplexDeinterleavingGraph { 200 public: 201 struct Product { 202 Value *Multiplier; 203 Value *Multiplicand; 204 bool IsPositive; 205 }; 206 207 using Addend = std::pair<Value *, bool>; 208 using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr; 209 using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr; 210 211 // Helper struct for holding info about potential partial multiplication 212 // candidates 213 struct PartialMulCandidate { 214 Value *Common; 215 NodePtr Node; 216 unsigned RealIdx; 217 unsigned ImagIdx; 218 bool IsNodeInverted; 219 }; 220 221 explicit ComplexDeinterleavingGraph(const TargetLowering *TL, 222 const TargetLibraryInfo *TLI) 223 : TL(TL), TLI(TLI) {} 224 225 private: 226 const TargetLowering *TL = nullptr; 227 const TargetLibraryInfo *TLI = nullptr; 228 SmallVector<NodePtr> CompositeNodes; 229 230 SmallPtrSet<Instruction *, 16> FinalInstructions; 231 232 /// Root instructions are instructions from which complex computation starts 233 std::map<Instruction *, NodePtr> RootToNode; 234 235 /// Topologically sorted root instructions 236 SmallVector<Instruction *, 1> OrderedRoots; 237 238 /// When examining a basic block for complex deinterleaving, if it is a simple 239 /// one-block loop, then the only incoming block is 'Incoming' and the 240 /// 'BackEdge' block is the block itself." 241 BasicBlock *BackEdge = nullptr; 242 BasicBlock *Incoming = nullptr; 243 244 /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction 245 /// %OutsideUser as it is shown in the IR: 246 /// 247 /// vector.body: 248 /// %PHInode = phi <vector type> [ zeroinitializer, %entry ], 249 /// [ %ReductionOp, %vector.body ] 250 /// ... 251 /// %ReductionOp = fadd i64 ... 252 /// ... 253 /// br i1 %condition, label %vector.body, %middle.block 254 /// 255 /// middle.block: 256 /// %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp) 257 /// 258 /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding 259 /// `llvm.vector.reduce.fadd` when unroll factor isn't one. 260 std::map<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo; 261 262 /// In the process of detecting a reduction, we consider a pair of 263 /// %ReductionOP, which we refer to as real and imag (or vice versa), and 264 /// traverse the use-tree to detect complex operations. As this is a reduction 265 /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds 266 /// to the %ReductionOPs that we suspect to be complex. 267 /// RealPHI and ImagPHI are used by the identifyPHINode method. 268 PHINode *RealPHI = nullptr; 269 PHINode *ImagPHI = nullptr; 270 271 /// Set this flag to true if RealPHI and ImagPHI were reached during reduction 272 /// detection. 273 bool PHIsFound = false; 274 275 /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode. 276 /// The new PHINode corresponds to a vector of deinterleaved complex numbers. 277 /// This mapping is populated during 278 /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then 279 /// used in the ComplexDeinterleavingOperation::ReductionOperation node 280 /// replacement process. 281 std::map<PHINode *, PHINode *> OldToNewPHI; 282 283 NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation, 284 Value *R, Value *I) { 285 assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI && 286 Operation != ComplexDeinterleavingOperation::ReductionOperation) || 287 (R && I)) && 288 "Reduction related nodes must have Real and Imaginary parts"); 289 return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R, 290 I); 291 } 292 293 NodePtr submitCompositeNode(NodePtr Node) { 294 CompositeNodes.push_back(Node); 295 return Node; 296 } 297 298 NodePtr getContainingComposite(Value *R, Value *I) { 299 for (const auto &CN : CompositeNodes) { 300 if (CN->Real == R && CN->Imag == I) 301 return CN; 302 } 303 return nullptr; 304 } 305 306 /// Identifies a complex partial multiply pattern and its rotation, based on 307 /// the following patterns 308 /// 309 /// 0: r: cr + ar * br 310 /// i: ci + ar * bi 311 /// 90: r: cr - ai * bi 312 /// i: ci + ai * br 313 /// 180: r: cr - ar * br 314 /// i: ci - ar * bi 315 /// 270: r: cr + ai * bi 316 /// i: ci - ai * br 317 NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag); 318 319 /// Identify the other branch of a Partial Mul, taking the CommonOperandI that 320 /// is partially known from identifyPartialMul, filling in the other half of 321 /// the complex pair. 322 NodePtr 323 identifyNodeWithImplicitAdd(Instruction *I, Instruction *J, 324 std::pair<Value *, Value *> &CommonOperandI); 325 326 /// Identifies a complex add pattern and its rotation, based on the following 327 /// patterns. 328 /// 329 /// 90: r: ar - bi 330 /// i: ai + br 331 /// 270: r: ar + bi 332 /// i: ai - br 333 NodePtr identifyAdd(Instruction *Real, Instruction *Imag); 334 NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag); 335 336 NodePtr identifyNode(Value *R, Value *I); 337 338 /// Determine if a sum of complex numbers can be formed from \p RealAddends 339 /// and \p ImagAddens. If \p Accumulator is not null, add the result to it. 340 /// Return nullptr if it is not possible to construct a complex number. 341 /// \p Flags are needed to generate symmetric Add and Sub operations. 342 NodePtr identifyAdditions(std::list<Addend> &RealAddends, 343 std::list<Addend> &ImagAddends, 344 std::optional<FastMathFlags> Flags, 345 NodePtr Accumulator); 346 347 /// Extract one addend that have both real and imaginary parts positive. 348 NodePtr extractPositiveAddend(std::list<Addend> &RealAddends, 349 std::list<Addend> &ImagAddends); 350 351 /// Determine if sum of multiplications of complex numbers can be formed from 352 /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result 353 /// to it. Return nullptr if it is not possible to construct a complex number. 354 NodePtr identifyMultiplications(std::vector<Product> &RealMuls, 355 std::vector<Product> &ImagMuls, 356 NodePtr Accumulator); 357 358 /// Go through pairs of multiplication (one Real and one Imag) and find all 359 /// possible candidates for partial multiplication and put them into \p 360 /// Candidates. Returns true if all Product has pair with common operand 361 bool collectPartialMuls(const std::vector<Product> &RealMuls, 362 const std::vector<Product> &ImagMuls, 363 std::vector<PartialMulCandidate> &Candidates); 364 365 /// If the code is compiled with -Ofast or expressions have `reassoc` flag, 366 /// the order of complex computation operations may be significantly altered, 367 /// and the real and imaginary parts may not be executed in parallel. This 368 /// function takes this into consideration and employs a more general approach 369 /// to identify complex computations. Initially, it gathers all the addends 370 /// and multiplicands and then constructs a complex expression from them. 371 NodePtr identifyReassocNodes(Instruction *I, Instruction *J); 372 373 NodePtr identifyRoot(Instruction *I); 374 375 /// Identifies the Deinterleave operation applied to a vector containing 376 /// complex numbers. There are two ways to represent the Deinterleave 377 /// operation: 378 /// * Using two shufflevectors with even indices for /pReal instruction and 379 /// odd indices for /pImag instructions (only for fixed-width vectors) 380 /// * Using two extractvalue instructions applied to `vector.deinterleave2` 381 /// intrinsic (for both fixed and scalable vectors) 382 NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag); 383 384 /// identifying the operation that represents a complex number repeated in a 385 /// Splat vector. There are two possible types of splats: ConstantExpr with 386 /// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an 387 /// initialization mask with all values set to zero. 388 NodePtr identifySplat(Value *Real, Value *Imag); 389 390 NodePtr identifyPHINode(Instruction *Real, Instruction *Imag); 391 392 /// Identifies SelectInsts in a loop that has reduction with predication masks 393 /// and/or predicated tail folding 394 NodePtr identifySelectNode(Instruction *Real, Instruction *Imag); 395 396 Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node); 397 398 /// Complete IR modifications after producing new reduction operation: 399 /// * Populate the PHINode generated for 400 /// ComplexDeinterleavingOperation::ReductionPHI 401 /// * Deinterleave the final value outside of the loop and repurpose original 402 /// reduction users 403 void processReductionOperation(Value *OperationReplacement, RawNodePtr Node); 404 405 public: 406 void dump() { dump(dbgs()); } 407 void dump(raw_ostream &OS) { 408 for (const auto &Node : CompositeNodes) 409 Node->dump(OS); 410 } 411 412 /// Returns false if the deinterleaving operation should be cancelled for the 413 /// current graph. 414 bool identifyNodes(Instruction *RootI); 415 416 /// In case \pB is one-block loop, this function seeks potential reductions 417 /// and populates ReductionInfo. Returns true if any reductions were 418 /// identified. 419 bool collectPotentialReductions(BasicBlock *B); 420 421 void identifyReductionNodes(); 422 423 /// Check that every instruction, from the roots to the leaves, has internal 424 /// uses. 425 bool checkNodes(); 426 427 /// Perform the actual replacement of the underlying instruction graph. 428 void replaceNodes(); 429 }; 430 431 class ComplexDeinterleaving { 432 public: 433 ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli) 434 : TL(tl), TLI(tli) {} 435 bool runOnFunction(Function &F); 436 437 private: 438 bool evaluateBasicBlock(BasicBlock *B); 439 440 const TargetLowering *TL = nullptr; 441 const TargetLibraryInfo *TLI = nullptr; 442 }; 443 444 } // namespace 445 446 char ComplexDeinterleavingLegacyPass::ID = 0; 447 448 INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE, 449 "Complex Deinterleaving", false, false) 450 INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE, 451 "Complex Deinterleaving", false, false) 452 453 PreservedAnalyses ComplexDeinterleavingPass::run(Function &F, 454 FunctionAnalysisManager &AM) { 455 const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering(); 456 auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F); 457 if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F)) 458 return PreservedAnalyses::all(); 459 460 PreservedAnalyses PA; 461 PA.preserve<FunctionAnalysisManagerModuleProxy>(); 462 return PA; 463 } 464 465 FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) { 466 return new ComplexDeinterleavingLegacyPass(TM); 467 } 468 469 bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) { 470 const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering(); 471 auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); 472 return ComplexDeinterleaving(TL, &TLI).runOnFunction(F); 473 } 474 475 bool ComplexDeinterleaving::runOnFunction(Function &F) { 476 if (!ComplexDeinterleavingEnabled) { 477 LLVM_DEBUG( 478 dbgs() << "Complex deinterleaving has been explicitly disabled.\n"); 479 return false; 480 } 481 482 if (!TL->isComplexDeinterleavingSupported()) { 483 LLVM_DEBUG( 484 dbgs() << "Complex deinterleaving has been disabled, target does " 485 "not support lowering of complex number operations.\n"); 486 return false; 487 } 488 489 bool Changed = false; 490 for (auto &B : F) 491 Changed |= evaluateBasicBlock(&B); 492 493 return Changed; 494 } 495 496 static bool isInterleavingMask(ArrayRef<int> Mask) { 497 // If the size is not even, it's not an interleaving mask 498 if ((Mask.size() & 1)) 499 return false; 500 501 int HalfNumElements = Mask.size() / 2; 502 for (int Idx = 0; Idx < HalfNumElements; ++Idx) { 503 int MaskIdx = Idx * 2; 504 if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements)) 505 return false; 506 } 507 508 return true; 509 } 510 511 static bool isDeinterleavingMask(ArrayRef<int> Mask) { 512 int Offset = Mask[0]; 513 int HalfNumElements = Mask.size() / 2; 514 515 for (int Idx = 1; Idx < HalfNumElements; ++Idx) { 516 if (Mask[Idx] != (Idx * 2) + Offset) 517 return false; 518 } 519 520 return true; 521 } 522 523 bool isNeg(Value *V) { 524 return match(V, m_FNeg(m_Value())) || match(V, m_Neg(m_Value())); 525 } 526 527 Value *getNegOperand(Value *V) { 528 assert(isNeg(V)); 529 auto *I = cast<Instruction>(V); 530 if (I->getOpcode() == Instruction::FNeg) 531 return I->getOperand(0); 532 533 return I->getOperand(1); 534 } 535 536 bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) { 537 ComplexDeinterleavingGraph Graph(TL, TLI); 538 if (Graph.collectPotentialReductions(B)) 539 Graph.identifyReductionNodes(); 540 541 for (auto &I : *B) 542 Graph.identifyNodes(&I); 543 544 if (Graph.checkNodes()) { 545 Graph.replaceNodes(); 546 return true; 547 } 548 549 return false; 550 } 551 552 ComplexDeinterleavingGraph::NodePtr 553 ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd( 554 Instruction *Real, Instruction *Imag, 555 std::pair<Value *, Value *> &PartialMatch) { 556 LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag 557 << "\n"); 558 559 if (!Real->hasOneUse() || !Imag->hasOneUse()) { 560 LLVM_DEBUG(dbgs() << " - Mul operand has multiple uses.\n"); 561 return nullptr; 562 } 563 564 if ((Real->getOpcode() != Instruction::FMul && 565 Real->getOpcode() != Instruction::Mul) || 566 (Imag->getOpcode() != Instruction::FMul && 567 Imag->getOpcode() != Instruction::Mul)) { 568 LLVM_DEBUG( 569 dbgs() << " - Real or imaginary instruction is not fmul or mul\n"); 570 return nullptr; 571 } 572 573 Value *R0 = Real->getOperand(0); 574 Value *R1 = Real->getOperand(1); 575 Value *I0 = Imag->getOperand(0); 576 Value *I1 = Imag->getOperand(1); 577 578 // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the 579 // rotations and use the operand. 580 unsigned Negs = 0; 581 Value *Op; 582 if (match(R0, m_Neg(m_Value(Op)))) { 583 Negs |= 1; 584 R0 = Op; 585 } else if (match(R1, m_Neg(m_Value(Op)))) { 586 Negs |= 1; 587 R1 = Op; 588 } 589 590 if (isNeg(I0)) { 591 Negs |= 2; 592 Negs ^= 1; 593 I0 = Op; 594 } else if (match(I1, m_Neg(m_Value(Op)))) { 595 Negs |= 2; 596 Negs ^= 1; 597 I1 = Op; 598 } 599 600 ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs; 601 602 Value *CommonOperand; 603 Value *UncommonRealOp; 604 Value *UncommonImagOp; 605 606 if (R0 == I0 || R0 == I1) { 607 CommonOperand = R0; 608 UncommonRealOp = R1; 609 } else if (R1 == I0 || R1 == I1) { 610 CommonOperand = R1; 611 UncommonRealOp = R0; 612 } else { 613 LLVM_DEBUG(dbgs() << " - No equal operand\n"); 614 return nullptr; 615 } 616 617 UncommonImagOp = (CommonOperand == I0) ? I1 : I0; 618 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 || 619 Rotation == ComplexDeinterleavingRotation::Rotation_270) 620 std::swap(UncommonRealOp, UncommonImagOp); 621 622 // Between identifyPartialMul and here we need to have found a complete valid 623 // pair from the CommonOperand of each part. 624 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 || 625 Rotation == ComplexDeinterleavingRotation::Rotation_180) 626 PartialMatch.first = CommonOperand; 627 else 628 PartialMatch.second = CommonOperand; 629 630 if (!PartialMatch.first || !PartialMatch.second) { 631 LLVM_DEBUG(dbgs() << " - Incomplete partial match\n"); 632 return nullptr; 633 } 634 635 NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second); 636 if (!CommonNode) { 637 LLVM_DEBUG(dbgs() << " - No CommonNode identified\n"); 638 return nullptr; 639 } 640 641 NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp); 642 if (!UncommonNode) { 643 LLVM_DEBUG(dbgs() << " - No UncommonNode identified\n"); 644 return nullptr; 645 } 646 647 NodePtr Node = prepareCompositeNode( 648 ComplexDeinterleavingOperation::CMulPartial, Real, Imag); 649 Node->Rotation = Rotation; 650 Node->addOperand(CommonNode); 651 Node->addOperand(UncommonNode); 652 return submitCompositeNode(Node); 653 } 654 655 ComplexDeinterleavingGraph::NodePtr 656 ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real, 657 Instruction *Imag) { 658 LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag 659 << "\n"); 660 // Determine rotation 661 auto IsAdd = [](unsigned Op) { 662 return Op == Instruction::FAdd || Op == Instruction::Add; 663 }; 664 auto IsSub = [](unsigned Op) { 665 return Op == Instruction::FSub || Op == Instruction::Sub; 666 }; 667 ComplexDeinterleavingRotation Rotation; 668 if (IsAdd(Real->getOpcode()) && IsAdd(Imag->getOpcode())) 669 Rotation = ComplexDeinterleavingRotation::Rotation_0; 670 else if (IsSub(Real->getOpcode()) && IsAdd(Imag->getOpcode())) 671 Rotation = ComplexDeinterleavingRotation::Rotation_90; 672 else if (IsSub(Real->getOpcode()) && IsSub(Imag->getOpcode())) 673 Rotation = ComplexDeinterleavingRotation::Rotation_180; 674 else if (IsAdd(Real->getOpcode()) && IsSub(Imag->getOpcode())) 675 Rotation = ComplexDeinterleavingRotation::Rotation_270; 676 else { 677 LLVM_DEBUG(dbgs() << " - Unhandled rotation.\n"); 678 return nullptr; 679 } 680 681 if (isa<FPMathOperator>(Real) && 682 (!Real->getFastMathFlags().allowContract() || 683 !Imag->getFastMathFlags().allowContract())) { 684 LLVM_DEBUG(dbgs() << " - Contract is missing from the FastMath flags.\n"); 685 return nullptr; 686 } 687 688 Value *CR = Real->getOperand(0); 689 Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1)); 690 if (!RealMulI) 691 return nullptr; 692 Value *CI = Imag->getOperand(0); 693 Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1)); 694 if (!ImagMulI) 695 return nullptr; 696 697 if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) { 698 LLVM_DEBUG(dbgs() << " - Mul instruction has multiple uses\n"); 699 return nullptr; 700 } 701 702 Value *R0 = RealMulI->getOperand(0); 703 Value *R1 = RealMulI->getOperand(1); 704 Value *I0 = ImagMulI->getOperand(0); 705 Value *I1 = ImagMulI->getOperand(1); 706 707 Value *CommonOperand; 708 Value *UncommonRealOp; 709 Value *UncommonImagOp; 710 711 if (R0 == I0 || R0 == I1) { 712 CommonOperand = R0; 713 UncommonRealOp = R1; 714 } else if (R1 == I0 || R1 == I1) { 715 CommonOperand = R1; 716 UncommonRealOp = R0; 717 } else { 718 LLVM_DEBUG(dbgs() << " - No equal operand\n"); 719 return nullptr; 720 } 721 722 UncommonImagOp = (CommonOperand == I0) ? I1 : I0; 723 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 || 724 Rotation == ComplexDeinterleavingRotation::Rotation_270) 725 std::swap(UncommonRealOp, UncommonImagOp); 726 727 std::pair<Value *, Value *> PartialMatch( 728 (Rotation == ComplexDeinterleavingRotation::Rotation_0 || 729 Rotation == ComplexDeinterleavingRotation::Rotation_180) 730 ? CommonOperand 731 : nullptr, 732 (Rotation == ComplexDeinterleavingRotation::Rotation_90 || 733 Rotation == ComplexDeinterleavingRotation::Rotation_270) 734 ? CommonOperand 735 : nullptr); 736 737 auto *CRInst = dyn_cast<Instruction>(CR); 738 auto *CIInst = dyn_cast<Instruction>(CI); 739 740 if (!CRInst || !CIInst) { 741 LLVM_DEBUG(dbgs() << " - Common operands are not instructions.\n"); 742 return nullptr; 743 } 744 745 NodePtr CNode = identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch); 746 if (!CNode) { 747 LLVM_DEBUG(dbgs() << " - No cnode identified\n"); 748 return nullptr; 749 } 750 751 NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp); 752 if (!UncommonRes) { 753 LLVM_DEBUG(dbgs() << " - No UncommonRes identified\n"); 754 return nullptr; 755 } 756 757 assert(PartialMatch.first && PartialMatch.second); 758 NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second); 759 if (!CommonRes) { 760 LLVM_DEBUG(dbgs() << " - No CommonRes identified\n"); 761 return nullptr; 762 } 763 764 NodePtr Node = prepareCompositeNode( 765 ComplexDeinterleavingOperation::CMulPartial, Real, Imag); 766 Node->Rotation = Rotation; 767 Node->addOperand(CommonRes); 768 Node->addOperand(UncommonRes); 769 Node->addOperand(CNode); 770 return submitCompositeNode(Node); 771 } 772 773 ComplexDeinterleavingGraph::NodePtr 774 ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) { 775 LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n"); 776 777 // Determine rotation 778 ComplexDeinterleavingRotation Rotation; 779 if ((Real->getOpcode() == Instruction::FSub && 780 Imag->getOpcode() == Instruction::FAdd) || 781 (Real->getOpcode() == Instruction::Sub && 782 Imag->getOpcode() == Instruction::Add)) 783 Rotation = ComplexDeinterleavingRotation::Rotation_90; 784 else if ((Real->getOpcode() == Instruction::FAdd && 785 Imag->getOpcode() == Instruction::FSub) || 786 (Real->getOpcode() == Instruction::Add && 787 Imag->getOpcode() == Instruction::Sub)) 788 Rotation = ComplexDeinterleavingRotation::Rotation_270; 789 else { 790 LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n"); 791 return nullptr; 792 } 793 794 auto *AR = dyn_cast<Instruction>(Real->getOperand(0)); 795 auto *BI = dyn_cast<Instruction>(Real->getOperand(1)); 796 auto *AI = dyn_cast<Instruction>(Imag->getOperand(0)); 797 auto *BR = dyn_cast<Instruction>(Imag->getOperand(1)); 798 799 if (!AR || !AI || !BR || !BI) { 800 LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n"); 801 return nullptr; 802 } 803 804 NodePtr ResA = identifyNode(AR, AI); 805 if (!ResA) { 806 LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n"); 807 return nullptr; 808 } 809 NodePtr ResB = identifyNode(BR, BI); 810 if (!ResB) { 811 LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n"); 812 return nullptr; 813 } 814 815 NodePtr Node = 816 prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag); 817 Node->Rotation = Rotation; 818 Node->addOperand(ResA); 819 Node->addOperand(ResB); 820 return submitCompositeNode(Node); 821 } 822 823 static bool isInstructionPairAdd(Instruction *A, Instruction *B) { 824 unsigned OpcA = A->getOpcode(); 825 unsigned OpcB = B->getOpcode(); 826 827 return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) || 828 (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) || 829 (OpcA == Instruction::Sub && OpcB == Instruction::Add) || 830 (OpcA == Instruction::Add && OpcB == Instruction::Sub); 831 } 832 833 static bool isInstructionPairMul(Instruction *A, Instruction *B) { 834 auto Pattern = 835 m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value())); 836 837 return match(A, Pattern) && match(B, Pattern); 838 } 839 840 static bool isInstructionPotentiallySymmetric(Instruction *I) { 841 switch (I->getOpcode()) { 842 case Instruction::FAdd: 843 case Instruction::FSub: 844 case Instruction::FMul: 845 case Instruction::FNeg: 846 case Instruction::Add: 847 case Instruction::Sub: 848 case Instruction::Mul: 849 return true; 850 default: 851 return false; 852 } 853 } 854 855 ComplexDeinterleavingGraph::NodePtr 856 ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real, 857 Instruction *Imag) { 858 if (Real->getOpcode() != Imag->getOpcode()) 859 return nullptr; 860 861 if (!isInstructionPotentiallySymmetric(Real) || 862 !isInstructionPotentiallySymmetric(Imag)) 863 return nullptr; 864 865 auto *R0 = Real->getOperand(0); 866 auto *I0 = Imag->getOperand(0); 867 868 NodePtr Op0 = identifyNode(R0, I0); 869 NodePtr Op1 = nullptr; 870 if (Op0 == nullptr) 871 return nullptr; 872 873 if (Real->isBinaryOp()) { 874 auto *R1 = Real->getOperand(1); 875 auto *I1 = Imag->getOperand(1); 876 Op1 = identifyNode(R1, I1); 877 if (Op1 == nullptr) 878 return nullptr; 879 } 880 881 if (isa<FPMathOperator>(Real) && 882 Real->getFastMathFlags() != Imag->getFastMathFlags()) 883 return nullptr; 884 885 auto Node = prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric, 886 Real, Imag); 887 Node->Opcode = Real->getOpcode(); 888 if (isa<FPMathOperator>(Real)) 889 Node->Flags = Real->getFastMathFlags(); 890 891 Node->addOperand(Op0); 892 if (Real->isBinaryOp()) 893 Node->addOperand(Op1); 894 895 return submitCompositeNode(Node); 896 } 897 898 ComplexDeinterleavingGraph::NodePtr 899 ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) { 900 LLVM_DEBUG(dbgs() << "identifyNode on " << *R << " / " << *I << "\n"); 901 assert(R->getType() == I->getType() && 902 "Real and imaginary parts should not have different types"); 903 if (NodePtr CN = getContainingComposite(R, I)) { 904 LLVM_DEBUG(dbgs() << " - Folding to existing node\n"); 905 return CN; 906 } 907 908 if (NodePtr CN = identifySplat(R, I)) 909 return CN; 910 911 auto *Real = dyn_cast<Instruction>(R); 912 auto *Imag = dyn_cast<Instruction>(I); 913 if (!Real || !Imag) 914 return nullptr; 915 916 if (NodePtr CN = identifyDeinterleave(Real, Imag)) 917 return CN; 918 919 if (NodePtr CN = identifyPHINode(Real, Imag)) 920 return CN; 921 922 if (NodePtr CN = identifySelectNode(Real, Imag)) 923 return CN; 924 925 auto *VTy = cast<VectorType>(Real->getType()); 926 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); 927 928 bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported( 929 ComplexDeinterleavingOperation::CMulPartial, NewVTy); 930 bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported( 931 ComplexDeinterleavingOperation::CAdd, NewVTy); 932 933 if (HasCMulSupport && isInstructionPairMul(Real, Imag)) { 934 if (NodePtr CN = identifyPartialMul(Real, Imag)) 935 return CN; 936 } 937 938 if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) { 939 if (NodePtr CN = identifyAdd(Real, Imag)) 940 return CN; 941 } 942 943 if (HasCMulSupport && HasCAddSupport) { 944 if (NodePtr CN = identifyReassocNodes(Real, Imag)) 945 return CN; 946 } 947 948 if (NodePtr CN = identifySymmetricOperation(Real, Imag)) 949 return CN; 950 951 LLVM_DEBUG(dbgs() << " - Not recognised as a valid pattern.\n"); 952 return nullptr; 953 } 954 955 ComplexDeinterleavingGraph::NodePtr 956 ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real, 957 Instruction *Imag) { 958 auto IsOperationSupported = [](unsigned Opcode) -> bool { 959 return Opcode == Instruction::FAdd || Opcode == Instruction::FSub || 960 Opcode == Instruction::FNeg || Opcode == Instruction::Add || 961 Opcode == Instruction::Sub; 962 }; 963 964 if (!IsOperationSupported(Real->getOpcode()) || 965 !IsOperationSupported(Imag->getOpcode())) 966 return nullptr; 967 968 std::optional<FastMathFlags> Flags; 969 if (isa<FPMathOperator>(Real)) { 970 if (Real->getFastMathFlags() != Imag->getFastMathFlags()) { 971 LLVM_DEBUG(dbgs() << "The flags in Real and Imaginary instructions are " 972 "not identical\n"); 973 return nullptr; 974 } 975 976 Flags = Real->getFastMathFlags(); 977 if (!Flags->allowReassoc()) { 978 LLVM_DEBUG( 979 dbgs() 980 << "the 'Reassoc' attribute is missing in the FastMath flags\n"); 981 return nullptr; 982 } 983 } 984 985 // Collect multiplications and addend instructions from the given instruction 986 // while traversing it operands. Additionally, verify that all instructions 987 // have the same fast math flags. 988 auto Collect = [&Flags](Instruction *Insn, std::vector<Product> &Muls, 989 std::list<Addend> &Addends) -> bool { 990 SmallVector<PointerIntPair<Value *, 1, bool>> Worklist = {{Insn, true}}; 991 SmallPtrSet<Value *, 8> Visited; 992 while (!Worklist.empty()) { 993 auto [V, IsPositive] = Worklist.back(); 994 Worklist.pop_back(); 995 if (!Visited.insert(V).second) 996 continue; 997 998 Instruction *I = dyn_cast<Instruction>(V); 999 if (!I) { 1000 Addends.emplace_back(V, IsPositive); 1001 continue; 1002 } 1003 1004 // If an instruction has more than one user, it indicates that it either 1005 // has an external user, which will be later checked by the checkNodes 1006 // function, or it is a subexpression utilized by multiple expressions. In 1007 // the latter case, we will attempt to separately identify the complex 1008 // operation from here in order to create a shared 1009 // ComplexDeinterleavingCompositeNode. 1010 if (I != Insn && I->getNumUses() > 1) { 1011 LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n"); 1012 Addends.emplace_back(I, IsPositive); 1013 continue; 1014 } 1015 switch (I->getOpcode()) { 1016 case Instruction::FAdd: 1017 case Instruction::Add: 1018 Worklist.emplace_back(I->getOperand(1), IsPositive); 1019 Worklist.emplace_back(I->getOperand(0), IsPositive); 1020 break; 1021 case Instruction::FSub: 1022 Worklist.emplace_back(I->getOperand(1), !IsPositive); 1023 Worklist.emplace_back(I->getOperand(0), IsPositive); 1024 break; 1025 case Instruction::Sub: 1026 if (isNeg(I)) { 1027 Worklist.emplace_back(getNegOperand(I), !IsPositive); 1028 } else { 1029 Worklist.emplace_back(I->getOperand(1), !IsPositive); 1030 Worklist.emplace_back(I->getOperand(0), IsPositive); 1031 } 1032 break; 1033 case Instruction::FMul: 1034 case Instruction::Mul: { 1035 Value *A, *B; 1036 if (isNeg(I->getOperand(0))) { 1037 A = getNegOperand(I->getOperand(0)); 1038 IsPositive = !IsPositive; 1039 } else { 1040 A = I->getOperand(0); 1041 } 1042 1043 if (isNeg(I->getOperand(1))) { 1044 B = getNegOperand(I->getOperand(1)); 1045 IsPositive = !IsPositive; 1046 } else { 1047 B = I->getOperand(1); 1048 } 1049 Muls.push_back(Product{A, B, IsPositive}); 1050 break; 1051 } 1052 case Instruction::FNeg: 1053 Worklist.emplace_back(I->getOperand(0), !IsPositive); 1054 break; 1055 default: 1056 Addends.emplace_back(I, IsPositive); 1057 continue; 1058 } 1059 1060 if (Flags && I->getFastMathFlags() != *Flags) { 1061 LLVM_DEBUG(dbgs() << "The instruction's fast math flags are " 1062 "inconsistent with the root instructions' flags: " 1063 << *I << "\n"); 1064 return false; 1065 } 1066 } 1067 return true; 1068 }; 1069 1070 std::vector<Product> RealMuls, ImagMuls; 1071 std::list<Addend> RealAddends, ImagAddends; 1072 if (!Collect(Real, RealMuls, RealAddends) || 1073 !Collect(Imag, ImagMuls, ImagAddends)) 1074 return nullptr; 1075 1076 if (RealAddends.size() != ImagAddends.size()) 1077 return nullptr; 1078 1079 NodePtr FinalNode; 1080 if (!RealMuls.empty() || !ImagMuls.empty()) { 1081 // If there are multiplicands, extract positive addend and use it as an 1082 // accumulator 1083 FinalNode = extractPositiveAddend(RealAddends, ImagAddends); 1084 FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode); 1085 if (!FinalNode) 1086 return nullptr; 1087 } 1088 1089 // Identify and process remaining additions 1090 if (!RealAddends.empty() || !ImagAddends.empty()) { 1091 FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode); 1092 if (!FinalNode) 1093 return nullptr; 1094 } 1095 assert(FinalNode && "FinalNode can not be nullptr here"); 1096 // Set the Real and Imag fields of the final node and submit it 1097 FinalNode->Real = Real; 1098 FinalNode->Imag = Imag; 1099 submitCompositeNode(FinalNode); 1100 return FinalNode; 1101 } 1102 1103 bool ComplexDeinterleavingGraph::collectPartialMuls( 1104 const std::vector<Product> &RealMuls, const std::vector<Product> &ImagMuls, 1105 std::vector<PartialMulCandidate> &PartialMulCandidates) { 1106 // Helper function to extract a common operand from two products 1107 auto FindCommonInstruction = [](const Product &Real, 1108 const Product &Imag) -> Value * { 1109 if (Real.Multiplicand == Imag.Multiplicand || 1110 Real.Multiplicand == Imag.Multiplier) 1111 return Real.Multiplicand; 1112 1113 if (Real.Multiplier == Imag.Multiplicand || 1114 Real.Multiplier == Imag.Multiplier) 1115 return Real.Multiplier; 1116 1117 return nullptr; 1118 }; 1119 1120 // Iterating over real and imaginary multiplications to find common operands 1121 // If a common operand is found, a partial multiplication candidate is created 1122 // and added to the candidates vector The function returns false if no common 1123 // operands are found for any product 1124 for (unsigned i = 0; i < RealMuls.size(); ++i) { 1125 bool FoundCommon = false; 1126 for (unsigned j = 0; j < ImagMuls.size(); ++j) { 1127 auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]); 1128 if (!Common) 1129 continue; 1130 1131 auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier 1132 : RealMuls[i].Multiplicand; 1133 auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier 1134 : ImagMuls[j].Multiplicand; 1135 1136 auto Node = identifyNode(A, B); 1137 if (Node) { 1138 FoundCommon = true; 1139 PartialMulCandidates.push_back({Common, Node, i, j, false}); 1140 } 1141 1142 Node = identifyNode(B, A); 1143 if (Node) { 1144 FoundCommon = true; 1145 PartialMulCandidates.push_back({Common, Node, i, j, true}); 1146 } 1147 } 1148 if (!FoundCommon) 1149 return false; 1150 } 1151 return true; 1152 } 1153 1154 ComplexDeinterleavingGraph::NodePtr 1155 ComplexDeinterleavingGraph::identifyMultiplications( 1156 std::vector<Product> &RealMuls, std::vector<Product> &ImagMuls, 1157 NodePtr Accumulator = nullptr) { 1158 if (RealMuls.size() != ImagMuls.size()) 1159 return nullptr; 1160 1161 std::vector<PartialMulCandidate> Info; 1162 if (!collectPartialMuls(RealMuls, ImagMuls, Info)) 1163 return nullptr; 1164 1165 // Map to store common instruction to node pointers 1166 std::map<Value *, NodePtr> CommonToNode; 1167 std::vector<bool> Processed(Info.size(), false); 1168 for (unsigned I = 0; I < Info.size(); ++I) { 1169 if (Processed[I]) 1170 continue; 1171 1172 PartialMulCandidate &InfoA = Info[I]; 1173 for (unsigned J = I + 1; J < Info.size(); ++J) { 1174 if (Processed[J]) 1175 continue; 1176 1177 PartialMulCandidate &InfoB = Info[J]; 1178 auto *InfoReal = &InfoA; 1179 auto *InfoImag = &InfoB; 1180 1181 auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common); 1182 if (!NodeFromCommon) { 1183 std::swap(InfoReal, InfoImag); 1184 NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common); 1185 } 1186 if (!NodeFromCommon) 1187 continue; 1188 1189 CommonToNode[InfoReal->Common] = NodeFromCommon; 1190 CommonToNode[InfoImag->Common] = NodeFromCommon; 1191 Processed[I] = true; 1192 Processed[J] = true; 1193 } 1194 } 1195 1196 std::vector<bool> ProcessedReal(RealMuls.size(), false); 1197 std::vector<bool> ProcessedImag(ImagMuls.size(), false); 1198 NodePtr Result = Accumulator; 1199 for (auto &PMI : Info) { 1200 if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx]) 1201 continue; 1202 1203 auto It = CommonToNode.find(PMI.Common); 1204 // TODO: Process independent complex multiplications. Cases like this: 1205 // A.real() * B where both A and B are complex numbers. 1206 if (It == CommonToNode.end()) { 1207 LLVM_DEBUG({ 1208 dbgs() << "Unprocessed independent partial multiplication:\n"; 1209 for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]}) 1210 dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier 1211 << " multiplied by " << *Mul->Multiplicand << "\n"; 1212 }); 1213 return nullptr; 1214 } 1215 1216 auto &RealMul = RealMuls[PMI.RealIdx]; 1217 auto &ImagMul = ImagMuls[PMI.ImagIdx]; 1218 1219 auto NodeA = It->second; 1220 auto NodeB = PMI.Node; 1221 auto IsMultiplicandReal = PMI.Common == NodeA->Real; 1222 // The following table illustrates the relationship between multiplications 1223 // and rotations. If we consider the multiplication (X + iY) * (U + iV), we 1224 // can see: 1225 // 1226 // Rotation | Real | Imag | 1227 // ---------+--------+--------+ 1228 // 0 | x * u | x * v | 1229 // 90 | -y * v | y * u | 1230 // 180 | -x * u | -x * v | 1231 // 270 | y * v | -y * u | 1232 // 1233 // Check if the candidate can indeed be represented by partial 1234 // multiplication 1235 // TODO: Add support for multiplication by complex one 1236 if ((IsMultiplicandReal && PMI.IsNodeInverted) || 1237 (!IsMultiplicandReal && !PMI.IsNodeInverted)) 1238 continue; 1239 1240 // Determine the rotation based on the multiplications 1241 ComplexDeinterleavingRotation Rotation; 1242 if (IsMultiplicandReal) { 1243 // Detect 0 and 180 degrees rotation 1244 if (RealMul.IsPositive && ImagMul.IsPositive) 1245 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_0; 1246 else if (!RealMul.IsPositive && !ImagMul.IsPositive) 1247 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_180; 1248 else 1249 continue; 1250 1251 } else { 1252 // Detect 90 and 270 degrees rotation 1253 if (!RealMul.IsPositive && ImagMul.IsPositive) 1254 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_90; 1255 else if (RealMul.IsPositive && !ImagMul.IsPositive) 1256 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_270; 1257 else 1258 continue; 1259 } 1260 1261 LLVM_DEBUG({ 1262 dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n"; 1263 dbgs().indent(4) << "X: " << *NodeA->Real << "\n"; 1264 dbgs().indent(4) << "Y: " << *NodeA->Imag << "\n"; 1265 dbgs().indent(4) << "U: " << *NodeB->Real << "\n"; 1266 dbgs().indent(4) << "V: " << *NodeB->Imag << "\n"; 1267 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n"; 1268 }); 1269 1270 NodePtr NodeMul = prepareCompositeNode( 1271 ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr); 1272 NodeMul->Rotation = Rotation; 1273 NodeMul->addOperand(NodeA); 1274 NodeMul->addOperand(NodeB); 1275 if (Result) 1276 NodeMul->addOperand(Result); 1277 submitCompositeNode(NodeMul); 1278 Result = NodeMul; 1279 ProcessedReal[PMI.RealIdx] = true; 1280 ProcessedImag[PMI.ImagIdx] = true; 1281 } 1282 1283 // Ensure all products have been processed, if not return nullptr. 1284 if (!all_of(ProcessedReal, [](bool V) { return V; }) || 1285 !all_of(ProcessedImag, [](bool V) { return V; })) { 1286 1287 // Dump debug information about which partial multiplications are not 1288 // processed. 1289 LLVM_DEBUG({ 1290 dbgs() << "Unprocessed products (Real):\n"; 1291 for (size_t i = 0; i < ProcessedReal.size(); ++i) { 1292 if (!ProcessedReal[i]) 1293 dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-") 1294 << *RealMuls[i].Multiplier << " multiplied by " 1295 << *RealMuls[i].Multiplicand << "\n"; 1296 } 1297 dbgs() << "Unprocessed products (Imag):\n"; 1298 for (size_t i = 0; i < ProcessedImag.size(); ++i) { 1299 if (!ProcessedImag[i]) 1300 dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-") 1301 << *ImagMuls[i].Multiplier << " multiplied by " 1302 << *ImagMuls[i].Multiplicand << "\n"; 1303 } 1304 }); 1305 return nullptr; 1306 } 1307 1308 return Result; 1309 } 1310 1311 ComplexDeinterleavingGraph::NodePtr 1312 ComplexDeinterleavingGraph::identifyAdditions( 1313 std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends, 1314 std::optional<FastMathFlags> Flags, NodePtr Accumulator = nullptr) { 1315 if (RealAddends.size() != ImagAddends.size()) 1316 return nullptr; 1317 1318 NodePtr Result; 1319 // If we have accumulator use it as first addend 1320 if (Accumulator) 1321 Result = Accumulator; 1322 // Otherwise find an element with both positive real and imaginary parts. 1323 else 1324 Result = extractPositiveAddend(RealAddends, ImagAddends); 1325 1326 if (!Result) 1327 return nullptr; 1328 1329 while (!RealAddends.empty()) { 1330 auto ItR = RealAddends.begin(); 1331 auto [R, IsPositiveR] = *ItR; 1332 1333 bool FoundImag = false; 1334 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) { 1335 auto [I, IsPositiveI] = *ItI; 1336 ComplexDeinterleavingRotation Rotation; 1337 if (IsPositiveR && IsPositiveI) 1338 Rotation = ComplexDeinterleavingRotation::Rotation_0; 1339 else if (!IsPositiveR && IsPositiveI) 1340 Rotation = ComplexDeinterleavingRotation::Rotation_90; 1341 else if (!IsPositiveR && !IsPositiveI) 1342 Rotation = ComplexDeinterleavingRotation::Rotation_180; 1343 else 1344 Rotation = ComplexDeinterleavingRotation::Rotation_270; 1345 1346 NodePtr AddNode; 1347 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 || 1348 Rotation == ComplexDeinterleavingRotation::Rotation_180) { 1349 AddNode = identifyNode(R, I); 1350 } else { 1351 AddNode = identifyNode(I, R); 1352 } 1353 if (AddNode) { 1354 LLVM_DEBUG({ 1355 dbgs() << "Identified addition:\n"; 1356 dbgs().indent(4) << "X: " << *R << "\n"; 1357 dbgs().indent(4) << "Y: " << *I << "\n"; 1358 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n"; 1359 }); 1360 1361 NodePtr TmpNode; 1362 if (Rotation == llvm::ComplexDeinterleavingRotation::Rotation_0) { 1363 TmpNode = prepareCompositeNode( 1364 ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr); 1365 if (Flags) { 1366 TmpNode->Opcode = Instruction::FAdd; 1367 TmpNode->Flags = *Flags; 1368 } else { 1369 TmpNode->Opcode = Instruction::Add; 1370 } 1371 } else if (Rotation == 1372 llvm::ComplexDeinterleavingRotation::Rotation_180) { 1373 TmpNode = prepareCompositeNode( 1374 ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr); 1375 if (Flags) { 1376 TmpNode->Opcode = Instruction::FSub; 1377 TmpNode->Flags = *Flags; 1378 } else { 1379 TmpNode->Opcode = Instruction::Sub; 1380 } 1381 } else { 1382 TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, 1383 nullptr, nullptr); 1384 TmpNode->Rotation = Rotation; 1385 } 1386 1387 TmpNode->addOperand(Result); 1388 TmpNode->addOperand(AddNode); 1389 submitCompositeNode(TmpNode); 1390 Result = TmpNode; 1391 RealAddends.erase(ItR); 1392 ImagAddends.erase(ItI); 1393 FoundImag = true; 1394 break; 1395 } 1396 } 1397 if (!FoundImag) 1398 return nullptr; 1399 } 1400 return Result; 1401 } 1402 1403 ComplexDeinterleavingGraph::NodePtr 1404 ComplexDeinterleavingGraph::extractPositiveAddend( 1405 std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends) { 1406 for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) { 1407 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) { 1408 auto [R, IsPositiveR] = *ItR; 1409 auto [I, IsPositiveI] = *ItI; 1410 if (IsPositiveR && IsPositiveI) { 1411 auto Result = identifyNode(R, I); 1412 if (Result) { 1413 RealAddends.erase(ItR); 1414 ImagAddends.erase(ItI); 1415 return Result; 1416 } 1417 } 1418 } 1419 } 1420 return nullptr; 1421 } 1422 1423 bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) { 1424 // This potential root instruction might already have been recognized as 1425 // reduction. Because RootToNode maps both Real and Imaginary parts to 1426 // CompositeNode we should choose only one either Real or Imag instruction to 1427 // use as an anchor for generating complex instruction. 1428 auto It = RootToNode.find(RootI); 1429 if (It != RootToNode.end() && It->second->Real == RootI) { 1430 OrderedRoots.push_back(RootI); 1431 return true; 1432 } 1433 1434 auto RootNode = identifyRoot(RootI); 1435 if (!RootNode) 1436 return false; 1437 1438 LLVM_DEBUG({ 1439 Function *F = RootI->getFunction(); 1440 BasicBlock *B = RootI->getParent(); 1441 dbgs() << "Complex deinterleaving graph for " << F->getName() 1442 << "::" << B->getName() << ".\n"; 1443 dump(dbgs()); 1444 dbgs() << "\n"; 1445 }); 1446 RootToNode[RootI] = RootNode; 1447 OrderedRoots.push_back(RootI); 1448 return true; 1449 } 1450 1451 bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) { 1452 bool FoundPotentialReduction = false; 1453 1454 auto *Br = dyn_cast<BranchInst>(B->getTerminator()); 1455 if (!Br || Br->getNumSuccessors() != 2) 1456 return false; 1457 1458 // Identify simple one-block loop 1459 if (Br->getSuccessor(0) != B && Br->getSuccessor(1) != B) 1460 return false; 1461 1462 SmallVector<PHINode *> PHIs; 1463 for (auto &PHI : B->phis()) { 1464 if (PHI.getNumIncomingValues() != 2) 1465 continue; 1466 1467 if (!PHI.getType()->isVectorTy()) 1468 continue; 1469 1470 auto *ReductionOp = dyn_cast<Instruction>(PHI.getIncomingValueForBlock(B)); 1471 if (!ReductionOp) 1472 continue; 1473 1474 // Check if final instruction is reduced outside of current block 1475 Instruction *FinalReduction = nullptr; 1476 auto NumUsers = 0u; 1477 for (auto *U : ReductionOp->users()) { 1478 ++NumUsers; 1479 if (U == &PHI) 1480 continue; 1481 FinalReduction = dyn_cast<Instruction>(U); 1482 } 1483 1484 if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B || 1485 isa<PHINode>(FinalReduction)) 1486 continue; 1487 1488 ReductionInfo[ReductionOp] = {&PHI, FinalReduction}; 1489 BackEdge = B; 1490 auto BackEdgeIdx = PHI.getBasicBlockIndex(B); 1491 auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0; 1492 Incoming = PHI.getIncomingBlock(IncomingIdx); 1493 FoundPotentialReduction = true; 1494 1495 // If the initial value of PHINode is an Instruction, consider it a leaf 1496 // value of a complex deinterleaving graph. 1497 if (auto *InitPHI = 1498 dyn_cast<Instruction>(PHI.getIncomingValueForBlock(Incoming))) 1499 FinalInstructions.insert(InitPHI); 1500 } 1501 return FoundPotentialReduction; 1502 } 1503 1504 void ComplexDeinterleavingGraph::identifyReductionNodes() { 1505 SmallVector<bool> Processed(ReductionInfo.size(), false); 1506 SmallVector<Instruction *> OperationInstruction; 1507 for (auto &P : ReductionInfo) 1508 OperationInstruction.push_back(P.first); 1509 1510 // Identify a complex computation by evaluating two reduction operations that 1511 // potentially could be involved 1512 for (size_t i = 0; i < OperationInstruction.size(); ++i) { 1513 if (Processed[i]) 1514 continue; 1515 for (size_t j = i + 1; j < OperationInstruction.size(); ++j) { 1516 if (Processed[j]) 1517 continue; 1518 1519 auto *Real = OperationInstruction[i]; 1520 auto *Imag = OperationInstruction[j]; 1521 if (Real->getType() != Imag->getType()) 1522 continue; 1523 1524 RealPHI = ReductionInfo[Real].first; 1525 ImagPHI = ReductionInfo[Imag].first; 1526 PHIsFound = false; 1527 auto Node = identifyNode(Real, Imag); 1528 if (!Node) { 1529 std::swap(Real, Imag); 1530 std::swap(RealPHI, ImagPHI); 1531 Node = identifyNode(Real, Imag); 1532 } 1533 1534 // If a node is identified and reduction PHINode is used in the chain of 1535 // operations, mark its operation instructions as used to prevent 1536 // re-identification and attach the node to the real part 1537 if (Node && PHIsFound) { 1538 LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: " 1539 << *Real << " / " << *Imag << "\n"); 1540 Processed[i] = true; 1541 Processed[j] = true; 1542 auto RootNode = prepareCompositeNode( 1543 ComplexDeinterleavingOperation::ReductionOperation, Real, Imag); 1544 RootNode->addOperand(Node); 1545 RootToNode[Real] = RootNode; 1546 RootToNode[Imag] = RootNode; 1547 submitCompositeNode(RootNode); 1548 break; 1549 } 1550 } 1551 } 1552 1553 RealPHI = nullptr; 1554 ImagPHI = nullptr; 1555 } 1556 1557 bool ComplexDeinterleavingGraph::checkNodes() { 1558 // Collect all instructions from roots to leaves 1559 SmallPtrSet<Instruction *, 16> AllInstructions; 1560 SmallVector<Instruction *, 8> Worklist; 1561 for (auto &Pair : RootToNode) 1562 Worklist.push_back(Pair.first); 1563 1564 // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG 1565 // chains 1566 while (!Worklist.empty()) { 1567 auto *I = Worklist.back(); 1568 Worklist.pop_back(); 1569 1570 if (!AllInstructions.insert(I).second) 1571 continue; 1572 1573 for (Value *Op : I->operands()) { 1574 if (auto *OpI = dyn_cast<Instruction>(Op)) { 1575 if (!FinalInstructions.count(I)) 1576 Worklist.emplace_back(OpI); 1577 } 1578 } 1579 } 1580 1581 // Find instructions that have users outside of chain 1582 SmallVector<Instruction *, 2> OuterInstructions; 1583 for (auto *I : AllInstructions) { 1584 // Skip root nodes 1585 if (RootToNode.count(I)) 1586 continue; 1587 1588 for (User *U : I->users()) { 1589 if (AllInstructions.count(cast<Instruction>(U))) 1590 continue; 1591 1592 // Found an instruction that is not used by XCMLA/XCADD chain 1593 Worklist.emplace_back(I); 1594 break; 1595 } 1596 } 1597 1598 // If any instructions are found to be used outside, find and remove roots 1599 // that somehow connect to those instructions. 1600 SmallPtrSet<Instruction *, 16> Visited; 1601 while (!Worklist.empty()) { 1602 auto *I = Worklist.back(); 1603 Worklist.pop_back(); 1604 if (!Visited.insert(I).second) 1605 continue; 1606 1607 // Found an impacted root node. Removing it from the nodes to be 1608 // deinterleaved 1609 if (RootToNode.count(I)) { 1610 LLVM_DEBUG(dbgs() << "Instruction " << *I 1611 << " could be deinterleaved but its chain of complex " 1612 "operations have an outside user\n"); 1613 RootToNode.erase(I); 1614 } 1615 1616 if (!AllInstructions.count(I) || FinalInstructions.count(I)) 1617 continue; 1618 1619 for (User *U : I->users()) 1620 Worklist.emplace_back(cast<Instruction>(U)); 1621 1622 for (Value *Op : I->operands()) { 1623 if (auto *OpI = dyn_cast<Instruction>(Op)) 1624 Worklist.emplace_back(OpI); 1625 } 1626 } 1627 return !RootToNode.empty(); 1628 } 1629 1630 ComplexDeinterleavingGraph::NodePtr 1631 ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) { 1632 if (auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) { 1633 if (Intrinsic->getIntrinsicID() != 1634 Intrinsic::experimental_vector_interleave2) 1635 return nullptr; 1636 1637 auto *Real = dyn_cast<Instruction>(Intrinsic->getOperand(0)); 1638 auto *Imag = dyn_cast<Instruction>(Intrinsic->getOperand(1)); 1639 if (!Real || !Imag) 1640 return nullptr; 1641 1642 return identifyNode(Real, Imag); 1643 } 1644 1645 auto *SVI = dyn_cast<ShuffleVectorInst>(RootI); 1646 if (!SVI) 1647 return nullptr; 1648 1649 // Look for a shufflevector that takes separate vectors of the real and 1650 // imaginary components and recombines them into a single vector. 1651 if (!isInterleavingMask(SVI->getShuffleMask())) 1652 return nullptr; 1653 1654 Instruction *Real; 1655 Instruction *Imag; 1656 if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag)))) 1657 return nullptr; 1658 1659 return identifyNode(Real, Imag); 1660 } 1661 1662 ComplexDeinterleavingGraph::NodePtr 1663 ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real, 1664 Instruction *Imag) { 1665 Instruction *I = nullptr; 1666 Value *FinalValue = nullptr; 1667 if (match(Real, m_ExtractValue<0>(m_Instruction(I))) && 1668 match(Imag, m_ExtractValue<1>(m_Specific(I))) && 1669 match(I, m_Intrinsic<Intrinsic::experimental_vector_deinterleave2>( 1670 m_Value(FinalValue)))) { 1671 NodePtr PlaceholderNode = prepareCompositeNode( 1672 llvm::ComplexDeinterleavingOperation::Deinterleave, Real, Imag); 1673 PlaceholderNode->ReplacementNode = FinalValue; 1674 FinalInstructions.insert(Real); 1675 FinalInstructions.insert(Imag); 1676 return submitCompositeNode(PlaceholderNode); 1677 } 1678 1679 auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real); 1680 auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag); 1681 if (!RealShuffle || !ImagShuffle) { 1682 if (RealShuffle || ImagShuffle) 1683 LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n"); 1684 return nullptr; 1685 } 1686 1687 Value *RealOp1 = RealShuffle->getOperand(1); 1688 if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) { 1689 LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n"); 1690 return nullptr; 1691 } 1692 Value *ImagOp1 = ImagShuffle->getOperand(1); 1693 if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) { 1694 LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n"); 1695 return nullptr; 1696 } 1697 1698 Value *RealOp0 = RealShuffle->getOperand(0); 1699 Value *ImagOp0 = ImagShuffle->getOperand(0); 1700 1701 if (RealOp0 != ImagOp0) { 1702 LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n"); 1703 return nullptr; 1704 } 1705 1706 ArrayRef<int> RealMask = RealShuffle->getShuffleMask(); 1707 ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask(); 1708 if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) { 1709 LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n"); 1710 return nullptr; 1711 } 1712 1713 if (RealMask[0] != 0 || ImagMask[0] != 1) { 1714 LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n"); 1715 return nullptr; 1716 } 1717 1718 // Type checking, the shuffle type should be a vector type of the same 1719 // scalar type, but half the size 1720 auto CheckType = [&](ShuffleVectorInst *Shuffle) { 1721 Value *Op = Shuffle->getOperand(0); 1722 auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType()); 1723 auto *OpTy = cast<FixedVectorType>(Op->getType()); 1724 1725 if (OpTy->getScalarType() != ShuffleTy->getScalarType()) 1726 return false; 1727 if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements()) 1728 return false; 1729 1730 return true; 1731 }; 1732 1733 auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool { 1734 if (!CheckType(Shuffle)) 1735 return false; 1736 1737 ArrayRef<int> Mask = Shuffle->getShuffleMask(); 1738 int Last = *Mask.rbegin(); 1739 1740 Value *Op = Shuffle->getOperand(0); 1741 auto *OpTy = cast<FixedVectorType>(Op->getType()); 1742 int NumElements = OpTy->getNumElements(); 1743 1744 // Ensure that the deinterleaving shuffle only pulls from the first 1745 // shuffle operand. 1746 return Last < NumElements; 1747 }; 1748 1749 if (RealShuffle->getType() != ImagShuffle->getType()) { 1750 LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n"); 1751 return nullptr; 1752 } 1753 if (!CheckDeinterleavingShuffle(RealShuffle)) { 1754 LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n"); 1755 return nullptr; 1756 } 1757 if (!CheckDeinterleavingShuffle(ImagShuffle)) { 1758 LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n"); 1759 return nullptr; 1760 } 1761 1762 NodePtr PlaceholderNode = 1763 prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Deinterleave, 1764 RealShuffle, ImagShuffle); 1765 PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0); 1766 FinalInstructions.insert(RealShuffle); 1767 FinalInstructions.insert(ImagShuffle); 1768 return submitCompositeNode(PlaceholderNode); 1769 } 1770 1771 ComplexDeinterleavingGraph::NodePtr 1772 ComplexDeinterleavingGraph::identifySplat(Value *R, Value *I) { 1773 auto IsSplat = [](Value *V) -> bool { 1774 // Fixed-width vector with constants 1775 if (isa<ConstantDataVector>(V)) 1776 return true; 1777 1778 VectorType *VTy; 1779 ArrayRef<int> Mask; 1780 // Splats are represented differently depending on whether the repeated 1781 // value is a constant or an Instruction 1782 if (auto *Const = dyn_cast<ConstantExpr>(V)) { 1783 if (Const->getOpcode() != Instruction::ShuffleVector) 1784 return false; 1785 VTy = cast<VectorType>(Const->getType()); 1786 Mask = Const->getShuffleMask(); 1787 } else if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V)) { 1788 VTy = Shuf->getType(); 1789 Mask = Shuf->getShuffleMask(); 1790 } else { 1791 return false; 1792 } 1793 1794 // When the data type is <1 x Type>, it's not possible to differentiate 1795 // between the ComplexDeinterleaving::Deinterleave and 1796 // ComplexDeinterleaving::Splat operations. 1797 if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1) 1798 return false; 1799 1800 return all_equal(Mask) && Mask[0] == 0; 1801 }; 1802 1803 if (!IsSplat(R) || !IsSplat(I)) 1804 return nullptr; 1805 1806 auto *Real = dyn_cast<Instruction>(R); 1807 auto *Imag = dyn_cast<Instruction>(I); 1808 if ((!Real && Imag) || (Real && !Imag)) 1809 return nullptr; 1810 1811 if (Real && Imag) { 1812 // Non-constant splats should be in the same basic block 1813 if (Real->getParent() != Imag->getParent()) 1814 return nullptr; 1815 1816 FinalInstructions.insert(Real); 1817 FinalInstructions.insert(Imag); 1818 } 1819 NodePtr PlaceholderNode = 1820 prepareCompositeNode(ComplexDeinterleavingOperation::Splat, R, I); 1821 return submitCompositeNode(PlaceholderNode); 1822 } 1823 1824 ComplexDeinterleavingGraph::NodePtr 1825 ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real, 1826 Instruction *Imag) { 1827 if (Real != RealPHI || Imag != ImagPHI) 1828 return nullptr; 1829 1830 PHIsFound = true; 1831 NodePtr PlaceholderNode = prepareCompositeNode( 1832 ComplexDeinterleavingOperation::ReductionPHI, Real, Imag); 1833 return submitCompositeNode(PlaceholderNode); 1834 } 1835 1836 ComplexDeinterleavingGraph::NodePtr 1837 ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real, 1838 Instruction *Imag) { 1839 auto *SelectReal = dyn_cast<SelectInst>(Real); 1840 auto *SelectImag = dyn_cast<SelectInst>(Imag); 1841 if (!SelectReal || !SelectImag) 1842 return nullptr; 1843 1844 Instruction *MaskA, *MaskB; 1845 Instruction *AR, *AI, *RA, *BI; 1846 if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR), 1847 m_Instruction(RA))) || 1848 !match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI), 1849 m_Instruction(BI)))) 1850 return nullptr; 1851 1852 if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB)) 1853 return nullptr; 1854 1855 if (!MaskA->getType()->isVectorTy()) 1856 return nullptr; 1857 1858 auto NodeA = identifyNode(AR, AI); 1859 if (!NodeA) 1860 return nullptr; 1861 1862 auto NodeB = identifyNode(RA, BI); 1863 if (!NodeB) 1864 return nullptr; 1865 1866 NodePtr PlaceholderNode = prepareCompositeNode( 1867 ComplexDeinterleavingOperation::ReductionSelect, Real, Imag); 1868 PlaceholderNode->addOperand(NodeA); 1869 PlaceholderNode->addOperand(NodeB); 1870 FinalInstructions.insert(MaskA); 1871 FinalInstructions.insert(MaskB); 1872 return submitCompositeNode(PlaceholderNode); 1873 } 1874 1875 static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode, 1876 std::optional<FastMathFlags> Flags, 1877 Value *InputA, Value *InputB) { 1878 Value *I; 1879 switch (Opcode) { 1880 case Instruction::FNeg: 1881 I = B.CreateFNeg(InputA); 1882 break; 1883 case Instruction::FAdd: 1884 I = B.CreateFAdd(InputA, InputB); 1885 break; 1886 case Instruction::Add: 1887 I = B.CreateAdd(InputA, InputB); 1888 break; 1889 case Instruction::FSub: 1890 I = B.CreateFSub(InputA, InputB); 1891 break; 1892 case Instruction::Sub: 1893 I = B.CreateSub(InputA, InputB); 1894 break; 1895 case Instruction::FMul: 1896 I = B.CreateFMul(InputA, InputB); 1897 break; 1898 case Instruction::Mul: 1899 I = B.CreateMul(InputA, InputB); 1900 break; 1901 default: 1902 llvm_unreachable("Incorrect symmetric opcode"); 1903 } 1904 if (Flags) 1905 cast<Instruction>(I)->setFastMathFlags(*Flags); 1906 return I; 1907 } 1908 1909 Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder, 1910 RawNodePtr Node) { 1911 if (Node->ReplacementNode) 1912 return Node->ReplacementNode; 1913 1914 auto ReplaceOperandIfExist = [&](RawNodePtr &Node, unsigned Idx) -> Value * { 1915 return Node->Operands.size() > Idx 1916 ? replaceNode(Builder, Node->Operands[Idx]) 1917 : nullptr; 1918 }; 1919 1920 Value *ReplacementNode; 1921 switch (Node->Operation) { 1922 case ComplexDeinterleavingOperation::CAdd: 1923 case ComplexDeinterleavingOperation::CMulPartial: 1924 case ComplexDeinterleavingOperation::Symmetric: { 1925 Value *Input0 = ReplaceOperandIfExist(Node, 0); 1926 Value *Input1 = ReplaceOperandIfExist(Node, 1); 1927 Value *Accumulator = ReplaceOperandIfExist(Node, 2); 1928 assert(!Input1 || (Input0->getType() == Input1->getType() && 1929 "Node inputs need to be of the same type")); 1930 assert(!Accumulator || 1931 (Input0->getType() == Accumulator->getType() && 1932 "Accumulator and input need to be of the same type")); 1933 if (Node->Operation == ComplexDeinterleavingOperation::Symmetric) 1934 ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, Node->Flags, 1935 Input0, Input1); 1936 else 1937 ReplacementNode = TL->createComplexDeinterleavingIR( 1938 Builder, Node->Operation, Node->Rotation, Input0, Input1, 1939 Accumulator); 1940 break; 1941 } 1942 case ComplexDeinterleavingOperation::Deinterleave: 1943 llvm_unreachable("Deinterleave node should already have ReplacementNode"); 1944 break; 1945 case ComplexDeinterleavingOperation::Splat: { 1946 auto *NewTy = VectorType::getDoubleElementsVectorType( 1947 cast<VectorType>(Node->Real->getType())); 1948 auto *R = dyn_cast<Instruction>(Node->Real); 1949 auto *I = dyn_cast<Instruction>(Node->Imag); 1950 if (R && I) { 1951 // Splats that are not constant are interleaved where they are located 1952 Instruction *InsertPoint = (I->comesBefore(R) ? R : I)->getNextNode(); 1953 IRBuilder<> IRB(InsertPoint); 1954 ReplacementNode = 1955 IRB.CreateIntrinsic(Intrinsic::experimental_vector_interleave2, NewTy, 1956 {Node->Real, Node->Imag}); 1957 } else { 1958 ReplacementNode = 1959 Builder.CreateIntrinsic(Intrinsic::experimental_vector_interleave2, 1960 NewTy, {Node->Real, Node->Imag}); 1961 } 1962 break; 1963 } 1964 case ComplexDeinterleavingOperation::ReductionPHI: { 1965 // If Operation is ReductionPHI, a new empty PHINode is created. 1966 // It is filled later when the ReductionOperation is processed. 1967 auto *VTy = cast<VectorType>(Node->Real->getType()); 1968 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); 1969 auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHI()); 1970 OldToNewPHI[dyn_cast<PHINode>(Node->Real)] = NewPHI; 1971 ReplacementNode = NewPHI; 1972 break; 1973 } 1974 case ComplexDeinterleavingOperation::ReductionOperation: 1975 ReplacementNode = replaceNode(Builder, Node->Operands[0]); 1976 processReductionOperation(ReplacementNode, Node); 1977 break; 1978 case ComplexDeinterleavingOperation::ReductionSelect: { 1979 auto *MaskReal = cast<Instruction>(Node->Real)->getOperand(0); 1980 auto *MaskImag = cast<Instruction>(Node->Imag)->getOperand(0); 1981 auto *A = replaceNode(Builder, Node->Operands[0]); 1982 auto *B = replaceNode(Builder, Node->Operands[1]); 1983 auto *NewMaskTy = VectorType::getDoubleElementsVectorType( 1984 cast<VectorType>(MaskReal->getType())); 1985 auto *NewMask = 1986 Builder.CreateIntrinsic(Intrinsic::experimental_vector_interleave2, 1987 NewMaskTy, {MaskReal, MaskImag}); 1988 ReplacementNode = Builder.CreateSelect(NewMask, A, B); 1989 break; 1990 } 1991 } 1992 1993 assert(ReplacementNode && "Target failed to create Intrinsic call."); 1994 NumComplexTransformations += 1; 1995 Node->ReplacementNode = ReplacementNode; 1996 return ReplacementNode; 1997 } 1998 1999 void ComplexDeinterleavingGraph::processReductionOperation( 2000 Value *OperationReplacement, RawNodePtr Node) { 2001 auto *Real = cast<Instruction>(Node->Real); 2002 auto *Imag = cast<Instruction>(Node->Imag); 2003 auto *OldPHIReal = ReductionInfo[Real].first; 2004 auto *OldPHIImag = ReductionInfo[Imag].first; 2005 auto *NewPHI = OldToNewPHI[OldPHIReal]; 2006 2007 auto *VTy = cast<VectorType>(Real->getType()); 2008 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); 2009 2010 // We have to interleave initial origin values coming from IncomingBlock 2011 Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming); 2012 Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming); 2013 2014 IRBuilder<> Builder(Incoming->getTerminator()); 2015 auto *NewInit = Builder.CreateIntrinsic( 2016 Intrinsic::experimental_vector_interleave2, NewVTy, {InitReal, InitImag}); 2017 2018 NewPHI->addIncoming(NewInit, Incoming); 2019 NewPHI->addIncoming(OperationReplacement, BackEdge); 2020 2021 // Deinterleave complex vector outside of loop so that it can be finally 2022 // reduced 2023 auto *FinalReductionReal = ReductionInfo[Real].second; 2024 auto *FinalReductionImag = ReductionInfo[Imag].second; 2025 2026 Builder.SetInsertPoint( 2027 &*FinalReductionReal->getParent()->getFirstInsertionPt()); 2028 auto *Deinterleave = Builder.CreateIntrinsic( 2029 Intrinsic::experimental_vector_deinterleave2, 2030 OperationReplacement->getType(), OperationReplacement); 2031 2032 auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0); 2033 FinalReductionReal->replaceUsesOfWith(Real, NewReal); 2034 2035 Builder.SetInsertPoint(FinalReductionImag); 2036 auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1); 2037 FinalReductionImag->replaceUsesOfWith(Imag, NewImag); 2038 } 2039 2040 void ComplexDeinterleavingGraph::replaceNodes() { 2041 SmallVector<Instruction *, 16> DeadInstrRoots; 2042 for (auto *RootInstruction : OrderedRoots) { 2043 // Check if this potential root went through check process and we can 2044 // deinterleave it 2045 if (!RootToNode.count(RootInstruction)) 2046 continue; 2047 2048 IRBuilder<> Builder(RootInstruction); 2049 auto RootNode = RootToNode[RootInstruction]; 2050 Value *R = replaceNode(Builder, RootNode.get()); 2051 2052 if (RootNode->Operation == 2053 ComplexDeinterleavingOperation::ReductionOperation) { 2054 auto *RootReal = cast<Instruction>(RootNode->Real); 2055 auto *RootImag = cast<Instruction>(RootNode->Imag); 2056 ReductionInfo[RootReal].first->removeIncomingValue(BackEdge); 2057 ReductionInfo[RootImag].first->removeIncomingValue(BackEdge); 2058 DeadInstrRoots.push_back(cast<Instruction>(RootReal)); 2059 DeadInstrRoots.push_back(cast<Instruction>(RootImag)); 2060 } else { 2061 assert(R && "Unable to find replacement for RootInstruction"); 2062 DeadInstrRoots.push_back(RootInstruction); 2063 RootInstruction->replaceAllUsesWith(R); 2064 } 2065 } 2066 2067 for (auto *I : DeadInstrRoots) 2068 RecursivelyDeleteTriviallyDeadInstructions(I, TLI); 2069 } 2070