1 //===- ComplexDeinterleavingPass.cpp --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Identification:
10 // This step is responsible for finding the patterns that can be lowered to
11 // complex instructions, and building a graph to represent the complex
12 // structures. Starting from the "Converging Shuffle" (a shuffle that
13 // reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the
14 // operands are evaluated and identified as "Composite Nodes" (collections of
15 // instructions that can potentially be lowered to a single complex
16 // instruction). This is performed by checking the real and imaginary components
17 // and tracking the data flow for each component while following the operand
18 // pairs. Validity of each node is expected to be done upon creation, and any
19 // validation errors should halt traversal and prevent further graph
20 // construction.
21 // Instead of relying on Shuffle operations, vector interleaving and
22 // deinterleaving can be represented by vector.interleave2 and
23 // vector.deinterleave2 intrinsics. Scalable vectors can be represented only by
24 // these intrinsics, whereas, fixed-width vectors are recognized for both
25 // shufflevector instruction and intrinsics.
26 //
27 // Replacement:
28 // This step traverses the graph built up by identification, delegating to the
29 // target to validate and generate the correct intrinsics, and plumbs them
30 // together connecting each end of the new intrinsics graph to the existing
31 // use-def chain. This step is assumed to finish successfully, as all
32 // information is expected to be correct by this point.
33 //
34 //
35 // Internal data structure:
36 // ComplexDeinterleavingGraph:
37 // Keeps references to all the valid CompositeNodes formed as part of the
38 // transformation, and every Instruction contained within said nodes. It also
39 // holds onto a reference to the root Instruction, and the root node that should
40 // replace it.
41 //
42 // ComplexDeinterleavingCompositeNode:
43 // A CompositeNode represents a single transformation point; each node should
44 // transform into a single complex instruction (ignoring vector splitting, which
45 // would generate more instructions per node). They are identified in a
46 // depth-first manner, traversing and identifying the operands of each
47 // instruction in the order they appear in the IR.
48 // Each node maintains a reference to its Real and Imaginary instructions,
49 // as well as any additional instructions that make up the identified operation
50 // (Internal instructions should only have uses within their containing node).
51 // A Node also contains the rotation and operation type that it represents.
52 // Operands contains pointers to other CompositeNodes, acting as the edges in
53 // the graph. ReplacementValue is the transformed Value* that has been emitted
54 // to the IR.
55 //
56 // Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and
57 // ReplacementValue fields of that Node are relevant, where the ReplacementValue
58 // should be pre-populated.
59 //
60 //===----------------------------------------------------------------------===//
61
62 #include "llvm/CodeGen/ComplexDeinterleavingPass.h"
63 #include "llvm/ADT/MapVector.h"
64 #include "llvm/ADT/Statistic.h"
65 #include "llvm/Analysis/TargetLibraryInfo.h"
66 #include "llvm/Analysis/TargetTransformInfo.h"
67 #include "llvm/CodeGen/TargetLowering.h"
68 #include "llvm/CodeGen/TargetPassConfig.h"
69 #include "llvm/CodeGen/TargetSubtargetInfo.h"
70 #include "llvm/IR/IRBuilder.h"
71 #include "llvm/IR/PatternMatch.h"
72 #include "llvm/InitializePasses.h"
73 #include "llvm/Target/TargetMachine.h"
74 #include "llvm/Transforms/Utils/Local.h"
75 #include <algorithm>
76
77 using namespace llvm;
78 using namespace PatternMatch;
79
80 #define DEBUG_TYPE "complex-deinterleaving"
81
82 STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed");
83
84 static cl::opt<bool> ComplexDeinterleavingEnabled(
85 "enable-complex-deinterleaving",
86 cl::desc("Enable generation of complex instructions"), cl::init(true),
87 cl::Hidden);
88
89 /// Checks the given mask, and determines whether said mask is interleaving.
90 ///
91 /// To be interleaving, a mask must alternate between `i` and `i + (Length /
92 /// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a
93 /// 4x vector interleaving mask would be <0, 2, 1, 3>).
94 static bool isInterleavingMask(ArrayRef<int> Mask);
95
96 /// Checks the given mask, and determines whether said mask is deinterleaving.
97 ///
98 /// To be deinterleaving, a mask must increment in steps of 2, and either start
99 /// with 0 or 1.
100 /// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or
101 /// <1, 3, 5, 7>).
102 static bool isDeinterleavingMask(ArrayRef<int> Mask);
103
104 /// Returns true if the operation is a negation of V, and it works for both
105 /// integers and floats.
106 static bool isNeg(Value *V);
107
108 /// Returns the operand for negation operation.
109 static Value *getNegOperand(Value *V);
110
111 namespace {
112
113 class ComplexDeinterleavingLegacyPass : public FunctionPass {
114 public:
115 static char ID;
116
ComplexDeinterleavingLegacyPass(const TargetMachine * TM=nullptr)117 ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr)
118 : FunctionPass(ID), TM(TM) {
119 initializeComplexDeinterleavingLegacyPassPass(
120 *PassRegistry::getPassRegistry());
121 }
122
getPassName() const123 StringRef getPassName() const override {
124 return "Complex Deinterleaving Pass";
125 }
126
127 bool runOnFunction(Function &F) override;
getAnalysisUsage(AnalysisUsage & AU) const128 void getAnalysisUsage(AnalysisUsage &AU) const override {
129 AU.addRequired<TargetLibraryInfoWrapperPass>();
130 AU.setPreservesCFG();
131 }
132
133 private:
134 const TargetMachine *TM;
135 };
136
137 class ComplexDeinterleavingGraph;
138 struct ComplexDeinterleavingCompositeNode {
139
ComplexDeinterleavingCompositeNode__anon87c1be940111::ComplexDeinterleavingCompositeNode140 ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
141 Value *R, Value *I)
142 : Operation(Op), Real(R), Imag(I) {}
143
144 private:
145 friend class ComplexDeinterleavingGraph;
146 using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>;
147 using RawNodePtr = ComplexDeinterleavingCompositeNode *;
148
149 public:
150 ComplexDeinterleavingOperation Operation;
151 Value *Real;
152 Value *Imag;
153
154 // This two members are required exclusively for generating
155 // ComplexDeinterleavingOperation::Symmetric operations.
156 unsigned Opcode;
157 std::optional<FastMathFlags> Flags;
158
159 ComplexDeinterleavingRotation Rotation =
160 ComplexDeinterleavingRotation::Rotation_0;
161 SmallVector<RawNodePtr> Operands;
162 Value *ReplacementNode = nullptr;
163
addOperand__anon87c1be940111::ComplexDeinterleavingCompositeNode164 void addOperand(NodePtr Node) { Operands.push_back(Node.get()); }
165
dump__anon87c1be940111::ComplexDeinterleavingCompositeNode166 void dump() { dump(dbgs()); }
dump__anon87c1be940111::ComplexDeinterleavingCompositeNode167 void dump(raw_ostream &OS) {
168 auto PrintValue = [&](Value *V) {
169 if (V) {
170 OS << "\"";
171 V->print(OS, true);
172 OS << "\"\n";
173 } else
174 OS << "nullptr\n";
175 };
176 auto PrintNodeRef = [&](RawNodePtr Ptr) {
177 if (Ptr)
178 OS << Ptr << "\n";
179 else
180 OS << "nullptr\n";
181 };
182
183 OS << "- CompositeNode: " << this << "\n";
184 OS << " Real: ";
185 PrintValue(Real);
186 OS << " Imag: ";
187 PrintValue(Imag);
188 OS << " ReplacementNode: ";
189 PrintValue(ReplacementNode);
190 OS << " Operation: " << (int)Operation << "\n";
191 OS << " Rotation: " << ((int)Rotation * 90) << "\n";
192 OS << " Operands: \n";
193 for (const auto &Op : Operands) {
194 OS << " - ";
195 PrintNodeRef(Op);
196 }
197 }
198 };
199
200 class ComplexDeinterleavingGraph {
201 public:
202 struct Product {
203 Value *Multiplier;
204 Value *Multiplicand;
205 bool IsPositive;
206 };
207
208 using Addend = std::pair<Value *, bool>;
209 using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr;
210 using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr;
211
212 // Helper struct for holding info about potential partial multiplication
213 // candidates
214 struct PartialMulCandidate {
215 Value *Common;
216 NodePtr Node;
217 unsigned RealIdx;
218 unsigned ImagIdx;
219 bool IsNodeInverted;
220 };
221
ComplexDeinterleavingGraph(const TargetLowering * TL,const TargetLibraryInfo * TLI)222 explicit ComplexDeinterleavingGraph(const TargetLowering *TL,
223 const TargetLibraryInfo *TLI)
224 : TL(TL), TLI(TLI) {}
225
226 private:
227 const TargetLowering *TL = nullptr;
228 const TargetLibraryInfo *TLI = nullptr;
229 SmallVector<NodePtr> CompositeNodes;
230 DenseMap<std::pair<Value *, Value *>, NodePtr> CachedResult;
231
232 SmallPtrSet<Instruction *, 16> FinalInstructions;
233
234 /// Root instructions are instructions from which complex computation starts
235 std::map<Instruction *, NodePtr> RootToNode;
236
237 /// Topologically sorted root instructions
238 SmallVector<Instruction *, 1> OrderedRoots;
239
240 /// When examining a basic block for complex deinterleaving, if it is a simple
241 /// one-block loop, then the only incoming block is 'Incoming' and the
242 /// 'BackEdge' block is the block itself."
243 BasicBlock *BackEdge = nullptr;
244 BasicBlock *Incoming = nullptr;
245
246 /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction
247 /// %OutsideUser as it is shown in the IR:
248 ///
249 /// vector.body:
250 /// %PHInode = phi <vector type> [ zeroinitializer, %entry ],
251 /// [ %ReductionOp, %vector.body ]
252 /// ...
253 /// %ReductionOp = fadd i64 ...
254 /// ...
255 /// br i1 %condition, label %vector.body, %middle.block
256 ///
257 /// middle.block:
258 /// %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp)
259 ///
260 /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding
261 /// `llvm.vector.reduce.fadd` when unroll factor isn't one.
262 MapVector<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo;
263
264 /// In the process of detecting a reduction, we consider a pair of
265 /// %ReductionOP, which we refer to as real and imag (or vice versa), and
266 /// traverse the use-tree to detect complex operations. As this is a reduction
267 /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds
268 /// to the %ReductionOPs that we suspect to be complex.
269 /// RealPHI and ImagPHI are used by the identifyPHINode method.
270 PHINode *RealPHI = nullptr;
271 PHINode *ImagPHI = nullptr;
272
273 /// Set this flag to true if RealPHI and ImagPHI were reached during reduction
274 /// detection.
275 bool PHIsFound = false;
276
277 /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode.
278 /// The new PHINode corresponds to a vector of deinterleaved complex numbers.
279 /// This mapping is populated during
280 /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then
281 /// used in the ComplexDeinterleavingOperation::ReductionOperation node
282 /// replacement process.
283 std::map<PHINode *, PHINode *> OldToNewPHI;
284
prepareCompositeNode(ComplexDeinterleavingOperation Operation,Value * R,Value * I)285 NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation,
286 Value *R, Value *I) {
287 assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
288 Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
289 (R && I)) &&
290 "Reduction related nodes must have Real and Imaginary parts");
291 return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R,
292 I);
293 }
294
submitCompositeNode(NodePtr Node)295 NodePtr submitCompositeNode(NodePtr Node) {
296 CompositeNodes.push_back(Node);
297 if (Node->Real && Node->Imag)
298 CachedResult[{Node->Real, Node->Imag}] = Node;
299 return Node;
300 }
301
302 /// Identifies a complex partial multiply pattern and its rotation, based on
303 /// the following patterns
304 ///
305 /// 0: r: cr + ar * br
306 /// i: ci + ar * bi
307 /// 90: r: cr - ai * bi
308 /// i: ci + ai * br
309 /// 180: r: cr - ar * br
310 /// i: ci - ar * bi
311 /// 270: r: cr + ai * bi
312 /// i: ci - ai * br
313 NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag);
314
315 /// Identify the other branch of a Partial Mul, taking the CommonOperandI that
316 /// is partially known from identifyPartialMul, filling in the other half of
317 /// the complex pair.
318 NodePtr
319 identifyNodeWithImplicitAdd(Instruction *I, Instruction *J,
320 std::pair<Value *, Value *> &CommonOperandI);
321
322 /// Identifies a complex add pattern and its rotation, based on the following
323 /// patterns.
324 ///
325 /// 90: r: ar - bi
326 /// i: ai + br
327 /// 270: r: ar + bi
328 /// i: ai - br
329 NodePtr identifyAdd(Instruction *Real, Instruction *Imag);
330 NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag);
331
332 NodePtr identifyNode(Value *R, Value *I);
333
334 /// Determine if a sum of complex numbers can be formed from \p RealAddends
335 /// and \p ImagAddens. If \p Accumulator is not null, add the result to it.
336 /// Return nullptr if it is not possible to construct a complex number.
337 /// \p Flags are needed to generate symmetric Add and Sub operations.
338 NodePtr identifyAdditions(std::list<Addend> &RealAddends,
339 std::list<Addend> &ImagAddends,
340 std::optional<FastMathFlags> Flags,
341 NodePtr Accumulator);
342
343 /// Extract one addend that have both real and imaginary parts positive.
344 NodePtr extractPositiveAddend(std::list<Addend> &RealAddends,
345 std::list<Addend> &ImagAddends);
346
347 /// Determine if sum of multiplications of complex numbers can be formed from
348 /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result
349 /// to it. Return nullptr if it is not possible to construct a complex number.
350 NodePtr identifyMultiplications(std::vector<Product> &RealMuls,
351 std::vector<Product> &ImagMuls,
352 NodePtr Accumulator);
353
354 /// Go through pairs of multiplication (one Real and one Imag) and find all
355 /// possible candidates for partial multiplication and put them into \p
356 /// Candidates. Returns true if all Product has pair with common operand
357 bool collectPartialMuls(const std::vector<Product> &RealMuls,
358 const std::vector<Product> &ImagMuls,
359 std::vector<PartialMulCandidate> &Candidates);
360
361 /// If the code is compiled with -Ofast or expressions have `reassoc` flag,
362 /// the order of complex computation operations may be significantly altered,
363 /// and the real and imaginary parts may not be executed in parallel. This
364 /// function takes this into consideration and employs a more general approach
365 /// to identify complex computations. Initially, it gathers all the addends
366 /// and multiplicands and then constructs a complex expression from them.
367 NodePtr identifyReassocNodes(Instruction *I, Instruction *J);
368
369 NodePtr identifyRoot(Instruction *I);
370
371 /// Identifies the Deinterleave operation applied to a vector containing
372 /// complex numbers. There are two ways to represent the Deinterleave
373 /// operation:
374 /// * Using two shufflevectors with even indices for /pReal instruction and
375 /// odd indices for /pImag instructions (only for fixed-width vectors)
376 /// * Using two extractvalue instructions applied to `vector.deinterleave2`
377 /// intrinsic (for both fixed and scalable vectors)
378 NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag);
379
380 /// identifying the operation that represents a complex number repeated in a
381 /// Splat vector. There are two possible types of splats: ConstantExpr with
382 /// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an
383 /// initialization mask with all values set to zero.
384 NodePtr identifySplat(Value *Real, Value *Imag);
385
386 NodePtr identifyPHINode(Instruction *Real, Instruction *Imag);
387
388 /// Identifies SelectInsts in a loop that has reduction with predication masks
389 /// and/or predicated tail folding
390 NodePtr identifySelectNode(Instruction *Real, Instruction *Imag);
391
392 Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node);
393
394 /// Complete IR modifications after producing new reduction operation:
395 /// * Populate the PHINode generated for
396 /// ComplexDeinterleavingOperation::ReductionPHI
397 /// * Deinterleave the final value outside of the loop and repurpose original
398 /// reduction users
399 void processReductionOperation(Value *OperationReplacement, RawNodePtr Node);
400
401 public:
dump()402 void dump() { dump(dbgs()); }
dump(raw_ostream & OS)403 void dump(raw_ostream &OS) {
404 for (const auto &Node : CompositeNodes)
405 Node->dump(OS);
406 }
407
408 /// Returns false if the deinterleaving operation should be cancelled for the
409 /// current graph.
410 bool identifyNodes(Instruction *RootI);
411
412 /// In case \pB is one-block loop, this function seeks potential reductions
413 /// and populates ReductionInfo. Returns true if any reductions were
414 /// identified.
415 bool collectPotentialReductions(BasicBlock *B);
416
417 void identifyReductionNodes();
418
419 /// Check that every instruction, from the roots to the leaves, has internal
420 /// uses.
421 bool checkNodes();
422
423 /// Perform the actual replacement of the underlying instruction graph.
424 void replaceNodes();
425 };
426
427 class ComplexDeinterleaving {
428 public:
ComplexDeinterleaving(const TargetLowering * tl,const TargetLibraryInfo * tli)429 ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli)
430 : TL(tl), TLI(tli) {}
431 bool runOnFunction(Function &F);
432
433 private:
434 bool evaluateBasicBlock(BasicBlock *B);
435
436 const TargetLowering *TL = nullptr;
437 const TargetLibraryInfo *TLI = nullptr;
438 };
439
440 } // namespace
441
442 char ComplexDeinterleavingLegacyPass::ID = 0;
443
444 INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
445 "Complex Deinterleaving", false, false)
446 INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
447 "Complex Deinterleaving", false, false)
448
run(Function & F,FunctionAnalysisManager & AM)449 PreservedAnalyses ComplexDeinterleavingPass::run(Function &F,
450 FunctionAnalysisManager &AM) {
451 const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering();
452 auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F);
453 if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F))
454 return PreservedAnalyses::all();
455
456 PreservedAnalyses PA;
457 PA.preserve<FunctionAnalysisManagerModuleProxy>();
458 return PA;
459 }
460
createComplexDeinterleavingPass(const TargetMachine * TM)461 FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) {
462 return new ComplexDeinterleavingLegacyPass(TM);
463 }
464
runOnFunction(Function & F)465 bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) {
466 const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering();
467 auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
468 return ComplexDeinterleaving(TL, &TLI).runOnFunction(F);
469 }
470
runOnFunction(Function & F)471 bool ComplexDeinterleaving::runOnFunction(Function &F) {
472 if (!ComplexDeinterleavingEnabled) {
473 LLVM_DEBUG(
474 dbgs() << "Complex deinterleaving has been explicitly disabled.\n");
475 return false;
476 }
477
478 if (!TL->isComplexDeinterleavingSupported()) {
479 LLVM_DEBUG(
480 dbgs() << "Complex deinterleaving has been disabled, target does "
481 "not support lowering of complex number operations.\n");
482 return false;
483 }
484
485 bool Changed = false;
486 for (auto &B : F)
487 Changed |= evaluateBasicBlock(&B);
488
489 return Changed;
490 }
491
isInterleavingMask(ArrayRef<int> Mask)492 static bool isInterleavingMask(ArrayRef<int> Mask) {
493 // If the size is not even, it's not an interleaving mask
494 if ((Mask.size() & 1))
495 return false;
496
497 int HalfNumElements = Mask.size() / 2;
498 for (int Idx = 0; Idx < HalfNumElements; ++Idx) {
499 int MaskIdx = Idx * 2;
500 if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))
501 return false;
502 }
503
504 return true;
505 }
506
isDeinterleavingMask(ArrayRef<int> Mask)507 static bool isDeinterleavingMask(ArrayRef<int> Mask) {
508 int Offset = Mask[0];
509 int HalfNumElements = Mask.size() / 2;
510
511 for (int Idx = 1; Idx < HalfNumElements; ++Idx) {
512 if (Mask[Idx] != (Idx * 2) + Offset)
513 return false;
514 }
515
516 return true;
517 }
518
isNeg(Value * V)519 bool isNeg(Value *V) {
520 return match(V, m_FNeg(m_Value())) || match(V, m_Neg(m_Value()));
521 }
522
getNegOperand(Value * V)523 Value *getNegOperand(Value *V) {
524 assert(isNeg(V));
525 auto *I = cast<Instruction>(V);
526 if (I->getOpcode() == Instruction::FNeg)
527 return I->getOperand(0);
528
529 return I->getOperand(1);
530 }
531
evaluateBasicBlock(BasicBlock * B)532 bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
533 ComplexDeinterleavingGraph Graph(TL, TLI);
534 if (Graph.collectPotentialReductions(B))
535 Graph.identifyReductionNodes();
536
537 for (auto &I : *B)
538 Graph.identifyNodes(&I);
539
540 if (Graph.checkNodes()) {
541 Graph.replaceNodes();
542 return true;
543 }
544
545 return false;
546 }
547
548 ComplexDeinterleavingGraph::NodePtr
identifyNodeWithImplicitAdd(Instruction * Real,Instruction * Imag,std::pair<Value *,Value * > & PartialMatch)549 ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
550 Instruction *Real, Instruction *Imag,
551 std::pair<Value *, Value *> &PartialMatch) {
552 LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag
553 << "\n");
554
555 if (!Real->hasOneUse() || !Imag->hasOneUse()) {
556 LLVM_DEBUG(dbgs() << " - Mul operand has multiple uses.\n");
557 return nullptr;
558 }
559
560 if ((Real->getOpcode() != Instruction::FMul &&
561 Real->getOpcode() != Instruction::Mul) ||
562 (Imag->getOpcode() != Instruction::FMul &&
563 Imag->getOpcode() != Instruction::Mul)) {
564 LLVM_DEBUG(
565 dbgs() << " - Real or imaginary instruction is not fmul or mul\n");
566 return nullptr;
567 }
568
569 Value *R0 = Real->getOperand(0);
570 Value *R1 = Real->getOperand(1);
571 Value *I0 = Imag->getOperand(0);
572 Value *I1 = Imag->getOperand(1);
573
574 // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the
575 // rotations and use the operand.
576 unsigned Negs = 0;
577 Value *Op;
578 if (match(R0, m_Neg(m_Value(Op)))) {
579 Negs |= 1;
580 R0 = Op;
581 } else if (match(R1, m_Neg(m_Value(Op)))) {
582 Negs |= 1;
583 R1 = Op;
584 }
585
586 if (isNeg(I0)) {
587 Negs |= 2;
588 Negs ^= 1;
589 I0 = Op;
590 } else if (match(I1, m_Neg(m_Value(Op)))) {
591 Negs |= 2;
592 Negs ^= 1;
593 I1 = Op;
594 }
595
596 ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs;
597
598 Value *CommonOperand;
599 Value *UncommonRealOp;
600 Value *UncommonImagOp;
601
602 if (R0 == I0 || R0 == I1) {
603 CommonOperand = R0;
604 UncommonRealOp = R1;
605 } else if (R1 == I0 || R1 == I1) {
606 CommonOperand = R1;
607 UncommonRealOp = R0;
608 } else {
609 LLVM_DEBUG(dbgs() << " - No equal operand\n");
610 return nullptr;
611 }
612
613 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
614 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
615 Rotation == ComplexDeinterleavingRotation::Rotation_270)
616 std::swap(UncommonRealOp, UncommonImagOp);
617
618 // Between identifyPartialMul and here we need to have found a complete valid
619 // pair from the CommonOperand of each part.
620 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
621 Rotation == ComplexDeinterleavingRotation::Rotation_180)
622 PartialMatch.first = CommonOperand;
623 else
624 PartialMatch.second = CommonOperand;
625
626 if (!PartialMatch.first || !PartialMatch.second) {
627 LLVM_DEBUG(dbgs() << " - Incomplete partial match\n");
628 return nullptr;
629 }
630
631 NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second);
632 if (!CommonNode) {
633 LLVM_DEBUG(dbgs() << " - No CommonNode identified\n");
634 return nullptr;
635 }
636
637 NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
638 if (!UncommonNode) {
639 LLVM_DEBUG(dbgs() << " - No UncommonNode identified\n");
640 return nullptr;
641 }
642
643 NodePtr Node = prepareCompositeNode(
644 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
645 Node->Rotation = Rotation;
646 Node->addOperand(CommonNode);
647 Node->addOperand(UncommonNode);
648 return submitCompositeNode(Node);
649 }
650
651 ComplexDeinterleavingGraph::NodePtr
identifyPartialMul(Instruction * Real,Instruction * Imag)652 ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
653 Instruction *Imag) {
654 LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag
655 << "\n");
656 // Determine rotation
657 auto IsAdd = [](unsigned Op) {
658 return Op == Instruction::FAdd || Op == Instruction::Add;
659 };
660 auto IsSub = [](unsigned Op) {
661 return Op == Instruction::FSub || Op == Instruction::Sub;
662 };
663 ComplexDeinterleavingRotation Rotation;
664 if (IsAdd(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
665 Rotation = ComplexDeinterleavingRotation::Rotation_0;
666 else if (IsSub(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
667 Rotation = ComplexDeinterleavingRotation::Rotation_90;
668 else if (IsSub(Real->getOpcode()) && IsSub(Imag->getOpcode()))
669 Rotation = ComplexDeinterleavingRotation::Rotation_180;
670 else if (IsAdd(Real->getOpcode()) && IsSub(Imag->getOpcode()))
671 Rotation = ComplexDeinterleavingRotation::Rotation_270;
672 else {
673 LLVM_DEBUG(dbgs() << " - Unhandled rotation.\n");
674 return nullptr;
675 }
676
677 if (isa<FPMathOperator>(Real) &&
678 (!Real->getFastMathFlags().allowContract() ||
679 !Imag->getFastMathFlags().allowContract())) {
680 LLVM_DEBUG(dbgs() << " - Contract is missing from the FastMath flags.\n");
681 return nullptr;
682 }
683
684 Value *CR = Real->getOperand(0);
685 Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1));
686 if (!RealMulI)
687 return nullptr;
688 Value *CI = Imag->getOperand(0);
689 Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1));
690 if (!ImagMulI)
691 return nullptr;
692
693 if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) {
694 LLVM_DEBUG(dbgs() << " - Mul instruction has multiple uses\n");
695 return nullptr;
696 }
697
698 Value *R0 = RealMulI->getOperand(0);
699 Value *R1 = RealMulI->getOperand(1);
700 Value *I0 = ImagMulI->getOperand(0);
701 Value *I1 = ImagMulI->getOperand(1);
702
703 Value *CommonOperand;
704 Value *UncommonRealOp;
705 Value *UncommonImagOp;
706
707 if (R0 == I0 || R0 == I1) {
708 CommonOperand = R0;
709 UncommonRealOp = R1;
710 } else if (R1 == I0 || R1 == I1) {
711 CommonOperand = R1;
712 UncommonRealOp = R0;
713 } else {
714 LLVM_DEBUG(dbgs() << " - No equal operand\n");
715 return nullptr;
716 }
717
718 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
719 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
720 Rotation == ComplexDeinterleavingRotation::Rotation_270)
721 std::swap(UncommonRealOp, UncommonImagOp);
722
723 std::pair<Value *, Value *> PartialMatch(
724 (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
725 Rotation == ComplexDeinterleavingRotation::Rotation_180)
726 ? CommonOperand
727 : nullptr,
728 (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
729 Rotation == ComplexDeinterleavingRotation::Rotation_270)
730 ? CommonOperand
731 : nullptr);
732
733 auto *CRInst = dyn_cast<Instruction>(CR);
734 auto *CIInst = dyn_cast<Instruction>(CI);
735
736 if (!CRInst || !CIInst) {
737 LLVM_DEBUG(dbgs() << " - Common operands are not instructions.\n");
738 return nullptr;
739 }
740
741 NodePtr CNode = identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch);
742 if (!CNode) {
743 LLVM_DEBUG(dbgs() << " - No cnode identified\n");
744 return nullptr;
745 }
746
747 NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
748 if (!UncommonRes) {
749 LLVM_DEBUG(dbgs() << " - No UncommonRes identified\n");
750 return nullptr;
751 }
752
753 assert(PartialMatch.first && PartialMatch.second);
754 NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second);
755 if (!CommonRes) {
756 LLVM_DEBUG(dbgs() << " - No CommonRes identified\n");
757 return nullptr;
758 }
759
760 NodePtr Node = prepareCompositeNode(
761 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
762 Node->Rotation = Rotation;
763 Node->addOperand(CommonRes);
764 Node->addOperand(UncommonRes);
765 Node->addOperand(CNode);
766 return submitCompositeNode(Node);
767 }
768
769 ComplexDeinterleavingGraph::NodePtr
identifyAdd(Instruction * Real,Instruction * Imag)770 ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
771 LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n");
772
773 // Determine rotation
774 ComplexDeinterleavingRotation Rotation;
775 if ((Real->getOpcode() == Instruction::FSub &&
776 Imag->getOpcode() == Instruction::FAdd) ||
777 (Real->getOpcode() == Instruction::Sub &&
778 Imag->getOpcode() == Instruction::Add))
779 Rotation = ComplexDeinterleavingRotation::Rotation_90;
780 else if ((Real->getOpcode() == Instruction::FAdd &&
781 Imag->getOpcode() == Instruction::FSub) ||
782 (Real->getOpcode() == Instruction::Add &&
783 Imag->getOpcode() == Instruction::Sub))
784 Rotation = ComplexDeinterleavingRotation::Rotation_270;
785 else {
786 LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n");
787 return nullptr;
788 }
789
790 auto *AR = dyn_cast<Instruction>(Real->getOperand(0));
791 auto *BI = dyn_cast<Instruction>(Real->getOperand(1));
792 auto *AI = dyn_cast<Instruction>(Imag->getOperand(0));
793 auto *BR = dyn_cast<Instruction>(Imag->getOperand(1));
794
795 if (!AR || !AI || !BR || !BI) {
796 LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n");
797 return nullptr;
798 }
799
800 NodePtr ResA = identifyNode(AR, AI);
801 if (!ResA) {
802 LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
803 return nullptr;
804 }
805 NodePtr ResB = identifyNode(BR, BI);
806 if (!ResB) {
807 LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");
808 return nullptr;
809 }
810
811 NodePtr Node =
812 prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
813 Node->Rotation = Rotation;
814 Node->addOperand(ResA);
815 Node->addOperand(ResB);
816 return submitCompositeNode(Node);
817 }
818
isInstructionPairAdd(Instruction * A,Instruction * B)819 static bool isInstructionPairAdd(Instruction *A, Instruction *B) {
820 unsigned OpcA = A->getOpcode();
821 unsigned OpcB = B->getOpcode();
822
823 return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
824 (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
825 (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
826 (OpcA == Instruction::Add && OpcB == Instruction::Sub);
827 }
828
isInstructionPairMul(Instruction * A,Instruction * B)829 static bool isInstructionPairMul(Instruction *A, Instruction *B) {
830 auto Pattern =
831 m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value()));
832
833 return match(A, Pattern) && match(B, Pattern);
834 }
835
isInstructionPotentiallySymmetric(Instruction * I)836 static bool isInstructionPotentiallySymmetric(Instruction *I) {
837 switch (I->getOpcode()) {
838 case Instruction::FAdd:
839 case Instruction::FSub:
840 case Instruction::FMul:
841 case Instruction::FNeg:
842 case Instruction::Add:
843 case Instruction::Sub:
844 case Instruction::Mul:
845 return true;
846 default:
847 return false;
848 }
849 }
850
851 ComplexDeinterleavingGraph::NodePtr
identifySymmetricOperation(Instruction * Real,Instruction * Imag)852 ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real,
853 Instruction *Imag) {
854 if (Real->getOpcode() != Imag->getOpcode())
855 return nullptr;
856
857 if (!isInstructionPotentiallySymmetric(Real) ||
858 !isInstructionPotentiallySymmetric(Imag))
859 return nullptr;
860
861 auto *R0 = Real->getOperand(0);
862 auto *I0 = Imag->getOperand(0);
863
864 NodePtr Op0 = identifyNode(R0, I0);
865 NodePtr Op1 = nullptr;
866 if (Op0 == nullptr)
867 return nullptr;
868
869 if (Real->isBinaryOp()) {
870 auto *R1 = Real->getOperand(1);
871 auto *I1 = Imag->getOperand(1);
872 Op1 = identifyNode(R1, I1);
873 if (Op1 == nullptr)
874 return nullptr;
875 }
876
877 if (isa<FPMathOperator>(Real) &&
878 Real->getFastMathFlags() != Imag->getFastMathFlags())
879 return nullptr;
880
881 auto Node = prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric,
882 Real, Imag);
883 Node->Opcode = Real->getOpcode();
884 if (isa<FPMathOperator>(Real))
885 Node->Flags = Real->getFastMathFlags();
886
887 Node->addOperand(Op0);
888 if (Real->isBinaryOp())
889 Node->addOperand(Op1);
890
891 return submitCompositeNode(Node);
892 }
893
894 ComplexDeinterleavingGraph::NodePtr
identifyNode(Value * R,Value * I)895 ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) {
896 LLVM_DEBUG(dbgs() << "identifyNode on " << *R << " / " << *I << "\n");
897 assert(R->getType() == I->getType() &&
898 "Real and imaginary parts should not have different types");
899
900 auto It = CachedResult.find({R, I});
901 if (It != CachedResult.end()) {
902 LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
903 return It->second;
904 }
905
906 if (NodePtr CN = identifySplat(R, I))
907 return CN;
908
909 auto *Real = dyn_cast<Instruction>(R);
910 auto *Imag = dyn_cast<Instruction>(I);
911 if (!Real || !Imag)
912 return nullptr;
913
914 if (NodePtr CN = identifyDeinterleave(Real, Imag))
915 return CN;
916
917 if (NodePtr CN = identifyPHINode(Real, Imag))
918 return CN;
919
920 if (NodePtr CN = identifySelectNode(Real, Imag))
921 return CN;
922
923 auto *VTy = cast<VectorType>(Real->getType());
924 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
925
926 bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported(
927 ComplexDeinterleavingOperation::CMulPartial, NewVTy);
928 bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported(
929 ComplexDeinterleavingOperation::CAdd, NewVTy);
930
931 if (HasCMulSupport && isInstructionPairMul(Real, Imag)) {
932 if (NodePtr CN = identifyPartialMul(Real, Imag))
933 return CN;
934 }
935
936 if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) {
937 if (NodePtr CN = identifyAdd(Real, Imag))
938 return CN;
939 }
940
941 if (HasCMulSupport && HasCAddSupport) {
942 if (NodePtr CN = identifyReassocNodes(Real, Imag))
943 return CN;
944 }
945
946 if (NodePtr CN = identifySymmetricOperation(Real, Imag))
947 return CN;
948
949 LLVM_DEBUG(dbgs() << " - Not recognised as a valid pattern.\n");
950 CachedResult[{R, I}] = nullptr;
951 return nullptr;
952 }
953
954 ComplexDeinterleavingGraph::NodePtr
identifyReassocNodes(Instruction * Real,Instruction * Imag)955 ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
956 Instruction *Imag) {
957 auto IsOperationSupported = [](unsigned Opcode) -> bool {
958 return Opcode == Instruction::FAdd || Opcode == Instruction::FSub ||
959 Opcode == Instruction::FNeg || Opcode == Instruction::Add ||
960 Opcode == Instruction::Sub;
961 };
962
963 if (!IsOperationSupported(Real->getOpcode()) ||
964 !IsOperationSupported(Imag->getOpcode()))
965 return nullptr;
966
967 std::optional<FastMathFlags> Flags;
968 if (isa<FPMathOperator>(Real)) {
969 if (Real->getFastMathFlags() != Imag->getFastMathFlags()) {
970 LLVM_DEBUG(dbgs() << "The flags in Real and Imaginary instructions are "
971 "not identical\n");
972 return nullptr;
973 }
974
975 Flags = Real->getFastMathFlags();
976 if (!Flags->allowReassoc()) {
977 LLVM_DEBUG(
978 dbgs()
979 << "the 'Reassoc' attribute is missing in the FastMath flags\n");
980 return nullptr;
981 }
982 }
983
984 // Collect multiplications and addend instructions from the given instruction
985 // while traversing it operands. Additionally, verify that all instructions
986 // have the same fast math flags.
987 auto Collect = [&Flags](Instruction *Insn, std::vector<Product> &Muls,
988 std::list<Addend> &Addends) -> bool {
989 SmallVector<PointerIntPair<Value *, 1, bool>> Worklist = {{Insn, true}};
990 SmallPtrSet<Value *, 8> Visited;
991 while (!Worklist.empty()) {
992 auto [V, IsPositive] = Worklist.back();
993 Worklist.pop_back();
994 if (!Visited.insert(V).second)
995 continue;
996
997 Instruction *I = dyn_cast<Instruction>(V);
998 if (!I) {
999 Addends.emplace_back(V, IsPositive);
1000 continue;
1001 }
1002
1003 // If an instruction has more than one user, it indicates that it either
1004 // has an external user, which will be later checked by the checkNodes
1005 // function, or it is a subexpression utilized by multiple expressions. In
1006 // the latter case, we will attempt to separately identify the complex
1007 // operation from here in order to create a shared
1008 // ComplexDeinterleavingCompositeNode.
1009 if (I != Insn && I->getNumUses() > 1) {
1010 LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n");
1011 Addends.emplace_back(I, IsPositive);
1012 continue;
1013 }
1014 switch (I->getOpcode()) {
1015 case Instruction::FAdd:
1016 case Instruction::Add:
1017 Worklist.emplace_back(I->getOperand(1), IsPositive);
1018 Worklist.emplace_back(I->getOperand(0), IsPositive);
1019 break;
1020 case Instruction::FSub:
1021 Worklist.emplace_back(I->getOperand(1), !IsPositive);
1022 Worklist.emplace_back(I->getOperand(0), IsPositive);
1023 break;
1024 case Instruction::Sub:
1025 if (isNeg(I)) {
1026 Worklist.emplace_back(getNegOperand(I), !IsPositive);
1027 } else {
1028 Worklist.emplace_back(I->getOperand(1), !IsPositive);
1029 Worklist.emplace_back(I->getOperand(0), IsPositive);
1030 }
1031 break;
1032 case Instruction::FMul:
1033 case Instruction::Mul: {
1034 Value *A, *B;
1035 if (isNeg(I->getOperand(0))) {
1036 A = getNegOperand(I->getOperand(0));
1037 IsPositive = !IsPositive;
1038 } else {
1039 A = I->getOperand(0);
1040 }
1041
1042 if (isNeg(I->getOperand(1))) {
1043 B = getNegOperand(I->getOperand(1));
1044 IsPositive = !IsPositive;
1045 } else {
1046 B = I->getOperand(1);
1047 }
1048 Muls.push_back(Product{A, B, IsPositive});
1049 break;
1050 }
1051 case Instruction::FNeg:
1052 Worklist.emplace_back(I->getOperand(0), !IsPositive);
1053 break;
1054 default:
1055 Addends.emplace_back(I, IsPositive);
1056 continue;
1057 }
1058
1059 if (Flags && I->getFastMathFlags() != *Flags) {
1060 LLVM_DEBUG(dbgs() << "The instruction's fast math flags are "
1061 "inconsistent with the root instructions' flags: "
1062 << *I << "\n");
1063 return false;
1064 }
1065 }
1066 return true;
1067 };
1068
1069 std::vector<Product> RealMuls, ImagMuls;
1070 std::list<Addend> RealAddends, ImagAddends;
1071 if (!Collect(Real, RealMuls, RealAddends) ||
1072 !Collect(Imag, ImagMuls, ImagAddends))
1073 return nullptr;
1074
1075 if (RealAddends.size() != ImagAddends.size())
1076 return nullptr;
1077
1078 NodePtr FinalNode;
1079 if (!RealMuls.empty() || !ImagMuls.empty()) {
1080 // If there are multiplicands, extract positive addend and use it as an
1081 // accumulator
1082 FinalNode = extractPositiveAddend(RealAddends, ImagAddends);
1083 FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode);
1084 if (!FinalNode)
1085 return nullptr;
1086 }
1087
1088 // Identify and process remaining additions
1089 if (!RealAddends.empty() || !ImagAddends.empty()) {
1090 FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode);
1091 if (!FinalNode)
1092 return nullptr;
1093 }
1094 assert(FinalNode && "FinalNode can not be nullptr here");
1095 // Set the Real and Imag fields of the final node and submit it
1096 FinalNode->Real = Real;
1097 FinalNode->Imag = Imag;
1098 submitCompositeNode(FinalNode);
1099 return FinalNode;
1100 }
1101
collectPartialMuls(const std::vector<Product> & RealMuls,const std::vector<Product> & ImagMuls,std::vector<PartialMulCandidate> & PartialMulCandidates)1102 bool ComplexDeinterleavingGraph::collectPartialMuls(
1103 const std::vector<Product> &RealMuls, const std::vector<Product> &ImagMuls,
1104 std::vector<PartialMulCandidate> &PartialMulCandidates) {
1105 // Helper function to extract a common operand from two products
1106 auto FindCommonInstruction = [](const Product &Real,
1107 const Product &Imag) -> Value * {
1108 if (Real.Multiplicand == Imag.Multiplicand ||
1109 Real.Multiplicand == Imag.Multiplier)
1110 return Real.Multiplicand;
1111
1112 if (Real.Multiplier == Imag.Multiplicand ||
1113 Real.Multiplier == Imag.Multiplier)
1114 return Real.Multiplier;
1115
1116 return nullptr;
1117 };
1118
1119 // Iterating over real and imaginary multiplications to find common operands
1120 // If a common operand is found, a partial multiplication candidate is created
1121 // and added to the candidates vector The function returns false if no common
1122 // operands are found for any product
1123 for (unsigned i = 0; i < RealMuls.size(); ++i) {
1124 bool FoundCommon = false;
1125 for (unsigned j = 0; j < ImagMuls.size(); ++j) {
1126 auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]);
1127 if (!Common)
1128 continue;
1129
1130 auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier
1131 : RealMuls[i].Multiplicand;
1132 auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier
1133 : ImagMuls[j].Multiplicand;
1134
1135 auto Node = identifyNode(A, B);
1136 if (Node) {
1137 FoundCommon = true;
1138 PartialMulCandidates.push_back({Common, Node, i, j, false});
1139 }
1140
1141 Node = identifyNode(B, A);
1142 if (Node) {
1143 FoundCommon = true;
1144 PartialMulCandidates.push_back({Common, Node, i, j, true});
1145 }
1146 }
1147 if (!FoundCommon)
1148 return false;
1149 }
1150 return true;
1151 }
1152
1153 ComplexDeinterleavingGraph::NodePtr
identifyMultiplications(std::vector<Product> & RealMuls,std::vector<Product> & ImagMuls,NodePtr Accumulator=nullptr)1154 ComplexDeinterleavingGraph::identifyMultiplications(
1155 std::vector<Product> &RealMuls, std::vector<Product> &ImagMuls,
1156 NodePtr Accumulator = nullptr) {
1157 if (RealMuls.size() != ImagMuls.size())
1158 return nullptr;
1159
1160 std::vector<PartialMulCandidate> Info;
1161 if (!collectPartialMuls(RealMuls, ImagMuls, Info))
1162 return nullptr;
1163
1164 // Map to store common instruction to node pointers
1165 std::map<Value *, NodePtr> CommonToNode;
1166 std::vector<bool> Processed(Info.size(), false);
1167 for (unsigned I = 0; I < Info.size(); ++I) {
1168 if (Processed[I])
1169 continue;
1170
1171 PartialMulCandidate &InfoA = Info[I];
1172 for (unsigned J = I + 1; J < Info.size(); ++J) {
1173 if (Processed[J])
1174 continue;
1175
1176 PartialMulCandidate &InfoB = Info[J];
1177 auto *InfoReal = &InfoA;
1178 auto *InfoImag = &InfoB;
1179
1180 auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1181 if (!NodeFromCommon) {
1182 std::swap(InfoReal, InfoImag);
1183 NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1184 }
1185 if (!NodeFromCommon)
1186 continue;
1187
1188 CommonToNode[InfoReal->Common] = NodeFromCommon;
1189 CommonToNode[InfoImag->Common] = NodeFromCommon;
1190 Processed[I] = true;
1191 Processed[J] = true;
1192 }
1193 }
1194
1195 std::vector<bool> ProcessedReal(RealMuls.size(), false);
1196 std::vector<bool> ProcessedImag(ImagMuls.size(), false);
1197 NodePtr Result = Accumulator;
1198 for (auto &PMI : Info) {
1199 if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx])
1200 continue;
1201
1202 auto It = CommonToNode.find(PMI.Common);
1203 // TODO: Process independent complex multiplications. Cases like this:
1204 // A.real() * B where both A and B are complex numbers.
1205 if (It == CommonToNode.end()) {
1206 LLVM_DEBUG({
1207 dbgs() << "Unprocessed independent partial multiplication:\n";
1208 for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]})
1209 dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier
1210 << " multiplied by " << *Mul->Multiplicand << "\n";
1211 });
1212 return nullptr;
1213 }
1214
1215 auto &RealMul = RealMuls[PMI.RealIdx];
1216 auto &ImagMul = ImagMuls[PMI.ImagIdx];
1217
1218 auto NodeA = It->second;
1219 auto NodeB = PMI.Node;
1220 auto IsMultiplicandReal = PMI.Common == NodeA->Real;
1221 // The following table illustrates the relationship between multiplications
1222 // and rotations. If we consider the multiplication (X + iY) * (U + iV), we
1223 // can see:
1224 //
1225 // Rotation | Real | Imag |
1226 // ---------+--------+--------+
1227 // 0 | x * u | x * v |
1228 // 90 | -y * v | y * u |
1229 // 180 | -x * u | -x * v |
1230 // 270 | y * v | -y * u |
1231 //
1232 // Check if the candidate can indeed be represented by partial
1233 // multiplication
1234 // TODO: Add support for multiplication by complex one
1235 if ((IsMultiplicandReal && PMI.IsNodeInverted) ||
1236 (!IsMultiplicandReal && !PMI.IsNodeInverted))
1237 continue;
1238
1239 // Determine the rotation based on the multiplications
1240 ComplexDeinterleavingRotation Rotation;
1241 if (IsMultiplicandReal) {
1242 // Detect 0 and 180 degrees rotation
1243 if (RealMul.IsPositive && ImagMul.IsPositive)
1244 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_0;
1245 else if (!RealMul.IsPositive && !ImagMul.IsPositive)
1246 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_180;
1247 else
1248 continue;
1249
1250 } else {
1251 // Detect 90 and 270 degrees rotation
1252 if (!RealMul.IsPositive && ImagMul.IsPositive)
1253 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_90;
1254 else if (RealMul.IsPositive && !ImagMul.IsPositive)
1255 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_270;
1256 else
1257 continue;
1258 }
1259
1260 LLVM_DEBUG({
1261 dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n";
1262 dbgs().indent(4) << "X: " << *NodeA->Real << "\n";
1263 dbgs().indent(4) << "Y: " << *NodeA->Imag << "\n";
1264 dbgs().indent(4) << "U: " << *NodeB->Real << "\n";
1265 dbgs().indent(4) << "V: " << *NodeB->Imag << "\n";
1266 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1267 });
1268
1269 NodePtr NodeMul = prepareCompositeNode(
1270 ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr);
1271 NodeMul->Rotation = Rotation;
1272 NodeMul->addOperand(NodeA);
1273 NodeMul->addOperand(NodeB);
1274 if (Result)
1275 NodeMul->addOperand(Result);
1276 submitCompositeNode(NodeMul);
1277 Result = NodeMul;
1278 ProcessedReal[PMI.RealIdx] = true;
1279 ProcessedImag[PMI.ImagIdx] = true;
1280 }
1281
1282 // Ensure all products have been processed, if not return nullptr.
1283 if (!all_of(ProcessedReal, [](bool V) { return V; }) ||
1284 !all_of(ProcessedImag, [](bool V) { return V; })) {
1285
1286 // Dump debug information about which partial multiplications are not
1287 // processed.
1288 LLVM_DEBUG({
1289 dbgs() << "Unprocessed products (Real):\n";
1290 for (size_t i = 0; i < ProcessedReal.size(); ++i) {
1291 if (!ProcessedReal[i])
1292 dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-")
1293 << *RealMuls[i].Multiplier << " multiplied by "
1294 << *RealMuls[i].Multiplicand << "\n";
1295 }
1296 dbgs() << "Unprocessed products (Imag):\n";
1297 for (size_t i = 0; i < ProcessedImag.size(); ++i) {
1298 if (!ProcessedImag[i])
1299 dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-")
1300 << *ImagMuls[i].Multiplier << " multiplied by "
1301 << *ImagMuls[i].Multiplicand << "\n";
1302 }
1303 });
1304 return nullptr;
1305 }
1306
1307 return Result;
1308 }
1309
1310 ComplexDeinterleavingGraph::NodePtr
identifyAdditions(std::list<Addend> & RealAddends,std::list<Addend> & ImagAddends,std::optional<FastMathFlags> Flags,NodePtr Accumulator=nullptr)1311 ComplexDeinterleavingGraph::identifyAdditions(
1312 std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends,
1313 std::optional<FastMathFlags> Flags, NodePtr Accumulator = nullptr) {
1314 if (RealAddends.size() != ImagAddends.size())
1315 return nullptr;
1316
1317 NodePtr Result;
1318 // If we have accumulator use it as first addend
1319 if (Accumulator)
1320 Result = Accumulator;
1321 // Otherwise find an element with both positive real and imaginary parts.
1322 else
1323 Result = extractPositiveAddend(RealAddends, ImagAddends);
1324
1325 if (!Result)
1326 return nullptr;
1327
1328 while (!RealAddends.empty()) {
1329 auto ItR = RealAddends.begin();
1330 auto [R, IsPositiveR] = *ItR;
1331
1332 bool FoundImag = false;
1333 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1334 auto [I, IsPositiveI] = *ItI;
1335 ComplexDeinterleavingRotation Rotation;
1336 if (IsPositiveR && IsPositiveI)
1337 Rotation = ComplexDeinterleavingRotation::Rotation_0;
1338 else if (!IsPositiveR && IsPositiveI)
1339 Rotation = ComplexDeinterleavingRotation::Rotation_90;
1340 else if (!IsPositiveR && !IsPositiveI)
1341 Rotation = ComplexDeinterleavingRotation::Rotation_180;
1342 else
1343 Rotation = ComplexDeinterleavingRotation::Rotation_270;
1344
1345 NodePtr AddNode;
1346 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
1347 Rotation == ComplexDeinterleavingRotation::Rotation_180) {
1348 AddNode = identifyNode(R, I);
1349 } else {
1350 AddNode = identifyNode(I, R);
1351 }
1352 if (AddNode) {
1353 LLVM_DEBUG({
1354 dbgs() << "Identified addition:\n";
1355 dbgs().indent(4) << "X: " << *R << "\n";
1356 dbgs().indent(4) << "Y: " << *I << "\n";
1357 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1358 });
1359
1360 NodePtr TmpNode;
1361 if (Rotation == llvm::ComplexDeinterleavingRotation::Rotation_0) {
1362 TmpNode = prepareCompositeNode(
1363 ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1364 if (Flags) {
1365 TmpNode->Opcode = Instruction::FAdd;
1366 TmpNode->Flags = *Flags;
1367 } else {
1368 TmpNode->Opcode = Instruction::Add;
1369 }
1370 } else if (Rotation ==
1371 llvm::ComplexDeinterleavingRotation::Rotation_180) {
1372 TmpNode = prepareCompositeNode(
1373 ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1374 if (Flags) {
1375 TmpNode->Opcode = Instruction::FSub;
1376 TmpNode->Flags = *Flags;
1377 } else {
1378 TmpNode->Opcode = Instruction::Sub;
1379 }
1380 } else {
1381 TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd,
1382 nullptr, nullptr);
1383 TmpNode->Rotation = Rotation;
1384 }
1385
1386 TmpNode->addOperand(Result);
1387 TmpNode->addOperand(AddNode);
1388 submitCompositeNode(TmpNode);
1389 Result = TmpNode;
1390 RealAddends.erase(ItR);
1391 ImagAddends.erase(ItI);
1392 FoundImag = true;
1393 break;
1394 }
1395 }
1396 if (!FoundImag)
1397 return nullptr;
1398 }
1399 return Result;
1400 }
1401
1402 ComplexDeinterleavingGraph::NodePtr
extractPositiveAddend(std::list<Addend> & RealAddends,std::list<Addend> & ImagAddends)1403 ComplexDeinterleavingGraph::extractPositiveAddend(
1404 std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends) {
1405 for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) {
1406 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1407 auto [R, IsPositiveR] = *ItR;
1408 auto [I, IsPositiveI] = *ItI;
1409 if (IsPositiveR && IsPositiveI) {
1410 auto Result = identifyNode(R, I);
1411 if (Result) {
1412 RealAddends.erase(ItR);
1413 ImagAddends.erase(ItI);
1414 return Result;
1415 }
1416 }
1417 }
1418 }
1419 return nullptr;
1420 }
1421
identifyNodes(Instruction * RootI)1422 bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
1423 // This potential root instruction might already have been recognized as
1424 // reduction. Because RootToNode maps both Real and Imaginary parts to
1425 // CompositeNode we should choose only one either Real or Imag instruction to
1426 // use as an anchor for generating complex instruction.
1427 auto It = RootToNode.find(RootI);
1428 if (It != RootToNode.end()) {
1429 auto RootNode = It->second;
1430 assert(RootNode->Operation ==
1431 ComplexDeinterleavingOperation::ReductionOperation);
1432 // Find out which part, Real or Imag, comes later, and only if we come to
1433 // the latest part, add it to OrderedRoots.
1434 auto *R = cast<Instruction>(RootNode->Real);
1435 auto *I = cast<Instruction>(RootNode->Imag);
1436 auto *ReplacementAnchor = R->comesBefore(I) ? I : R;
1437 if (ReplacementAnchor != RootI)
1438 return false;
1439 OrderedRoots.push_back(RootI);
1440 return true;
1441 }
1442
1443 auto RootNode = identifyRoot(RootI);
1444 if (!RootNode)
1445 return false;
1446
1447 LLVM_DEBUG({
1448 Function *F = RootI->getFunction();
1449 BasicBlock *B = RootI->getParent();
1450 dbgs() << "Complex deinterleaving graph for " << F->getName()
1451 << "::" << B->getName() << ".\n";
1452 dump(dbgs());
1453 dbgs() << "\n";
1454 });
1455 RootToNode[RootI] = RootNode;
1456 OrderedRoots.push_back(RootI);
1457 return true;
1458 }
1459
collectPotentialReductions(BasicBlock * B)1460 bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) {
1461 bool FoundPotentialReduction = false;
1462
1463 auto *Br = dyn_cast<BranchInst>(B->getTerminator());
1464 if (!Br || Br->getNumSuccessors() != 2)
1465 return false;
1466
1467 // Identify simple one-block loop
1468 if (Br->getSuccessor(0) != B && Br->getSuccessor(1) != B)
1469 return false;
1470
1471 SmallVector<PHINode *> PHIs;
1472 for (auto &PHI : B->phis()) {
1473 if (PHI.getNumIncomingValues() != 2)
1474 continue;
1475
1476 if (!PHI.getType()->isVectorTy())
1477 continue;
1478
1479 auto *ReductionOp = dyn_cast<Instruction>(PHI.getIncomingValueForBlock(B));
1480 if (!ReductionOp)
1481 continue;
1482
1483 // Check if final instruction is reduced outside of current block
1484 Instruction *FinalReduction = nullptr;
1485 auto NumUsers = 0u;
1486 for (auto *U : ReductionOp->users()) {
1487 ++NumUsers;
1488 if (U == &PHI)
1489 continue;
1490 FinalReduction = dyn_cast<Instruction>(U);
1491 }
1492
1493 if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B ||
1494 isa<PHINode>(FinalReduction))
1495 continue;
1496
1497 ReductionInfo[ReductionOp] = {&PHI, FinalReduction};
1498 BackEdge = B;
1499 auto BackEdgeIdx = PHI.getBasicBlockIndex(B);
1500 auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0;
1501 Incoming = PHI.getIncomingBlock(IncomingIdx);
1502 FoundPotentialReduction = true;
1503
1504 // If the initial value of PHINode is an Instruction, consider it a leaf
1505 // value of a complex deinterleaving graph.
1506 if (auto *InitPHI =
1507 dyn_cast<Instruction>(PHI.getIncomingValueForBlock(Incoming)))
1508 FinalInstructions.insert(InitPHI);
1509 }
1510 return FoundPotentialReduction;
1511 }
1512
identifyReductionNodes()1513 void ComplexDeinterleavingGraph::identifyReductionNodes() {
1514 SmallVector<bool> Processed(ReductionInfo.size(), false);
1515 SmallVector<Instruction *> OperationInstruction;
1516 for (auto &P : ReductionInfo)
1517 OperationInstruction.push_back(P.first);
1518
1519 // Identify a complex computation by evaluating two reduction operations that
1520 // potentially could be involved
1521 for (size_t i = 0; i < OperationInstruction.size(); ++i) {
1522 if (Processed[i])
1523 continue;
1524 for (size_t j = i + 1; j < OperationInstruction.size(); ++j) {
1525 if (Processed[j])
1526 continue;
1527
1528 auto *Real = OperationInstruction[i];
1529 auto *Imag = OperationInstruction[j];
1530 if (Real->getType() != Imag->getType())
1531 continue;
1532
1533 RealPHI = ReductionInfo[Real].first;
1534 ImagPHI = ReductionInfo[Imag].first;
1535 PHIsFound = false;
1536 auto Node = identifyNode(Real, Imag);
1537 if (!Node) {
1538 std::swap(Real, Imag);
1539 std::swap(RealPHI, ImagPHI);
1540 Node = identifyNode(Real, Imag);
1541 }
1542
1543 // If a node is identified and reduction PHINode is used in the chain of
1544 // operations, mark its operation instructions as used to prevent
1545 // re-identification and attach the node to the real part
1546 if (Node && PHIsFound) {
1547 LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: "
1548 << *Real << " / " << *Imag << "\n");
1549 Processed[i] = true;
1550 Processed[j] = true;
1551 auto RootNode = prepareCompositeNode(
1552 ComplexDeinterleavingOperation::ReductionOperation, Real, Imag);
1553 RootNode->addOperand(Node);
1554 RootToNode[Real] = RootNode;
1555 RootToNode[Imag] = RootNode;
1556 submitCompositeNode(RootNode);
1557 break;
1558 }
1559 }
1560 }
1561
1562 RealPHI = nullptr;
1563 ImagPHI = nullptr;
1564 }
1565
checkNodes()1566 bool ComplexDeinterleavingGraph::checkNodes() {
1567 // Collect all instructions from roots to leaves
1568 SmallPtrSet<Instruction *, 16> AllInstructions;
1569 SmallVector<Instruction *, 8> Worklist;
1570 for (auto &Pair : RootToNode)
1571 Worklist.push_back(Pair.first);
1572
1573 // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG
1574 // chains
1575 while (!Worklist.empty()) {
1576 auto *I = Worklist.back();
1577 Worklist.pop_back();
1578
1579 if (!AllInstructions.insert(I).second)
1580 continue;
1581
1582 for (Value *Op : I->operands()) {
1583 if (auto *OpI = dyn_cast<Instruction>(Op)) {
1584 if (!FinalInstructions.count(I))
1585 Worklist.emplace_back(OpI);
1586 }
1587 }
1588 }
1589
1590 // Find instructions that have users outside of chain
1591 SmallVector<Instruction *, 2> OuterInstructions;
1592 for (auto *I : AllInstructions) {
1593 // Skip root nodes
1594 if (RootToNode.count(I))
1595 continue;
1596
1597 for (User *U : I->users()) {
1598 if (AllInstructions.count(cast<Instruction>(U)))
1599 continue;
1600
1601 // Found an instruction that is not used by XCMLA/XCADD chain
1602 Worklist.emplace_back(I);
1603 break;
1604 }
1605 }
1606
1607 // If any instructions are found to be used outside, find and remove roots
1608 // that somehow connect to those instructions.
1609 SmallPtrSet<Instruction *, 16> Visited;
1610 while (!Worklist.empty()) {
1611 auto *I = Worklist.back();
1612 Worklist.pop_back();
1613 if (!Visited.insert(I).second)
1614 continue;
1615
1616 // Found an impacted root node. Removing it from the nodes to be
1617 // deinterleaved
1618 if (RootToNode.count(I)) {
1619 LLVM_DEBUG(dbgs() << "Instruction " << *I
1620 << " could be deinterleaved but its chain of complex "
1621 "operations have an outside user\n");
1622 RootToNode.erase(I);
1623 }
1624
1625 if (!AllInstructions.count(I) || FinalInstructions.count(I))
1626 continue;
1627
1628 for (User *U : I->users())
1629 Worklist.emplace_back(cast<Instruction>(U));
1630
1631 for (Value *Op : I->operands()) {
1632 if (auto *OpI = dyn_cast<Instruction>(Op))
1633 Worklist.emplace_back(OpI);
1634 }
1635 }
1636 return !RootToNode.empty();
1637 }
1638
1639 ComplexDeinterleavingGraph::NodePtr
identifyRoot(Instruction * RootI)1640 ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) {
1641 if (auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) {
1642 if (Intrinsic->getIntrinsicID() != Intrinsic::vector_interleave2)
1643 return nullptr;
1644
1645 auto *Real = dyn_cast<Instruction>(Intrinsic->getOperand(0));
1646 auto *Imag = dyn_cast<Instruction>(Intrinsic->getOperand(1));
1647 if (!Real || !Imag)
1648 return nullptr;
1649
1650 return identifyNode(Real, Imag);
1651 }
1652
1653 auto *SVI = dyn_cast<ShuffleVectorInst>(RootI);
1654 if (!SVI)
1655 return nullptr;
1656
1657 // Look for a shufflevector that takes separate vectors of the real and
1658 // imaginary components and recombines them into a single vector.
1659 if (!isInterleavingMask(SVI->getShuffleMask()))
1660 return nullptr;
1661
1662 Instruction *Real;
1663 Instruction *Imag;
1664 if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag))))
1665 return nullptr;
1666
1667 return identifyNode(Real, Imag);
1668 }
1669
1670 ComplexDeinterleavingGraph::NodePtr
identifyDeinterleave(Instruction * Real,Instruction * Imag)1671 ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real,
1672 Instruction *Imag) {
1673 Instruction *I = nullptr;
1674 Value *FinalValue = nullptr;
1675 if (match(Real, m_ExtractValue<0>(m_Instruction(I))) &&
1676 match(Imag, m_ExtractValue<1>(m_Specific(I))) &&
1677 match(I, m_Intrinsic<Intrinsic::vector_deinterleave2>(
1678 m_Value(FinalValue)))) {
1679 NodePtr PlaceholderNode = prepareCompositeNode(
1680 llvm::ComplexDeinterleavingOperation::Deinterleave, Real, Imag);
1681 PlaceholderNode->ReplacementNode = FinalValue;
1682 FinalInstructions.insert(Real);
1683 FinalInstructions.insert(Imag);
1684 return submitCompositeNode(PlaceholderNode);
1685 }
1686
1687 auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real);
1688 auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag);
1689 if (!RealShuffle || !ImagShuffle) {
1690 if (RealShuffle || ImagShuffle)
1691 LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n");
1692 return nullptr;
1693 }
1694
1695 Value *RealOp1 = RealShuffle->getOperand(1);
1696 if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) {
1697 LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n");
1698 return nullptr;
1699 }
1700 Value *ImagOp1 = ImagShuffle->getOperand(1);
1701 if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) {
1702 LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n");
1703 return nullptr;
1704 }
1705
1706 Value *RealOp0 = RealShuffle->getOperand(0);
1707 Value *ImagOp0 = ImagShuffle->getOperand(0);
1708
1709 if (RealOp0 != ImagOp0) {
1710 LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n");
1711 return nullptr;
1712 }
1713
1714 ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
1715 ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
1716 if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) {
1717 LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n");
1718 return nullptr;
1719 }
1720
1721 if (RealMask[0] != 0 || ImagMask[0] != 1) {
1722 LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n");
1723 return nullptr;
1724 }
1725
1726 // Type checking, the shuffle type should be a vector type of the same
1727 // scalar type, but half the size
1728 auto CheckType = [&](ShuffleVectorInst *Shuffle) {
1729 Value *Op = Shuffle->getOperand(0);
1730 auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType());
1731 auto *OpTy = cast<FixedVectorType>(Op->getType());
1732
1733 if (OpTy->getScalarType() != ShuffleTy->getScalarType())
1734 return false;
1735 if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
1736 return false;
1737
1738 return true;
1739 };
1740
1741 auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool {
1742 if (!CheckType(Shuffle))
1743 return false;
1744
1745 ArrayRef<int> Mask = Shuffle->getShuffleMask();
1746 int Last = *Mask.rbegin();
1747
1748 Value *Op = Shuffle->getOperand(0);
1749 auto *OpTy = cast<FixedVectorType>(Op->getType());
1750 int NumElements = OpTy->getNumElements();
1751
1752 // Ensure that the deinterleaving shuffle only pulls from the first
1753 // shuffle operand.
1754 return Last < NumElements;
1755 };
1756
1757 if (RealShuffle->getType() != ImagShuffle->getType()) {
1758 LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n");
1759 return nullptr;
1760 }
1761 if (!CheckDeinterleavingShuffle(RealShuffle)) {
1762 LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n");
1763 return nullptr;
1764 }
1765 if (!CheckDeinterleavingShuffle(ImagShuffle)) {
1766 LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n");
1767 return nullptr;
1768 }
1769
1770 NodePtr PlaceholderNode =
1771 prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Deinterleave,
1772 RealShuffle, ImagShuffle);
1773 PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
1774 FinalInstructions.insert(RealShuffle);
1775 FinalInstructions.insert(ImagShuffle);
1776 return submitCompositeNode(PlaceholderNode);
1777 }
1778
1779 ComplexDeinterleavingGraph::NodePtr
identifySplat(Value * R,Value * I)1780 ComplexDeinterleavingGraph::identifySplat(Value *R, Value *I) {
1781 auto IsSplat = [](Value *V) -> bool {
1782 // Fixed-width vector with constants
1783 if (isa<ConstantDataVector>(V))
1784 return true;
1785
1786 VectorType *VTy;
1787 ArrayRef<int> Mask;
1788 // Splats are represented differently depending on whether the repeated
1789 // value is a constant or an Instruction
1790 if (auto *Const = dyn_cast<ConstantExpr>(V)) {
1791 if (Const->getOpcode() != Instruction::ShuffleVector)
1792 return false;
1793 VTy = cast<VectorType>(Const->getType());
1794 Mask = Const->getShuffleMask();
1795 } else if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V)) {
1796 VTy = Shuf->getType();
1797 Mask = Shuf->getShuffleMask();
1798 } else {
1799 return false;
1800 }
1801
1802 // When the data type is <1 x Type>, it's not possible to differentiate
1803 // between the ComplexDeinterleaving::Deinterleave and
1804 // ComplexDeinterleaving::Splat operations.
1805 if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1)
1806 return false;
1807
1808 return all_equal(Mask) && Mask[0] == 0;
1809 };
1810
1811 if (!IsSplat(R) || !IsSplat(I))
1812 return nullptr;
1813
1814 auto *Real = dyn_cast<Instruction>(R);
1815 auto *Imag = dyn_cast<Instruction>(I);
1816 if ((!Real && Imag) || (Real && !Imag))
1817 return nullptr;
1818
1819 if (Real && Imag) {
1820 // Non-constant splats should be in the same basic block
1821 if (Real->getParent() != Imag->getParent())
1822 return nullptr;
1823
1824 FinalInstructions.insert(Real);
1825 FinalInstructions.insert(Imag);
1826 }
1827 NodePtr PlaceholderNode =
1828 prepareCompositeNode(ComplexDeinterleavingOperation::Splat, R, I);
1829 return submitCompositeNode(PlaceholderNode);
1830 }
1831
1832 ComplexDeinterleavingGraph::NodePtr
identifyPHINode(Instruction * Real,Instruction * Imag)1833 ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
1834 Instruction *Imag) {
1835 if (Real != RealPHI || Imag != ImagPHI)
1836 return nullptr;
1837
1838 PHIsFound = true;
1839 NodePtr PlaceholderNode = prepareCompositeNode(
1840 ComplexDeinterleavingOperation::ReductionPHI, Real, Imag);
1841 return submitCompositeNode(PlaceholderNode);
1842 }
1843
1844 ComplexDeinterleavingGraph::NodePtr
identifySelectNode(Instruction * Real,Instruction * Imag)1845 ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,
1846 Instruction *Imag) {
1847 auto *SelectReal = dyn_cast<SelectInst>(Real);
1848 auto *SelectImag = dyn_cast<SelectInst>(Imag);
1849 if (!SelectReal || !SelectImag)
1850 return nullptr;
1851
1852 Instruction *MaskA, *MaskB;
1853 Instruction *AR, *AI, *RA, *BI;
1854 if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR),
1855 m_Instruction(RA))) ||
1856 !match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI),
1857 m_Instruction(BI))))
1858 return nullptr;
1859
1860 if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB))
1861 return nullptr;
1862
1863 if (!MaskA->getType()->isVectorTy())
1864 return nullptr;
1865
1866 auto NodeA = identifyNode(AR, AI);
1867 if (!NodeA)
1868 return nullptr;
1869
1870 auto NodeB = identifyNode(RA, BI);
1871 if (!NodeB)
1872 return nullptr;
1873
1874 NodePtr PlaceholderNode = prepareCompositeNode(
1875 ComplexDeinterleavingOperation::ReductionSelect, Real, Imag);
1876 PlaceholderNode->addOperand(NodeA);
1877 PlaceholderNode->addOperand(NodeB);
1878 FinalInstructions.insert(MaskA);
1879 FinalInstructions.insert(MaskB);
1880 return submitCompositeNode(PlaceholderNode);
1881 }
1882
replaceSymmetricNode(IRBuilderBase & B,unsigned Opcode,std::optional<FastMathFlags> Flags,Value * InputA,Value * InputB)1883 static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,
1884 std::optional<FastMathFlags> Flags,
1885 Value *InputA, Value *InputB) {
1886 Value *I;
1887 switch (Opcode) {
1888 case Instruction::FNeg:
1889 I = B.CreateFNeg(InputA);
1890 break;
1891 case Instruction::FAdd:
1892 I = B.CreateFAdd(InputA, InputB);
1893 break;
1894 case Instruction::Add:
1895 I = B.CreateAdd(InputA, InputB);
1896 break;
1897 case Instruction::FSub:
1898 I = B.CreateFSub(InputA, InputB);
1899 break;
1900 case Instruction::Sub:
1901 I = B.CreateSub(InputA, InputB);
1902 break;
1903 case Instruction::FMul:
1904 I = B.CreateFMul(InputA, InputB);
1905 break;
1906 case Instruction::Mul:
1907 I = B.CreateMul(InputA, InputB);
1908 break;
1909 default:
1910 llvm_unreachable("Incorrect symmetric opcode");
1911 }
1912 if (Flags)
1913 cast<Instruction>(I)->setFastMathFlags(*Flags);
1914 return I;
1915 }
1916
replaceNode(IRBuilderBase & Builder,RawNodePtr Node)1917 Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
1918 RawNodePtr Node) {
1919 if (Node->ReplacementNode)
1920 return Node->ReplacementNode;
1921
1922 auto ReplaceOperandIfExist = [&](RawNodePtr &Node, unsigned Idx) -> Value * {
1923 return Node->Operands.size() > Idx
1924 ? replaceNode(Builder, Node->Operands[Idx])
1925 : nullptr;
1926 };
1927
1928 Value *ReplacementNode;
1929 switch (Node->Operation) {
1930 case ComplexDeinterleavingOperation::CAdd:
1931 case ComplexDeinterleavingOperation::CMulPartial:
1932 case ComplexDeinterleavingOperation::Symmetric: {
1933 Value *Input0 = ReplaceOperandIfExist(Node, 0);
1934 Value *Input1 = ReplaceOperandIfExist(Node, 1);
1935 Value *Accumulator = ReplaceOperandIfExist(Node, 2);
1936 assert(!Input1 || (Input0->getType() == Input1->getType() &&
1937 "Node inputs need to be of the same type"));
1938 assert(!Accumulator ||
1939 (Input0->getType() == Accumulator->getType() &&
1940 "Accumulator and input need to be of the same type"));
1941 if (Node->Operation == ComplexDeinterleavingOperation::Symmetric)
1942 ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, Node->Flags,
1943 Input0, Input1);
1944 else
1945 ReplacementNode = TL->createComplexDeinterleavingIR(
1946 Builder, Node->Operation, Node->Rotation, Input0, Input1,
1947 Accumulator);
1948 break;
1949 }
1950 case ComplexDeinterleavingOperation::Deinterleave:
1951 llvm_unreachable("Deinterleave node should already have ReplacementNode");
1952 break;
1953 case ComplexDeinterleavingOperation::Splat: {
1954 auto *NewTy = VectorType::getDoubleElementsVectorType(
1955 cast<VectorType>(Node->Real->getType()));
1956 auto *R = dyn_cast<Instruction>(Node->Real);
1957 auto *I = dyn_cast<Instruction>(Node->Imag);
1958 if (R && I) {
1959 // Splats that are not constant are interleaved where they are located
1960 Instruction *InsertPoint = (I->comesBefore(R) ? R : I)->getNextNode();
1961 IRBuilder<> IRB(InsertPoint);
1962 ReplacementNode = IRB.CreateIntrinsic(Intrinsic::vector_interleave2,
1963 NewTy, {Node->Real, Node->Imag});
1964 } else {
1965 ReplacementNode = Builder.CreateIntrinsic(
1966 Intrinsic::vector_interleave2, NewTy, {Node->Real, Node->Imag});
1967 }
1968 break;
1969 }
1970 case ComplexDeinterleavingOperation::ReductionPHI: {
1971 // If Operation is ReductionPHI, a new empty PHINode is created.
1972 // It is filled later when the ReductionOperation is processed.
1973 auto *VTy = cast<VectorType>(Node->Real->getType());
1974 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
1975 auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHIIt());
1976 OldToNewPHI[dyn_cast<PHINode>(Node->Real)] = NewPHI;
1977 ReplacementNode = NewPHI;
1978 break;
1979 }
1980 case ComplexDeinterleavingOperation::ReductionOperation:
1981 ReplacementNode = replaceNode(Builder, Node->Operands[0]);
1982 processReductionOperation(ReplacementNode, Node);
1983 break;
1984 case ComplexDeinterleavingOperation::ReductionSelect: {
1985 auto *MaskReal = cast<Instruction>(Node->Real)->getOperand(0);
1986 auto *MaskImag = cast<Instruction>(Node->Imag)->getOperand(0);
1987 auto *A = replaceNode(Builder, Node->Operands[0]);
1988 auto *B = replaceNode(Builder, Node->Operands[1]);
1989 auto *NewMaskTy = VectorType::getDoubleElementsVectorType(
1990 cast<VectorType>(MaskReal->getType()));
1991 auto *NewMask = Builder.CreateIntrinsic(Intrinsic::vector_interleave2,
1992 NewMaskTy, {MaskReal, MaskImag});
1993 ReplacementNode = Builder.CreateSelect(NewMask, A, B);
1994 break;
1995 }
1996 }
1997
1998 assert(ReplacementNode && "Target failed to create Intrinsic call.");
1999 NumComplexTransformations += 1;
2000 Node->ReplacementNode = ReplacementNode;
2001 return ReplacementNode;
2002 }
2003
processReductionOperation(Value * OperationReplacement,RawNodePtr Node)2004 void ComplexDeinterleavingGraph::processReductionOperation(
2005 Value *OperationReplacement, RawNodePtr Node) {
2006 auto *Real = cast<Instruction>(Node->Real);
2007 auto *Imag = cast<Instruction>(Node->Imag);
2008 auto *OldPHIReal = ReductionInfo[Real].first;
2009 auto *OldPHIImag = ReductionInfo[Imag].first;
2010 auto *NewPHI = OldToNewPHI[OldPHIReal];
2011
2012 auto *VTy = cast<VectorType>(Real->getType());
2013 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2014
2015 // We have to interleave initial origin values coming from IncomingBlock
2016 Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming);
2017 Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming);
2018
2019 IRBuilder<> Builder(Incoming->getTerminator());
2020 auto *NewInit = Builder.CreateIntrinsic(Intrinsic::vector_interleave2, NewVTy,
2021 {InitReal, InitImag});
2022
2023 NewPHI->addIncoming(NewInit, Incoming);
2024 NewPHI->addIncoming(OperationReplacement, BackEdge);
2025
2026 // Deinterleave complex vector outside of loop so that it can be finally
2027 // reduced
2028 auto *FinalReductionReal = ReductionInfo[Real].second;
2029 auto *FinalReductionImag = ReductionInfo[Imag].second;
2030
2031 Builder.SetInsertPoint(
2032 &*FinalReductionReal->getParent()->getFirstInsertionPt());
2033 auto *Deinterleave = Builder.CreateIntrinsic(Intrinsic::vector_deinterleave2,
2034 OperationReplacement->getType(),
2035 OperationReplacement);
2036
2037 auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0);
2038 FinalReductionReal->replaceUsesOfWith(Real, NewReal);
2039
2040 Builder.SetInsertPoint(FinalReductionImag);
2041 auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1);
2042 FinalReductionImag->replaceUsesOfWith(Imag, NewImag);
2043 }
2044
replaceNodes()2045 void ComplexDeinterleavingGraph::replaceNodes() {
2046 SmallVector<Instruction *, 16> DeadInstrRoots;
2047 for (auto *RootInstruction : OrderedRoots) {
2048 // Check if this potential root went through check process and we can
2049 // deinterleave it
2050 if (!RootToNode.count(RootInstruction))
2051 continue;
2052
2053 IRBuilder<> Builder(RootInstruction);
2054 auto RootNode = RootToNode[RootInstruction];
2055 Value *R = replaceNode(Builder, RootNode.get());
2056
2057 if (RootNode->Operation ==
2058 ComplexDeinterleavingOperation::ReductionOperation) {
2059 auto *RootReal = cast<Instruction>(RootNode->Real);
2060 auto *RootImag = cast<Instruction>(RootNode->Imag);
2061 ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);
2062 ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
2063 DeadInstrRoots.push_back(cast<Instruction>(RootReal));
2064 DeadInstrRoots.push_back(cast<Instruction>(RootImag));
2065 } else {
2066 assert(R && "Unable to find replacement for RootInstruction");
2067 DeadInstrRoots.push_back(RootInstruction);
2068 RootInstruction->replaceAllUsesWith(R);
2069 }
2070 }
2071
2072 for (auto *I : DeadInstrRoots)
2073 RecursivelyDeleteTriviallyDeadInstructions(I, TLI);
2074 }
2075