1 //===- ComplexDeinterleavingPass.cpp --------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Identification: 10 // This step is responsible for finding the patterns that can be lowered to 11 // complex instructions, and building a graph to represent the complex 12 // structures. Starting from the "Converging Shuffle" (a shuffle that 13 // reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the 14 // operands are evaluated and identified as "Composite Nodes" (collections of 15 // instructions that can potentially be lowered to a single complex 16 // instruction). This is performed by checking the real and imaginary components 17 // and tracking the data flow for each component while following the operand 18 // pairs. Validity of each node is expected to be done upon creation, and any 19 // validation errors should halt traversal and prevent further graph 20 // construction. 21 // Instead of relying on Shuffle operations, vector interleaving and 22 // deinterleaving can be represented by vector.interleave2 and 23 // vector.deinterleave2 intrinsics. Scalable vectors can be represented only by 24 // these intrinsics, whereas, fixed-width vectors are recognized for both 25 // shufflevector instruction and intrinsics. 26 // 27 // Replacement: 28 // This step traverses the graph built up by identification, delegating to the 29 // target to validate and generate the correct intrinsics, and plumbs them 30 // together connecting each end of the new intrinsics graph to the existing 31 // use-def chain. This step is assumed to finish successfully, as all 32 // information is expected to be correct by this point. 33 // 34 // 35 // Internal data structure: 36 // ComplexDeinterleavingGraph: 37 // Keeps references to all the valid CompositeNodes formed as part of the 38 // transformation, and every Instruction contained within said nodes. It also 39 // holds onto a reference to the root Instruction, and the root node that should 40 // replace it. 41 // 42 // ComplexDeinterleavingCompositeNode: 43 // A CompositeNode represents a single transformation point; each node should 44 // transform into a single complex instruction (ignoring vector splitting, which 45 // would generate more instructions per node). They are identified in a 46 // depth-first manner, traversing and identifying the operands of each 47 // instruction in the order they appear in the IR. 48 // Each node maintains a reference to its Real and Imaginary instructions, 49 // as well as any additional instructions that make up the identified operation 50 // (Internal instructions should only have uses within their containing node). 51 // A Node also contains the rotation and operation type that it represents. 52 // Operands contains pointers to other CompositeNodes, acting as the edges in 53 // the graph. ReplacementValue is the transformed Value* that has been emitted 54 // to the IR. 55 // 56 // Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and 57 // ReplacementValue fields of that Node are relevant, where the ReplacementValue 58 // should be pre-populated. 59 // 60 //===----------------------------------------------------------------------===// 61 62 #include "llvm/CodeGen/ComplexDeinterleavingPass.h" 63 #include "llvm/ADT/MapVector.h" 64 #include "llvm/ADT/Statistic.h" 65 #include "llvm/Analysis/TargetLibraryInfo.h" 66 #include "llvm/Analysis/TargetTransformInfo.h" 67 #include "llvm/CodeGen/TargetLowering.h" 68 #include "llvm/CodeGen/TargetPassConfig.h" 69 #include "llvm/CodeGen/TargetSubtargetInfo.h" 70 #include "llvm/IR/IRBuilder.h" 71 #include "llvm/IR/PatternMatch.h" 72 #include "llvm/InitializePasses.h" 73 #include "llvm/Target/TargetMachine.h" 74 #include "llvm/Transforms/Utils/Local.h" 75 #include <algorithm> 76 77 using namespace llvm; 78 using namespace PatternMatch; 79 80 #define DEBUG_TYPE "complex-deinterleaving" 81 82 STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed"); 83 84 static cl::opt<bool> ComplexDeinterleavingEnabled( 85 "enable-complex-deinterleaving", 86 cl::desc("Enable generation of complex instructions"), cl::init(true), 87 cl::Hidden); 88 89 /// Checks the given mask, and determines whether said mask is interleaving. 90 /// 91 /// To be interleaving, a mask must alternate between `i` and `i + (Length / 92 /// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a 93 /// 4x vector interleaving mask would be <0, 2, 1, 3>). 94 static bool isInterleavingMask(ArrayRef<int> Mask); 95 96 /// Checks the given mask, and determines whether said mask is deinterleaving. 97 /// 98 /// To be deinterleaving, a mask must increment in steps of 2, and either start 99 /// with 0 or 1. 100 /// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or 101 /// <1, 3, 5, 7>). 102 static bool isDeinterleavingMask(ArrayRef<int> Mask); 103 104 /// Returns true if the operation is a negation of V, and it works for both 105 /// integers and floats. 106 static bool isNeg(Value *V); 107 108 /// Returns the operand for negation operation. 109 static Value *getNegOperand(Value *V); 110 111 namespace { 112 113 class ComplexDeinterleavingLegacyPass : public FunctionPass { 114 public: 115 static char ID; 116 117 ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr) 118 : FunctionPass(ID), TM(TM) { 119 initializeComplexDeinterleavingLegacyPassPass( 120 *PassRegistry::getPassRegistry()); 121 } 122 123 StringRef getPassName() const override { 124 return "Complex Deinterleaving Pass"; 125 } 126 127 bool runOnFunction(Function &F) override; 128 void getAnalysisUsage(AnalysisUsage &AU) const override { 129 AU.addRequired<TargetLibraryInfoWrapperPass>(); 130 AU.setPreservesCFG(); 131 } 132 133 private: 134 const TargetMachine *TM; 135 }; 136 137 class ComplexDeinterleavingGraph; 138 struct ComplexDeinterleavingCompositeNode { 139 140 ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op, 141 Value *R, Value *I) 142 : Operation(Op), Real(R), Imag(I) {} 143 144 private: 145 friend class ComplexDeinterleavingGraph; 146 using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>; 147 using RawNodePtr = ComplexDeinterleavingCompositeNode *; 148 149 public: 150 ComplexDeinterleavingOperation Operation; 151 Value *Real; 152 Value *Imag; 153 154 // This two members are required exclusively for generating 155 // ComplexDeinterleavingOperation::Symmetric operations. 156 unsigned Opcode; 157 std::optional<FastMathFlags> Flags; 158 159 ComplexDeinterleavingRotation Rotation = 160 ComplexDeinterleavingRotation::Rotation_0; 161 SmallVector<RawNodePtr> Operands; 162 Value *ReplacementNode = nullptr; 163 164 void addOperand(NodePtr Node) { Operands.push_back(Node.get()); } 165 166 void dump() { dump(dbgs()); } 167 void dump(raw_ostream &OS) { 168 auto PrintValue = [&](Value *V) { 169 if (V) { 170 OS << "\""; 171 V->print(OS, true); 172 OS << "\"\n"; 173 } else 174 OS << "nullptr\n"; 175 }; 176 auto PrintNodeRef = [&](RawNodePtr Ptr) { 177 if (Ptr) 178 OS << Ptr << "\n"; 179 else 180 OS << "nullptr\n"; 181 }; 182 183 OS << "- CompositeNode: " << this << "\n"; 184 OS << " Real: "; 185 PrintValue(Real); 186 OS << " Imag: "; 187 PrintValue(Imag); 188 OS << " ReplacementNode: "; 189 PrintValue(ReplacementNode); 190 OS << " Operation: " << (int)Operation << "\n"; 191 OS << " Rotation: " << ((int)Rotation * 90) << "\n"; 192 OS << " Operands: \n"; 193 for (const auto &Op : Operands) { 194 OS << " - "; 195 PrintNodeRef(Op); 196 } 197 } 198 }; 199 200 class ComplexDeinterleavingGraph { 201 public: 202 struct Product { 203 Value *Multiplier; 204 Value *Multiplicand; 205 bool IsPositive; 206 }; 207 208 using Addend = std::pair<Value *, bool>; 209 using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr; 210 using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr; 211 212 // Helper struct for holding info about potential partial multiplication 213 // candidates 214 struct PartialMulCandidate { 215 Value *Common; 216 NodePtr Node; 217 unsigned RealIdx; 218 unsigned ImagIdx; 219 bool IsNodeInverted; 220 }; 221 222 explicit ComplexDeinterleavingGraph(const TargetLowering *TL, 223 const TargetLibraryInfo *TLI) 224 : TL(TL), TLI(TLI) {} 225 226 private: 227 const TargetLowering *TL = nullptr; 228 const TargetLibraryInfo *TLI = nullptr; 229 SmallVector<NodePtr> CompositeNodes; 230 DenseMap<std::pair<Value *, Value *>, NodePtr> CachedResult; 231 232 SmallPtrSet<Instruction *, 16> FinalInstructions; 233 234 /// Root instructions are instructions from which complex computation starts 235 std::map<Instruction *, NodePtr> RootToNode; 236 237 /// Topologically sorted root instructions 238 SmallVector<Instruction *, 1> OrderedRoots; 239 240 /// When examining a basic block for complex deinterleaving, if it is a simple 241 /// one-block loop, then the only incoming block is 'Incoming' and the 242 /// 'BackEdge' block is the block itself." 243 BasicBlock *BackEdge = nullptr; 244 BasicBlock *Incoming = nullptr; 245 246 /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction 247 /// %OutsideUser as it is shown in the IR: 248 /// 249 /// vector.body: 250 /// %PHInode = phi <vector type> [ zeroinitializer, %entry ], 251 /// [ %ReductionOp, %vector.body ] 252 /// ... 253 /// %ReductionOp = fadd i64 ... 254 /// ... 255 /// br i1 %condition, label %vector.body, %middle.block 256 /// 257 /// middle.block: 258 /// %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp) 259 /// 260 /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding 261 /// `llvm.vector.reduce.fadd` when unroll factor isn't one. 262 MapVector<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo; 263 264 /// In the process of detecting a reduction, we consider a pair of 265 /// %ReductionOP, which we refer to as real and imag (or vice versa), and 266 /// traverse the use-tree to detect complex operations. As this is a reduction 267 /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds 268 /// to the %ReductionOPs that we suspect to be complex. 269 /// RealPHI and ImagPHI are used by the identifyPHINode method. 270 PHINode *RealPHI = nullptr; 271 PHINode *ImagPHI = nullptr; 272 273 /// Set this flag to true if RealPHI and ImagPHI were reached during reduction 274 /// detection. 275 bool PHIsFound = false; 276 277 /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode. 278 /// The new PHINode corresponds to a vector of deinterleaved complex numbers. 279 /// This mapping is populated during 280 /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then 281 /// used in the ComplexDeinterleavingOperation::ReductionOperation node 282 /// replacement process. 283 std::map<PHINode *, PHINode *> OldToNewPHI; 284 285 NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation, 286 Value *R, Value *I) { 287 assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI && 288 Operation != ComplexDeinterleavingOperation::ReductionOperation) || 289 (R && I)) && 290 "Reduction related nodes must have Real and Imaginary parts"); 291 return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R, 292 I); 293 } 294 295 NodePtr submitCompositeNode(NodePtr Node) { 296 CompositeNodes.push_back(Node); 297 if (Node->Real && Node->Imag) 298 CachedResult[{Node->Real, Node->Imag}] = Node; 299 return Node; 300 } 301 302 /// Identifies a complex partial multiply pattern and its rotation, based on 303 /// the following patterns 304 /// 305 /// 0: r: cr + ar * br 306 /// i: ci + ar * bi 307 /// 90: r: cr - ai * bi 308 /// i: ci + ai * br 309 /// 180: r: cr - ar * br 310 /// i: ci - ar * bi 311 /// 270: r: cr + ai * bi 312 /// i: ci - ai * br 313 NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag); 314 315 /// Identify the other branch of a Partial Mul, taking the CommonOperandI that 316 /// is partially known from identifyPartialMul, filling in the other half of 317 /// the complex pair. 318 NodePtr 319 identifyNodeWithImplicitAdd(Instruction *I, Instruction *J, 320 std::pair<Value *, Value *> &CommonOperandI); 321 322 /// Identifies a complex add pattern and its rotation, based on the following 323 /// patterns. 324 /// 325 /// 90: r: ar - bi 326 /// i: ai + br 327 /// 270: r: ar + bi 328 /// i: ai - br 329 NodePtr identifyAdd(Instruction *Real, Instruction *Imag); 330 NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag); 331 332 NodePtr identifyNode(Value *R, Value *I); 333 334 /// Determine if a sum of complex numbers can be formed from \p RealAddends 335 /// and \p ImagAddens. If \p Accumulator is not null, add the result to it. 336 /// Return nullptr if it is not possible to construct a complex number. 337 /// \p Flags are needed to generate symmetric Add and Sub operations. 338 NodePtr identifyAdditions(std::list<Addend> &RealAddends, 339 std::list<Addend> &ImagAddends, 340 std::optional<FastMathFlags> Flags, 341 NodePtr Accumulator); 342 343 /// Extract one addend that have both real and imaginary parts positive. 344 NodePtr extractPositiveAddend(std::list<Addend> &RealAddends, 345 std::list<Addend> &ImagAddends); 346 347 /// Determine if sum of multiplications of complex numbers can be formed from 348 /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result 349 /// to it. Return nullptr if it is not possible to construct a complex number. 350 NodePtr identifyMultiplications(std::vector<Product> &RealMuls, 351 std::vector<Product> &ImagMuls, 352 NodePtr Accumulator); 353 354 /// Go through pairs of multiplication (one Real and one Imag) and find all 355 /// possible candidates for partial multiplication and put them into \p 356 /// Candidates. Returns true if all Product has pair with common operand 357 bool collectPartialMuls(const std::vector<Product> &RealMuls, 358 const std::vector<Product> &ImagMuls, 359 std::vector<PartialMulCandidate> &Candidates); 360 361 /// If the code is compiled with -Ofast or expressions have `reassoc` flag, 362 /// the order of complex computation operations may be significantly altered, 363 /// and the real and imaginary parts may not be executed in parallel. This 364 /// function takes this into consideration and employs a more general approach 365 /// to identify complex computations. Initially, it gathers all the addends 366 /// and multiplicands and then constructs a complex expression from them. 367 NodePtr identifyReassocNodes(Instruction *I, Instruction *J); 368 369 NodePtr identifyRoot(Instruction *I); 370 371 /// Identifies the Deinterleave operation applied to a vector containing 372 /// complex numbers. There are two ways to represent the Deinterleave 373 /// operation: 374 /// * Using two shufflevectors with even indices for /pReal instruction and 375 /// odd indices for /pImag instructions (only for fixed-width vectors) 376 /// * Using two extractvalue instructions applied to `vector.deinterleave2` 377 /// intrinsic (for both fixed and scalable vectors) 378 NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag); 379 380 /// identifying the operation that represents a complex number repeated in a 381 /// Splat vector. There are two possible types of splats: ConstantExpr with 382 /// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an 383 /// initialization mask with all values set to zero. 384 NodePtr identifySplat(Value *Real, Value *Imag); 385 386 NodePtr identifyPHINode(Instruction *Real, Instruction *Imag); 387 388 /// Identifies SelectInsts in a loop that has reduction with predication masks 389 /// and/or predicated tail folding 390 NodePtr identifySelectNode(Instruction *Real, Instruction *Imag); 391 392 Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node); 393 394 /// Complete IR modifications after producing new reduction operation: 395 /// * Populate the PHINode generated for 396 /// ComplexDeinterleavingOperation::ReductionPHI 397 /// * Deinterleave the final value outside of the loop and repurpose original 398 /// reduction users 399 void processReductionOperation(Value *OperationReplacement, RawNodePtr Node); 400 401 public: 402 void dump() { dump(dbgs()); } 403 void dump(raw_ostream &OS) { 404 for (const auto &Node : CompositeNodes) 405 Node->dump(OS); 406 } 407 408 /// Returns false if the deinterleaving operation should be cancelled for the 409 /// current graph. 410 bool identifyNodes(Instruction *RootI); 411 412 /// In case \pB is one-block loop, this function seeks potential reductions 413 /// and populates ReductionInfo. Returns true if any reductions were 414 /// identified. 415 bool collectPotentialReductions(BasicBlock *B); 416 417 void identifyReductionNodes(); 418 419 /// Check that every instruction, from the roots to the leaves, has internal 420 /// uses. 421 bool checkNodes(); 422 423 /// Perform the actual replacement of the underlying instruction graph. 424 void replaceNodes(); 425 }; 426 427 class ComplexDeinterleaving { 428 public: 429 ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli) 430 : TL(tl), TLI(tli) {} 431 bool runOnFunction(Function &F); 432 433 private: 434 bool evaluateBasicBlock(BasicBlock *B); 435 436 const TargetLowering *TL = nullptr; 437 const TargetLibraryInfo *TLI = nullptr; 438 }; 439 440 } // namespace 441 442 char ComplexDeinterleavingLegacyPass::ID = 0; 443 444 INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE, 445 "Complex Deinterleaving", false, false) 446 INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE, 447 "Complex Deinterleaving", false, false) 448 449 PreservedAnalyses ComplexDeinterleavingPass::run(Function &F, 450 FunctionAnalysisManager &AM) { 451 const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering(); 452 auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F); 453 if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F)) 454 return PreservedAnalyses::all(); 455 456 PreservedAnalyses PA; 457 PA.preserve<FunctionAnalysisManagerModuleProxy>(); 458 return PA; 459 } 460 461 FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) { 462 return new ComplexDeinterleavingLegacyPass(TM); 463 } 464 465 bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) { 466 const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering(); 467 auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); 468 return ComplexDeinterleaving(TL, &TLI).runOnFunction(F); 469 } 470 471 bool ComplexDeinterleaving::runOnFunction(Function &F) { 472 if (!ComplexDeinterleavingEnabled) { 473 LLVM_DEBUG( 474 dbgs() << "Complex deinterleaving has been explicitly disabled.\n"); 475 return false; 476 } 477 478 if (!TL->isComplexDeinterleavingSupported()) { 479 LLVM_DEBUG( 480 dbgs() << "Complex deinterleaving has been disabled, target does " 481 "not support lowering of complex number operations.\n"); 482 return false; 483 } 484 485 bool Changed = false; 486 for (auto &B : F) 487 Changed |= evaluateBasicBlock(&B); 488 489 return Changed; 490 } 491 492 static bool isInterleavingMask(ArrayRef<int> Mask) { 493 // If the size is not even, it's not an interleaving mask 494 if ((Mask.size() & 1)) 495 return false; 496 497 int HalfNumElements = Mask.size() / 2; 498 for (int Idx = 0; Idx < HalfNumElements; ++Idx) { 499 int MaskIdx = Idx * 2; 500 if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements)) 501 return false; 502 } 503 504 return true; 505 } 506 507 static bool isDeinterleavingMask(ArrayRef<int> Mask) { 508 int Offset = Mask[0]; 509 int HalfNumElements = Mask.size() / 2; 510 511 for (int Idx = 1; Idx < HalfNumElements; ++Idx) { 512 if (Mask[Idx] != (Idx * 2) + Offset) 513 return false; 514 } 515 516 return true; 517 } 518 519 bool isNeg(Value *V) { 520 return match(V, m_FNeg(m_Value())) || match(V, m_Neg(m_Value())); 521 } 522 523 Value *getNegOperand(Value *V) { 524 assert(isNeg(V)); 525 auto *I = cast<Instruction>(V); 526 if (I->getOpcode() == Instruction::FNeg) 527 return I->getOperand(0); 528 529 return I->getOperand(1); 530 } 531 532 bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) { 533 ComplexDeinterleavingGraph Graph(TL, TLI); 534 if (Graph.collectPotentialReductions(B)) 535 Graph.identifyReductionNodes(); 536 537 for (auto &I : *B) 538 Graph.identifyNodes(&I); 539 540 if (Graph.checkNodes()) { 541 Graph.replaceNodes(); 542 return true; 543 } 544 545 return false; 546 } 547 548 ComplexDeinterleavingGraph::NodePtr 549 ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd( 550 Instruction *Real, Instruction *Imag, 551 std::pair<Value *, Value *> &PartialMatch) { 552 LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag 553 << "\n"); 554 555 if (!Real->hasOneUse() || !Imag->hasOneUse()) { 556 LLVM_DEBUG(dbgs() << " - Mul operand has multiple uses.\n"); 557 return nullptr; 558 } 559 560 if ((Real->getOpcode() != Instruction::FMul && 561 Real->getOpcode() != Instruction::Mul) || 562 (Imag->getOpcode() != Instruction::FMul && 563 Imag->getOpcode() != Instruction::Mul)) { 564 LLVM_DEBUG( 565 dbgs() << " - Real or imaginary instruction is not fmul or mul\n"); 566 return nullptr; 567 } 568 569 Value *R0 = Real->getOperand(0); 570 Value *R1 = Real->getOperand(1); 571 Value *I0 = Imag->getOperand(0); 572 Value *I1 = Imag->getOperand(1); 573 574 // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the 575 // rotations and use the operand. 576 unsigned Negs = 0; 577 Value *Op; 578 if (match(R0, m_Neg(m_Value(Op)))) { 579 Negs |= 1; 580 R0 = Op; 581 } else if (match(R1, m_Neg(m_Value(Op)))) { 582 Negs |= 1; 583 R1 = Op; 584 } 585 586 if (isNeg(I0)) { 587 Negs |= 2; 588 Negs ^= 1; 589 I0 = Op; 590 } else if (match(I1, m_Neg(m_Value(Op)))) { 591 Negs |= 2; 592 Negs ^= 1; 593 I1 = Op; 594 } 595 596 ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs; 597 598 Value *CommonOperand; 599 Value *UncommonRealOp; 600 Value *UncommonImagOp; 601 602 if (R0 == I0 || R0 == I1) { 603 CommonOperand = R0; 604 UncommonRealOp = R1; 605 } else if (R1 == I0 || R1 == I1) { 606 CommonOperand = R1; 607 UncommonRealOp = R0; 608 } else { 609 LLVM_DEBUG(dbgs() << " - No equal operand\n"); 610 return nullptr; 611 } 612 613 UncommonImagOp = (CommonOperand == I0) ? I1 : I0; 614 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 || 615 Rotation == ComplexDeinterleavingRotation::Rotation_270) 616 std::swap(UncommonRealOp, UncommonImagOp); 617 618 // Between identifyPartialMul and here we need to have found a complete valid 619 // pair from the CommonOperand of each part. 620 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 || 621 Rotation == ComplexDeinterleavingRotation::Rotation_180) 622 PartialMatch.first = CommonOperand; 623 else 624 PartialMatch.second = CommonOperand; 625 626 if (!PartialMatch.first || !PartialMatch.second) { 627 LLVM_DEBUG(dbgs() << " - Incomplete partial match\n"); 628 return nullptr; 629 } 630 631 NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second); 632 if (!CommonNode) { 633 LLVM_DEBUG(dbgs() << " - No CommonNode identified\n"); 634 return nullptr; 635 } 636 637 NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp); 638 if (!UncommonNode) { 639 LLVM_DEBUG(dbgs() << " - No UncommonNode identified\n"); 640 return nullptr; 641 } 642 643 NodePtr Node = prepareCompositeNode( 644 ComplexDeinterleavingOperation::CMulPartial, Real, Imag); 645 Node->Rotation = Rotation; 646 Node->addOperand(CommonNode); 647 Node->addOperand(UncommonNode); 648 return submitCompositeNode(Node); 649 } 650 651 ComplexDeinterleavingGraph::NodePtr 652 ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real, 653 Instruction *Imag) { 654 LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag 655 << "\n"); 656 // Determine rotation 657 auto IsAdd = [](unsigned Op) { 658 return Op == Instruction::FAdd || Op == Instruction::Add; 659 }; 660 auto IsSub = [](unsigned Op) { 661 return Op == Instruction::FSub || Op == Instruction::Sub; 662 }; 663 ComplexDeinterleavingRotation Rotation; 664 if (IsAdd(Real->getOpcode()) && IsAdd(Imag->getOpcode())) 665 Rotation = ComplexDeinterleavingRotation::Rotation_0; 666 else if (IsSub(Real->getOpcode()) && IsAdd(Imag->getOpcode())) 667 Rotation = ComplexDeinterleavingRotation::Rotation_90; 668 else if (IsSub(Real->getOpcode()) && IsSub(Imag->getOpcode())) 669 Rotation = ComplexDeinterleavingRotation::Rotation_180; 670 else if (IsAdd(Real->getOpcode()) && IsSub(Imag->getOpcode())) 671 Rotation = ComplexDeinterleavingRotation::Rotation_270; 672 else { 673 LLVM_DEBUG(dbgs() << " - Unhandled rotation.\n"); 674 return nullptr; 675 } 676 677 if (isa<FPMathOperator>(Real) && 678 (!Real->getFastMathFlags().allowContract() || 679 !Imag->getFastMathFlags().allowContract())) { 680 LLVM_DEBUG(dbgs() << " - Contract is missing from the FastMath flags.\n"); 681 return nullptr; 682 } 683 684 Value *CR = Real->getOperand(0); 685 Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1)); 686 if (!RealMulI) 687 return nullptr; 688 Value *CI = Imag->getOperand(0); 689 Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1)); 690 if (!ImagMulI) 691 return nullptr; 692 693 if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) { 694 LLVM_DEBUG(dbgs() << " - Mul instruction has multiple uses\n"); 695 return nullptr; 696 } 697 698 Value *R0 = RealMulI->getOperand(0); 699 Value *R1 = RealMulI->getOperand(1); 700 Value *I0 = ImagMulI->getOperand(0); 701 Value *I1 = ImagMulI->getOperand(1); 702 703 Value *CommonOperand; 704 Value *UncommonRealOp; 705 Value *UncommonImagOp; 706 707 if (R0 == I0 || R0 == I1) { 708 CommonOperand = R0; 709 UncommonRealOp = R1; 710 } else if (R1 == I0 || R1 == I1) { 711 CommonOperand = R1; 712 UncommonRealOp = R0; 713 } else { 714 LLVM_DEBUG(dbgs() << " - No equal operand\n"); 715 return nullptr; 716 } 717 718 UncommonImagOp = (CommonOperand == I0) ? I1 : I0; 719 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 || 720 Rotation == ComplexDeinterleavingRotation::Rotation_270) 721 std::swap(UncommonRealOp, UncommonImagOp); 722 723 std::pair<Value *, Value *> PartialMatch( 724 (Rotation == ComplexDeinterleavingRotation::Rotation_0 || 725 Rotation == ComplexDeinterleavingRotation::Rotation_180) 726 ? CommonOperand 727 : nullptr, 728 (Rotation == ComplexDeinterleavingRotation::Rotation_90 || 729 Rotation == ComplexDeinterleavingRotation::Rotation_270) 730 ? CommonOperand 731 : nullptr); 732 733 auto *CRInst = dyn_cast<Instruction>(CR); 734 auto *CIInst = dyn_cast<Instruction>(CI); 735 736 if (!CRInst || !CIInst) { 737 LLVM_DEBUG(dbgs() << " - Common operands are not instructions.\n"); 738 return nullptr; 739 } 740 741 NodePtr CNode = identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch); 742 if (!CNode) { 743 LLVM_DEBUG(dbgs() << " - No cnode identified\n"); 744 return nullptr; 745 } 746 747 NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp); 748 if (!UncommonRes) { 749 LLVM_DEBUG(dbgs() << " - No UncommonRes identified\n"); 750 return nullptr; 751 } 752 753 assert(PartialMatch.first && PartialMatch.second); 754 NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second); 755 if (!CommonRes) { 756 LLVM_DEBUG(dbgs() << " - No CommonRes identified\n"); 757 return nullptr; 758 } 759 760 NodePtr Node = prepareCompositeNode( 761 ComplexDeinterleavingOperation::CMulPartial, Real, Imag); 762 Node->Rotation = Rotation; 763 Node->addOperand(CommonRes); 764 Node->addOperand(UncommonRes); 765 Node->addOperand(CNode); 766 return submitCompositeNode(Node); 767 } 768 769 ComplexDeinterleavingGraph::NodePtr 770 ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) { 771 LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n"); 772 773 // Determine rotation 774 ComplexDeinterleavingRotation Rotation; 775 if ((Real->getOpcode() == Instruction::FSub && 776 Imag->getOpcode() == Instruction::FAdd) || 777 (Real->getOpcode() == Instruction::Sub && 778 Imag->getOpcode() == Instruction::Add)) 779 Rotation = ComplexDeinterleavingRotation::Rotation_90; 780 else if ((Real->getOpcode() == Instruction::FAdd && 781 Imag->getOpcode() == Instruction::FSub) || 782 (Real->getOpcode() == Instruction::Add && 783 Imag->getOpcode() == Instruction::Sub)) 784 Rotation = ComplexDeinterleavingRotation::Rotation_270; 785 else { 786 LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n"); 787 return nullptr; 788 } 789 790 auto *AR = dyn_cast<Instruction>(Real->getOperand(0)); 791 auto *BI = dyn_cast<Instruction>(Real->getOperand(1)); 792 auto *AI = dyn_cast<Instruction>(Imag->getOperand(0)); 793 auto *BR = dyn_cast<Instruction>(Imag->getOperand(1)); 794 795 if (!AR || !AI || !BR || !BI) { 796 LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n"); 797 return nullptr; 798 } 799 800 NodePtr ResA = identifyNode(AR, AI); 801 if (!ResA) { 802 LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n"); 803 return nullptr; 804 } 805 NodePtr ResB = identifyNode(BR, BI); 806 if (!ResB) { 807 LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n"); 808 return nullptr; 809 } 810 811 NodePtr Node = 812 prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag); 813 Node->Rotation = Rotation; 814 Node->addOperand(ResA); 815 Node->addOperand(ResB); 816 return submitCompositeNode(Node); 817 } 818 819 static bool isInstructionPairAdd(Instruction *A, Instruction *B) { 820 unsigned OpcA = A->getOpcode(); 821 unsigned OpcB = B->getOpcode(); 822 823 return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) || 824 (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) || 825 (OpcA == Instruction::Sub && OpcB == Instruction::Add) || 826 (OpcA == Instruction::Add && OpcB == Instruction::Sub); 827 } 828 829 static bool isInstructionPairMul(Instruction *A, Instruction *B) { 830 auto Pattern = 831 m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value())); 832 833 return match(A, Pattern) && match(B, Pattern); 834 } 835 836 static bool isInstructionPotentiallySymmetric(Instruction *I) { 837 switch (I->getOpcode()) { 838 case Instruction::FAdd: 839 case Instruction::FSub: 840 case Instruction::FMul: 841 case Instruction::FNeg: 842 case Instruction::Add: 843 case Instruction::Sub: 844 case Instruction::Mul: 845 return true; 846 default: 847 return false; 848 } 849 } 850 851 ComplexDeinterleavingGraph::NodePtr 852 ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real, 853 Instruction *Imag) { 854 if (Real->getOpcode() != Imag->getOpcode()) 855 return nullptr; 856 857 if (!isInstructionPotentiallySymmetric(Real) || 858 !isInstructionPotentiallySymmetric(Imag)) 859 return nullptr; 860 861 auto *R0 = Real->getOperand(0); 862 auto *I0 = Imag->getOperand(0); 863 864 NodePtr Op0 = identifyNode(R0, I0); 865 NodePtr Op1 = nullptr; 866 if (Op0 == nullptr) 867 return nullptr; 868 869 if (Real->isBinaryOp()) { 870 auto *R1 = Real->getOperand(1); 871 auto *I1 = Imag->getOperand(1); 872 Op1 = identifyNode(R1, I1); 873 if (Op1 == nullptr) 874 return nullptr; 875 } 876 877 if (isa<FPMathOperator>(Real) && 878 Real->getFastMathFlags() != Imag->getFastMathFlags()) 879 return nullptr; 880 881 auto Node = prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric, 882 Real, Imag); 883 Node->Opcode = Real->getOpcode(); 884 if (isa<FPMathOperator>(Real)) 885 Node->Flags = Real->getFastMathFlags(); 886 887 Node->addOperand(Op0); 888 if (Real->isBinaryOp()) 889 Node->addOperand(Op1); 890 891 return submitCompositeNode(Node); 892 } 893 894 ComplexDeinterleavingGraph::NodePtr 895 ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) { 896 LLVM_DEBUG(dbgs() << "identifyNode on " << *R << " / " << *I << "\n"); 897 assert(R->getType() == I->getType() && 898 "Real and imaginary parts should not have different types"); 899 900 auto It = CachedResult.find({R, I}); 901 if (It != CachedResult.end()) { 902 LLVM_DEBUG(dbgs() << " - Folding to existing node\n"); 903 return It->second; 904 } 905 906 if (NodePtr CN = identifySplat(R, I)) 907 return CN; 908 909 auto *Real = dyn_cast<Instruction>(R); 910 auto *Imag = dyn_cast<Instruction>(I); 911 if (!Real || !Imag) 912 return nullptr; 913 914 if (NodePtr CN = identifyDeinterleave(Real, Imag)) 915 return CN; 916 917 if (NodePtr CN = identifyPHINode(Real, Imag)) 918 return CN; 919 920 if (NodePtr CN = identifySelectNode(Real, Imag)) 921 return CN; 922 923 auto *VTy = cast<VectorType>(Real->getType()); 924 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); 925 926 bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported( 927 ComplexDeinterleavingOperation::CMulPartial, NewVTy); 928 bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported( 929 ComplexDeinterleavingOperation::CAdd, NewVTy); 930 931 if (HasCMulSupport && isInstructionPairMul(Real, Imag)) { 932 if (NodePtr CN = identifyPartialMul(Real, Imag)) 933 return CN; 934 } 935 936 if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) { 937 if (NodePtr CN = identifyAdd(Real, Imag)) 938 return CN; 939 } 940 941 if (HasCMulSupport && HasCAddSupport) { 942 if (NodePtr CN = identifyReassocNodes(Real, Imag)) 943 return CN; 944 } 945 946 if (NodePtr CN = identifySymmetricOperation(Real, Imag)) 947 return CN; 948 949 LLVM_DEBUG(dbgs() << " - Not recognised as a valid pattern.\n"); 950 CachedResult[{R, I}] = nullptr; 951 return nullptr; 952 } 953 954 ComplexDeinterleavingGraph::NodePtr 955 ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real, 956 Instruction *Imag) { 957 auto IsOperationSupported = [](unsigned Opcode) -> bool { 958 return Opcode == Instruction::FAdd || Opcode == Instruction::FSub || 959 Opcode == Instruction::FNeg || Opcode == Instruction::Add || 960 Opcode == Instruction::Sub; 961 }; 962 963 if (!IsOperationSupported(Real->getOpcode()) || 964 !IsOperationSupported(Imag->getOpcode())) 965 return nullptr; 966 967 std::optional<FastMathFlags> Flags; 968 if (isa<FPMathOperator>(Real)) { 969 if (Real->getFastMathFlags() != Imag->getFastMathFlags()) { 970 LLVM_DEBUG(dbgs() << "The flags in Real and Imaginary instructions are " 971 "not identical\n"); 972 return nullptr; 973 } 974 975 Flags = Real->getFastMathFlags(); 976 if (!Flags->allowReassoc()) { 977 LLVM_DEBUG( 978 dbgs() 979 << "the 'Reassoc' attribute is missing in the FastMath flags\n"); 980 return nullptr; 981 } 982 } 983 984 // Collect multiplications and addend instructions from the given instruction 985 // while traversing it operands. Additionally, verify that all instructions 986 // have the same fast math flags. 987 auto Collect = [&Flags](Instruction *Insn, std::vector<Product> &Muls, 988 std::list<Addend> &Addends) -> bool { 989 SmallVector<PointerIntPair<Value *, 1, bool>> Worklist = {{Insn, true}}; 990 SmallPtrSet<Value *, 8> Visited; 991 while (!Worklist.empty()) { 992 auto [V, IsPositive] = Worklist.back(); 993 Worklist.pop_back(); 994 if (!Visited.insert(V).second) 995 continue; 996 997 Instruction *I = dyn_cast<Instruction>(V); 998 if (!I) { 999 Addends.emplace_back(V, IsPositive); 1000 continue; 1001 } 1002 1003 // If an instruction has more than one user, it indicates that it either 1004 // has an external user, which will be later checked by the checkNodes 1005 // function, or it is a subexpression utilized by multiple expressions. In 1006 // the latter case, we will attempt to separately identify the complex 1007 // operation from here in order to create a shared 1008 // ComplexDeinterleavingCompositeNode. 1009 if (I != Insn && I->getNumUses() > 1) { 1010 LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n"); 1011 Addends.emplace_back(I, IsPositive); 1012 continue; 1013 } 1014 switch (I->getOpcode()) { 1015 case Instruction::FAdd: 1016 case Instruction::Add: 1017 Worklist.emplace_back(I->getOperand(1), IsPositive); 1018 Worklist.emplace_back(I->getOperand(0), IsPositive); 1019 break; 1020 case Instruction::FSub: 1021 Worklist.emplace_back(I->getOperand(1), !IsPositive); 1022 Worklist.emplace_back(I->getOperand(0), IsPositive); 1023 break; 1024 case Instruction::Sub: 1025 if (isNeg(I)) { 1026 Worklist.emplace_back(getNegOperand(I), !IsPositive); 1027 } else { 1028 Worklist.emplace_back(I->getOperand(1), !IsPositive); 1029 Worklist.emplace_back(I->getOperand(0), IsPositive); 1030 } 1031 break; 1032 case Instruction::FMul: 1033 case Instruction::Mul: { 1034 Value *A, *B; 1035 if (isNeg(I->getOperand(0))) { 1036 A = getNegOperand(I->getOperand(0)); 1037 IsPositive = !IsPositive; 1038 } else { 1039 A = I->getOperand(0); 1040 } 1041 1042 if (isNeg(I->getOperand(1))) { 1043 B = getNegOperand(I->getOperand(1)); 1044 IsPositive = !IsPositive; 1045 } else { 1046 B = I->getOperand(1); 1047 } 1048 Muls.push_back(Product{A, B, IsPositive}); 1049 break; 1050 } 1051 case Instruction::FNeg: 1052 Worklist.emplace_back(I->getOperand(0), !IsPositive); 1053 break; 1054 default: 1055 Addends.emplace_back(I, IsPositive); 1056 continue; 1057 } 1058 1059 if (Flags && I->getFastMathFlags() != *Flags) { 1060 LLVM_DEBUG(dbgs() << "The instruction's fast math flags are " 1061 "inconsistent with the root instructions' flags: " 1062 << *I << "\n"); 1063 return false; 1064 } 1065 } 1066 return true; 1067 }; 1068 1069 std::vector<Product> RealMuls, ImagMuls; 1070 std::list<Addend> RealAddends, ImagAddends; 1071 if (!Collect(Real, RealMuls, RealAddends) || 1072 !Collect(Imag, ImagMuls, ImagAddends)) 1073 return nullptr; 1074 1075 if (RealAddends.size() != ImagAddends.size()) 1076 return nullptr; 1077 1078 NodePtr FinalNode; 1079 if (!RealMuls.empty() || !ImagMuls.empty()) { 1080 // If there are multiplicands, extract positive addend and use it as an 1081 // accumulator 1082 FinalNode = extractPositiveAddend(RealAddends, ImagAddends); 1083 FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode); 1084 if (!FinalNode) 1085 return nullptr; 1086 } 1087 1088 // Identify and process remaining additions 1089 if (!RealAddends.empty() || !ImagAddends.empty()) { 1090 FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode); 1091 if (!FinalNode) 1092 return nullptr; 1093 } 1094 assert(FinalNode && "FinalNode can not be nullptr here"); 1095 // Set the Real and Imag fields of the final node and submit it 1096 FinalNode->Real = Real; 1097 FinalNode->Imag = Imag; 1098 submitCompositeNode(FinalNode); 1099 return FinalNode; 1100 } 1101 1102 bool ComplexDeinterleavingGraph::collectPartialMuls( 1103 const std::vector<Product> &RealMuls, const std::vector<Product> &ImagMuls, 1104 std::vector<PartialMulCandidate> &PartialMulCandidates) { 1105 // Helper function to extract a common operand from two products 1106 auto FindCommonInstruction = [](const Product &Real, 1107 const Product &Imag) -> Value * { 1108 if (Real.Multiplicand == Imag.Multiplicand || 1109 Real.Multiplicand == Imag.Multiplier) 1110 return Real.Multiplicand; 1111 1112 if (Real.Multiplier == Imag.Multiplicand || 1113 Real.Multiplier == Imag.Multiplier) 1114 return Real.Multiplier; 1115 1116 return nullptr; 1117 }; 1118 1119 // Iterating over real and imaginary multiplications to find common operands 1120 // If a common operand is found, a partial multiplication candidate is created 1121 // and added to the candidates vector The function returns false if no common 1122 // operands are found for any product 1123 for (unsigned i = 0; i < RealMuls.size(); ++i) { 1124 bool FoundCommon = false; 1125 for (unsigned j = 0; j < ImagMuls.size(); ++j) { 1126 auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]); 1127 if (!Common) 1128 continue; 1129 1130 auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier 1131 : RealMuls[i].Multiplicand; 1132 auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier 1133 : ImagMuls[j].Multiplicand; 1134 1135 auto Node = identifyNode(A, B); 1136 if (Node) { 1137 FoundCommon = true; 1138 PartialMulCandidates.push_back({Common, Node, i, j, false}); 1139 } 1140 1141 Node = identifyNode(B, A); 1142 if (Node) { 1143 FoundCommon = true; 1144 PartialMulCandidates.push_back({Common, Node, i, j, true}); 1145 } 1146 } 1147 if (!FoundCommon) 1148 return false; 1149 } 1150 return true; 1151 } 1152 1153 ComplexDeinterleavingGraph::NodePtr 1154 ComplexDeinterleavingGraph::identifyMultiplications( 1155 std::vector<Product> &RealMuls, std::vector<Product> &ImagMuls, 1156 NodePtr Accumulator = nullptr) { 1157 if (RealMuls.size() != ImagMuls.size()) 1158 return nullptr; 1159 1160 std::vector<PartialMulCandidate> Info; 1161 if (!collectPartialMuls(RealMuls, ImagMuls, Info)) 1162 return nullptr; 1163 1164 // Map to store common instruction to node pointers 1165 std::map<Value *, NodePtr> CommonToNode; 1166 std::vector<bool> Processed(Info.size(), false); 1167 for (unsigned I = 0; I < Info.size(); ++I) { 1168 if (Processed[I]) 1169 continue; 1170 1171 PartialMulCandidate &InfoA = Info[I]; 1172 for (unsigned J = I + 1; J < Info.size(); ++J) { 1173 if (Processed[J]) 1174 continue; 1175 1176 PartialMulCandidate &InfoB = Info[J]; 1177 auto *InfoReal = &InfoA; 1178 auto *InfoImag = &InfoB; 1179 1180 auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common); 1181 if (!NodeFromCommon) { 1182 std::swap(InfoReal, InfoImag); 1183 NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common); 1184 } 1185 if (!NodeFromCommon) 1186 continue; 1187 1188 CommonToNode[InfoReal->Common] = NodeFromCommon; 1189 CommonToNode[InfoImag->Common] = NodeFromCommon; 1190 Processed[I] = true; 1191 Processed[J] = true; 1192 } 1193 } 1194 1195 std::vector<bool> ProcessedReal(RealMuls.size(), false); 1196 std::vector<bool> ProcessedImag(ImagMuls.size(), false); 1197 NodePtr Result = Accumulator; 1198 for (auto &PMI : Info) { 1199 if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx]) 1200 continue; 1201 1202 auto It = CommonToNode.find(PMI.Common); 1203 // TODO: Process independent complex multiplications. Cases like this: 1204 // A.real() * B where both A and B are complex numbers. 1205 if (It == CommonToNode.end()) { 1206 LLVM_DEBUG({ 1207 dbgs() << "Unprocessed independent partial multiplication:\n"; 1208 for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]}) 1209 dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier 1210 << " multiplied by " << *Mul->Multiplicand << "\n"; 1211 }); 1212 return nullptr; 1213 } 1214 1215 auto &RealMul = RealMuls[PMI.RealIdx]; 1216 auto &ImagMul = ImagMuls[PMI.ImagIdx]; 1217 1218 auto NodeA = It->second; 1219 auto NodeB = PMI.Node; 1220 auto IsMultiplicandReal = PMI.Common == NodeA->Real; 1221 // The following table illustrates the relationship between multiplications 1222 // and rotations. If we consider the multiplication (X + iY) * (U + iV), we 1223 // can see: 1224 // 1225 // Rotation | Real | Imag | 1226 // ---------+--------+--------+ 1227 // 0 | x * u | x * v | 1228 // 90 | -y * v | y * u | 1229 // 180 | -x * u | -x * v | 1230 // 270 | y * v | -y * u | 1231 // 1232 // Check if the candidate can indeed be represented by partial 1233 // multiplication 1234 // TODO: Add support for multiplication by complex one 1235 if ((IsMultiplicandReal && PMI.IsNodeInverted) || 1236 (!IsMultiplicandReal && !PMI.IsNodeInverted)) 1237 continue; 1238 1239 // Determine the rotation based on the multiplications 1240 ComplexDeinterleavingRotation Rotation; 1241 if (IsMultiplicandReal) { 1242 // Detect 0 and 180 degrees rotation 1243 if (RealMul.IsPositive && ImagMul.IsPositive) 1244 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_0; 1245 else if (!RealMul.IsPositive && !ImagMul.IsPositive) 1246 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_180; 1247 else 1248 continue; 1249 1250 } else { 1251 // Detect 90 and 270 degrees rotation 1252 if (!RealMul.IsPositive && ImagMul.IsPositive) 1253 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_90; 1254 else if (RealMul.IsPositive && !ImagMul.IsPositive) 1255 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_270; 1256 else 1257 continue; 1258 } 1259 1260 LLVM_DEBUG({ 1261 dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n"; 1262 dbgs().indent(4) << "X: " << *NodeA->Real << "\n"; 1263 dbgs().indent(4) << "Y: " << *NodeA->Imag << "\n"; 1264 dbgs().indent(4) << "U: " << *NodeB->Real << "\n"; 1265 dbgs().indent(4) << "V: " << *NodeB->Imag << "\n"; 1266 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n"; 1267 }); 1268 1269 NodePtr NodeMul = prepareCompositeNode( 1270 ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr); 1271 NodeMul->Rotation = Rotation; 1272 NodeMul->addOperand(NodeA); 1273 NodeMul->addOperand(NodeB); 1274 if (Result) 1275 NodeMul->addOperand(Result); 1276 submitCompositeNode(NodeMul); 1277 Result = NodeMul; 1278 ProcessedReal[PMI.RealIdx] = true; 1279 ProcessedImag[PMI.ImagIdx] = true; 1280 } 1281 1282 // Ensure all products have been processed, if not return nullptr. 1283 if (!all_of(ProcessedReal, [](bool V) { return V; }) || 1284 !all_of(ProcessedImag, [](bool V) { return V; })) { 1285 1286 // Dump debug information about which partial multiplications are not 1287 // processed. 1288 LLVM_DEBUG({ 1289 dbgs() << "Unprocessed products (Real):\n"; 1290 for (size_t i = 0; i < ProcessedReal.size(); ++i) { 1291 if (!ProcessedReal[i]) 1292 dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-") 1293 << *RealMuls[i].Multiplier << " multiplied by " 1294 << *RealMuls[i].Multiplicand << "\n"; 1295 } 1296 dbgs() << "Unprocessed products (Imag):\n"; 1297 for (size_t i = 0; i < ProcessedImag.size(); ++i) { 1298 if (!ProcessedImag[i]) 1299 dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-") 1300 << *ImagMuls[i].Multiplier << " multiplied by " 1301 << *ImagMuls[i].Multiplicand << "\n"; 1302 } 1303 }); 1304 return nullptr; 1305 } 1306 1307 return Result; 1308 } 1309 1310 ComplexDeinterleavingGraph::NodePtr 1311 ComplexDeinterleavingGraph::identifyAdditions( 1312 std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends, 1313 std::optional<FastMathFlags> Flags, NodePtr Accumulator = nullptr) { 1314 if (RealAddends.size() != ImagAddends.size()) 1315 return nullptr; 1316 1317 NodePtr Result; 1318 // If we have accumulator use it as first addend 1319 if (Accumulator) 1320 Result = Accumulator; 1321 // Otherwise find an element with both positive real and imaginary parts. 1322 else 1323 Result = extractPositiveAddend(RealAddends, ImagAddends); 1324 1325 if (!Result) 1326 return nullptr; 1327 1328 while (!RealAddends.empty()) { 1329 auto ItR = RealAddends.begin(); 1330 auto [R, IsPositiveR] = *ItR; 1331 1332 bool FoundImag = false; 1333 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) { 1334 auto [I, IsPositiveI] = *ItI; 1335 ComplexDeinterleavingRotation Rotation; 1336 if (IsPositiveR && IsPositiveI) 1337 Rotation = ComplexDeinterleavingRotation::Rotation_0; 1338 else if (!IsPositiveR && IsPositiveI) 1339 Rotation = ComplexDeinterleavingRotation::Rotation_90; 1340 else if (!IsPositiveR && !IsPositiveI) 1341 Rotation = ComplexDeinterleavingRotation::Rotation_180; 1342 else 1343 Rotation = ComplexDeinterleavingRotation::Rotation_270; 1344 1345 NodePtr AddNode; 1346 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 || 1347 Rotation == ComplexDeinterleavingRotation::Rotation_180) { 1348 AddNode = identifyNode(R, I); 1349 } else { 1350 AddNode = identifyNode(I, R); 1351 } 1352 if (AddNode) { 1353 LLVM_DEBUG({ 1354 dbgs() << "Identified addition:\n"; 1355 dbgs().indent(4) << "X: " << *R << "\n"; 1356 dbgs().indent(4) << "Y: " << *I << "\n"; 1357 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n"; 1358 }); 1359 1360 NodePtr TmpNode; 1361 if (Rotation == llvm::ComplexDeinterleavingRotation::Rotation_0) { 1362 TmpNode = prepareCompositeNode( 1363 ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr); 1364 if (Flags) { 1365 TmpNode->Opcode = Instruction::FAdd; 1366 TmpNode->Flags = *Flags; 1367 } else { 1368 TmpNode->Opcode = Instruction::Add; 1369 } 1370 } else if (Rotation == 1371 llvm::ComplexDeinterleavingRotation::Rotation_180) { 1372 TmpNode = prepareCompositeNode( 1373 ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr); 1374 if (Flags) { 1375 TmpNode->Opcode = Instruction::FSub; 1376 TmpNode->Flags = *Flags; 1377 } else { 1378 TmpNode->Opcode = Instruction::Sub; 1379 } 1380 } else { 1381 TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, 1382 nullptr, nullptr); 1383 TmpNode->Rotation = Rotation; 1384 } 1385 1386 TmpNode->addOperand(Result); 1387 TmpNode->addOperand(AddNode); 1388 submitCompositeNode(TmpNode); 1389 Result = TmpNode; 1390 RealAddends.erase(ItR); 1391 ImagAddends.erase(ItI); 1392 FoundImag = true; 1393 break; 1394 } 1395 } 1396 if (!FoundImag) 1397 return nullptr; 1398 } 1399 return Result; 1400 } 1401 1402 ComplexDeinterleavingGraph::NodePtr 1403 ComplexDeinterleavingGraph::extractPositiveAddend( 1404 std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends) { 1405 for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) { 1406 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) { 1407 auto [R, IsPositiveR] = *ItR; 1408 auto [I, IsPositiveI] = *ItI; 1409 if (IsPositiveR && IsPositiveI) { 1410 auto Result = identifyNode(R, I); 1411 if (Result) { 1412 RealAddends.erase(ItR); 1413 ImagAddends.erase(ItI); 1414 return Result; 1415 } 1416 } 1417 } 1418 } 1419 return nullptr; 1420 } 1421 1422 bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) { 1423 // This potential root instruction might already have been recognized as 1424 // reduction. Because RootToNode maps both Real and Imaginary parts to 1425 // CompositeNode we should choose only one either Real or Imag instruction to 1426 // use as an anchor for generating complex instruction. 1427 auto It = RootToNode.find(RootI); 1428 if (It != RootToNode.end()) { 1429 auto RootNode = It->second; 1430 assert(RootNode->Operation == 1431 ComplexDeinterleavingOperation::ReductionOperation); 1432 // Find out which part, Real or Imag, comes later, and only if we come to 1433 // the latest part, add it to OrderedRoots. 1434 auto *R = cast<Instruction>(RootNode->Real); 1435 auto *I = cast<Instruction>(RootNode->Imag); 1436 auto *ReplacementAnchor = R->comesBefore(I) ? I : R; 1437 if (ReplacementAnchor != RootI) 1438 return false; 1439 OrderedRoots.push_back(RootI); 1440 return true; 1441 } 1442 1443 auto RootNode = identifyRoot(RootI); 1444 if (!RootNode) 1445 return false; 1446 1447 LLVM_DEBUG({ 1448 Function *F = RootI->getFunction(); 1449 BasicBlock *B = RootI->getParent(); 1450 dbgs() << "Complex deinterleaving graph for " << F->getName() 1451 << "::" << B->getName() << ".\n"; 1452 dump(dbgs()); 1453 dbgs() << "\n"; 1454 }); 1455 RootToNode[RootI] = RootNode; 1456 OrderedRoots.push_back(RootI); 1457 return true; 1458 } 1459 1460 bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) { 1461 bool FoundPotentialReduction = false; 1462 1463 auto *Br = dyn_cast<BranchInst>(B->getTerminator()); 1464 if (!Br || Br->getNumSuccessors() != 2) 1465 return false; 1466 1467 // Identify simple one-block loop 1468 if (Br->getSuccessor(0) != B && Br->getSuccessor(1) != B) 1469 return false; 1470 1471 SmallVector<PHINode *> PHIs; 1472 for (auto &PHI : B->phis()) { 1473 if (PHI.getNumIncomingValues() != 2) 1474 continue; 1475 1476 if (!PHI.getType()->isVectorTy()) 1477 continue; 1478 1479 auto *ReductionOp = dyn_cast<Instruction>(PHI.getIncomingValueForBlock(B)); 1480 if (!ReductionOp) 1481 continue; 1482 1483 // Check if final instruction is reduced outside of current block 1484 Instruction *FinalReduction = nullptr; 1485 auto NumUsers = 0u; 1486 for (auto *U : ReductionOp->users()) { 1487 ++NumUsers; 1488 if (U == &PHI) 1489 continue; 1490 FinalReduction = dyn_cast<Instruction>(U); 1491 } 1492 1493 if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B || 1494 isa<PHINode>(FinalReduction)) 1495 continue; 1496 1497 ReductionInfo[ReductionOp] = {&PHI, FinalReduction}; 1498 BackEdge = B; 1499 auto BackEdgeIdx = PHI.getBasicBlockIndex(B); 1500 auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0; 1501 Incoming = PHI.getIncomingBlock(IncomingIdx); 1502 FoundPotentialReduction = true; 1503 1504 // If the initial value of PHINode is an Instruction, consider it a leaf 1505 // value of a complex deinterleaving graph. 1506 if (auto *InitPHI = 1507 dyn_cast<Instruction>(PHI.getIncomingValueForBlock(Incoming))) 1508 FinalInstructions.insert(InitPHI); 1509 } 1510 return FoundPotentialReduction; 1511 } 1512 1513 void ComplexDeinterleavingGraph::identifyReductionNodes() { 1514 SmallVector<bool> Processed(ReductionInfo.size(), false); 1515 SmallVector<Instruction *> OperationInstruction; 1516 for (auto &P : ReductionInfo) 1517 OperationInstruction.push_back(P.first); 1518 1519 // Identify a complex computation by evaluating two reduction operations that 1520 // potentially could be involved 1521 for (size_t i = 0; i < OperationInstruction.size(); ++i) { 1522 if (Processed[i]) 1523 continue; 1524 for (size_t j = i + 1; j < OperationInstruction.size(); ++j) { 1525 if (Processed[j]) 1526 continue; 1527 1528 auto *Real = OperationInstruction[i]; 1529 auto *Imag = OperationInstruction[j]; 1530 if (Real->getType() != Imag->getType()) 1531 continue; 1532 1533 RealPHI = ReductionInfo[Real].first; 1534 ImagPHI = ReductionInfo[Imag].first; 1535 PHIsFound = false; 1536 auto Node = identifyNode(Real, Imag); 1537 if (!Node) { 1538 std::swap(Real, Imag); 1539 std::swap(RealPHI, ImagPHI); 1540 Node = identifyNode(Real, Imag); 1541 } 1542 1543 // If a node is identified and reduction PHINode is used in the chain of 1544 // operations, mark its operation instructions as used to prevent 1545 // re-identification and attach the node to the real part 1546 if (Node && PHIsFound) { 1547 LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: " 1548 << *Real << " / " << *Imag << "\n"); 1549 Processed[i] = true; 1550 Processed[j] = true; 1551 auto RootNode = prepareCompositeNode( 1552 ComplexDeinterleavingOperation::ReductionOperation, Real, Imag); 1553 RootNode->addOperand(Node); 1554 RootToNode[Real] = RootNode; 1555 RootToNode[Imag] = RootNode; 1556 submitCompositeNode(RootNode); 1557 break; 1558 } 1559 } 1560 } 1561 1562 RealPHI = nullptr; 1563 ImagPHI = nullptr; 1564 } 1565 1566 bool ComplexDeinterleavingGraph::checkNodes() { 1567 // Collect all instructions from roots to leaves 1568 SmallPtrSet<Instruction *, 16> AllInstructions; 1569 SmallVector<Instruction *, 8> Worklist; 1570 for (auto &Pair : RootToNode) 1571 Worklist.push_back(Pair.first); 1572 1573 // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG 1574 // chains 1575 while (!Worklist.empty()) { 1576 auto *I = Worklist.back(); 1577 Worklist.pop_back(); 1578 1579 if (!AllInstructions.insert(I).second) 1580 continue; 1581 1582 for (Value *Op : I->operands()) { 1583 if (auto *OpI = dyn_cast<Instruction>(Op)) { 1584 if (!FinalInstructions.count(I)) 1585 Worklist.emplace_back(OpI); 1586 } 1587 } 1588 } 1589 1590 // Find instructions that have users outside of chain 1591 SmallVector<Instruction *, 2> OuterInstructions; 1592 for (auto *I : AllInstructions) { 1593 // Skip root nodes 1594 if (RootToNode.count(I)) 1595 continue; 1596 1597 for (User *U : I->users()) { 1598 if (AllInstructions.count(cast<Instruction>(U))) 1599 continue; 1600 1601 // Found an instruction that is not used by XCMLA/XCADD chain 1602 Worklist.emplace_back(I); 1603 break; 1604 } 1605 } 1606 1607 // If any instructions are found to be used outside, find and remove roots 1608 // that somehow connect to those instructions. 1609 SmallPtrSet<Instruction *, 16> Visited; 1610 while (!Worklist.empty()) { 1611 auto *I = Worklist.back(); 1612 Worklist.pop_back(); 1613 if (!Visited.insert(I).second) 1614 continue; 1615 1616 // Found an impacted root node. Removing it from the nodes to be 1617 // deinterleaved 1618 if (RootToNode.count(I)) { 1619 LLVM_DEBUG(dbgs() << "Instruction " << *I 1620 << " could be deinterleaved but its chain of complex " 1621 "operations have an outside user\n"); 1622 RootToNode.erase(I); 1623 } 1624 1625 if (!AllInstructions.count(I) || FinalInstructions.count(I)) 1626 continue; 1627 1628 for (User *U : I->users()) 1629 Worklist.emplace_back(cast<Instruction>(U)); 1630 1631 for (Value *Op : I->operands()) { 1632 if (auto *OpI = dyn_cast<Instruction>(Op)) 1633 Worklist.emplace_back(OpI); 1634 } 1635 } 1636 return !RootToNode.empty(); 1637 } 1638 1639 ComplexDeinterleavingGraph::NodePtr 1640 ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) { 1641 if (auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) { 1642 if (Intrinsic->getIntrinsicID() != Intrinsic::vector_interleave2) 1643 return nullptr; 1644 1645 auto *Real = dyn_cast<Instruction>(Intrinsic->getOperand(0)); 1646 auto *Imag = dyn_cast<Instruction>(Intrinsic->getOperand(1)); 1647 if (!Real || !Imag) 1648 return nullptr; 1649 1650 return identifyNode(Real, Imag); 1651 } 1652 1653 auto *SVI = dyn_cast<ShuffleVectorInst>(RootI); 1654 if (!SVI) 1655 return nullptr; 1656 1657 // Look for a shufflevector that takes separate vectors of the real and 1658 // imaginary components and recombines them into a single vector. 1659 if (!isInterleavingMask(SVI->getShuffleMask())) 1660 return nullptr; 1661 1662 Instruction *Real; 1663 Instruction *Imag; 1664 if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag)))) 1665 return nullptr; 1666 1667 return identifyNode(Real, Imag); 1668 } 1669 1670 ComplexDeinterleavingGraph::NodePtr 1671 ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real, 1672 Instruction *Imag) { 1673 Instruction *I = nullptr; 1674 Value *FinalValue = nullptr; 1675 if (match(Real, m_ExtractValue<0>(m_Instruction(I))) && 1676 match(Imag, m_ExtractValue<1>(m_Specific(I))) && 1677 match(I, m_Intrinsic<Intrinsic::vector_deinterleave2>( 1678 m_Value(FinalValue)))) { 1679 NodePtr PlaceholderNode = prepareCompositeNode( 1680 llvm::ComplexDeinterleavingOperation::Deinterleave, Real, Imag); 1681 PlaceholderNode->ReplacementNode = FinalValue; 1682 FinalInstructions.insert(Real); 1683 FinalInstructions.insert(Imag); 1684 return submitCompositeNode(PlaceholderNode); 1685 } 1686 1687 auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real); 1688 auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag); 1689 if (!RealShuffle || !ImagShuffle) { 1690 if (RealShuffle || ImagShuffle) 1691 LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n"); 1692 return nullptr; 1693 } 1694 1695 Value *RealOp1 = RealShuffle->getOperand(1); 1696 if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) { 1697 LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n"); 1698 return nullptr; 1699 } 1700 Value *ImagOp1 = ImagShuffle->getOperand(1); 1701 if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) { 1702 LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n"); 1703 return nullptr; 1704 } 1705 1706 Value *RealOp0 = RealShuffle->getOperand(0); 1707 Value *ImagOp0 = ImagShuffle->getOperand(0); 1708 1709 if (RealOp0 != ImagOp0) { 1710 LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n"); 1711 return nullptr; 1712 } 1713 1714 ArrayRef<int> RealMask = RealShuffle->getShuffleMask(); 1715 ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask(); 1716 if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) { 1717 LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n"); 1718 return nullptr; 1719 } 1720 1721 if (RealMask[0] != 0 || ImagMask[0] != 1) { 1722 LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n"); 1723 return nullptr; 1724 } 1725 1726 // Type checking, the shuffle type should be a vector type of the same 1727 // scalar type, but half the size 1728 auto CheckType = [&](ShuffleVectorInst *Shuffle) { 1729 Value *Op = Shuffle->getOperand(0); 1730 auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType()); 1731 auto *OpTy = cast<FixedVectorType>(Op->getType()); 1732 1733 if (OpTy->getScalarType() != ShuffleTy->getScalarType()) 1734 return false; 1735 if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements()) 1736 return false; 1737 1738 return true; 1739 }; 1740 1741 auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool { 1742 if (!CheckType(Shuffle)) 1743 return false; 1744 1745 ArrayRef<int> Mask = Shuffle->getShuffleMask(); 1746 int Last = *Mask.rbegin(); 1747 1748 Value *Op = Shuffle->getOperand(0); 1749 auto *OpTy = cast<FixedVectorType>(Op->getType()); 1750 int NumElements = OpTy->getNumElements(); 1751 1752 // Ensure that the deinterleaving shuffle only pulls from the first 1753 // shuffle operand. 1754 return Last < NumElements; 1755 }; 1756 1757 if (RealShuffle->getType() != ImagShuffle->getType()) { 1758 LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n"); 1759 return nullptr; 1760 } 1761 if (!CheckDeinterleavingShuffle(RealShuffle)) { 1762 LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n"); 1763 return nullptr; 1764 } 1765 if (!CheckDeinterleavingShuffle(ImagShuffle)) { 1766 LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n"); 1767 return nullptr; 1768 } 1769 1770 NodePtr PlaceholderNode = 1771 prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Deinterleave, 1772 RealShuffle, ImagShuffle); 1773 PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0); 1774 FinalInstructions.insert(RealShuffle); 1775 FinalInstructions.insert(ImagShuffle); 1776 return submitCompositeNode(PlaceholderNode); 1777 } 1778 1779 ComplexDeinterleavingGraph::NodePtr 1780 ComplexDeinterleavingGraph::identifySplat(Value *R, Value *I) { 1781 auto IsSplat = [](Value *V) -> bool { 1782 // Fixed-width vector with constants 1783 if (isa<ConstantDataVector>(V)) 1784 return true; 1785 1786 VectorType *VTy; 1787 ArrayRef<int> Mask; 1788 // Splats are represented differently depending on whether the repeated 1789 // value is a constant or an Instruction 1790 if (auto *Const = dyn_cast<ConstantExpr>(V)) { 1791 if (Const->getOpcode() != Instruction::ShuffleVector) 1792 return false; 1793 VTy = cast<VectorType>(Const->getType()); 1794 Mask = Const->getShuffleMask(); 1795 } else if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V)) { 1796 VTy = Shuf->getType(); 1797 Mask = Shuf->getShuffleMask(); 1798 } else { 1799 return false; 1800 } 1801 1802 // When the data type is <1 x Type>, it's not possible to differentiate 1803 // between the ComplexDeinterleaving::Deinterleave and 1804 // ComplexDeinterleaving::Splat operations. 1805 if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1) 1806 return false; 1807 1808 return all_equal(Mask) && Mask[0] == 0; 1809 }; 1810 1811 if (!IsSplat(R) || !IsSplat(I)) 1812 return nullptr; 1813 1814 auto *Real = dyn_cast<Instruction>(R); 1815 auto *Imag = dyn_cast<Instruction>(I); 1816 if ((!Real && Imag) || (Real && !Imag)) 1817 return nullptr; 1818 1819 if (Real && Imag) { 1820 // Non-constant splats should be in the same basic block 1821 if (Real->getParent() != Imag->getParent()) 1822 return nullptr; 1823 1824 FinalInstructions.insert(Real); 1825 FinalInstructions.insert(Imag); 1826 } 1827 NodePtr PlaceholderNode = 1828 prepareCompositeNode(ComplexDeinterleavingOperation::Splat, R, I); 1829 return submitCompositeNode(PlaceholderNode); 1830 } 1831 1832 ComplexDeinterleavingGraph::NodePtr 1833 ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real, 1834 Instruction *Imag) { 1835 if (Real != RealPHI || Imag != ImagPHI) 1836 return nullptr; 1837 1838 PHIsFound = true; 1839 NodePtr PlaceholderNode = prepareCompositeNode( 1840 ComplexDeinterleavingOperation::ReductionPHI, Real, Imag); 1841 return submitCompositeNode(PlaceholderNode); 1842 } 1843 1844 ComplexDeinterleavingGraph::NodePtr 1845 ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real, 1846 Instruction *Imag) { 1847 auto *SelectReal = dyn_cast<SelectInst>(Real); 1848 auto *SelectImag = dyn_cast<SelectInst>(Imag); 1849 if (!SelectReal || !SelectImag) 1850 return nullptr; 1851 1852 Instruction *MaskA, *MaskB; 1853 Instruction *AR, *AI, *RA, *BI; 1854 if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR), 1855 m_Instruction(RA))) || 1856 !match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI), 1857 m_Instruction(BI)))) 1858 return nullptr; 1859 1860 if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB)) 1861 return nullptr; 1862 1863 if (!MaskA->getType()->isVectorTy()) 1864 return nullptr; 1865 1866 auto NodeA = identifyNode(AR, AI); 1867 if (!NodeA) 1868 return nullptr; 1869 1870 auto NodeB = identifyNode(RA, BI); 1871 if (!NodeB) 1872 return nullptr; 1873 1874 NodePtr PlaceholderNode = prepareCompositeNode( 1875 ComplexDeinterleavingOperation::ReductionSelect, Real, Imag); 1876 PlaceholderNode->addOperand(NodeA); 1877 PlaceholderNode->addOperand(NodeB); 1878 FinalInstructions.insert(MaskA); 1879 FinalInstructions.insert(MaskB); 1880 return submitCompositeNode(PlaceholderNode); 1881 } 1882 1883 static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode, 1884 std::optional<FastMathFlags> Flags, 1885 Value *InputA, Value *InputB) { 1886 Value *I; 1887 switch (Opcode) { 1888 case Instruction::FNeg: 1889 I = B.CreateFNeg(InputA); 1890 break; 1891 case Instruction::FAdd: 1892 I = B.CreateFAdd(InputA, InputB); 1893 break; 1894 case Instruction::Add: 1895 I = B.CreateAdd(InputA, InputB); 1896 break; 1897 case Instruction::FSub: 1898 I = B.CreateFSub(InputA, InputB); 1899 break; 1900 case Instruction::Sub: 1901 I = B.CreateSub(InputA, InputB); 1902 break; 1903 case Instruction::FMul: 1904 I = B.CreateFMul(InputA, InputB); 1905 break; 1906 case Instruction::Mul: 1907 I = B.CreateMul(InputA, InputB); 1908 break; 1909 default: 1910 llvm_unreachable("Incorrect symmetric opcode"); 1911 } 1912 if (Flags) 1913 cast<Instruction>(I)->setFastMathFlags(*Flags); 1914 return I; 1915 } 1916 1917 Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder, 1918 RawNodePtr Node) { 1919 if (Node->ReplacementNode) 1920 return Node->ReplacementNode; 1921 1922 auto ReplaceOperandIfExist = [&](RawNodePtr &Node, unsigned Idx) -> Value * { 1923 return Node->Operands.size() > Idx 1924 ? replaceNode(Builder, Node->Operands[Idx]) 1925 : nullptr; 1926 }; 1927 1928 Value *ReplacementNode; 1929 switch (Node->Operation) { 1930 case ComplexDeinterleavingOperation::CAdd: 1931 case ComplexDeinterleavingOperation::CMulPartial: 1932 case ComplexDeinterleavingOperation::Symmetric: { 1933 Value *Input0 = ReplaceOperandIfExist(Node, 0); 1934 Value *Input1 = ReplaceOperandIfExist(Node, 1); 1935 Value *Accumulator = ReplaceOperandIfExist(Node, 2); 1936 assert(!Input1 || (Input0->getType() == Input1->getType() && 1937 "Node inputs need to be of the same type")); 1938 assert(!Accumulator || 1939 (Input0->getType() == Accumulator->getType() && 1940 "Accumulator and input need to be of the same type")); 1941 if (Node->Operation == ComplexDeinterleavingOperation::Symmetric) 1942 ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, Node->Flags, 1943 Input0, Input1); 1944 else 1945 ReplacementNode = TL->createComplexDeinterleavingIR( 1946 Builder, Node->Operation, Node->Rotation, Input0, Input1, 1947 Accumulator); 1948 break; 1949 } 1950 case ComplexDeinterleavingOperation::Deinterleave: 1951 llvm_unreachable("Deinterleave node should already have ReplacementNode"); 1952 break; 1953 case ComplexDeinterleavingOperation::Splat: { 1954 auto *NewTy = VectorType::getDoubleElementsVectorType( 1955 cast<VectorType>(Node->Real->getType())); 1956 auto *R = dyn_cast<Instruction>(Node->Real); 1957 auto *I = dyn_cast<Instruction>(Node->Imag); 1958 if (R && I) { 1959 // Splats that are not constant are interleaved where they are located 1960 Instruction *InsertPoint = (I->comesBefore(R) ? R : I)->getNextNode(); 1961 IRBuilder<> IRB(InsertPoint); 1962 ReplacementNode = IRB.CreateIntrinsic(Intrinsic::vector_interleave2, 1963 NewTy, {Node->Real, Node->Imag}); 1964 } else { 1965 ReplacementNode = Builder.CreateIntrinsic( 1966 Intrinsic::vector_interleave2, NewTy, {Node->Real, Node->Imag}); 1967 } 1968 break; 1969 } 1970 case ComplexDeinterleavingOperation::ReductionPHI: { 1971 // If Operation is ReductionPHI, a new empty PHINode is created. 1972 // It is filled later when the ReductionOperation is processed. 1973 auto *VTy = cast<VectorType>(Node->Real->getType()); 1974 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); 1975 auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHIIt()); 1976 OldToNewPHI[dyn_cast<PHINode>(Node->Real)] = NewPHI; 1977 ReplacementNode = NewPHI; 1978 break; 1979 } 1980 case ComplexDeinterleavingOperation::ReductionOperation: 1981 ReplacementNode = replaceNode(Builder, Node->Operands[0]); 1982 processReductionOperation(ReplacementNode, Node); 1983 break; 1984 case ComplexDeinterleavingOperation::ReductionSelect: { 1985 auto *MaskReal = cast<Instruction>(Node->Real)->getOperand(0); 1986 auto *MaskImag = cast<Instruction>(Node->Imag)->getOperand(0); 1987 auto *A = replaceNode(Builder, Node->Operands[0]); 1988 auto *B = replaceNode(Builder, Node->Operands[1]); 1989 auto *NewMaskTy = VectorType::getDoubleElementsVectorType( 1990 cast<VectorType>(MaskReal->getType())); 1991 auto *NewMask = Builder.CreateIntrinsic(Intrinsic::vector_interleave2, 1992 NewMaskTy, {MaskReal, MaskImag}); 1993 ReplacementNode = Builder.CreateSelect(NewMask, A, B); 1994 break; 1995 } 1996 } 1997 1998 assert(ReplacementNode && "Target failed to create Intrinsic call."); 1999 NumComplexTransformations += 1; 2000 Node->ReplacementNode = ReplacementNode; 2001 return ReplacementNode; 2002 } 2003 2004 void ComplexDeinterleavingGraph::processReductionOperation( 2005 Value *OperationReplacement, RawNodePtr Node) { 2006 auto *Real = cast<Instruction>(Node->Real); 2007 auto *Imag = cast<Instruction>(Node->Imag); 2008 auto *OldPHIReal = ReductionInfo[Real].first; 2009 auto *OldPHIImag = ReductionInfo[Imag].first; 2010 auto *NewPHI = OldToNewPHI[OldPHIReal]; 2011 2012 auto *VTy = cast<VectorType>(Real->getType()); 2013 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); 2014 2015 // We have to interleave initial origin values coming from IncomingBlock 2016 Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming); 2017 Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming); 2018 2019 IRBuilder<> Builder(Incoming->getTerminator()); 2020 auto *NewInit = Builder.CreateIntrinsic(Intrinsic::vector_interleave2, NewVTy, 2021 {InitReal, InitImag}); 2022 2023 NewPHI->addIncoming(NewInit, Incoming); 2024 NewPHI->addIncoming(OperationReplacement, BackEdge); 2025 2026 // Deinterleave complex vector outside of loop so that it can be finally 2027 // reduced 2028 auto *FinalReductionReal = ReductionInfo[Real].second; 2029 auto *FinalReductionImag = ReductionInfo[Imag].second; 2030 2031 Builder.SetInsertPoint( 2032 &*FinalReductionReal->getParent()->getFirstInsertionPt()); 2033 auto *Deinterleave = Builder.CreateIntrinsic(Intrinsic::vector_deinterleave2, 2034 OperationReplacement->getType(), 2035 OperationReplacement); 2036 2037 auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0); 2038 FinalReductionReal->replaceUsesOfWith(Real, NewReal); 2039 2040 Builder.SetInsertPoint(FinalReductionImag); 2041 auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1); 2042 FinalReductionImag->replaceUsesOfWith(Imag, NewImag); 2043 } 2044 2045 void ComplexDeinterleavingGraph::replaceNodes() { 2046 SmallVector<Instruction *, 16> DeadInstrRoots; 2047 for (auto *RootInstruction : OrderedRoots) { 2048 // Check if this potential root went through check process and we can 2049 // deinterleave it 2050 if (!RootToNode.count(RootInstruction)) 2051 continue; 2052 2053 IRBuilder<> Builder(RootInstruction); 2054 auto RootNode = RootToNode[RootInstruction]; 2055 Value *R = replaceNode(Builder, RootNode.get()); 2056 2057 if (RootNode->Operation == 2058 ComplexDeinterleavingOperation::ReductionOperation) { 2059 auto *RootReal = cast<Instruction>(RootNode->Real); 2060 auto *RootImag = cast<Instruction>(RootNode->Imag); 2061 ReductionInfo[RootReal].first->removeIncomingValue(BackEdge); 2062 ReductionInfo[RootImag].first->removeIncomingValue(BackEdge); 2063 DeadInstrRoots.push_back(cast<Instruction>(RootReal)); 2064 DeadInstrRoots.push_back(cast<Instruction>(RootImag)); 2065 } else { 2066 assert(R && "Unable to find replacement for RootInstruction"); 2067 DeadInstrRoots.push_back(RootInstruction); 2068 RootInstruction->replaceAllUsesWith(R); 2069 } 2070 } 2071 2072 for (auto *I : DeadInstrRoots) 2073 RecursivelyDeleteTriviallyDeadInstructions(I, TLI); 2074 } 2075