LLVM 19.0.0git
ComplexDeinterleavingPass.cpp
Go to the documentation of this file.
1//===- ComplexDeinterleavingPass.cpp --------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Identification:
10// This step is responsible for finding the patterns that can be lowered to
11// complex instructions, and building a graph to represent the complex
12// structures. Starting from the "Converging Shuffle" (a shuffle that
13// reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the
14// operands are evaluated and identified as "Composite Nodes" (collections of
15// instructions that can potentially be lowered to a single complex
16// instruction). This is performed by checking the real and imaginary components
17// and tracking the data flow for each component while following the operand
18// pairs. Validity of each node is expected to be done upon creation, and any
19// validation errors should halt traversal and prevent further graph
20// construction.
21// Instead of relying on Shuffle operations, vector interleaving and
22// deinterleaving can be represented by vector.interleave2 and
23// vector.deinterleave2 intrinsics. Scalable vectors can be represented only by
24// these intrinsics, whereas, fixed-width vectors are recognized for both
25// shufflevector instruction and intrinsics.
26//
27// Replacement:
28// This step traverses the graph built up by identification, delegating to the
29// target to validate and generate the correct intrinsics, and plumbs them
30// together connecting each end of the new intrinsics graph to the existing
31// use-def chain. This step is assumed to finish successfully, as all
32// information is expected to be correct by this point.
33//
34//
35// Internal data structure:
36// ComplexDeinterleavingGraph:
37// Keeps references to all the valid CompositeNodes formed as part of the
38// transformation, and every Instruction contained within said nodes. It also
39// holds onto a reference to the root Instruction, and the root node that should
40// replace it.
41//
42// ComplexDeinterleavingCompositeNode:
43// A CompositeNode represents a single transformation point; each node should
44// transform into a single complex instruction (ignoring vector splitting, which
45// would generate more instructions per node). They are identified in a
46// depth-first manner, traversing and identifying the operands of each
47// instruction in the order they appear in the IR.
48// Each node maintains a reference to its Real and Imaginary instructions,
49// as well as any additional instructions that make up the identified operation
50// (Internal instructions should only have uses within their containing node).
51// A Node also contains the rotation and operation type that it represents.
52// Operands contains pointers to other CompositeNodes, acting as the edges in
53// the graph. ReplacementValue is the transformed Value* that has been emitted
54// to the IR.
55//
56// Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and
57// ReplacementValue fields of that Node are relevant, where the ReplacementValue
58// should be pre-populated.
59//
60//===----------------------------------------------------------------------===//
61
63#include "llvm/ADT/MapVector.h"
64#include "llvm/ADT/Statistic.h"
70#include "llvm/IR/IRBuilder.h"
75#include <algorithm>
76
77using namespace llvm;
78using namespace PatternMatch;
79
80#define DEBUG_TYPE "complex-deinterleaving"
81
82STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed");
83
85 "enable-complex-deinterleaving",
86 cl::desc("Enable generation of complex instructions"), cl::init(true),
88
89/// Checks the given mask, and determines whether said mask is interleaving.
90///
91/// To be interleaving, a mask must alternate between `i` and `i + (Length /
92/// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a
93/// 4x vector interleaving mask would be <0, 2, 1, 3>).
94static bool isInterleavingMask(ArrayRef<int> Mask);
95
96/// Checks the given mask, and determines whether said mask is deinterleaving.
97///
98/// To be deinterleaving, a mask must increment in steps of 2, and either start
99/// with 0 or 1.
100/// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or
101/// <1, 3, 5, 7>).
102static bool isDeinterleavingMask(ArrayRef<int> Mask);
103
104/// Returns true if the operation is a negation of V, and it works for both
105/// integers and floats.
106static bool isNeg(Value *V);
107
108/// Returns the operand for negation operation.
109static Value *getNegOperand(Value *V);
110
111namespace {
112
113class ComplexDeinterleavingLegacyPass : public FunctionPass {
114public:
115 static char ID;
116
117 ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr)
118 : FunctionPass(ID), TM(TM) {
121 }
122
123 StringRef getPassName() const override {
124 return "Complex Deinterleaving Pass";
125 }
126
127 bool runOnFunction(Function &F) override;
128 void getAnalysisUsage(AnalysisUsage &AU) const override {
130 AU.setPreservesCFG();
131 }
132
133private:
134 const TargetMachine *TM;
135};
136
137class ComplexDeinterleavingGraph;
138struct ComplexDeinterleavingCompositeNode {
139
140 ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
141 Value *R, Value *I)
142 : Operation(Op), Real(R), Imag(I) {}
143
144private:
145 friend class ComplexDeinterleavingGraph;
146 using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>;
147 using RawNodePtr = ComplexDeinterleavingCompositeNode *;
148
149public:
151 Value *Real;
152 Value *Imag;
153
154 // This two members are required exclusively for generating
155 // ComplexDeinterleavingOperation::Symmetric operations.
156 unsigned Opcode;
157 std::optional<FastMathFlags> Flags;
158
160 ComplexDeinterleavingRotation::Rotation_0;
162 Value *ReplacementNode = nullptr;
163
164 void addOperand(NodePtr Node) { Operands.push_back(Node.get()); }
165
166 void dump() { dump(dbgs()); }
167 void dump(raw_ostream &OS) {
168 auto PrintValue = [&](Value *V) {
169 if (V) {
170 OS << "\"";
171 V->print(OS, true);
172 OS << "\"\n";
173 } else
174 OS << "nullptr\n";
175 };
176 auto PrintNodeRef = [&](RawNodePtr Ptr) {
177 if (Ptr)
178 OS << Ptr << "\n";
179 else
180 OS << "nullptr\n";
181 };
182
183 OS << "- CompositeNode: " << this << "\n";
184 OS << " Real: ";
185 PrintValue(Real);
186 OS << " Imag: ";
187 PrintValue(Imag);
188 OS << " ReplacementNode: ";
189 PrintValue(ReplacementNode);
190 OS << " Operation: " << (int)Operation << "\n";
191 OS << " Rotation: " << ((int)Rotation * 90) << "\n";
192 OS << " Operands: \n";
193 for (const auto &Op : Operands) {
194 OS << " - ";
195 PrintNodeRef(Op);
196 }
197 }
198};
199
200class ComplexDeinterleavingGraph {
201public:
202 struct Product {
203 Value *Multiplier;
204 Value *Multiplicand;
205 bool IsPositive;
206 };
207
208 using Addend = std::pair<Value *, bool>;
209 using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr;
210 using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr;
211
212 // Helper struct for holding info about potential partial multiplication
213 // candidates
214 struct PartialMulCandidate {
215 Value *Common;
216 NodePtr Node;
217 unsigned RealIdx;
218 unsigned ImagIdx;
219 bool IsNodeInverted;
220 };
221
222 explicit ComplexDeinterleavingGraph(const TargetLowering *TL,
223 const TargetLibraryInfo *TLI)
224 : TL(TL), TLI(TLI) {}
225
226private:
227 const TargetLowering *TL = nullptr;
228 const TargetLibraryInfo *TLI = nullptr;
229 SmallVector<NodePtr> CompositeNodes;
230 DenseMap<std::pair<Value *, Value *>, NodePtr> CachedResult;
231
232 SmallPtrSet<Instruction *, 16> FinalInstructions;
233
234 /// Root instructions are instructions from which complex computation starts
235 std::map<Instruction *, NodePtr> RootToNode;
236
237 /// Topologically sorted root instructions
239
240 /// When examining a basic block for complex deinterleaving, if it is a simple
241 /// one-block loop, then the only incoming block is 'Incoming' and the
242 /// 'BackEdge' block is the block itself."
243 BasicBlock *BackEdge = nullptr;
244 BasicBlock *Incoming = nullptr;
245
246 /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction
247 /// %OutsideUser as it is shown in the IR:
248 ///
249 /// vector.body:
250 /// %PHInode = phi <vector type> [ zeroinitializer, %entry ],
251 /// [ %ReductionOp, %vector.body ]
252 /// ...
253 /// %ReductionOp = fadd i64 ...
254 /// ...
255 /// br i1 %condition, label %vector.body, %middle.block
256 ///
257 /// middle.block:
258 /// %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp)
259 ///
260 /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding
261 /// `llvm.vector.reduce.fadd` when unroll factor isn't one.
263
264 /// In the process of detecting a reduction, we consider a pair of
265 /// %ReductionOP, which we refer to as real and imag (or vice versa), and
266 /// traverse the use-tree to detect complex operations. As this is a reduction
267 /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds
268 /// to the %ReductionOPs that we suspect to be complex.
269 /// RealPHI and ImagPHI are used by the identifyPHINode method.
270 PHINode *RealPHI = nullptr;
271 PHINode *ImagPHI = nullptr;
272
273 /// Set this flag to true if RealPHI and ImagPHI were reached during reduction
274 /// detection.
275 bool PHIsFound = false;
276
277 /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode.
278 /// The new PHINode corresponds to a vector of deinterleaved complex numbers.
279 /// This mapping is populated during
280 /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then
281 /// used in the ComplexDeinterleavingOperation::ReductionOperation node
282 /// replacement process.
283 std::map<PHINode *, PHINode *> OldToNewPHI;
284
285 NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation,
286 Value *R, Value *I) {
287 assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
288 Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
289 (R && I)) &&
290 "Reduction related nodes must have Real and Imaginary parts");
291 return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R,
292 I);
293 }
294
295 NodePtr submitCompositeNode(NodePtr Node) {
296 CompositeNodes.push_back(Node);
297 if (Node->Real && Node->Imag)
298 CachedResult[{Node->Real, Node->Imag}] = Node;
299 return Node;
300 }
301
302 /// Identifies a complex partial multiply pattern and its rotation, based on
303 /// the following patterns
304 ///
305 /// 0: r: cr + ar * br
306 /// i: ci + ar * bi
307 /// 90: r: cr - ai * bi
308 /// i: ci + ai * br
309 /// 180: r: cr - ar * br
310 /// i: ci - ar * bi
311 /// 270: r: cr + ai * bi
312 /// i: ci - ai * br
313 NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag);
314
315 /// Identify the other branch of a Partial Mul, taking the CommonOperandI that
316 /// is partially known from identifyPartialMul, filling in the other half of
317 /// the complex pair.
318 NodePtr
319 identifyNodeWithImplicitAdd(Instruction *I, Instruction *J,
320 std::pair<Value *, Value *> &CommonOperandI);
321
322 /// Identifies a complex add pattern and its rotation, based on the following
323 /// patterns.
324 ///
325 /// 90: r: ar - bi
326 /// i: ai + br
327 /// 270: r: ar + bi
328 /// i: ai - br
329 NodePtr identifyAdd(Instruction *Real, Instruction *Imag);
330 NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag);
331
332 NodePtr identifyNode(Value *R, Value *I);
333
334 /// Determine if a sum of complex numbers can be formed from \p RealAddends
335 /// and \p ImagAddens. If \p Accumulator is not null, add the result to it.
336 /// Return nullptr if it is not possible to construct a complex number.
337 /// \p Flags are needed to generate symmetric Add and Sub operations.
338 NodePtr identifyAdditions(std::list<Addend> &RealAddends,
339 std::list<Addend> &ImagAddends,
340 std::optional<FastMathFlags> Flags,
341 NodePtr Accumulator);
342
343 /// Extract one addend that have both real and imaginary parts positive.
344 NodePtr extractPositiveAddend(std::list<Addend> &RealAddends,
345 std::list<Addend> &ImagAddends);
346
347 /// Determine if sum of multiplications of complex numbers can be formed from
348 /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result
349 /// to it. Return nullptr if it is not possible to construct a complex number.
350 NodePtr identifyMultiplications(std::vector<Product> &RealMuls,
351 std::vector<Product> &ImagMuls,
352 NodePtr Accumulator);
353
354 /// Go through pairs of multiplication (one Real and one Imag) and find all
355 /// possible candidates for partial multiplication and put them into \p
356 /// Candidates. Returns true if all Product has pair with common operand
357 bool collectPartialMuls(const std::vector<Product> &RealMuls,
358 const std::vector<Product> &ImagMuls,
359 std::vector<PartialMulCandidate> &Candidates);
360
361 /// If the code is compiled with -Ofast or expressions have `reassoc` flag,
362 /// the order of complex computation operations may be significantly altered,
363 /// and the real and imaginary parts may not be executed in parallel. This
364 /// function takes this into consideration and employs a more general approach
365 /// to identify complex computations. Initially, it gathers all the addends
366 /// and multiplicands and then constructs a complex expression from them.
367 NodePtr identifyReassocNodes(Instruction *I, Instruction *J);
368
369 NodePtr identifyRoot(Instruction *I);
370
371 /// Identifies the Deinterleave operation applied to a vector containing
372 /// complex numbers. There are two ways to represent the Deinterleave
373 /// operation:
374 /// * Using two shufflevectors with even indices for /pReal instruction and
375 /// odd indices for /pImag instructions (only for fixed-width vectors)
376 /// * Using two extractvalue instructions applied to `vector.deinterleave2`
377 /// intrinsic (for both fixed and scalable vectors)
378 NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag);
379
380 /// identifying the operation that represents a complex number repeated in a
381 /// Splat vector. There are two possible types of splats: ConstantExpr with
382 /// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an
383 /// initialization mask with all values set to zero.
384 NodePtr identifySplat(Value *Real, Value *Imag);
385
386 NodePtr identifyPHINode(Instruction *Real, Instruction *Imag);
387
388 /// Identifies SelectInsts in a loop that has reduction with predication masks
389 /// and/or predicated tail folding
390 NodePtr identifySelectNode(Instruction *Real, Instruction *Imag);
391
392 Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node);
393
394 /// Complete IR modifications after producing new reduction operation:
395 /// * Populate the PHINode generated for
396 /// ComplexDeinterleavingOperation::ReductionPHI
397 /// * Deinterleave the final value outside of the loop and repurpose original
398 /// reduction users
399 void processReductionOperation(Value *OperationReplacement, RawNodePtr Node);
400
401public:
402 void dump() { dump(dbgs()); }
403 void dump(raw_ostream &OS) {
404 for (const auto &Node : CompositeNodes)
405 Node->dump(OS);
406 }
407
408 /// Returns false if the deinterleaving operation should be cancelled for the
409 /// current graph.
410 bool identifyNodes(Instruction *RootI);
411
412 /// In case \pB is one-block loop, this function seeks potential reductions
413 /// and populates ReductionInfo. Returns true if any reductions were
414 /// identified.
415 bool collectPotentialReductions(BasicBlock *B);
416
417 void identifyReductionNodes();
418
419 /// Check that every instruction, from the roots to the leaves, has internal
420 /// uses.
421 bool checkNodes();
422
423 /// Perform the actual replacement of the underlying instruction graph.
424 void replaceNodes();
425};
426
427class ComplexDeinterleaving {
428public:
429 ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli)
430 : TL(tl), TLI(tli) {}
431 bool runOnFunction(Function &F);
432
433private:
434 bool evaluateBasicBlock(BasicBlock *B);
435
436 const TargetLowering *TL = nullptr;
437 const TargetLibraryInfo *TLI = nullptr;
438};
439
440} // namespace
441
442char ComplexDeinterleavingLegacyPass::ID = 0;
443
444INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
445 "Complex Deinterleaving", false, false)
446INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
448
451 const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering();
452 auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F);
453 if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F))
454 return PreservedAnalyses::all();
455
458 return PA;
459}
460
462 return new ComplexDeinterleavingLegacyPass(TM);
463}
464
465bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) {
466 const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering();
467 auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
468 return ComplexDeinterleaving(TL, &TLI).runOnFunction(F);
469}
470
471bool ComplexDeinterleaving::runOnFunction(Function &F) {
474 dbgs() << "Complex deinterleaving has been explicitly disabled.\n");
475 return false;
476 }
477
480 dbgs() << "Complex deinterleaving has been disabled, target does "
481 "not support lowering of complex number operations.\n");
482 return false;
483 }
484
485 bool Changed = false;
486 for (auto &B : F)
487 Changed |= evaluateBasicBlock(&B);
488
489 return Changed;
490}
491
493 // If the size is not even, it's not an interleaving mask
494 if ((Mask.size() & 1))
495 return false;
496
497 int HalfNumElements = Mask.size() / 2;
498 for (int Idx = 0; Idx < HalfNumElements; ++Idx) {
499 int MaskIdx = Idx * 2;
500 if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))
501 return false;
502 }
503
504 return true;
505}
506
508 int Offset = Mask[0];
509 int HalfNumElements = Mask.size() / 2;
510
511 for (int Idx = 1; Idx < HalfNumElements; ++Idx) {
512 if (Mask[Idx] != (Idx * 2) + Offset)
513 return false;
514 }
515
516 return true;
517}
518
519bool isNeg(Value *V) {
520 return match(V, m_FNeg(m_Value())) || match(V, m_Neg(m_Value()));
521}
522
524 assert(isNeg(V));
525 auto *I = cast<Instruction>(V);
526 if (I->getOpcode() == Instruction::FNeg)
527 return I->getOperand(0);
528
529 return I->getOperand(1);
530}
531
532bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
533 ComplexDeinterleavingGraph Graph(TL, TLI);
534 if (Graph.collectPotentialReductions(B))
535 Graph.identifyReductionNodes();
536
537 for (auto &I : *B)
538 Graph.identifyNodes(&I);
539
540 if (Graph.checkNodes()) {
541 Graph.replaceNodes();
542 return true;
543 }
544
545 return false;
546}
547
548ComplexDeinterleavingGraph::NodePtr
549ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
550 Instruction *Real, Instruction *Imag,
551 std::pair<Value *, Value *> &PartialMatch) {
552 LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag
553 << "\n");
554
555 if (!Real->hasOneUse() || !Imag->hasOneUse()) {
556 LLVM_DEBUG(dbgs() << " - Mul operand has multiple uses.\n");
557 return nullptr;
558 }
559
560 if ((Real->getOpcode() != Instruction::FMul &&
561 Real->getOpcode() != Instruction::Mul) ||
562 (Imag->getOpcode() != Instruction::FMul &&
563 Imag->getOpcode() != Instruction::Mul)) {
565 dbgs() << " - Real or imaginary instruction is not fmul or mul\n");
566 return nullptr;
567 }
568
569 Value *R0 = Real->getOperand(0);
570 Value *R1 = Real->getOperand(1);
571 Value *I0 = Imag->getOperand(0);
572 Value *I1 = Imag->getOperand(1);
573
574 // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the
575 // rotations and use the operand.
576 unsigned Negs = 0;
577 Value *Op;
578 if (match(R0, m_Neg(m_Value(Op)))) {
579 Negs |= 1;
580 R0 = Op;
581 } else if (match(R1, m_Neg(m_Value(Op)))) {
582 Negs |= 1;
583 R1 = Op;
584 }
585
586 if (isNeg(I0)) {
587 Negs |= 2;
588 Negs ^= 1;
589 I0 = Op;
590 } else if (match(I1, m_Neg(m_Value(Op)))) {
591 Negs |= 2;
592 Negs ^= 1;
593 I1 = Op;
594 }
595
597
598 Value *CommonOperand;
599 Value *UncommonRealOp;
600 Value *UncommonImagOp;
601
602 if (R0 == I0 || R0 == I1) {
603 CommonOperand = R0;
604 UncommonRealOp = R1;
605 } else if (R1 == I0 || R1 == I1) {
606 CommonOperand = R1;
607 UncommonRealOp = R0;
608 } else {
609 LLVM_DEBUG(dbgs() << " - No equal operand\n");
610 return nullptr;
611 }
612
613 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
614 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
615 Rotation == ComplexDeinterleavingRotation::Rotation_270)
616 std::swap(UncommonRealOp, UncommonImagOp);
617
618 // Between identifyPartialMul and here we need to have found a complete valid
619 // pair from the CommonOperand of each part.
620 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
621 Rotation == ComplexDeinterleavingRotation::Rotation_180)
622 PartialMatch.first = CommonOperand;
623 else
624 PartialMatch.second = CommonOperand;
625
626 if (!PartialMatch.first || !PartialMatch.second) {
627 LLVM_DEBUG(dbgs() << " - Incomplete partial match\n");
628 return nullptr;
629 }
630
631 NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second);
632 if (!CommonNode) {
633 LLVM_DEBUG(dbgs() << " - No CommonNode identified\n");
634 return nullptr;
635 }
636
637 NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
638 if (!UncommonNode) {
639 LLVM_DEBUG(dbgs() << " - No UncommonNode identified\n");
640 return nullptr;
641 }
642
643 NodePtr Node = prepareCompositeNode(
644 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
645 Node->Rotation = Rotation;
646 Node->addOperand(CommonNode);
647 Node->addOperand(UncommonNode);
648 return submitCompositeNode(Node);
649}
650
651ComplexDeinterleavingGraph::NodePtr
652ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
653 Instruction *Imag) {
654 LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag
655 << "\n");
656 // Determine rotation
657 auto IsAdd = [](unsigned Op) {
658 return Op == Instruction::FAdd || Op == Instruction::Add;
659 };
660 auto IsSub = [](unsigned Op) {
661 return Op == Instruction::FSub || Op == Instruction::Sub;
662 };
664 if (IsAdd(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
665 Rotation = ComplexDeinterleavingRotation::Rotation_0;
666 else if (IsSub(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
667 Rotation = ComplexDeinterleavingRotation::Rotation_90;
668 else if (IsSub(Real->getOpcode()) && IsSub(Imag->getOpcode()))
669 Rotation = ComplexDeinterleavingRotation::Rotation_180;
670 else if (IsAdd(Real->getOpcode()) && IsSub(Imag->getOpcode()))
671 Rotation = ComplexDeinterleavingRotation::Rotation_270;
672 else {
673 LLVM_DEBUG(dbgs() << " - Unhandled rotation.\n");
674 return nullptr;
675 }
676
677 if (isa<FPMathOperator>(Real) &&
678 (!Real->getFastMathFlags().allowContract() ||
679 !Imag->getFastMathFlags().allowContract())) {
680 LLVM_DEBUG(dbgs() << " - Contract is missing from the FastMath flags.\n");
681 return nullptr;
682 }
683
684 Value *CR = Real->getOperand(0);
685 Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1));
686 if (!RealMulI)
687 return nullptr;
688 Value *CI = Imag->getOperand(0);
689 Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1));
690 if (!ImagMulI)
691 return nullptr;
692
693 if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) {
694 LLVM_DEBUG(dbgs() << " - Mul instruction has multiple uses\n");
695 return nullptr;
696 }
697
698 Value *R0 = RealMulI->getOperand(0);
699 Value *R1 = RealMulI->getOperand(1);
700 Value *I0 = ImagMulI->getOperand(0);
701 Value *I1 = ImagMulI->getOperand(1);
702
703 Value *CommonOperand;
704 Value *UncommonRealOp;
705 Value *UncommonImagOp;
706
707 if (R0 == I0 || R0 == I1) {
708 CommonOperand = R0;
709 UncommonRealOp = R1;
710 } else if (R1 == I0 || R1 == I1) {
711 CommonOperand = R1;
712 UncommonRealOp = R0;
713 } else {
714 LLVM_DEBUG(dbgs() << " - No equal operand\n");
715 return nullptr;
716 }
717
718 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
719 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
720 Rotation == ComplexDeinterleavingRotation::Rotation_270)
721 std::swap(UncommonRealOp, UncommonImagOp);
722
723 std::pair<Value *, Value *> PartialMatch(
724 (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
725 Rotation == ComplexDeinterleavingRotation::Rotation_180)
726 ? CommonOperand
727 : nullptr,
728 (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
729 Rotation == ComplexDeinterleavingRotation::Rotation_270)
730 ? CommonOperand
731 : nullptr);
732
733 auto *CRInst = dyn_cast<Instruction>(CR);
734 auto *CIInst = dyn_cast<Instruction>(CI);
735
736 if (!CRInst || !CIInst) {
737 LLVM_DEBUG(dbgs() << " - Common operands are not instructions.\n");
738 return nullptr;
739 }
740
741 NodePtr CNode = identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch);
742 if (!CNode) {
743 LLVM_DEBUG(dbgs() << " - No cnode identified\n");
744 return nullptr;
745 }
746
747 NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
748 if (!UncommonRes) {
749 LLVM_DEBUG(dbgs() << " - No UncommonRes identified\n");
750 return nullptr;
751 }
752
753 assert(PartialMatch.first && PartialMatch.second);
754 NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second);
755 if (!CommonRes) {
756 LLVM_DEBUG(dbgs() << " - No CommonRes identified\n");
757 return nullptr;
758 }
759
760 NodePtr Node = prepareCompositeNode(
761 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
762 Node->Rotation = Rotation;
763 Node->addOperand(CommonRes);
764 Node->addOperand(UncommonRes);
765 Node->addOperand(CNode);
766 return submitCompositeNode(Node);
767}
768
769ComplexDeinterleavingGraph::NodePtr
770ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
771 LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n");
772
773 // Determine rotation
775 if ((Real->getOpcode() == Instruction::FSub &&
776 Imag->getOpcode() == Instruction::FAdd) ||
777 (Real->getOpcode() == Instruction::Sub &&
778 Imag->getOpcode() == Instruction::Add))
779 Rotation = ComplexDeinterleavingRotation::Rotation_90;
780 else if ((Real->getOpcode() == Instruction::FAdd &&
781 Imag->getOpcode() == Instruction::FSub) ||
782 (Real->getOpcode() == Instruction::Add &&
783 Imag->getOpcode() == Instruction::Sub))
784 Rotation = ComplexDeinterleavingRotation::Rotation_270;
785 else {
786 LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n");
787 return nullptr;
788 }
789
790 auto *AR = dyn_cast<Instruction>(Real->getOperand(0));
791 auto *BI = dyn_cast<Instruction>(Real->getOperand(1));
792 auto *AI = dyn_cast<Instruction>(Imag->getOperand(0));
793 auto *BR = dyn_cast<Instruction>(Imag->getOperand(1));
794
795 if (!AR || !AI || !BR || !BI) {
796 LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n");
797 return nullptr;
798 }
799
800 NodePtr ResA = identifyNode(AR, AI);
801 if (!ResA) {
802 LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
803 return nullptr;
804 }
805 NodePtr ResB = identifyNode(BR, BI);
806 if (!ResB) {
807 LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");
808 return nullptr;
809 }
810
811 NodePtr Node =
812 prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
813 Node->Rotation = Rotation;
814 Node->addOperand(ResA);
815 Node->addOperand(ResB);
816 return submitCompositeNode(Node);
817}
818
820 unsigned OpcA = A->getOpcode();
821 unsigned OpcB = B->getOpcode();
822
823 return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
824 (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
825 (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
826 (OpcA == Instruction::Add && OpcB == Instruction::Sub);
827}
828
830 auto Pattern =
832
833 return match(A, Pattern) && match(B, Pattern);
834}
835
837 switch (I->getOpcode()) {
838 case Instruction::FAdd:
839 case Instruction::FSub:
840 case Instruction::FMul:
841 case Instruction::FNeg:
842 case Instruction::Add:
843 case Instruction::Sub:
844 case Instruction::Mul:
845 return true;
846 default:
847 return false;
848 }
849}
850
851ComplexDeinterleavingGraph::NodePtr
852ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real,
853 Instruction *Imag) {
854 if (Real->getOpcode() != Imag->getOpcode())
855 return nullptr;
856
859 return nullptr;
860
861 auto *R0 = Real->getOperand(0);
862 auto *I0 = Imag->getOperand(0);
863
864 NodePtr Op0 = identifyNode(R0, I0);
865 NodePtr Op1 = nullptr;
866 if (Op0 == nullptr)
867 return nullptr;
868
869 if (Real->isBinaryOp()) {
870 auto *R1 = Real->getOperand(1);
871 auto *I1 = Imag->getOperand(1);
872 Op1 = identifyNode(R1, I1);
873 if (Op1 == nullptr)
874 return nullptr;
875 }
876
877 if (isa<FPMathOperator>(Real) &&
878 Real->getFastMathFlags() != Imag->getFastMathFlags())
879 return nullptr;
880
881 auto Node = prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric,
882 Real, Imag);
883 Node->Opcode = Real->getOpcode();
884 if (isa<FPMathOperator>(Real))
885 Node->Flags = Real->getFastMathFlags();
886
887 Node->addOperand(Op0);
888 if (Real->isBinaryOp())
889 Node->addOperand(Op1);
890
891 return submitCompositeNode(Node);
892}
893
894ComplexDeinterleavingGraph::NodePtr
895ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) {
896 LLVM_DEBUG(dbgs() << "identifyNode on " << *R << " / " << *I << "\n");
897 assert(R->getType() == I->getType() &&
898 "Real and imaginary parts should not have different types");
899
900 auto It = CachedResult.find({R, I});
901 if (It != CachedResult.end()) {
902 LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
903 return It->second;
904 }
905
906 if (NodePtr CN = identifySplat(R, I))
907 return CN;
908
909 auto *Real = dyn_cast<Instruction>(R);
910 auto *Imag = dyn_cast<Instruction>(I);
911 if (!Real || !Imag)
912 return nullptr;
913
914 if (NodePtr CN = identifyDeinterleave(Real, Imag))
915 return CN;
916
917 if (NodePtr CN = identifyPHINode(Real, Imag))
918 return CN;
919
920 if (NodePtr CN = identifySelectNode(Real, Imag))
921 return CN;
922
923 auto *VTy = cast<VectorType>(Real->getType());
924 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
925
926 bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported(
927 ComplexDeinterleavingOperation::CMulPartial, NewVTy);
928 bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported(
929 ComplexDeinterleavingOperation::CAdd, NewVTy);
930
931 if (HasCMulSupport && isInstructionPairMul(Real, Imag)) {
932 if (NodePtr CN = identifyPartialMul(Real, Imag))
933 return CN;
934 }
935
936 if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) {
937 if (NodePtr CN = identifyAdd(Real, Imag))
938 return CN;
939 }
940
941 if (HasCMulSupport && HasCAddSupport) {
942 if (NodePtr CN = identifyReassocNodes(Real, Imag))
943 return CN;
944 }
945
946 if (NodePtr CN = identifySymmetricOperation(Real, Imag))
947 return CN;
948
949 LLVM_DEBUG(dbgs() << " - Not recognised as a valid pattern.\n");
950 CachedResult[{R, I}] = nullptr;
951 return nullptr;
952}
953
954ComplexDeinterleavingGraph::NodePtr
955ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
956 Instruction *Imag) {
957 auto IsOperationSupported = [](unsigned Opcode) -> bool {
958 return Opcode == Instruction::FAdd || Opcode == Instruction::FSub ||
959 Opcode == Instruction::FNeg || Opcode == Instruction::Add ||
960 Opcode == Instruction::Sub;
961 };
962
963 if (!IsOperationSupported(Real->getOpcode()) ||
964 !IsOperationSupported(Imag->getOpcode()))
965 return nullptr;
966
967 std::optional<FastMathFlags> Flags;
968 if (isa<FPMathOperator>(Real)) {
969 if (Real->getFastMathFlags() != Imag->getFastMathFlags()) {
970 LLVM_DEBUG(dbgs() << "The flags in Real and Imaginary instructions are "
971 "not identical\n");
972 return nullptr;
973 }
974
975 Flags = Real->getFastMathFlags();
976 if (!Flags->allowReassoc()) {
978 dbgs()
979 << "the 'Reassoc' attribute is missing in the FastMath flags\n");
980 return nullptr;
981 }
982 }
983
984 // Collect multiplications and addend instructions from the given instruction
985 // while traversing it operands. Additionally, verify that all instructions
986 // have the same fast math flags.
987 auto Collect = [&Flags](Instruction *Insn, std::vector<Product> &Muls,
988 std::list<Addend> &Addends) -> bool {
991 while (!Worklist.empty()) {
992 auto [V, IsPositive] = Worklist.back();
993 Worklist.pop_back();
994 if (!Visited.insert(V).second)
995 continue;
996
997 Instruction *I = dyn_cast<Instruction>(V);
998 if (!I) {
999 Addends.emplace_back(V, IsPositive);
1000 continue;
1001 }
1002
1003 // If an instruction has more than one user, it indicates that it either
1004 // has an external user, which will be later checked by the checkNodes
1005 // function, or it is a subexpression utilized by multiple expressions. In
1006 // the latter case, we will attempt to separately identify the complex
1007 // operation from here in order to create a shared
1008 // ComplexDeinterleavingCompositeNode.
1009 if (I != Insn && I->getNumUses() > 1) {
1010 LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n");
1011 Addends.emplace_back(I, IsPositive);
1012 continue;
1013 }
1014 switch (I->getOpcode()) {
1015 case Instruction::FAdd:
1016 case Instruction::Add:
1017 Worklist.emplace_back(I->getOperand(1), IsPositive);
1018 Worklist.emplace_back(I->getOperand(0), IsPositive);
1019 break;
1020 case Instruction::FSub:
1021 Worklist.emplace_back(I->getOperand(1), !IsPositive);
1022 Worklist.emplace_back(I->getOperand(0), IsPositive);
1023 break;
1024 case Instruction::Sub:
1025 if (isNeg(I)) {
1026 Worklist.emplace_back(getNegOperand(I), !IsPositive);
1027 } else {
1028 Worklist.emplace_back(I->getOperand(1), !IsPositive);
1029 Worklist.emplace_back(I->getOperand(0), IsPositive);
1030 }
1031 break;
1032 case Instruction::FMul:
1033 case Instruction::Mul: {
1034 Value *A, *B;
1035 if (isNeg(I->getOperand(0))) {
1036 A = getNegOperand(I->getOperand(0));
1037 IsPositive = !IsPositive;
1038 } else {
1039 A = I->getOperand(0);
1040 }
1041
1042 if (isNeg(I->getOperand(1))) {
1043 B = getNegOperand(I->getOperand(1));
1044 IsPositive = !IsPositive;
1045 } else {
1046 B = I->getOperand(1);
1047 }
1048 Muls.push_back(Product{A, B, IsPositive});
1049 break;
1050 }
1051 case Instruction::FNeg:
1052 Worklist.emplace_back(I->getOperand(0), !IsPositive);
1053 break;
1054 default:
1055 Addends.emplace_back(I, IsPositive);
1056 continue;
1057 }
1058
1059 if (Flags && I->getFastMathFlags() != *Flags) {
1060 LLVM_DEBUG(dbgs() << "The instruction's fast math flags are "
1061 "inconsistent with the root instructions' flags: "
1062 << *I << "\n");
1063 return false;
1064 }
1065 }
1066 return true;
1067 };
1068
1069 std::vector<Product> RealMuls, ImagMuls;
1070 std::list<Addend> RealAddends, ImagAddends;
1071 if (!Collect(Real, RealMuls, RealAddends) ||
1072 !Collect(Imag, ImagMuls, ImagAddends))
1073 return nullptr;
1074
1075 if (RealAddends.size() != ImagAddends.size())
1076 return nullptr;
1077
1078 NodePtr FinalNode;
1079 if (!RealMuls.empty() || !ImagMuls.empty()) {
1080 // If there are multiplicands, extract positive addend and use it as an
1081 // accumulator
1082 FinalNode = extractPositiveAddend(RealAddends, ImagAddends);
1083 FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode);
1084 if (!FinalNode)
1085 return nullptr;
1086 }
1087
1088 // Identify and process remaining additions
1089 if (!RealAddends.empty() || !ImagAddends.empty()) {
1090 FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode);
1091 if (!FinalNode)
1092 return nullptr;
1093 }
1094 assert(FinalNode && "FinalNode can not be nullptr here");
1095 // Set the Real and Imag fields of the final node and submit it
1096 FinalNode->Real = Real;
1097 FinalNode->Imag = Imag;
1098 submitCompositeNode(FinalNode);
1099 return FinalNode;
1100}
1101
1102bool ComplexDeinterleavingGraph::collectPartialMuls(
1103 const std::vector<Product> &RealMuls, const std::vector<Product> &ImagMuls,
1104 std::vector<PartialMulCandidate> &PartialMulCandidates) {
1105 // Helper function to extract a common operand from two products
1106 auto FindCommonInstruction = [](const Product &Real,
1107 const Product &Imag) -> Value * {
1108 if (Real.Multiplicand == Imag.Multiplicand ||
1109 Real.Multiplicand == Imag.Multiplier)
1110 return Real.Multiplicand;
1111
1112 if (Real.Multiplier == Imag.Multiplicand ||
1113 Real.Multiplier == Imag.Multiplier)
1114 return Real.Multiplier;
1115
1116 return nullptr;
1117 };
1118
1119 // Iterating over real and imaginary multiplications to find common operands
1120 // If a common operand is found, a partial multiplication candidate is created
1121 // and added to the candidates vector The function returns false if no common
1122 // operands are found for any product
1123 for (unsigned i = 0; i < RealMuls.size(); ++i) {
1124 bool FoundCommon = false;
1125 for (unsigned j = 0; j < ImagMuls.size(); ++j) {
1126 auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]);
1127 if (!Common)
1128 continue;
1129
1130 auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier
1131 : RealMuls[i].Multiplicand;
1132 auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier
1133 : ImagMuls[j].Multiplicand;
1134
1135 auto Node = identifyNode(A, B);
1136 if (Node) {
1137 FoundCommon = true;
1138 PartialMulCandidates.push_back({Common, Node, i, j, false});
1139 }
1140
1141 Node = identifyNode(B, A);
1142 if (Node) {
1143 FoundCommon = true;
1144 PartialMulCandidates.push_back({Common, Node, i, j, true});
1145 }
1146 }
1147 if (!FoundCommon)
1148 return false;
1149 }
1150 return true;
1151}
1152
1153ComplexDeinterleavingGraph::NodePtr
1154ComplexDeinterleavingGraph::identifyMultiplications(
1155 std::vector<Product> &RealMuls, std::vector<Product> &ImagMuls,
1156 NodePtr Accumulator = nullptr) {
1157 if (RealMuls.size() != ImagMuls.size())
1158 return nullptr;
1159
1160 std::vector<PartialMulCandidate> Info;
1161 if (!collectPartialMuls(RealMuls, ImagMuls, Info))
1162 return nullptr;
1163
1164 // Map to store common instruction to node pointers
1165 std::map<Value *, NodePtr> CommonToNode;
1166 std::vector<bool> Processed(Info.size(), false);
1167 for (unsigned I = 0; I < Info.size(); ++I) {
1168 if (Processed[I])
1169 continue;
1170
1171 PartialMulCandidate &InfoA = Info[I];
1172 for (unsigned J = I + 1; J < Info.size(); ++J) {
1173 if (Processed[J])
1174 continue;
1175
1176 PartialMulCandidate &InfoB = Info[J];
1177 auto *InfoReal = &InfoA;
1178 auto *InfoImag = &InfoB;
1179
1180 auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1181 if (!NodeFromCommon) {
1182 std::swap(InfoReal, InfoImag);
1183 NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1184 }
1185 if (!NodeFromCommon)
1186 continue;
1187
1188 CommonToNode[InfoReal->Common] = NodeFromCommon;
1189 CommonToNode[InfoImag->Common] = NodeFromCommon;
1190 Processed[I] = true;
1191 Processed[J] = true;
1192 }
1193 }
1194
1195 std::vector<bool> ProcessedReal(RealMuls.size(), false);
1196 std::vector<bool> ProcessedImag(ImagMuls.size(), false);
1197 NodePtr Result = Accumulator;
1198 for (auto &PMI : Info) {
1199 if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx])
1200 continue;
1201
1202 auto It = CommonToNode.find(PMI.Common);
1203 // TODO: Process independent complex multiplications. Cases like this:
1204 // A.real() * B where both A and B are complex numbers.
1205 if (It == CommonToNode.end()) {
1206 LLVM_DEBUG({
1207 dbgs() << "Unprocessed independent partial multiplication:\n";
1208 for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]})
1209 dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier
1210 << " multiplied by " << *Mul->Multiplicand << "\n";
1211 });
1212 return nullptr;
1213 }
1214
1215 auto &RealMul = RealMuls[PMI.RealIdx];
1216 auto &ImagMul = ImagMuls[PMI.ImagIdx];
1217
1218 auto NodeA = It->second;
1219 auto NodeB = PMI.Node;
1220 auto IsMultiplicandReal = PMI.Common == NodeA->Real;
1221 // The following table illustrates the relationship between multiplications
1222 // and rotations. If we consider the multiplication (X + iY) * (U + iV), we
1223 // can see:
1224 //
1225 // Rotation | Real | Imag |
1226 // ---------+--------+--------+
1227 // 0 | x * u | x * v |
1228 // 90 | -y * v | y * u |
1229 // 180 | -x * u | -x * v |
1230 // 270 | y * v | -y * u |
1231 //
1232 // Check if the candidate can indeed be represented by partial
1233 // multiplication
1234 // TODO: Add support for multiplication by complex one
1235 if ((IsMultiplicandReal && PMI.IsNodeInverted) ||
1236 (!IsMultiplicandReal && !PMI.IsNodeInverted))
1237 continue;
1238
1239 // Determine the rotation based on the multiplications
1241 if (IsMultiplicandReal) {
1242 // Detect 0 and 180 degrees rotation
1243 if (RealMul.IsPositive && ImagMul.IsPositive)
1245 else if (!RealMul.IsPositive && !ImagMul.IsPositive)
1247 else
1248 continue;
1249
1250 } else {
1251 // Detect 90 and 270 degrees rotation
1252 if (!RealMul.IsPositive && ImagMul.IsPositive)
1254 else if (RealMul.IsPositive && !ImagMul.IsPositive)
1256 else
1257 continue;
1258 }
1259
1260 LLVM_DEBUG({
1261 dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n";
1262 dbgs().indent(4) << "X: " << *NodeA->Real << "\n";
1263 dbgs().indent(4) << "Y: " << *NodeA->Imag << "\n";
1264 dbgs().indent(4) << "U: " << *NodeB->Real << "\n";
1265 dbgs().indent(4) << "V: " << *NodeB->Imag << "\n";
1266 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1267 });
1268
1269 NodePtr NodeMul = prepareCompositeNode(
1270 ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr);
1271 NodeMul->Rotation = Rotation;
1272 NodeMul->addOperand(NodeA);
1273 NodeMul->addOperand(NodeB);
1274 if (Result)
1275 NodeMul->addOperand(Result);
1276 submitCompositeNode(NodeMul);
1277 Result = NodeMul;
1278 ProcessedReal[PMI.RealIdx] = true;
1279 ProcessedImag[PMI.ImagIdx] = true;
1280 }
1281
1282 // Ensure all products have been processed, if not return nullptr.
1283 if (!all_of(ProcessedReal, [](bool V) { return V; }) ||
1284 !all_of(ProcessedImag, [](bool V) { return V; })) {
1285
1286 // Dump debug information about which partial multiplications are not
1287 // processed.
1288 LLVM_DEBUG({
1289 dbgs() << "Unprocessed products (Real):\n";
1290 for (size_t i = 0; i < ProcessedReal.size(); ++i) {
1291 if (!ProcessedReal[i])
1292 dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-")
1293 << *RealMuls[i].Multiplier << " multiplied by "
1294 << *RealMuls[i].Multiplicand << "\n";
1295 }
1296 dbgs() << "Unprocessed products (Imag):\n";
1297 for (size_t i = 0; i < ProcessedImag.size(); ++i) {
1298 if (!ProcessedImag[i])
1299 dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-")
1300 << *ImagMuls[i].Multiplier << " multiplied by "
1301 << *ImagMuls[i].Multiplicand << "\n";
1302 }
1303 });
1304 return nullptr;
1305 }
1306
1307 return Result;
1308}
1309
1310ComplexDeinterleavingGraph::NodePtr
1311ComplexDeinterleavingGraph::identifyAdditions(
1312 std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends,
1313 std::optional<FastMathFlags> Flags, NodePtr Accumulator = nullptr) {
1314 if (RealAddends.size() != ImagAddends.size())
1315 return nullptr;
1316
1317 NodePtr Result;
1318 // If we have accumulator use it as first addend
1319 if (Accumulator)
1321 // Otherwise find an element with both positive real and imaginary parts.
1322 else
1323 Result = extractPositiveAddend(RealAddends, ImagAddends);
1324
1325 if (!Result)
1326 return nullptr;
1327
1328 while (!RealAddends.empty()) {
1329 auto ItR = RealAddends.begin();
1330 auto [R, IsPositiveR] = *ItR;
1331
1332 bool FoundImag = false;
1333 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1334 auto [I, IsPositiveI] = *ItI;
1336 if (IsPositiveR && IsPositiveI)
1337 Rotation = ComplexDeinterleavingRotation::Rotation_0;
1338 else if (!IsPositiveR && IsPositiveI)
1339 Rotation = ComplexDeinterleavingRotation::Rotation_90;
1340 else if (!IsPositiveR && !IsPositiveI)
1341 Rotation = ComplexDeinterleavingRotation::Rotation_180;
1342 else
1343 Rotation = ComplexDeinterleavingRotation::Rotation_270;
1344
1345 NodePtr AddNode;
1346 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
1347 Rotation == ComplexDeinterleavingRotation::Rotation_180) {
1348 AddNode = identifyNode(R, I);
1349 } else {
1350 AddNode = identifyNode(I, R);
1351 }
1352 if (AddNode) {
1353 LLVM_DEBUG({
1354 dbgs() << "Identified addition:\n";
1355 dbgs().indent(4) << "X: " << *R << "\n";
1356 dbgs().indent(4) << "Y: " << *I << "\n";
1357 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1358 });
1359
1360 NodePtr TmpNode;
1362 TmpNode = prepareCompositeNode(
1363 ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1364 if (Flags) {
1365 TmpNode->Opcode = Instruction::FAdd;
1366 TmpNode->Flags = *Flags;
1367 } else {
1368 TmpNode->Opcode = Instruction::Add;
1369 }
1370 } else if (Rotation ==
1372 TmpNode = prepareCompositeNode(
1373 ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1374 if (Flags) {
1375 TmpNode->Opcode = Instruction::FSub;
1376 TmpNode->Flags = *Flags;
1377 } else {
1378 TmpNode->Opcode = Instruction::Sub;
1379 }
1380 } else {
1381 TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd,
1382 nullptr, nullptr);
1383 TmpNode->Rotation = Rotation;
1384 }
1385
1386 TmpNode->addOperand(Result);
1387 TmpNode->addOperand(AddNode);
1388 submitCompositeNode(TmpNode);
1389 Result = TmpNode;
1390 RealAddends.erase(ItR);
1391 ImagAddends.erase(ItI);
1392 FoundImag = true;
1393 break;
1394 }
1395 }
1396 if (!FoundImag)
1397 return nullptr;
1398 }
1399 return Result;
1400}
1401
1402ComplexDeinterleavingGraph::NodePtr
1403ComplexDeinterleavingGraph::extractPositiveAddend(
1404 std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends) {
1405 for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) {
1406 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1407 auto [R, IsPositiveR] = *ItR;
1408 auto [I, IsPositiveI] = *ItI;
1409 if (IsPositiveR && IsPositiveI) {
1410 auto Result = identifyNode(R, I);
1411 if (Result) {
1412 RealAddends.erase(ItR);
1413 ImagAddends.erase(ItI);
1414 return Result;
1415 }
1416 }
1417 }
1418 }
1419 return nullptr;
1420}
1421
1422bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
1423 // This potential root instruction might already have been recognized as
1424 // reduction. Because RootToNode maps both Real and Imaginary parts to
1425 // CompositeNode we should choose only one either Real or Imag instruction to
1426 // use as an anchor for generating complex instruction.
1427 auto It = RootToNode.find(RootI);
1428 if (It != RootToNode.end()) {
1429 auto RootNode = It->second;
1430 assert(RootNode->Operation ==
1431 ComplexDeinterleavingOperation::ReductionOperation);
1432 // Find out which part, Real or Imag, comes later, and only if we come to
1433 // the latest part, add it to OrderedRoots.
1434 auto *R = cast<Instruction>(RootNode->Real);
1435 auto *I = cast<Instruction>(RootNode->Imag);
1436 auto *ReplacementAnchor = R->comesBefore(I) ? I : R;
1437 if (ReplacementAnchor != RootI)
1438 return false;
1439 OrderedRoots.push_back(RootI);
1440 return true;
1441 }
1442
1443 auto RootNode = identifyRoot(RootI);
1444 if (!RootNode)
1445 return false;
1446
1447 LLVM_DEBUG({
1448 Function *F = RootI->getFunction();
1449 BasicBlock *B = RootI->getParent();
1450 dbgs() << "Complex deinterleaving graph for " << F->getName()
1451 << "::" << B->getName() << ".\n";
1452 dump(dbgs());
1453 dbgs() << "\n";
1454 });
1455 RootToNode[RootI] = RootNode;
1456 OrderedRoots.push_back(RootI);
1457 return true;
1458}
1459
1460bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) {
1461 bool FoundPotentialReduction = false;
1462
1463 auto *Br = dyn_cast<BranchInst>(B->getTerminator());
1464 if (!Br || Br->getNumSuccessors() != 2)
1465 return false;
1466
1467 // Identify simple one-block loop
1468 if (Br->getSuccessor(0) != B && Br->getSuccessor(1) != B)
1469 return false;
1470
1472 for (auto &PHI : B->phis()) {
1473 if (PHI.getNumIncomingValues() != 2)
1474 continue;
1475
1476 if (!PHI.getType()->isVectorTy())
1477 continue;
1478
1479 auto *ReductionOp = dyn_cast<Instruction>(PHI.getIncomingValueForBlock(B));
1480 if (!ReductionOp)
1481 continue;
1482
1483 // Check if final instruction is reduced outside of current block
1484 Instruction *FinalReduction = nullptr;
1485 auto NumUsers = 0u;
1486 for (auto *U : ReductionOp->users()) {
1487 ++NumUsers;
1488 if (U == &PHI)
1489 continue;
1490 FinalReduction = dyn_cast<Instruction>(U);
1491 }
1492
1493 if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B ||
1494 isa<PHINode>(FinalReduction))
1495 continue;
1496
1497 ReductionInfo[ReductionOp] = {&PHI, FinalReduction};
1498 BackEdge = B;
1499 auto BackEdgeIdx = PHI.getBasicBlockIndex(B);
1500 auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0;
1501 Incoming = PHI.getIncomingBlock(IncomingIdx);
1502 FoundPotentialReduction = true;
1503
1504 // If the initial value of PHINode is an Instruction, consider it a leaf
1505 // value of a complex deinterleaving graph.
1506 if (auto *InitPHI =
1507 dyn_cast<Instruction>(PHI.getIncomingValueForBlock(Incoming)))
1508 FinalInstructions.insert(InitPHI);
1509 }
1510 return FoundPotentialReduction;
1511}
1512
1513void ComplexDeinterleavingGraph::identifyReductionNodes() {
1514 SmallVector<bool> Processed(ReductionInfo.size(), false);
1515 SmallVector<Instruction *> OperationInstruction;
1516 for (auto &P : ReductionInfo)
1517 OperationInstruction.push_back(P.first);
1518
1519 // Identify a complex computation by evaluating two reduction operations that
1520 // potentially could be involved
1521 for (size_t i = 0; i < OperationInstruction.size(); ++i) {
1522 if (Processed[i])
1523 continue;
1524 for (size_t j = i + 1; j < OperationInstruction.size(); ++j) {
1525 if (Processed[j])
1526 continue;
1527
1528 auto *Real = OperationInstruction[i];
1529 auto *Imag = OperationInstruction[j];
1530 if (Real->getType() != Imag->getType())
1531 continue;
1532
1533 RealPHI = ReductionInfo[Real].first;
1534 ImagPHI = ReductionInfo[Imag].first;
1535 PHIsFound = false;
1536 auto Node = identifyNode(Real, Imag);
1537 if (!Node) {
1538 std::swap(Real, Imag);
1539 std::swap(RealPHI, ImagPHI);
1540 Node = identifyNode(Real, Imag);
1541 }
1542
1543 // If a node is identified and reduction PHINode is used in the chain of
1544 // operations, mark its operation instructions as used to prevent
1545 // re-identification and attach the node to the real part
1546 if (Node && PHIsFound) {
1547 LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: "
1548 << *Real << " / " << *Imag << "\n");
1549 Processed[i] = true;
1550 Processed[j] = true;
1551 auto RootNode = prepareCompositeNode(
1552 ComplexDeinterleavingOperation::ReductionOperation, Real, Imag);
1553 RootNode->addOperand(Node);
1554 RootToNode[Real] = RootNode;
1555 RootToNode[Imag] = RootNode;
1556 submitCompositeNode(RootNode);
1557 break;
1558 }
1559 }
1560 }
1561
1562 RealPHI = nullptr;
1563 ImagPHI = nullptr;
1564}
1565
1566bool ComplexDeinterleavingGraph::checkNodes() {
1567 // Collect all instructions from roots to leaves
1568 SmallPtrSet<Instruction *, 16> AllInstructions;
1570 for (auto &Pair : RootToNode)
1571 Worklist.push_back(Pair.first);
1572
1573 // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG
1574 // chains
1575 while (!Worklist.empty()) {
1576 auto *I = Worklist.back();
1577 Worklist.pop_back();
1578
1579 if (!AllInstructions.insert(I).second)
1580 continue;
1581
1582 for (Value *Op : I->operands()) {
1583 if (auto *OpI = dyn_cast<Instruction>(Op)) {
1584 if (!FinalInstructions.count(I))
1585 Worklist.emplace_back(OpI);
1586 }
1587 }
1588 }
1589
1590 // Find instructions that have users outside of chain
1591 SmallVector<Instruction *, 2> OuterInstructions;
1592 for (auto *I : AllInstructions) {
1593 // Skip root nodes
1594 if (RootToNode.count(I))
1595 continue;
1596
1597 for (User *U : I->users()) {
1598 if (AllInstructions.count(cast<Instruction>(U)))
1599 continue;
1600
1601 // Found an instruction that is not used by XCMLA/XCADD chain
1602 Worklist.emplace_back(I);
1603 break;
1604 }
1605 }
1606
1607 // If any instructions are found to be used outside, find and remove roots
1608 // that somehow connect to those instructions.
1610 while (!Worklist.empty()) {
1611 auto *I = Worklist.back();
1612 Worklist.pop_back();
1613 if (!Visited.insert(I).second)
1614 continue;
1615
1616 // Found an impacted root node. Removing it from the nodes to be
1617 // deinterleaved
1618 if (RootToNode.count(I)) {
1619 LLVM_DEBUG(dbgs() << "Instruction " << *I
1620 << " could be deinterleaved but its chain of complex "
1621 "operations have an outside user\n");
1622 RootToNode.erase(I);
1623 }
1624
1625 if (!AllInstructions.count(I) || FinalInstructions.count(I))
1626 continue;
1627
1628 for (User *U : I->users())
1629 Worklist.emplace_back(cast<Instruction>(U));
1630
1631 for (Value *Op : I->operands()) {
1632 if (auto *OpI = dyn_cast<Instruction>(Op))
1633 Worklist.emplace_back(OpI);
1634 }
1635 }
1636 return !RootToNode.empty();
1637}
1638
1639ComplexDeinterleavingGraph::NodePtr
1640ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) {
1641 if (auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) {
1642 if (Intrinsic->getIntrinsicID() !=
1643 Intrinsic::experimental_vector_interleave2)
1644 return nullptr;
1645
1646 auto *Real = dyn_cast<Instruction>(Intrinsic->getOperand(0));
1647 auto *Imag = dyn_cast<Instruction>(Intrinsic->getOperand(1));
1648 if (!Real || !Imag)
1649 return nullptr;
1650
1651 return identifyNode(Real, Imag);
1652 }
1653
1654 auto *SVI = dyn_cast<ShuffleVectorInst>(RootI);
1655 if (!SVI)
1656 return nullptr;
1657
1658 // Look for a shufflevector that takes separate vectors of the real and
1659 // imaginary components and recombines them into a single vector.
1660 if (!isInterleavingMask(SVI->getShuffleMask()))
1661 return nullptr;
1662
1663 Instruction *Real;
1664 Instruction *Imag;
1665 if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag))))
1666 return nullptr;
1667
1668 return identifyNode(Real, Imag);
1669}
1670
1671ComplexDeinterleavingGraph::NodePtr
1672ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real,
1673 Instruction *Imag) {
1674 Instruction *I = nullptr;
1675 Value *FinalValue = nullptr;
1676 if (match(Real, m_ExtractValue<0>(m_Instruction(I))) &&
1677 match(Imag, m_ExtractValue<1>(m_Specific(I))) &&
1678 match(I, m_Intrinsic<Intrinsic::experimental_vector_deinterleave2>(
1679 m_Value(FinalValue)))) {
1680 NodePtr PlaceholderNode = prepareCompositeNode(
1682 PlaceholderNode->ReplacementNode = FinalValue;
1683 FinalInstructions.insert(Real);
1684 FinalInstructions.insert(Imag);
1685 return submitCompositeNode(PlaceholderNode);
1686 }
1687
1688 auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real);
1689 auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag);
1690 if (!RealShuffle || !ImagShuffle) {
1691 if (RealShuffle || ImagShuffle)
1692 LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n");
1693 return nullptr;
1694 }
1695
1696 Value *RealOp1 = RealShuffle->getOperand(1);
1697 if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) {
1698 LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n");
1699 return nullptr;
1700 }
1701 Value *ImagOp1 = ImagShuffle->getOperand(1);
1702 if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) {
1703 LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n");
1704 return nullptr;
1705 }
1706
1707 Value *RealOp0 = RealShuffle->getOperand(0);
1708 Value *ImagOp0 = ImagShuffle->getOperand(0);
1709
1710 if (RealOp0 != ImagOp0) {
1711 LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n");
1712 return nullptr;
1713 }
1714
1715 ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
1716 ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
1717 if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) {
1718 LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n");
1719 return nullptr;
1720 }
1721
1722 if (RealMask[0] != 0 || ImagMask[0] != 1) {
1723 LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n");
1724 return nullptr;
1725 }
1726
1727 // Type checking, the shuffle type should be a vector type of the same
1728 // scalar type, but half the size
1729 auto CheckType = [&](ShuffleVectorInst *Shuffle) {
1730 Value *Op = Shuffle->getOperand(0);
1731 auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType());
1732 auto *OpTy = cast<FixedVectorType>(Op->getType());
1733
1734 if (OpTy->getScalarType() != ShuffleTy->getScalarType())
1735 return false;
1736 if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
1737 return false;
1738
1739 return true;
1740 };
1741
1742 auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool {
1743 if (!CheckType(Shuffle))
1744 return false;
1745
1746 ArrayRef<int> Mask = Shuffle->getShuffleMask();
1747 int Last = *Mask.rbegin();
1748
1749 Value *Op = Shuffle->getOperand(0);
1750 auto *OpTy = cast<FixedVectorType>(Op->getType());
1751 int NumElements = OpTy->getNumElements();
1752
1753 // Ensure that the deinterleaving shuffle only pulls from the first
1754 // shuffle operand.
1755 return Last < NumElements;
1756 };
1757
1758 if (RealShuffle->getType() != ImagShuffle->getType()) {
1759 LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n");
1760 return nullptr;
1761 }
1762 if (!CheckDeinterleavingShuffle(RealShuffle)) {
1763 LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n");
1764 return nullptr;
1765 }
1766 if (!CheckDeinterleavingShuffle(ImagShuffle)) {
1767 LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n");
1768 return nullptr;
1769 }
1770
1771 NodePtr PlaceholderNode =
1773 RealShuffle, ImagShuffle);
1774 PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
1775 FinalInstructions.insert(RealShuffle);
1776 FinalInstructions.insert(ImagShuffle);
1777 return submitCompositeNode(PlaceholderNode);
1778}
1779
1780ComplexDeinterleavingGraph::NodePtr
1781ComplexDeinterleavingGraph::identifySplat(Value *R, Value *I) {
1782 auto IsSplat = [](Value *V) -> bool {
1783 // Fixed-width vector with constants
1784 if (isa<ConstantDataVector>(V))
1785 return true;
1786
1787 VectorType *VTy;
1789 // Splats are represented differently depending on whether the repeated
1790 // value is a constant or an Instruction
1791 if (auto *Const = dyn_cast<ConstantExpr>(V)) {
1792 if (Const->getOpcode() != Instruction::ShuffleVector)
1793 return false;
1794 VTy = cast<VectorType>(Const->getType());
1795 Mask = Const->getShuffleMask();
1796 } else if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V)) {
1797 VTy = Shuf->getType();
1798 Mask = Shuf->getShuffleMask();
1799 } else {
1800 return false;
1801 }
1802
1803 // When the data type is <1 x Type>, it's not possible to differentiate
1804 // between the ComplexDeinterleaving::Deinterleave and
1805 // ComplexDeinterleaving::Splat operations.
1806 if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1)
1807 return false;
1808
1809 return all_equal(Mask) && Mask[0] == 0;
1810 };
1811
1812 if (!IsSplat(R) || !IsSplat(I))
1813 return nullptr;
1814
1815 auto *Real = dyn_cast<Instruction>(R);
1816 auto *Imag = dyn_cast<Instruction>(I);
1817 if ((!Real && Imag) || (Real && !Imag))
1818 return nullptr;
1819
1820 if (Real && Imag) {
1821 // Non-constant splats should be in the same basic block
1822 if (Real->getParent() != Imag->getParent())
1823 return nullptr;
1824
1825 FinalInstructions.insert(Real);
1826 FinalInstructions.insert(Imag);
1827 }
1828 NodePtr PlaceholderNode =
1829 prepareCompositeNode(ComplexDeinterleavingOperation::Splat, R, I);
1830 return submitCompositeNode(PlaceholderNode);
1831}
1832
1833ComplexDeinterleavingGraph::NodePtr
1834ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
1835 Instruction *Imag) {
1836 if (Real != RealPHI || Imag != ImagPHI)
1837 return nullptr;
1838
1839 PHIsFound = true;
1840 NodePtr PlaceholderNode = prepareCompositeNode(
1841 ComplexDeinterleavingOperation::ReductionPHI, Real, Imag);
1842 return submitCompositeNode(PlaceholderNode);
1843}
1844
1845ComplexDeinterleavingGraph::NodePtr
1846ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,
1847 Instruction *Imag) {
1848 auto *SelectReal = dyn_cast<SelectInst>(Real);
1849 auto *SelectImag = dyn_cast<SelectInst>(Imag);
1850 if (!SelectReal || !SelectImag)
1851 return nullptr;
1852
1853 Instruction *MaskA, *MaskB;
1854 Instruction *AR, *AI, *RA, *BI;
1855 if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR),
1856 m_Instruction(RA))) ||
1857 !match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI),
1858 m_Instruction(BI))))
1859 return nullptr;
1860
1861 if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB))
1862 return nullptr;
1863
1864 if (!MaskA->getType()->isVectorTy())
1865 return nullptr;
1866
1867 auto NodeA = identifyNode(AR, AI);
1868 if (!NodeA)
1869 return nullptr;
1870
1871 auto NodeB = identifyNode(RA, BI);
1872 if (!NodeB)
1873 return nullptr;
1874
1875 NodePtr PlaceholderNode = prepareCompositeNode(
1876 ComplexDeinterleavingOperation::ReductionSelect, Real, Imag);
1877 PlaceholderNode->addOperand(NodeA);
1878 PlaceholderNode->addOperand(NodeB);
1879 FinalInstructions.insert(MaskA);
1880 FinalInstructions.insert(MaskB);
1881 return submitCompositeNode(PlaceholderNode);
1882}
1883
1884static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,
1885 std::optional<FastMathFlags> Flags,
1886 Value *InputA, Value *InputB) {
1887 Value *I;
1888 switch (Opcode) {
1889 case Instruction::FNeg:
1890 I = B.CreateFNeg(InputA);
1891 break;
1892 case Instruction::FAdd:
1893 I = B.CreateFAdd(InputA, InputB);
1894 break;
1895 case Instruction::Add:
1896 I = B.CreateAdd(InputA, InputB);
1897 break;
1898 case Instruction::FSub:
1899 I = B.CreateFSub(InputA, InputB);
1900 break;
1901 case Instruction::Sub:
1902 I = B.CreateSub(InputA, InputB);
1903 break;
1904 case Instruction::FMul:
1905 I = B.CreateFMul(InputA, InputB);
1906 break;
1907 case Instruction::Mul:
1908 I = B.CreateMul(InputA, InputB);
1909 break;
1910 default:
1911 llvm_unreachable("Incorrect symmetric opcode");
1912 }
1913 if (Flags)
1914 cast<Instruction>(I)->setFastMathFlags(*Flags);
1915 return I;
1916}
1917
1918Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
1919 RawNodePtr Node) {
1920 if (Node->ReplacementNode)
1921 return Node->ReplacementNode;
1922
1923 auto ReplaceOperandIfExist = [&](RawNodePtr &Node, unsigned Idx) -> Value * {
1924 return Node->Operands.size() > Idx
1925 ? replaceNode(Builder, Node->Operands[Idx])
1926 : nullptr;
1927 };
1928
1929 Value *ReplacementNode;
1930 switch (Node->Operation) {
1931 case ComplexDeinterleavingOperation::CAdd:
1932 case ComplexDeinterleavingOperation::CMulPartial:
1933 case ComplexDeinterleavingOperation::Symmetric: {
1934 Value *Input0 = ReplaceOperandIfExist(Node, 0);
1935 Value *Input1 = ReplaceOperandIfExist(Node, 1);
1936 Value *Accumulator = ReplaceOperandIfExist(Node, 2);
1937 assert(!Input1 || (Input0->getType() == Input1->getType() &&
1938 "Node inputs need to be of the same type"));
1940 (Input0->getType() == Accumulator->getType() &&
1941 "Accumulator and input need to be of the same type"));
1942 if (Node->Operation == ComplexDeinterleavingOperation::Symmetric)
1943 ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, Node->Flags,
1944 Input0, Input1);
1945 else
1946 ReplacementNode = TL->createComplexDeinterleavingIR(
1947 Builder, Node->Operation, Node->Rotation, Input0, Input1,
1948 Accumulator);
1949 break;
1950 }
1951 case ComplexDeinterleavingOperation::Deinterleave:
1952 llvm_unreachable("Deinterleave node should already have ReplacementNode");
1953 break;
1954 case ComplexDeinterleavingOperation::Splat: {
1955 auto *NewTy = VectorType::getDoubleElementsVectorType(
1956 cast<VectorType>(Node->Real->getType()));
1957 auto *R = dyn_cast<Instruction>(Node->Real);
1958 auto *I = dyn_cast<Instruction>(Node->Imag);
1959 if (R && I) {
1960 // Splats that are not constant are interleaved where they are located
1961 Instruction *InsertPoint = (I->comesBefore(R) ? R : I)->getNextNode();
1962 IRBuilder<> IRB(InsertPoint);
1963 ReplacementNode =
1964 IRB.CreateIntrinsic(Intrinsic::experimental_vector_interleave2, NewTy,
1965 {Node->Real, Node->Imag});
1966 } else {
1967 ReplacementNode =
1968 Builder.CreateIntrinsic(Intrinsic::experimental_vector_interleave2,
1969 NewTy, {Node->Real, Node->Imag});
1970 }
1971 break;
1972 }
1973 case ComplexDeinterleavingOperation::ReductionPHI: {
1974 // If Operation is ReductionPHI, a new empty PHINode is created.
1975 // It is filled later when the ReductionOperation is processed.
1976 auto *VTy = cast<VectorType>(Node->Real->getType());
1977 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
1978 auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHIIt());
1979 OldToNewPHI[dyn_cast<PHINode>(Node->Real)] = NewPHI;
1980 ReplacementNode = NewPHI;
1981 break;
1982 }
1983 case ComplexDeinterleavingOperation::ReductionOperation:
1984 ReplacementNode = replaceNode(Builder, Node->Operands[0]);
1985 processReductionOperation(ReplacementNode, Node);
1986 break;
1987 case ComplexDeinterleavingOperation::ReductionSelect: {
1988 auto *MaskReal = cast<Instruction>(Node->Real)->getOperand(0);
1989 auto *MaskImag = cast<Instruction>(Node->Imag)->getOperand(0);
1990 auto *A = replaceNode(Builder, Node->Operands[0]);
1991 auto *B = replaceNode(Builder, Node->Operands[1]);
1992 auto *NewMaskTy = VectorType::getDoubleElementsVectorType(
1993 cast<VectorType>(MaskReal->getType()));
1994 auto *NewMask =
1995 Builder.CreateIntrinsic(Intrinsic::experimental_vector_interleave2,
1996 NewMaskTy, {MaskReal, MaskImag});
1997 ReplacementNode = Builder.CreateSelect(NewMask, A, B);
1998 break;
1999 }
2000 }
2001
2002 assert(ReplacementNode && "Target failed to create Intrinsic call.");
2003 NumComplexTransformations += 1;
2004 Node->ReplacementNode = ReplacementNode;
2005 return ReplacementNode;
2006}
2007
2008void ComplexDeinterleavingGraph::processReductionOperation(
2009 Value *OperationReplacement, RawNodePtr Node) {
2010 auto *Real = cast<Instruction>(Node->Real);
2011 auto *Imag = cast<Instruction>(Node->Imag);
2012 auto *OldPHIReal = ReductionInfo[Real].first;
2013 auto *OldPHIImag = ReductionInfo[Imag].first;
2014 auto *NewPHI = OldToNewPHI[OldPHIReal];
2015
2016 auto *VTy = cast<VectorType>(Real->getType());
2017 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2018
2019 // We have to interleave initial origin values coming from IncomingBlock
2020 Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming);
2021 Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming);
2022
2023 IRBuilder<> Builder(Incoming->getTerminator());
2024 auto *NewInit = Builder.CreateIntrinsic(
2025 Intrinsic::experimental_vector_interleave2, NewVTy, {InitReal, InitImag});
2026
2027 NewPHI->addIncoming(NewInit, Incoming);
2028 NewPHI->addIncoming(OperationReplacement, BackEdge);
2029
2030 // Deinterleave complex vector outside of loop so that it can be finally
2031 // reduced
2032 auto *FinalReductionReal = ReductionInfo[Real].second;
2033 auto *FinalReductionImag = ReductionInfo[Imag].second;
2034
2035 Builder.SetInsertPoint(
2036 &*FinalReductionReal->getParent()->getFirstInsertionPt());
2037 auto *Deinterleave = Builder.CreateIntrinsic(
2038 Intrinsic::experimental_vector_deinterleave2,
2039 OperationReplacement->getType(), OperationReplacement);
2040
2041 auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0);
2042 FinalReductionReal->replaceUsesOfWith(Real, NewReal);
2043
2044 Builder.SetInsertPoint(FinalReductionImag);
2045 auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1);
2046 FinalReductionImag->replaceUsesOfWith(Imag, NewImag);
2047}
2048
2049void ComplexDeinterleavingGraph::replaceNodes() {
2050 SmallVector<Instruction *, 16> DeadInstrRoots;
2051 for (auto *RootInstruction : OrderedRoots) {
2052 // Check if this potential root went through check process and we can
2053 // deinterleave it
2054 if (!RootToNode.count(RootInstruction))
2055 continue;
2056
2057 IRBuilder<> Builder(RootInstruction);
2058 auto RootNode = RootToNode[RootInstruction];
2059 Value *R = replaceNode(Builder, RootNode.get());
2060
2061 if (RootNode->Operation ==
2062 ComplexDeinterleavingOperation::ReductionOperation) {
2063 auto *RootReal = cast<Instruction>(RootNode->Real);
2064 auto *RootImag = cast<Instruction>(RootNode->Imag);
2065 ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);
2066 ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
2067 DeadInstrRoots.push_back(cast<Instruction>(RootReal));
2068 DeadInstrRoots.push_back(cast<Instruction>(RootImag));
2069 } else {
2070 assert(R && "Unable to find replacement for RootInstruction");
2071 DeadInstrRoots.push_back(RootInstruction);
2072 RootInstruction->replaceAllUsesWith(R);
2073 }
2074 }
2075
2076 for (auto *I : DeadInstrRoots)
2078}
SmallVector< AArch64_IMM::ImmInsnModel, 4 > Insn
static MCDisassembler::DecodeStatus addOperand(MCInst &Inst, const MCOperand &Opnd)
Rewrite undef for PHI
VarLocInsertPt getNextNode(const DbgRecord *DVR)
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
Analysis containing CSE Info
Definition: CSEInfo.cpp:27
static bool isInstructionPotentiallySymmetric(Instruction *I)
static Value * getNegOperand(Value *V)
Returns the operand for negation operation.
static bool isNeg(Value *V)
Returns true if the operation is a negation of V, and it works for both integers and floats.
static cl::opt< bool > ComplexDeinterleavingEnabled("enable-complex-deinterleaving", cl::desc("Enable generation of complex instructions"), cl::init(true), cl::Hidden)
static bool isInstructionPairAdd(Instruction *A, Instruction *B)
Complex Deinterleaving
static Value * replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode, std::optional< FastMathFlags > Flags, Value *InputA, Value *InputB)
static bool isInterleavingMask(ArrayRef< int > Mask)
Checks the given mask, and determines whether said mask is interleaving.
static bool isDeinterleavingMask(ArrayRef< int > Mask)
Checks the given mask, and determines whether said mask is deinterleaving.
static bool isInstructionPairMul(Instruction *A, Instruction *B)
#define DEBUG_TYPE
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
#define LLVM_DEBUG(X)
Definition: Debug.h:101
static bool runOnFunction(Function &F, bool PostInlining)
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
mir Rename Register Operands
This file implements a map that provides insertion order iteration.
#define P(N)
PowerPC Reduce CR logical Operation
const char LLVMTargetMachineRef TM
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:59
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:52
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
SI optimize exec mask operations pre RA
raw_pwrite_stream & OS
static LLVM_ATTRIBUTE_ALWAYS_INLINE bool CheckType(MVT::SimpleValueType VT, SDValue N, const TargetLowering *TLI, const DataLayout &DL)
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition: Statistic.h:167
This file describes how to lower LLVM code to machine code.
Target-Independent Code Generator Pass Configuration Options pass.
This pass exposes codegen information to IR-level passes.
BinaryOperator * Mul
DEMANGLE_DUMP_METHOD void dump() const
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:348
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition: Pass.cpp:269
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: ArrayRef.h:41
LLVM Basic Block Representation.
Definition: BasicBlock.h:60
InstListType::const_iterator getFirstNonPHIIt() const
Iterator returning form of getFirstNonPHI.
Definition: BasicBlock.cpp:356
This class represents an Operation in the Expression.
iterator find(const_arg_type_t< KeyT > Val)
Definition: DenseMap.h:155
iterator end()
Definition: DenseMap.h:84
bool allowContract() const
Definition: FMF.h:70
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:311
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
Common base class shared among various IRBuilders.
Definition: IRBuilder.h:94
Value * CreateExtractValue(Value *Agg, ArrayRef< unsigned > Idxs, const Twine &Name="")
Definition: IRBuilder.h:2494
CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, Instruction *FMFSource=nullptr, const Twine &Name="")
Create a call to intrinsic ID with Args, mangled using Types.
Definition: IRBuilder.cpp:932
Value * CreateSelect(Value *C, Value *True, Value *False, const Twine &Name="", Instruction *MDFrom=nullptr)
Definition: IRBuilder.cpp:1110
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
Definition: IRBuilder.h:180
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2644
An analysis over an "outer" IR unit that provides access to an analysis manager over an "inner" IR un...
Definition: PassManager.h:658
bool isBinaryOp() const
Definition: Instruction.h:257
const BasicBlock * getParent() const
Definition: Instruction.h:152
const Function * getFunction() const
Return the function this instruction belongs to.
Definition: Instruction.cpp:84
FastMathFlags getFastMathFlags() const LLVM_READONLY
Convenience function for getting all the fast-math flags, which must be an operator which supports th...
unsigned getOpcode() const
Returns a member of one of the enums like Instruction::Add.
Definition: Instruction.h:252
bool isIdenticalTo(const Instruction *I) const LLVM_READONLY
Return true if the specified instruction is exactly identical to the current one.
This class implements a map that also provides access to all stored values in a deterministic order.
Definition: MapVector.h:36
size_type size() const
Definition: MapVector.h:60
static PHINode * Create(Type *Ty, unsigned NumReservedValues, const Twine &NameStr, BasicBlock::iterator InsertBefore)
Constructors - NumReservedValues is a hint for the number of incoming edges that this phi node will h...
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
Definition: Pass.cpp:98
virtual StringRef getPassName() const
getPassName - Return a nice clean name for a pass.
Definition: Pass.cpp:81
A set of analyses that are preserved following a run of a transformation pass.
Definition: Analysis.h:109
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: Analysis.h:115
void preserve()
Mark an analysis as preserved.
Definition: Analysis.h:129
This instruction constructs a fixed permutation of two input vectors.
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
Definition: SmallPtrSet.h:360
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:342
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
Definition: SmallPtrSet.h:427
bool empty() const
Definition: SmallVector.h:94
size_t size() const
Definition: SmallVector.h:91
reference emplace_back(ArgTypes &&... Args)
Definition: SmallVector.h:950
void push_back(const T &Elt)
Definition: SmallVector.h:426
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1209
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:50
Analysis pass providing the TargetLibraryInfo.
Provides information about what library functions are available for the current target.
virtual bool isComplexDeinterleavingOperationSupported(ComplexDeinterleavingOperation Operation, Type *Ty) const
Does this target support complex deinterleaving with the given operation and type.
virtual Value * createComplexDeinterleavingIR(IRBuilderBase &B, ComplexDeinterleavingOperation OperationType, ComplexDeinterleavingRotation Rotation, Value *InputA, Value *InputB, Value *Accumulator=nullptr) const
Create the IR node for the given complex deinterleaving operation.
virtual bool isComplexDeinterleavingSupported() const
Does this target support complex deinterleaving.
This class defines information used to lower LLVM code to legal SelectionDAG operators that the targe...
Primary interface to the complete machine description for the target machine.
Definition: TargetMachine.h:76
bool isVectorTy() const
True if this is an instance of VectorType.
Definition: Type.h:265
Value * getOperand(unsigned i) const
Definition: User.h:169
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition: Value.h:434
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition: raw_ostream.h:52
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
Definition: BitmaskEnum.h:121
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
@ BR
Control flow instructions. These all have token chains.
Definition: ISDOpcodes.h:1047
class_match< BinaryOperator > m_BinOp()
Match an arbitrary binary operation and ignore it.
Definition: PatternMatch.h:100
BinaryOp_match< LHS, RHS, Instruction::FMul > m_FMul(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
bind_ty< Instruction > m_Instruction(Instruction *&I)
Match an instruction, capturing it if we match.
Definition: PatternMatch.h:765
specificval_ty m_Specific(const Value *V)
Match if we have a specific specified value.
Definition: PatternMatch.h:821
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
BinaryOp_match< cst_pred_ty< is_zero_int >, ValTy, Instruction::Sub > m_Neg(const ValTy &V)
Matches a 'Neg' as 'sub 0, V'.
TwoOps_match< V1_t, V2_t, Instruction::ShuffleVector > m_Shuffle(const V1_t &v1, const V2_t &v2)
Matches ShuffleVectorInst independently of mask value.
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
Definition: PatternMatch.h:92
FNeg_match< OpTy > m_FNeg(const OpTy &X)
Match 'fneg X' as 'fsub -0.0, X'.
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:450
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
void dump(const SparseBitVector< ElementSize > &LHS, raw_ostream &out)
@ Offset
Definition: DWP.cpp:456
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1731
bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
If the specified value is a trivially dead instruction, delete it.
Definition: Local.cpp:533
void initializeComplexDeinterleavingLegacyPassPass(PassRegistry &)
FunctionPass * createComplexDeinterleavingPass(const TargetMachine *TM)
This pass implements generation of target-specific intrinsics to support handling of complex number a...
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
DWARFExpression::Operation Op
bool all_equal(std::initializer_list< T > Values)
Returns true if all Values in the initializer lists are equal or the list.
Definition: STLExtras.h:2048
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition: BitVector.h:860
Incoming for lane maks phi as machine instruction, incoming register Reg and incoming block Block are...