1 //===- ComplexDeinterleavingPass.cpp --------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Identification: 10 // This step is responsible for finding the patterns that can be lowered to 11 // complex instructions, and building a graph to represent the complex 12 // structures. Starting from the "Converging Shuffle" (a shuffle that 13 // reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the 14 // operands are evaluated and identified as "Composite Nodes" (collections of 15 // instructions that can potentially be lowered to a single complex 16 // instruction). This is performed by checking the real and imaginary components 17 // and tracking the data flow for each component while following the operand 18 // pairs. Validity of each node is expected to be done upon creation, and any 19 // validation errors should halt traversal and prevent further graph 20 // construction. 21 // 22 // Replacement: 23 // This step traverses the graph built up by identification, delegating to the 24 // target to validate and generate the correct intrinsics, and plumbs them 25 // together connecting each end of the new intrinsics graph to the existing 26 // use-def chain. This step is assumed to finish successfully, as all 27 // information is expected to be correct by this point. 28 // 29 // 30 // Internal data structure: 31 // ComplexDeinterleavingGraph: 32 // Keeps references to all the valid CompositeNodes formed as part of the 33 // transformation, and every Instruction contained within said nodes. It also 34 // holds onto a reference to the root Instruction, and the root node that should 35 // replace it. 36 // 37 // ComplexDeinterleavingCompositeNode: 38 // A CompositeNode represents a single transformation point; each node should 39 // transform into a single complex instruction (ignoring vector splitting, which 40 // would generate more instructions per node). They are identified in a 41 // depth-first manner, traversing and identifying the operands of each 42 // instruction in the order they appear in the IR. 43 // Each node maintains a reference to its Real and Imaginary instructions, 44 // as well as any additional instructions that make up the identified operation 45 // (Internal instructions should only have uses within their containing node). 46 // A Node also contains the rotation and operation type that it represents. 47 // Operands contains pointers to other CompositeNodes, acting as the edges in 48 // the graph. ReplacementValue is the transformed Value* that has been emitted 49 // to the IR. 50 // 51 // Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and 52 // ReplacementValue fields of that Node are relevant, where the ReplacementValue 53 // should be pre-populated. 54 // 55 //===----------------------------------------------------------------------===// 56 57 #include "llvm/CodeGen/ComplexDeinterleavingPass.h" 58 #include "llvm/ADT/Statistic.h" 59 #include "llvm/Analysis/TargetLibraryInfo.h" 60 #include "llvm/Analysis/TargetTransformInfo.h" 61 #include "llvm/CodeGen/TargetLowering.h" 62 #include "llvm/CodeGen/TargetPassConfig.h" 63 #include "llvm/CodeGen/TargetSubtargetInfo.h" 64 #include "llvm/IR/IRBuilder.h" 65 #include "llvm/InitializePasses.h" 66 #include "llvm/Target/TargetMachine.h" 67 #include "llvm/Transforms/Utils/Local.h" 68 #include <algorithm> 69 70 using namespace llvm; 71 using namespace PatternMatch; 72 73 #define DEBUG_TYPE "complex-deinterleaving" 74 75 STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed"); 76 77 static cl::opt<bool> ComplexDeinterleavingEnabled( 78 "enable-complex-deinterleaving", 79 cl::desc("Enable generation of complex instructions"), cl::init(true), 80 cl::Hidden); 81 82 /// Checks the given mask, and determines whether said mask is interleaving. 83 /// 84 /// To be interleaving, a mask must alternate between `i` and `i + (Length / 85 /// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a 86 /// 4x vector interleaving mask would be <0, 2, 1, 3>). 87 static bool isInterleavingMask(ArrayRef<int> Mask); 88 89 /// Checks the given mask, and determines whether said mask is deinterleaving. 90 /// 91 /// To be deinterleaving, a mask must increment in steps of 2, and either start 92 /// with 0 or 1. 93 /// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or 94 /// <1, 3, 5, 7>). 95 static bool isDeinterleavingMask(ArrayRef<int> Mask); 96 97 namespace { 98 99 class ComplexDeinterleavingLegacyPass : public FunctionPass { 100 public: 101 static char ID; 102 103 ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr) 104 : FunctionPass(ID), TM(TM) { 105 initializeComplexDeinterleavingLegacyPassPass( 106 *PassRegistry::getPassRegistry()); 107 } 108 109 StringRef getPassName() const override { 110 return "Complex Deinterleaving Pass"; 111 } 112 113 bool runOnFunction(Function &F) override; 114 void getAnalysisUsage(AnalysisUsage &AU) const override { 115 AU.addRequired<TargetLibraryInfoWrapperPass>(); 116 AU.setPreservesCFG(); 117 } 118 119 private: 120 const TargetMachine *TM; 121 }; 122 123 class ComplexDeinterleavingGraph; 124 struct ComplexDeinterleavingCompositeNode { 125 126 ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op, 127 Instruction *R, Instruction *I) 128 : Operation(Op), Real(R), Imag(I) {} 129 130 private: 131 friend class ComplexDeinterleavingGraph; 132 using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>; 133 using RawNodePtr = ComplexDeinterleavingCompositeNode *; 134 135 public: 136 ComplexDeinterleavingOperation Operation; 137 Instruction *Real; 138 Instruction *Imag; 139 140 // Instructions that should only exist within this node, there should be no 141 // users of these instructions outside the node. An example of these would be 142 // the multiply instructions of a partial multiply operation. 143 SmallVector<Instruction *> InternalInstructions; 144 ComplexDeinterleavingRotation Rotation; 145 SmallVector<RawNodePtr> Operands; 146 Value *ReplacementNode = nullptr; 147 148 void addInstruction(Instruction *I) { InternalInstructions.push_back(I); } 149 void addOperand(NodePtr Node) { Operands.push_back(Node.get()); } 150 151 bool hasAllInternalUses(SmallPtrSet<Instruction *, 16> &AllInstructions); 152 153 void dump() { dump(dbgs()); } 154 void dump(raw_ostream &OS) { 155 auto PrintValue = [&](Value *V) { 156 if (V) { 157 OS << "\""; 158 V->print(OS, true); 159 OS << "\"\n"; 160 } else 161 OS << "nullptr\n"; 162 }; 163 auto PrintNodeRef = [&](RawNodePtr Ptr) { 164 if (Ptr) 165 OS << Ptr << "\n"; 166 else 167 OS << "nullptr\n"; 168 }; 169 170 OS << "- CompositeNode: " << this << "\n"; 171 OS << " Real: "; 172 PrintValue(Real); 173 OS << " Imag: "; 174 PrintValue(Imag); 175 OS << " ReplacementNode: "; 176 PrintValue(ReplacementNode); 177 OS << " Operation: " << (int)Operation << "\n"; 178 OS << " Rotation: " << ((int)Rotation * 90) << "\n"; 179 OS << " Operands: \n"; 180 for (const auto &Op : Operands) { 181 OS << " - "; 182 PrintNodeRef(Op); 183 } 184 OS << " InternalInstructions:\n"; 185 for (const auto &I : InternalInstructions) { 186 OS << " - \""; 187 I->print(OS, true); 188 OS << "\"\n"; 189 } 190 } 191 }; 192 193 class ComplexDeinterleavingGraph { 194 public: 195 using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr; 196 using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr; 197 explicit ComplexDeinterleavingGraph(const TargetLowering *tl) : TL(tl) {} 198 199 private: 200 const TargetLowering *TL; 201 Instruction *RootValue; 202 NodePtr RootNode; 203 SmallVector<NodePtr> CompositeNodes; 204 SmallPtrSet<Instruction *, 16> AllInstructions; 205 206 NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation, 207 Instruction *R, Instruction *I) { 208 return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R, 209 I); 210 } 211 212 NodePtr submitCompositeNode(NodePtr Node) { 213 CompositeNodes.push_back(Node); 214 AllInstructions.insert(Node->Real); 215 AllInstructions.insert(Node->Imag); 216 for (auto *I : Node->InternalInstructions) 217 AllInstructions.insert(I); 218 return Node; 219 } 220 221 NodePtr getContainingComposite(Value *R, Value *I) { 222 for (const auto &CN : CompositeNodes) { 223 if (CN->Real == R && CN->Imag == I) 224 return CN; 225 } 226 return nullptr; 227 } 228 229 /// Identifies a complex partial multiply pattern and its rotation, based on 230 /// the following patterns 231 /// 232 /// 0: r: cr + ar * br 233 /// i: ci + ar * bi 234 /// 90: r: cr - ai * bi 235 /// i: ci + ai * br 236 /// 180: r: cr - ar * br 237 /// i: ci - ar * bi 238 /// 270: r: cr + ai * bi 239 /// i: ci - ai * br 240 NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag); 241 242 /// Identify the other branch of a Partial Mul, taking the CommonOperandI that 243 /// is partially known from identifyPartialMul, filling in the other half of 244 /// the complex pair. 245 NodePtr identifyNodeWithImplicitAdd( 246 Instruction *I, Instruction *J, 247 std::pair<Instruction *, Instruction *> &CommonOperandI); 248 249 /// Identifies a complex add pattern and its rotation, based on the following 250 /// patterns. 251 /// 252 /// 90: r: ar - bi 253 /// i: ai + br 254 /// 270: r: ar + bi 255 /// i: ai - br 256 NodePtr identifyAdd(Instruction *Real, Instruction *Imag); 257 258 NodePtr identifyNode(Instruction *I, Instruction *J); 259 260 Value *replaceNode(RawNodePtr Node); 261 262 public: 263 void dump() { dump(dbgs()); } 264 void dump(raw_ostream &OS) { 265 for (const auto &Node : CompositeNodes) 266 Node->dump(OS); 267 } 268 269 /// Returns false if the deinterleaving operation should be cancelled for the 270 /// current graph. 271 bool identifyNodes(Instruction *RootI); 272 273 /// Perform the actual replacement of the underlying instruction graph. 274 /// Returns false if the deinterleaving operation should be cancelled for the 275 /// current graph. 276 void replaceNodes(); 277 }; 278 279 class ComplexDeinterleaving { 280 public: 281 ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli) 282 : TL(tl), TLI(tli) {} 283 bool runOnFunction(Function &F); 284 285 private: 286 bool evaluateBasicBlock(BasicBlock *B); 287 288 const TargetLowering *TL = nullptr; 289 const TargetLibraryInfo *TLI = nullptr; 290 }; 291 292 } // namespace 293 294 char ComplexDeinterleavingLegacyPass::ID = 0; 295 296 INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE, 297 "Complex Deinterleaving", false, false) 298 INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE, 299 "Complex Deinterleaving", false, false) 300 301 PreservedAnalyses ComplexDeinterleavingPass::run(Function &F, 302 FunctionAnalysisManager &AM) { 303 const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering(); 304 auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F); 305 if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F)) 306 return PreservedAnalyses::all(); 307 308 PreservedAnalyses PA; 309 PA.preserve<FunctionAnalysisManagerModuleProxy>(); 310 return PA; 311 } 312 313 FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) { 314 return new ComplexDeinterleavingLegacyPass(TM); 315 } 316 317 bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) { 318 const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering(); 319 auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); 320 return ComplexDeinterleaving(TL, &TLI).runOnFunction(F); 321 } 322 323 bool ComplexDeinterleaving::runOnFunction(Function &F) { 324 if (!ComplexDeinterleavingEnabled) { 325 LLVM_DEBUG( 326 dbgs() << "Complex deinterleaving has been explicitly disabled.\n"); 327 return false; 328 } 329 330 if (!TL->isComplexDeinterleavingSupported()) { 331 LLVM_DEBUG( 332 dbgs() << "Complex deinterleaving has been disabled, target does " 333 "not support lowering of complex number operations.\n"); 334 return false; 335 } 336 337 bool Changed = false; 338 for (auto &B : F) 339 Changed |= evaluateBasicBlock(&B); 340 341 return Changed; 342 } 343 344 static bool isInterleavingMask(ArrayRef<int> Mask) { 345 // If the size is not even, it's not an interleaving mask 346 if ((Mask.size() & 1)) 347 return false; 348 349 int HalfNumElements = Mask.size() / 2; 350 for (int Idx = 0; Idx < HalfNumElements; ++Idx) { 351 int MaskIdx = Idx * 2; 352 if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements)) 353 return false; 354 } 355 356 return true; 357 } 358 359 static bool isDeinterleavingMask(ArrayRef<int> Mask) { 360 int Offset = Mask[0]; 361 int HalfNumElements = Mask.size() / 2; 362 363 for (int Idx = 1; Idx < HalfNumElements; ++Idx) { 364 if (Mask[Idx] != (Idx * 2) + Offset) 365 return false; 366 } 367 368 return true; 369 } 370 371 bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) { 372 bool Changed = false; 373 374 SmallVector<Instruction *> DeadInstrRoots; 375 376 for (auto &I : *B) { 377 auto *SVI = dyn_cast<ShuffleVectorInst>(&I); 378 if (!SVI) 379 continue; 380 381 // Look for a shufflevector that takes separate vectors of the real and 382 // imaginary components and recombines them into a single vector. 383 if (!isInterleavingMask(SVI->getShuffleMask())) 384 continue; 385 386 ComplexDeinterleavingGraph Graph(TL); 387 if (!Graph.identifyNodes(SVI)) 388 continue; 389 390 Graph.replaceNodes(); 391 DeadInstrRoots.push_back(SVI); 392 Changed = true; 393 } 394 395 for (const auto &I : DeadInstrRoots) { 396 if (!I || I->getParent() == nullptr) 397 continue; 398 llvm::RecursivelyDeleteTriviallyDeadInstructions(I, TLI); 399 } 400 401 return Changed; 402 } 403 404 ComplexDeinterleavingGraph::NodePtr 405 ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd( 406 Instruction *Real, Instruction *Imag, 407 std::pair<Instruction *, Instruction *> &PartialMatch) { 408 LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag 409 << "\n"); 410 411 if (!Real->hasOneUse() || !Imag->hasOneUse()) { 412 LLVM_DEBUG(dbgs() << " - Mul operand has multiple uses.\n"); 413 return nullptr; 414 } 415 416 if (Real->getOpcode() != Instruction::FMul || 417 Imag->getOpcode() != Instruction::FMul) { 418 LLVM_DEBUG(dbgs() << " - Real or imaginary instruction is not fmul\n"); 419 return nullptr; 420 } 421 422 Instruction *R0 = dyn_cast<Instruction>(Real->getOperand(0)); 423 Instruction *R1 = dyn_cast<Instruction>(Real->getOperand(1)); 424 Instruction *I0 = dyn_cast<Instruction>(Imag->getOperand(0)); 425 Instruction *I1 = dyn_cast<Instruction>(Imag->getOperand(1)); 426 if (!R0 || !R1 || !I0 || !I1) { 427 LLVM_DEBUG(dbgs() << " - Mul operand not Instruction\n"); 428 return nullptr; 429 } 430 431 // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the 432 // rotations and use the operand. 433 unsigned Negs = 0; 434 SmallVector<Instruction *> FNegs; 435 if (R0->getOpcode() == Instruction::FNeg || 436 R1->getOpcode() == Instruction::FNeg) { 437 Negs |= 1; 438 if (R0->getOpcode() == Instruction::FNeg) { 439 FNegs.push_back(R0); 440 R0 = dyn_cast<Instruction>(R0->getOperand(0)); 441 } else { 442 FNegs.push_back(R1); 443 R1 = dyn_cast<Instruction>(R1->getOperand(0)); 444 } 445 if (!R0 || !R1) 446 return nullptr; 447 } 448 if (I0->getOpcode() == Instruction::FNeg || 449 I1->getOpcode() == Instruction::FNeg) { 450 Negs |= 2; 451 Negs ^= 1; 452 if (I0->getOpcode() == Instruction::FNeg) { 453 FNegs.push_back(I0); 454 I0 = dyn_cast<Instruction>(I0->getOperand(0)); 455 } else { 456 FNegs.push_back(I1); 457 I1 = dyn_cast<Instruction>(I1->getOperand(0)); 458 } 459 if (!I0 || !I1) 460 return nullptr; 461 } 462 463 ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs; 464 465 Instruction *CommonOperand; 466 Instruction *UncommonRealOp; 467 Instruction *UncommonImagOp; 468 469 if (R0 == I0 || R0 == I1) { 470 CommonOperand = R0; 471 UncommonRealOp = R1; 472 } else if (R1 == I0 || R1 == I1) { 473 CommonOperand = R1; 474 UncommonRealOp = R0; 475 } else { 476 LLVM_DEBUG(dbgs() << " - No equal operand\n"); 477 return nullptr; 478 } 479 480 UncommonImagOp = (CommonOperand == I0) ? I1 : I0; 481 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 || 482 Rotation == ComplexDeinterleavingRotation::Rotation_270) 483 std::swap(UncommonRealOp, UncommonImagOp); 484 485 // Between identifyPartialMul and here we need to have found a complete valid 486 // pair from the CommonOperand of each part. 487 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 || 488 Rotation == ComplexDeinterleavingRotation::Rotation_180) 489 PartialMatch.first = CommonOperand; 490 else 491 PartialMatch.second = CommonOperand; 492 493 if (!PartialMatch.first || !PartialMatch.second) { 494 LLVM_DEBUG(dbgs() << " - Incomplete partial match\n"); 495 return nullptr; 496 } 497 498 NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second); 499 if (!CommonNode) { 500 LLVM_DEBUG(dbgs() << " - No CommonNode identified\n"); 501 return nullptr; 502 } 503 504 NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp); 505 if (!UncommonNode) { 506 LLVM_DEBUG(dbgs() << " - No UncommonNode identified\n"); 507 return nullptr; 508 } 509 510 NodePtr Node = prepareCompositeNode( 511 ComplexDeinterleavingOperation::CMulPartial, Real, Imag); 512 Node->Rotation = Rotation; 513 Node->addOperand(CommonNode); 514 Node->addOperand(UncommonNode); 515 Node->InternalInstructions.append(FNegs); 516 return submitCompositeNode(Node); 517 } 518 519 ComplexDeinterleavingGraph::NodePtr 520 ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real, 521 Instruction *Imag) { 522 LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag 523 << "\n"); 524 // Determine rotation 525 ComplexDeinterleavingRotation Rotation; 526 if (Real->getOpcode() == Instruction::FAdd && 527 Imag->getOpcode() == Instruction::FAdd) 528 Rotation = ComplexDeinterleavingRotation::Rotation_0; 529 else if (Real->getOpcode() == Instruction::FSub && 530 Imag->getOpcode() == Instruction::FAdd) 531 Rotation = ComplexDeinterleavingRotation::Rotation_90; 532 else if (Real->getOpcode() == Instruction::FSub && 533 Imag->getOpcode() == Instruction::FSub) 534 Rotation = ComplexDeinterleavingRotation::Rotation_180; 535 else if (Real->getOpcode() == Instruction::FAdd && 536 Imag->getOpcode() == Instruction::FSub) 537 Rotation = ComplexDeinterleavingRotation::Rotation_270; 538 else { 539 LLVM_DEBUG(dbgs() << " - Unhandled rotation.\n"); 540 return nullptr; 541 } 542 543 if (!Real->getFastMathFlags().allowContract() || 544 !Imag->getFastMathFlags().allowContract()) { 545 LLVM_DEBUG(dbgs() << " - Contract is missing from the FastMath flags.\n"); 546 return nullptr; 547 } 548 549 Value *CR = Real->getOperand(0); 550 Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1)); 551 if (!RealMulI) 552 return nullptr; 553 Value *CI = Imag->getOperand(0); 554 Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1)); 555 if (!ImagMulI) 556 return nullptr; 557 558 if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) { 559 LLVM_DEBUG(dbgs() << " - Mul instruction has multiple uses\n"); 560 return nullptr; 561 } 562 563 Instruction *R0 = dyn_cast<Instruction>(RealMulI->getOperand(0)); 564 Instruction *R1 = dyn_cast<Instruction>(RealMulI->getOperand(1)); 565 Instruction *I0 = dyn_cast<Instruction>(ImagMulI->getOperand(0)); 566 Instruction *I1 = dyn_cast<Instruction>(ImagMulI->getOperand(1)); 567 if (!R0 || !R1 || !I0 || !I1) { 568 LLVM_DEBUG(dbgs() << " - Mul operand not Instruction\n"); 569 return nullptr; 570 } 571 572 Instruction *CommonOperand; 573 Instruction *UncommonRealOp; 574 Instruction *UncommonImagOp; 575 576 if (R0 == I0 || R0 == I1) { 577 CommonOperand = R0; 578 UncommonRealOp = R1; 579 } else if (R1 == I0 || R1 == I1) { 580 CommonOperand = R1; 581 UncommonRealOp = R0; 582 } else { 583 LLVM_DEBUG(dbgs() << " - No equal operand\n"); 584 return nullptr; 585 } 586 587 UncommonImagOp = (CommonOperand == I0) ? I1 : I0; 588 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 || 589 Rotation == ComplexDeinterleavingRotation::Rotation_270) 590 std::swap(UncommonRealOp, UncommonImagOp); 591 592 std::pair<Instruction *, Instruction *> PartialMatch( 593 (Rotation == ComplexDeinterleavingRotation::Rotation_0 || 594 Rotation == ComplexDeinterleavingRotation::Rotation_180) 595 ? CommonOperand 596 : nullptr, 597 (Rotation == ComplexDeinterleavingRotation::Rotation_90 || 598 Rotation == ComplexDeinterleavingRotation::Rotation_270) 599 ? CommonOperand 600 : nullptr); 601 NodePtr CNode = identifyNodeWithImplicitAdd( 602 cast<Instruction>(CR), cast<Instruction>(CI), PartialMatch); 603 if (!CNode) { 604 LLVM_DEBUG(dbgs() << " - No cnode identified\n"); 605 return nullptr; 606 } 607 608 NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp); 609 if (!UncommonRes) { 610 LLVM_DEBUG(dbgs() << " - No UncommonRes identified\n"); 611 return nullptr; 612 } 613 614 assert(PartialMatch.first && PartialMatch.second); 615 NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second); 616 if (!CommonRes) { 617 LLVM_DEBUG(dbgs() << " - No CommonRes identified\n"); 618 return nullptr; 619 } 620 621 NodePtr Node = prepareCompositeNode( 622 ComplexDeinterleavingOperation::CMulPartial, Real, Imag); 623 Node->addInstruction(RealMulI); 624 Node->addInstruction(ImagMulI); 625 Node->Rotation = Rotation; 626 Node->addOperand(CommonRes); 627 Node->addOperand(UncommonRes); 628 Node->addOperand(CNode); 629 return submitCompositeNode(Node); 630 } 631 632 ComplexDeinterleavingGraph::NodePtr 633 ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) { 634 LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n"); 635 636 // Determine rotation 637 ComplexDeinterleavingRotation Rotation; 638 if ((Real->getOpcode() == Instruction::FSub && 639 Imag->getOpcode() == Instruction::FAdd) || 640 (Real->getOpcode() == Instruction::Sub && 641 Imag->getOpcode() == Instruction::Add)) 642 Rotation = ComplexDeinterleavingRotation::Rotation_90; 643 else if ((Real->getOpcode() == Instruction::FAdd && 644 Imag->getOpcode() == Instruction::FSub) || 645 (Real->getOpcode() == Instruction::Add && 646 Imag->getOpcode() == Instruction::Sub)) 647 Rotation = ComplexDeinterleavingRotation::Rotation_270; 648 else { 649 LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n"); 650 return nullptr; 651 } 652 653 auto *AR = dyn_cast<Instruction>(Real->getOperand(0)); 654 auto *BI = dyn_cast<Instruction>(Real->getOperand(1)); 655 auto *AI = dyn_cast<Instruction>(Imag->getOperand(0)); 656 auto *BR = dyn_cast<Instruction>(Imag->getOperand(1)); 657 658 if (!AR || !AI || !BR || !BI) { 659 LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n"); 660 return nullptr; 661 } 662 663 NodePtr ResA = identifyNode(AR, AI); 664 if (!ResA) { 665 LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n"); 666 return nullptr; 667 } 668 NodePtr ResB = identifyNode(BR, BI); 669 if (!ResB) { 670 LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n"); 671 return nullptr; 672 } 673 674 NodePtr Node = 675 prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag); 676 Node->Rotation = Rotation; 677 Node->addOperand(ResA); 678 Node->addOperand(ResB); 679 return submitCompositeNode(Node); 680 } 681 682 static bool isInstructionPairAdd(Instruction *A, Instruction *B) { 683 unsigned OpcA = A->getOpcode(); 684 unsigned OpcB = B->getOpcode(); 685 686 return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) || 687 (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) || 688 (OpcA == Instruction::Sub && OpcB == Instruction::Add) || 689 (OpcA == Instruction::Add && OpcB == Instruction::Sub); 690 } 691 692 static bool isInstructionPairMul(Instruction *A, Instruction *B) { 693 auto Pattern = 694 m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value())); 695 696 return match(A, Pattern) && match(B, Pattern); 697 } 698 699 ComplexDeinterleavingGraph::NodePtr 700 ComplexDeinterleavingGraph::identifyNode(Instruction *Real, Instruction *Imag) { 701 LLVM_DEBUG(dbgs() << "identifyNode on " << *Real << " / " << *Imag << "\n"); 702 if (NodePtr CN = getContainingComposite(Real, Imag)) { 703 LLVM_DEBUG(dbgs() << " - Folding to existing node\n"); 704 return CN; 705 } 706 707 auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real); 708 auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag); 709 if (RealShuffle && ImagShuffle) { 710 Value *RealOp1 = RealShuffle->getOperand(1); 711 if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) { 712 LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n"); 713 return nullptr; 714 } 715 Value *ImagOp1 = ImagShuffle->getOperand(1); 716 if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) { 717 LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n"); 718 return nullptr; 719 } 720 721 Value *RealOp0 = RealShuffle->getOperand(0); 722 Value *ImagOp0 = ImagShuffle->getOperand(0); 723 724 if (RealOp0 != ImagOp0) { 725 LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n"); 726 return nullptr; 727 } 728 729 ArrayRef<int> RealMask = RealShuffle->getShuffleMask(); 730 ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask(); 731 if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) { 732 LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n"); 733 return nullptr; 734 } 735 736 if (RealMask[0] != 0 || ImagMask[0] != 1) { 737 LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n"); 738 return nullptr; 739 } 740 741 // Type checking, the shuffle type should be a vector type of the same 742 // scalar type, but half the size 743 auto CheckType = [&](ShuffleVectorInst *Shuffle) { 744 Value *Op = Shuffle->getOperand(0); 745 auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType()); 746 auto *OpTy = cast<FixedVectorType>(Op->getType()); 747 748 if (OpTy->getScalarType() != ShuffleTy->getScalarType()) 749 return false; 750 if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements()) 751 return false; 752 753 return true; 754 }; 755 756 auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool { 757 if (!CheckType(Shuffle)) 758 return false; 759 760 ArrayRef<int> Mask = Shuffle->getShuffleMask(); 761 int Last = *Mask.rbegin(); 762 763 Value *Op = Shuffle->getOperand(0); 764 auto *OpTy = cast<FixedVectorType>(Op->getType()); 765 int NumElements = OpTy->getNumElements(); 766 767 // Ensure that the deinterleaving shuffle only pulls from the first 768 // shuffle operand. 769 return Last < NumElements; 770 }; 771 772 if (RealShuffle->getType() != ImagShuffle->getType()) { 773 LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n"); 774 return nullptr; 775 } 776 if (!CheckDeinterleavingShuffle(RealShuffle)) { 777 LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n"); 778 return nullptr; 779 } 780 if (!CheckDeinterleavingShuffle(ImagShuffle)) { 781 LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n"); 782 return nullptr; 783 } 784 785 NodePtr PlaceholderNode = 786 prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Shuffle, 787 RealShuffle, ImagShuffle); 788 PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0); 789 return submitCompositeNode(PlaceholderNode); 790 } 791 if (RealShuffle || ImagShuffle) 792 return nullptr; 793 794 auto *VTy = cast<FixedVectorType>(Real->getType()); 795 auto *NewVTy = 796 FixedVectorType::get(VTy->getScalarType(), VTy->getNumElements() * 2); 797 798 if (TL->isComplexDeinterleavingOperationSupported( 799 ComplexDeinterleavingOperation::CMulPartial, NewVTy) && 800 isInstructionPairMul(Real, Imag)) { 801 return identifyPartialMul(Real, Imag); 802 } 803 804 if (TL->isComplexDeinterleavingOperationSupported( 805 ComplexDeinterleavingOperation::CAdd, NewVTy) && 806 isInstructionPairAdd(Real, Imag)) { 807 return identifyAdd(Real, Imag); 808 } 809 810 return nullptr; 811 } 812 813 bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) { 814 Instruction *Real; 815 Instruction *Imag; 816 if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag)))) 817 return false; 818 819 RootValue = RootI; 820 AllInstructions.insert(RootI); 821 RootNode = identifyNode(Real, Imag); 822 823 LLVM_DEBUG({ 824 Function *F = RootI->getFunction(); 825 BasicBlock *B = RootI->getParent(); 826 dbgs() << "Complex deinterleaving graph for " << F->getName() 827 << "::" << B->getName() << ".\n"; 828 dump(dbgs()); 829 dbgs() << "\n"; 830 }); 831 832 // Check all instructions have internal uses 833 for (const auto &Node : CompositeNodes) { 834 if (!Node->hasAllInternalUses(AllInstructions)) { 835 LLVM_DEBUG(dbgs() << " - Invalid internal uses\n"); 836 return false; 837 } 838 } 839 return RootNode != nullptr; 840 } 841 842 Value *ComplexDeinterleavingGraph::replaceNode( 843 ComplexDeinterleavingGraph::RawNodePtr Node) { 844 if (Node->ReplacementNode) 845 return Node->ReplacementNode; 846 847 Value *Input0 = replaceNode(Node->Operands[0]); 848 Value *Input1 = replaceNode(Node->Operands[1]); 849 Value *Accumulator = 850 Node->Operands.size() > 2 ? replaceNode(Node->Operands[2]) : nullptr; 851 852 assert(Input0->getType() == Input1->getType() && 853 "Node inputs need to be of the same type"); 854 855 Node->ReplacementNode = TL->createComplexDeinterleavingIR( 856 Node->Real, Node->Operation, Node->Rotation, Input0, Input1, Accumulator); 857 858 assert(Node->ReplacementNode && "Target failed to create Intrinsic call."); 859 NumComplexTransformations += 1; 860 return Node->ReplacementNode; 861 } 862 863 void ComplexDeinterleavingGraph::replaceNodes() { 864 Value *R = replaceNode(RootNode.get()); 865 assert(R && "Unable to find replacement for RootValue"); 866 RootValue->replaceAllUsesWith(R); 867 } 868 869 bool ComplexDeinterleavingCompositeNode::hasAllInternalUses( 870 SmallPtrSet<Instruction *, 16> &AllInstructions) { 871 if (Operation == ComplexDeinterleavingOperation::Shuffle) 872 return true; 873 874 for (auto *User : Real->users()) { 875 if (!AllInstructions.contains(cast<Instruction>(User))) 876 return false; 877 } 878 for (auto *User : Imag->users()) { 879 if (!AllInstructions.contains(cast<Instruction>(User))) 880 return false; 881 } 882 for (auto *I : InternalInstructions) { 883 for (auto *User : I->users()) { 884 if (!AllInstructions.contains(cast<Instruction>(User))) 885 return false; 886 } 887 } 888 return true; 889 } 890