LLVM 22.0.0git
ComplexDeinterleavingPass.cpp
Go to the documentation of this file.
1//===- ComplexDeinterleavingPass.cpp --------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Identification:
10// This step is responsible for finding the patterns that can be lowered to
11// complex instructions, and building a graph to represent the complex
12// structures. Starting from the "Converging Shuffle" (a shuffle that
13// reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the
14// operands are evaluated and identified as "Composite Nodes" (collections of
15// instructions that can potentially be lowered to a single complex
16// instruction). This is performed by checking the real and imaginary components
17// and tracking the data flow for each component while following the operand
18// pairs. Validity of each node is expected to be done upon creation, and any
19// validation errors should halt traversal and prevent further graph
20// construction.
21// Instead of relying on Shuffle operations, vector interleaving and
22// deinterleaving can be represented by vector.interleave2 and
23// vector.deinterleave2 intrinsics. Scalable vectors can be represented only by
24// these intrinsics, whereas, fixed-width vectors are recognized for both
25// shufflevector instruction and intrinsics.
26//
27// Replacement:
28// This step traverses the graph built up by identification, delegating to the
29// target to validate and generate the correct intrinsics, and plumbs them
30// together connecting each end of the new intrinsics graph to the existing
31// use-def chain. This step is assumed to finish successfully, as all
32// information is expected to be correct by this point.
33//
34//
35// Internal data structure:
36// ComplexDeinterleavingGraph:
37// Keeps references to all the valid CompositeNodes formed as part of the
38// transformation, and every Instruction contained within said nodes. It also
39// holds onto a reference to the root Instruction, and the root node that should
40// replace it.
41//
42// ComplexDeinterleavingCompositeNode:
43// A CompositeNode represents a single transformation point; each node should
44// transform into a single complex instruction (ignoring vector splitting, which
45// would generate more instructions per node). They are identified in a
46// depth-first manner, traversing and identifying the operands of each
47// instruction in the order they appear in the IR.
48// Each node maintains a reference to its Real and Imaginary instructions,
49// as well as any additional instructions that make up the identified operation
50// (Internal instructions should only have uses within their containing node).
51// A Node also contains the rotation and operation type that it represents.
52// Operands contains pointers to other CompositeNodes, acting as the edges in
53// the graph. ReplacementValue is the transformed Value* that has been emitted
54// to the IR.
55//
56// Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and
57// ReplacementValue fields of that Node are relevant, where the ReplacementValue
58// should be pre-populated.
59//
60//===----------------------------------------------------------------------===//
61
64#include "llvm/ADT/MapVector.h"
65#include "llvm/ADT/Statistic.h"
70#include "llvm/IR/IRBuilder.h"
71#include "llvm/IR/Intrinsics.h"
77#include <algorithm>
78
79using namespace llvm;
80using namespace PatternMatch;
81
82#define DEBUG_TYPE "complex-deinterleaving"
83
84STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed");
85
87 "enable-complex-deinterleaving",
88 cl::desc("Enable generation of complex instructions"), cl::init(true),
90
91/// Checks the given mask, and determines whether said mask is interleaving.
92///
93/// To be interleaving, a mask must alternate between `i` and `i + (Length /
94/// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a
95/// 4x vector interleaving mask would be <0, 2, 1, 3>).
96static bool isInterleavingMask(ArrayRef<int> Mask);
97
98/// Checks the given mask, and determines whether said mask is deinterleaving.
99///
100/// To be deinterleaving, a mask must increment in steps of 2, and either start
101/// with 0 or 1.
102/// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or
103/// <1, 3, 5, 7>).
104static bool isDeinterleavingMask(ArrayRef<int> Mask);
105
106/// Returns true if the operation is a negation of V, and it works for both
107/// integers and floats.
108static bool isNeg(Value *V);
109
110/// Returns the operand for negation operation.
111static Value *getNegOperand(Value *V);
112
113namespace {
114struct ComplexValue {
115 Value *Real = nullptr;
116 Value *Imag = nullptr;
117
118 bool operator==(const ComplexValue &Other) const {
119 return Real == Other.Real && Imag == Other.Imag;
120 }
121};
122hash_code hash_value(const ComplexValue &Arg) {
125}
126} // end namespace
128
129template <> struct llvm::DenseMapInfo<ComplexValue> {
130 static inline ComplexValue getEmptyKey() {
133 }
134 static inline ComplexValue getTombstoneKey() {
137 }
138 static unsigned getHashValue(const ComplexValue &Val) {
141 }
142 static bool isEqual(const ComplexValue &LHS, const ComplexValue &RHS) {
143 return LHS.Real == RHS.Real && LHS.Imag == RHS.Imag;
144 }
145};
146
147namespace {
148template <typename T, typename IterT>
149std::optional<T> findCommonBetweenCollections(IterT A, IterT B) {
150 auto Common = llvm::find_if(A, [B](T I) { return llvm::is_contained(B, I); });
151 if (Common != A.end())
152 return std::make_optional(*Common);
153 return std::nullopt;
154}
155
156class ComplexDeinterleavingLegacyPass : public FunctionPass {
157public:
158 static char ID;
159
160 ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr)
161 : FunctionPass(ID), TM(TM) {
164 }
165
166 StringRef getPassName() const override {
167 return "Complex Deinterleaving Pass";
168 }
169
170 bool runOnFunction(Function &F) override;
171 void getAnalysisUsage(AnalysisUsage &AU) const override {
172 AU.addRequired<TargetLibraryInfoWrapperPass>();
173 AU.setPreservesCFG();
174 }
175
176private:
177 const TargetMachine *TM;
178};
179
180class ComplexDeinterleavingGraph;
181struct ComplexDeinterleavingCompositeNode {
182
183 ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
184 Value *R, Value *I)
185 : Operation(Op) {
186 Vals.push_back({R, I});
187 }
188
189 ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
191 : Operation(Op), Vals(Other) {}
192
193private:
194 friend class ComplexDeinterleavingGraph;
195 using CompositeNode = ComplexDeinterleavingCompositeNode;
196 bool OperandsValid = true;
197
198public:
200 ComplexValues Vals;
201
202 // This two members are required exclusively for generating
203 // ComplexDeinterleavingOperation::Symmetric operations.
204 unsigned Opcode;
205 std::optional<FastMathFlags> Flags;
206
208 ComplexDeinterleavingRotation::Rotation_0;
210 Value *ReplacementNode = nullptr;
211
212 void addOperand(CompositeNode *Node) {
213 if (!Node)
214 OperandsValid = false;
215 Operands.push_back(Node);
216 }
217
218 void dump() { dump(dbgs()); }
219 void dump(raw_ostream &OS) {
220 auto PrintValue = [&](Value *V) {
221 if (V) {
222 OS << "\"";
223 V->print(OS, true);
224 OS << "\"\n";
225 } else
226 OS << "nullptr\n";
227 };
228 auto PrintNodeRef = [&](CompositeNode *Ptr) {
229 if (Ptr)
230 OS << Ptr << "\n";
231 else
232 OS << "nullptr\n";
233 };
234
235 OS << "- CompositeNode: " << this << "\n";
236 for (unsigned I = 0; I < Vals.size(); I++) {
237 OS << " Real(" << I << ") : ";
238 PrintValue(Vals[I].Real);
239 OS << " Imag(" << I << ") : ";
240 PrintValue(Vals[I].Imag);
241 }
242 OS << " ReplacementNode: ";
243 PrintValue(ReplacementNode);
244 OS << " Operation: " << (int)Operation << "\n";
245 OS << " Rotation: " << ((int)Rotation * 90) << "\n";
246 OS << " Operands: \n";
247 for (const auto &Op : Operands) {
248 OS << " - ";
249 PrintNodeRef(Op);
250 }
251 }
252
253 bool areOperandsValid() { return OperandsValid; }
254};
255
256class ComplexDeinterleavingGraph {
257public:
258 struct Product {
259 Value *Multiplier;
260 Value *Multiplicand;
261 bool IsPositive;
262 };
263
264 using Addend = std::pair<Value *, bool>;
265 using AddendList = BumpPtrList<Addend>;
266 using CompositeNode = ComplexDeinterleavingCompositeNode::CompositeNode;
267
268 // Helper struct for holding info about potential partial multiplication
269 // candidates
270 struct PartialMulCandidate {
271 Value *Common;
272 CompositeNode *Node;
273 unsigned RealIdx;
274 unsigned ImagIdx;
275 bool IsNodeInverted;
276 };
277
278 explicit ComplexDeinterleavingGraph(const TargetLowering *TL,
279 const TargetLibraryInfo *TLI,
280 unsigned Factor)
281 : TL(TL), TLI(TLI), Factor(Factor) {}
282
283private:
284 const TargetLowering *TL = nullptr;
285 const TargetLibraryInfo *TLI = nullptr;
286 unsigned Factor;
287 SmallVector<CompositeNode *> CompositeNodes;
288 DenseMap<ComplexValues, CompositeNode *> CachedResult;
289 SpecificBumpPtrAllocator<ComplexDeinterleavingCompositeNode> Allocator;
290
291 SmallPtrSet<Instruction *, 16> FinalInstructions;
292
293 /// Root instructions are instructions from which complex computation starts
294 DenseMap<Instruction *, CompositeNode *> RootToNode;
295
296 /// Topologically sorted root instructions
298
299 /// When examining a basic block for complex deinterleaving, if it is a simple
300 /// one-block loop, then the only incoming block is 'Incoming' and the
301 /// 'BackEdge' block is the block itself."
302 BasicBlock *BackEdge = nullptr;
303 BasicBlock *Incoming = nullptr;
304
305 /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction
306 /// %OutsideUser as it is shown in the IR:
307 ///
308 /// vector.body:
309 /// %PHInode = phi <vector type> [ zeroinitializer, %entry ],
310 /// [ %ReductionOp, %vector.body ]
311 /// ...
312 /// %ReductionOp = fadd i64 ...
313 /// ...
314 /// br i1 %condition, label %vector.body, %middle.block
315 ///
316 /// middle.block:
317 /// %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp)
318 ///
319 /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding
320 /// `llvm.vector.reduce.fadd` when unroll factor isn't one.
321 MapVector<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo;
322
323 /// In the process of detecting a reduction, we consider a pair of
324 /// %ReductionOP, which we refer to as real and imag (or vice versa), and
325 /// traverse the use-tree to detect complex operations. As this is a reduction
326 /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds
327 /// to the %ReductionOPs that we suspect to be complex.
328 /// RealPHI and ImagPHI are used by the identifyPHINode method.
329 PHINode *RealPHI = nullptr;
330 PHINode *ImagPHI = nullptr;
331
332 /// Set this flag to true if RealPHI and ImagPHI were reached during reduction
333 /// detection.
334 bool PHIsFound = false;
335
336 /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode.
337 /// The new PHINode corresponds to a vector of deinterleaved complex numbers.
338 /// This mapping is populated during
339 /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then
340 /// used in the ComplexDeinterleavingOperation::ReductionOperation node
341 /// replacement process.
342 DenseMap<PHINode *, PHINode *> OldToNewPHI;
343
344 CompositeNode *prepareCompositeNode(ComplexDeinterleavingOperation Operation,
345 Value *R, Value *I) {
346 assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
347 Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
348 (R && I)) &&
349 "Reduction related nodes must have Real and Imaginary parts");
350 return new (Allocator.Allocate())
351 ComplexDeinterleavingCompositeNode(Operation, R, I);
352 }
353
354 CompositeNode *prepareCompositeNode(ComplexDeinterleavingOperation Operation,
355 ComplexValues &Vals) {
356#ifndef NDEBUG
357 for (auto &V : Vals) {
358 assert(
359 ((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
360 Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
361 (V.Real && V.Imag)) &&
362 "Reduction related nodes must have Real and Imaginary parts");
363 }
364#endif
365 return new (Allocator.Allocate())
366 ComplexDeinterleavingCompositeNode(Operation, Vals);
367 }
368
369 CompositeNode *submitCompositeNode(CompositeNode *Node) {
370 CompositeNodes.push_back(Node);
371 if (Node->Vals[0].Real)
372 CachedResult[Node->Vals] = Node;
373 return Node;
374 }
375
376 /// Identifies a complex partial multiply pattern and its rotation, based on
377 /// the following patterns
378 ///
379 /// 0: r: cr + ar * br
380 /// i: ci + ar * bi
381 /// 90: r: cr - ai * bi
382 /// i: ci + ai * br
383 /// 180: r: cr - ar * br
384 /// i: ci - ar * bi
385 /// 270: r: cr + ai * bi
386 /// i: ci - ai * br
387 CompositeNode *identifyPartialMul(Instruction *Real, Instruction *Imag);
388
389 /// Identify the other branch of a Partial Mul, taking the CommonOperandI that
390 /// is partially known from identifyPartialMul, filling in the other half of
391 /// the complex pair.
392 CompositeNode *
393 identifyNodeWithImplicitAdd(Instruction *I, Instruction *J,
394 std::pair<Value *, Value *> &CommonOperandI);
395
396 /// Identifies a complex add pattern and its rotation, based on the following
397 /// patterns.
398 ///
399 /// 90: r: ar - bi
400 /// i: ai + br
401 /// 270: r: ar + bi
402 /// i: ai - br
403 CompositeNode *identifyAdd(Instruction *Real, Instruction *Imag);
404 CompositeNode *identifySymmetricOperation(ComplexValues &Vals);
405 CompositeNode *identifyPartialReduction(Value *R, Value *I);
406 CompositeNode *identifyDotProduct(Value *Inst);
407
408 CompositeNode *identifyNode(ComplexValues &Vals);
409
410 CompositeNode *identifyNode(Value *R, Value *I) {
411 ComplexValues Vals;
412 Vals.push_back({R, I});
413 return identifyNode(Vals);
414 }
415
416 /// Determine if a sum of complex numbers can be formed from \p RealAddends
417 /// and \p ImagAddens. If \p Accumulator is not null, add the result to it.
418 /// Return nullptr if it is not possible to construct a complex number.
419 /// \p Flags are needed to generate symmetric Add and Sub operations.
420 CompositeNode *identifyAdditions(AddendList &RealAddends,
421 AddendList &ImagAddends,
422 std::optional<FastMathFlags> Flags,
423 CompositeNode *Accumulator);
424
425 /// Extract one addend that have both real and imaginary parts positive.
426 CompositeNode *extractPositiveAddend(AddendList &RealAddends,
427 AddendList &ImagAddends);
428
429 /// Determine if sum of multiplications of complex numbers can be formed from
430 /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result
431 /// to it. Return nullptr if it is not possible to construct a complex number.
432 CompositeNode *identifyMultiplications(SmallVectorImpl<Product> &RealMuls,
433 SmallVectorImpl<Product> &ImagMuls,
434 CompositeNode *Accumulator);
435
436 /// Go through pairs of multiplication (one Real and one Imag) and find all
437 /// possible candidates for partial multiplication and put them into \p
438 /// Candidates. Returns true if all Product has pair with common operand
439 bool collectPartialMuls(ArrayRef<Product> RealMuls,
440 ArrayRef<Product> ImagMuls,
441 SmallVectorImpl<PartialMulCandidate> &Candidates);
442
443 /// If the code is compiled with -Ofast or expressions have `reassoc` flag,
444 /// the order of complex computation operations may be significantly altered,
445 /// and the real and imaginary parts may not be executed in parallel. This
446 /// function takes this into consideration and employs a more general approach
447 /// to identify complex computations. Initially, it gathers all the addends
448 /// and multiplicands and then constructs a complex expression from them.
449 CompositeNode *identifyReassocNodes(Instruction *I, Instruction *J);
450
451 CompositeNode *identifyRoot(Instruction *I);
452
453 /// Identifies the Deinterleave operation applied to a vector containing
454 /// complex numbers. There are two ways to represent the Deinterleave
455 /// operation:
456 /// * Using two shufflevectors with even indices for /pReal instruction and
457 /// odd indices for /pImag instructions (only for fixed-width vectors)
458 /// * Using N extractvalue instructions applied to `vector.deinterleaveN`
459 /// intrinsics (for both fixed and scalable vectors) where N is a multiple of
460 /// 2.
461 CompositeNode *identifyDeinterleave(ComplexValues &Vals);
462
463 /// identifying the operation that represents a complex number repeated in a
464 /// Splat vector. There are two possible types of splats: ConstantExpr with
465 /// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an
466 /// initialization mask with all values set to zero.
467 CompositeNode *identifySplat(ComplexValues &Vals);
468
469 CompositeNode *identifyPHINode(Instruction *Real, Instruction *Imag);
470
471 /// Identifies SelectInsts in a loop that has reduction with predication masks
472 /// and/or predicated tail folding
473 CompositeNode *identifySelectNode(Instruction *Real, Instruction *Imag);
474
475 Value *replaceNode(IRBuilderBase &Builder, CompositeNode *Node);
476
477 /// Complete IR modifications after producing new reduction operation:
478 /// * Populate the PHINode generated for
479 /// ComplexDeinterleavingOperation::ReductionPHI
480 /// * Deinterleave the final value outside of the loop and repurpose original
481 /// reduction users
482 void processReductionOperation(Value *OperationReplacement,
483 CompositeNode *Node);
484 void processReductionSingle(Value *OperationReplacement, CompositeNode *Node);
485
486public:
487 void dump() { dump(dbgs()); }
488 void dump(raw_ostream &OS) {
489 for (const auto &Node : CompositeNodes)
490 Node->dump(OS);
491 }
492
493 /// Returns false if the deinterleaving operation should be cancelled for the
494 /// current graph.
495 bool identifyNodes(Instruction *RootI);
496
497 /// In case \pB is one-block loop, this function seeks potential reductions
498 /// and populates ReductionInfo. Returns true if any reductions were
499 /// identified.
500 bool collectPotentialReductions(BasicBlock *B);
501
502 void identifyReductionNodes();
503
504 /// Check that every instruction, from the roots to the leaves, has internal
505 /// uses.
506 bool checkNodes();
507
508 /// Perform the actual replacement of the underlying instruction graph.
509 void replaceNodes();
510};
511
512class ComplexDeinterleaving {
513public:
514 ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli)
515 : TL(tl), TLI(tli) {}
516 bool runOnFunction(Function &F);
517
518private:
519 bool evaluateBasicBlock(BasicBlock *B, unsigned Factor);
520
521 const TargetLowering *TL = nullptr;
522 const TargetLibraryInfo *TLI = nullptr;
523};
524
525} // namespace
526
527char ComplexDeinterleavingLegacyPass::ID = 0;
528
529INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
530 "Complex Deinterleaving", false, false)
531INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
532 "Complex Deinterleaving", false, false)
533
536 const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering();
537 auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F);
538 if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F))
539 return PreservedAnalyses::all();
540
543 return PA;
544}
545
547 return new ComplexDeinterleavingLegacyPass(TM);
548}
549
550bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) {
551 const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering();
552 auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
553 return ComplexDeinterleaving(TL, &TLI).runOnFunction(F);
554}
555
556bool ComplexDeinterleaving::runOnFunction(Function &F) {
559 dbgs() << "Complex deinterleaving has been explicitly disabled.\n");
560 return false;
561 }
562
565 dbgs() << "Complex deinterleaving has been disabled, target does "
566 "not support lowering of complex number operations.\n");
567 return false;
568 }
569
570 bool Changed = false;
571 for (auto &B : F)
572 Changed |= evaluateBasicBlock(&B, 2);
573
574 // TODO: Permit changes for both interleave factors in the same function.
575 if (!Changed) {
576 for (auto &B : F)
577 Changed |= evaluateBasicBlock(&B, 4);
578 }
579
580 // TODO: We can also support interleave factors of 6 and 8 if needed.
581
582 return Changed;
583}
584
586 // If the size is not even, it's not an interleaving mask
587 if ((Mask.size() & 1))
588 return false;
589
590 int HalfNumElements = Mask.size() / 2;
591 for (int Idx = 0; Idx < HalfNumElements; ++Idx) {
592 int MaskIdx = Idx * 2;
593 if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))
594 return false;
595 }
596
597 return true;
598}
599
601 int Offset = Mask[0];
602 int HalfNumElements = Mask.size() / 2;
603
604 for (int Idx = 1; Idx < HalfNumElements; ++Idx) {
605 if (Mask[Idx] != (Idx * 2) + Offset)
606 return false;
607 }
608
609 return true;
610}
611
612bool isNeg(Value *V) {
613 return match(V, m_FNeg(m_Value())) || match(V, m_Neg(m_Value()));
614}
615
617 assert(isNeg(V));
618 auto *I = cast<Instruction>(V);
619 if (I->getOpcode() == Instruction::FNeg)
620 return I->getOperand(0);
621
622 return I->getOperand(1);
623}
624
625bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B, unsigned Factor) {
626 ComplexDeinterleavingGraph Graph(TL, TLI, Factor);
627 if (Graph.collectPotentialReductions(B))
628 Graph.identifyReductionNodes();
629
630 for (auto &I : *B)
631 Graph.identifyNodes(&I);
632
633 if (Graph.checkNodes()) {
634 Graph.replaceNodes();
635 return true;
636 }
637
638 return false;
639}
640
641ComplexDeinterleavingGraph::CompositeNode *
642ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
643 Instruction *Real, Instruction *Imag,
644 std::pair<Value *, Value *> &PartialMatch) {
645 LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag
646 << "\n");
647
648 if (!Real->hasOneUse() || !Imag->hasOneUse()) {
649 LLVM_DEBUG(dbgs() << " - Mul operand has multiple uses.\n");
650 return nullptr;
651 }
652
653 if ((Real->getOpcode() != Instruction::FMul &&
654 Real->getOpcode() != Instruction::Mul) ||
655 (Imag->getOpcode() != Instruction::FMul &&
656 Imag->getOpcode() != Instruction::Mul)) {
658 dbgs() << " - Real or imaginary instruction is not fmul or mul\n");
659 return nullptr;
660 }
661
662 Value *R0 = Real->getOperand(0);
663 Value *R1 = Real->getOperand(1);
664 Value *I0 = Imag->getOperand(0);
665 Value *I1 = Imag->getOperand(1);
666
667 // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the
668 // rotations and use the operand.
669 unsigned Negs = 0;
670 Value *Op;
671 if (match(R0, m_Neg(m_Value(Op)))) {
672 Negs |= 1;
673 R0 = Op;
674 } else if (match(R1, m_Neg(m_Value(Op)))) {
675 Negs |= 1;
676 R1 = Op;
677 }
678
679 if (isNeg(I0)) {
680 Negs |= 2;
681 Negs ^= 1;
682 I0 = Op;
683 } else if (match(I1, m_Neg(m_Value(Op)))) {
684 Negs |= 2;
685 Negs ^= 1;
686 I1 = Op;
687 }
688
690
691 Value *CommonOperand;
692 Value *UncommonRealOp;
693 Value *UncommonImagOp;
694
695 if (R0 == I0 || R0 == I1) {
696 CommonOperand = R0;
697 UncommonRealOp = R1;
698 } else if (R1 == I0 || R1 == I1) {
699 CommonOperand = R1;
700 UncommonRealOp = R0;
701 } else {
702 LLVM_DEBUG(dbgs() << " - No equal operand\n");
703 return nullptr;
704 }
705
706 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
707 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
708 Rotation == ComplexDeinterleavingRotation::Rotation_270)
709 std::swap(UncommonRealOp, UncommonImagOp);
710
711 // Between identifyPartialMul and here we need to have found a complete valid
712 // pair from the CommonOperand of each part.
713 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
714 Rotation == ComplexDeinterleavingRotation::Rotation_180)
715 PartialMatch.first = CommonOperand;
716 else
717 PartialMatch.second = CommonOperand;
718
719 if (!PartialMatch.first || !PartialMatch.second) {
720 LLVM_DEBUG(dbgs() << " - Incomplete partial match\n");
721 return nullptr;
722 }
723
724 CompositeNode *CommonNode =
725 identifyNode(PartialMatch.first, PartialMatch.second);
726 if (!CommonNode) {
727 LLVM_DEBUG(dbgs() << " - No CommonNode identified\n");
728 return nullptr;
729 }
730
731 CompositeNode *UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
732 if (!UncommonNode) {
733 LLVM_DEBUG(dbgs() << " - No UncommonNode identified\n");
734 return nullptr;
735 }
736
737 CompositeNode *Node = prepareCompositeNode(
738 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
739 Node->Rotation = Rotation;
740 Node->addOperand(CommonNode);
741 Node->addOperand(UncommonNode);
742 return submitCompositeNode(Node);
743}
744
745ComplexDeinterleavingGraph::CompositeNode *
746ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
747 Instruction *Imag) {
748 LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag
749 << "\n");
750
751 // Determine rotation
752 auto IsAdd = [](unsigned Op) {
753 return Op == Instruction::FAdd || Op == Instruction::Add;
754 };
755 auto IsSub = [](unsigned Op) {
756 return Op == Instruction::FSub || Op == Instruction::Sub;
757 };
759 if (IsAdd(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
760 Rotation = ComplexDeinterleavingRotation::Rotation_0;
761 else if (IsSub(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
762 Rotation = ComplexDeinterleavingRotation::Rotation_90;
763 else if (IsSub(Real->getOpcode()) && IsSub(Imag->getOpcode()))
764 Rotation = ComplexDeinterleavingRotation::Rotation_180;
765 else if (IsAdd(Real->getOpcode()) && IsSub(Imag->getOpcode()))
766 Rotation = ComplexDeinterleavingRotation::Rotation_270;
767 else {
768 LLVM_DEBUG(dbgs() << " - Unhandled rotation.\n");
769 return nullptr;
770 }
771
772 if (isa<FPMathOperator>(Real) &&
773 (!Real->getFastMathFlags().allowContract() ||
774 !Imag->getFastMathFlags().allowContract())) {
775 LLVM_DEBUG(dbgs() << " - Contract is missing from the FastMath flags.\n");
776 return nullptr;
777 }
778
779 Value *CR = Real->getOperand(0);
780 Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1));
781 if (!RealMulI)
782 return nullptr;
783 Value *CI = Imag->getOperand(0);
784 Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1));
785 if (!ImagMulI)
786 return nullptr;
787
788 if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) {
789 LLVM_DEBUG(dbgs() << " - Mul instruction has multiple uses\n");
790 return nullptr;
791 }
792
793 Value *R0 = RealMulI->getOperand(0);
794 Value *R1 = RealMulI->getOperand(1);
795 Value *I0 = ImagMulI->getOperand(0);
796 Value *I1 = ImagMulI->getOperand(1);
797
798 Value *CommonOperand;
799 Value *UncommonRealOp;
800 Value *UncommonImagOp;
801
802 if (R0 == I0 || R0 == I1) {
803 CommonOperand = R0;
804 UncommonRealOp = R1;
805 } else if (R1 == I0 || R1 == I1) {
806 CommonOperand = R1;
807 UncommonRealOp = R0;
808 } else {
809 LLVM_DEBUG(dbgs() << " - No equal operand\n");
810 return nullptr;
811 }
812
813 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
814 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
815 Rotation == ComplexDeinterleavingRotation::Rotation_270)
816 std::swap(UncommonRealOp, UncommonImagOp);
817
818 std::pair<Value *, Value *> PartialMatch(
819 (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
820 Rotation == ComplexDeinterleavingRotation::Rotation_180)
821 ? CommonOperand
822 : nullptr,
823 (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
824 Rotation == ComplexDeinterleavingRotation::Rotation_270)
825 ? CommonOperand
826 : nullptr);
827
828 auto *CRInst = dyn_cast<Instruction>(CR);
829 auto *CIInst = dyn_cast<Instruction>(CI);
830
831 if (!CRInst || !CIInst) {
832 LLVM_DEBUG(dbgs() << " - Common operands are not instructions.\n");
833 return nullptr;
834 }
835
836 CompositeNode *CNode =
837 identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch);
838 if (!CNode) {
839 LLVM_DEBUG(dbgs() << " - No cnode identified\n");
840 return nullptr;
841 }
842
843 CompositeNode *UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
844 if (!UncommonRes) {
845 LLVM_DEBUG(dbgs() << " - No UncommonRes identified\n");
846 return nullptr;
847 }
848
849 assert(PartialMatch.first && PartialMatch.second);
850 CompositeNode *CommonRes =
851 identifyNode(PartialMatch.first, PartialMatch.second);
852 if (!CommonRes) {
853 LLVM_DEBUG(dbgs() << " - No CommonRes identified\n");
854 return nullptr;
855 }
856
857 CompositeNode *Node = prepareCompositeNode(
858 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
859 Node->Rotation = Rotation;
860 Node->addOperand(CommonRes);
861 Node->addOperand(UncommonRes);
862 Node->addOperand(CNode);
863 return submitCompositeNode(Node);
864}
865
866ComplexDeinterleavingGraph::CompositeNode *
867ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
868 LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n");
869
870 // Determine rotation
872 if ((Real->getOpcode() == Instruction::FSub &&
873 Imag->getOpcode() == Instruction::FAdd) ||
874 (Real->getOpcode() == Instruction::Sub &&
875 Imag->getOpcode() == Instruction::Add))
876 Rotation = ComplexDeinterleavingRotation::Rotation_90;
877 else if ((Real->getOpcode() == Instruction::FAdd &&
878 Imag->getOpcode() == Instruction::FSub) ||
879 (Real->getOpcode() == Instruction::Add &&
880 Imag->getOpcode() == Instruction::Sub))
881 Rotation = ComplexDeinterleavingRotation::Rotation_270;
882 else {
883 LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n");
884 return nullptr;
885 }
886
887 auto *AR = dyn_cast<Instruction>(Real->getOperand(0));
888 auto *BI = dyn_cast<Instruction>(Real->getOperand(1));
889 auto *AI = dyn_cast<Instruction>(Imag->getOperand(0));
890 auto *BR = dyn_cast<Instruction>(Imag->getOperand(1));
891
892 if (!AR || !AI || !BR || !BI) {
893 LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n");
894 return nullptr;
895 }
896
897 CompositeNode *ResA = identifyNode(AR, AI);
898 if (!ResA) {
899 LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
900 return nullptr;
901 }
902 CompositeNode *ResB = identifyNode(BR, BI);
903 if (!ResB) {
904 LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");
905 return nullptr;
906 }
907
908 CompositeNode *Node =
909 prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
910 Node->Rotation = Rotation;
911 Node->addOperand(ResA);
912 Node->addOperand(ResB);
913 return submitCompositeNode(Node);
914}
915
917 unsigned OpcA = A->getOpcode();
918 unsigned OpcB = B->getOpcode();
919
920 return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
921 (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
922 (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
923 (OpcA == Instruction::Add && OpcB == Instruction::Sub);
924}
925
927 auto Pattern =
929
930 return match(A, Pattern) && match(B, Pattern);
931}
932
934 switch (I->getOpcode()) {
935 case Instruction::FAdd:
936 case Instruction::FSub:
937 case Instruction::FMul:
938 case Instruction::FNeg:
939 case Instruction::Add:
940 case Instruction::Sub:
941 case Instruction::Mul:
942 return true;
943 default:
944 return false;
945 }
946}
947
948ComplexDeinterleavingGraph::CompositeNode *
949ComplexDeinterleavingGraph::identifySymmetricOperation(ComplexValues &Vals) {
950 auto *FirstReal = cast<Instruction>(Vals[0].Real);
951 unsigned FirstOpc = FirstReal->getOpcode();
952 for (auto &V : Vals) {
953 auto *Real = cast<Instruction>(V.Real);
954 auto *Imag = cast<Instruction>(V.Imag);
955 if (Real->getOpcode() != FirstOpc || Imag->getOpcode() != FirstOpc)
956 return nullptr;
957
960 return nullptr;
961
962 if (isa<FPMathOperator>(FirstReal))
963 if (Real->getFastMathFlags() != FirstReal->getFastMathFlags() ||
964 Imag->getFastMathFlags() != FirstReal->getFastMathFlags())
965 return nullptr;
966 }
967
968 ComplexValues OpVals;
969 for (auto &V : Vals) {
970 auto *R0 = cast<Instruction>(V.Real)->getOperand(0);
971 auto *I0 = cast<Instruction>(V.Imag)->getOperand(0);
972 OpVals.push_back({R0, I0});
973 }
974
975 CompositeNode *Op0 = identifyNode(OpVals);
976 CompositeNode *Op1 = nullptr;
977 if (Op0 == nullptr)
978 return nullptr;
979
980 if (FirstReal->isBinaryOp()) {
981 OpVals.clear();
982 for (auto &V : Vals) {
983 auto *R1 = cast<Instruction>(V.Real)->getOperand(1);
984 auto *I1 = cast<Instruction>(V.Imag)->getOperand(1);
985 OpVals.push_back({R1, I1});
986 }
987 Op1 = identifyNode(OpVals);
988 if (Op1 == nullptr)
989 return nullptr;
990 }
991
992 auto Node =
993 prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric, Vals);
994 Node->Opcode = FirstReal->getOpcode();
995 if (isa<FPMathOperator>(FirstReal))
996 Node->Flags = FirstReal->getFastMathFlags();
997
998 Node->addOperand(Op0);
999 if (FirstReal->isBinaryOp())
1000 Node->addOperand(Op1);
1001
1002 return submitCompositeNode(Node);
1003}
1004
1005ComplexDeinterleavingGraph::CompositeNode *
1006ComplexDeinterleavingGraph::identifyDotProduct(Value *V) {
1008 ComplexDeinterleavingOperation::CDot, V->getType())) {
1009 LLVM_DEBUG(dbgs() << "Target doesn't support complex deinterleaving "
1010 "operation CDot with the type "
1011 << *V->getType() << "\n");
1012 return nullptr;
1013 }
1014
1015 auto *Inst = cast<Instruction>(V);
1016 auto *RealUser = cast<Instruction>(*Inst->user_begin());
1017
1018 CompositeNode *CN =
1019 prepareCompositeNode(ComplexDeinterleavingOperation::CDot, Inst, nullptr);
1020
1021 CompositeNode *ANode = nullptr;
1022
1023 const Intrinsic::ID PartialReduceInt = Intrinsic::vector_partial_reduce_add;
1024
1025 Value *AReal = nullptr;
1026 Value *AImag = nullptr;
1027 Value *BReal = nullptr;
1028 Value *BImag = nullptr;
1029 Value *Phi = nullptr;
1030
1031 auto UnwrapCast = [](Value *V) -> Value * {
1032 if (auto *CI = dyn_cast<CastInst>(V))
1033 return CI->getOperand(0);
1034 return V;
1035 };
1036
1037 auto PatternRot0 = m_Intrinsic<PartialReduceInt>(
1039 m_Mul(m_Value(BReal), m_Value(AReal))),
1040 m_Neg(m_Mul(m_Value(BImag), m_Value(AImag))));
1041
1042 auto PatternRot270 = m_Intrinsic<PartialReduceInt>(
1044 m_Value(Phi), m_Neg(m_Mul(m_Value(BReal), m_Value(AImag)))),
1045 m_Mul(m_Value(BImag), m_Value(AReal)));
1046
1047 if (match(Inst, PatternRot0)) {
1048 CN->Rotation = ComplexDeinterleavingRotation::Rotation_0;
1049 } else if (match(Inst, PatternRot270)) {
1050 CN->Rotation = ComplexDeinterleavingRotation::Rotation_270;
1051 } else {
1052 Value *A0, *A1;
1053 // The rotations 90 and 180 share the same operation pattern, so inspect the
1054 // order of the operands, identifying where the real and imaginary
1055 // components of A go, to discern between the aforementioned rotations.
1056 auto PatternRot90Rot180 = m_Intrinsic<PartialReduceInt>(
1058 m_Mul(m_Value(BReal), m_Value(A0))),
1059 m_Mul(m_Value(BImag), m_Value(A1)));
1060
1061 if (!match(Inst, PatternRot90Rot180))
1062 return nullptr;
1063
1064 A0 = UnwrapCast(A0);
1065 A1 = UnwrapCast(A1);
1066
1067 // Test if A0 is real/A1 is imag
1068 ANode = identifyNode(A0, A1);
1069 if (!ANode) {
1070 // Test if A0 is imag/A1 is real
1071 ANode = identifyNode(A1, A0);
1072 // Unable to identify operand components, thus unable to identify rotation
1073 if (!ANode)
1074 return nullptr;
1075 CN->Rotation = ComplexDeinterleavingRotation::Rotation_90;
1076 AReal = A1;
1077 AImag = A0;
1078 } else {
1079 AReal = A0;
1080 AImag = A1;
1081 CN->Rotation = ComplexDeinterleavingRotation::Rotation_180;
1082 }
1083 }
1084
1085 AReal = UnwrapCast(AReal);
1086 AImag = UnwrapCast(AImag);
1087 BReal = UnwrapCast(BReal);
1088 BImag = UnwrapCast(BImag);
1089
1090 VectorType *VTy = cast<VectorType>(V->getType());
1091 Type *ExpectedOperandTy = VectorType::getSubdividedVectorType(VTy, 2);
1092 if (AReal->getType() != ExpectedOperandTy)
1093 return nullptr;
1094 if (AImag->getType() != ExpectedOperandTy)
1095 return nullptr;
1096 if (BReal->getType() != ExpectedOperandTy)
1097 return nullptr;
1098 if (BImag->getType() != ExpectedOperandTy)
1099 return nullptr;
1100
1101 if (Phi->getType() != VTy && RealUser->getType() != VTy)
1102 return nullptr;
1103
1104 CompositeNode *Node = identifyNode(AReal, AImag);
1105
1106 // In the case that a node was identified to figure out the rotation, ensure
1107 // that trying to identify a node with AReal and AImag post-unwrap results in
1108 // the same node
1109 if (ANode && Node != ANode) {
1110 LLVM_DEBUG(
1111 dbgs()
1112 << "Identified node is different from previously identified node. "
1113 "Unable to confidently generate a complex operation node\n");
1114 return nullptr;
1115 }
1116
1117 CN->addOperand(Node);
1118 CN->addOperand(identifyNode(BReal, BImag));
1119 CN->addOperand(identifyNode(Phi, RealUser));
1120
1121 return submitCompositeNode(CN);
1122}
1123
1124ComplexDeinterleavingGraph::CompositeNode *
1125ComplexDeinterleavingGraph::identifyPartialReduction(Value *R, Value *I) {
1126 // Partial reductions don't support non-vector types, so check these first
1127 if (!isa<VectorType>(R->getType()) || !isa<VectorType>(I->getType()))
1128 return nullptr;
1129
1130 if (!R->hasUseList() || !I->hasUseList())
1131 return nullptr;
1132
1133 auto CommonUser =
1134 findCommonBetweenCollections<Value *>(R->users(), I->users());
1135 if (!CommonUser)
1136 return nullptr;
1137
1138 auto *IInst = dyn_cast<IntrinsicInst>(*CommonUser);
1139 if (!IInst || IInst->getIntrinsicID() != Intrinsic::vector_partial_reduce_add)
1140 return nullptr;
1141
1142 if (CompositeNode *CN = identifyDotProduct(IInst))
1143 return CN;
1144
1145 return nullptr;
1146}
1147
1148ComplexDeinterleavingGraph::CompositeNode *
1149ComplexDeinterleavingGraph::identifyNode(ComplexValues &Vals) {
1150 auto It = CachedResult.find(Vals);
1151 if (It != CachedResult.end()) {
1152 LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
1153 return It->second;
1154 }
1155
1156 if (Vals.size() == 1) {
1157 assert(Factor == 2 && "Can only handle interleave factors of 2");
1158 Value *R = Vals[0].Real;
1159 Value *I = Vals[0].Imag;
1160 if (CompositeNode *CN = identifyPartialReduction(R, I))
1161 return CN;
1162 bool IsReduction = RealPHI == R && (!ImagPHI || ImagPHI == I);
1163 if (!IsReduction && R->getType() != I->getType())
1164 return nullptr;
1165 }
1166
1167 if (CompositeNode *CN = identifySplat(Vals))
1168 return CN;
1169
1170 for (auto &V : Vals) {
1171 auto *Real = dyn_cast<Instruction>(V.Real);
1172 auto *Imag = dyn_cast<Instruction>(V.Imag);
1173 if (!Real || !Imag)
1174 return nullptr;
1175 }
1176
1177 if (CompositeNode *CN = identifyDeinterleave(Vals))
1178 return CN;
1179
1180 if (Vals.size() == 1) {
1181 assert(Factor == 2 && "Can only handle interleave factors of 2");
1182 auto *Real = dyn_cast<Instruction>(Vals[0].Real);
1183 auto *Imag = dyn_cast<Instruction>(Vals[0].Imag);
1184 if (CompositeNode *CN = identifyPHINode(Real, Imag))
1185 return CN;
1186
1187 if (CompositeNode *CN = identifySelectNode(Real, Imag))
1188 return CN;
1189
1190 auto *VTy = cast<VectorType>(Real->getType());
1191 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
1192
1193 bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported(
1194 ComplexDeinterleavingOperation::CMulPartial, NewVTy);
1195 bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported(
1196 ComplexDeinterleavingOperation::CAdd, NewVTy);
1197
1198 if (HasCMulSupport && isInstructionPairMul(Real, Imag)) {
1199 if (CompositeNode *CN = identifyPartialMul(Real, Imag))
1200 return CN;
1201 }
1202
1203 if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) {
1204 if (CompositeNode *CN = identifyAdd(Real, Imag))
1205 return CN;
1206 }
1207
1208 if (HasCMulSupport && HasCAddSupport) {
1209 if (CompositeNode *CN = identifyReassocNodes(Real, Imag)) {
1210 return CN;
1211 }
1212 }
1213 }
1214
1215 if (CompositeNode *CN = identifySymmetricOperation(Vals))
1216 return CN;
1217
1218 LLVM_DEBUG(dbgs() << " - Not recognised as a valid pattern.\n");
1219 CachedResult[Vals] = nullptr;
1220 return nullptr;
1221}
1222
1223ComplexDeinterleavingGraph::CompositeNode *
1224ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
1225 Instruction *Imag) {
1226 auto IsOperationSupported = [](unsigned Opcode) -> bool {
1227 return Opcode == Instruction::FAdd || Opcode == Instruction::FSub ||
1228 Opcode == Instruction::FNeg || Opcode == Instruction::Add ||
1229 Opcode == Instruction::Sub;
1230 };
1231
1232 if (!IsOperationSupported(Real->getOpcode()) ||
1233 !IsOperationSupported(Imag->getOpcode()))
1234 return nullptr;
1235
1236 std::optional<FastMathFlags> Flags;
1237 if (isa<FPMathOperator>(Real)) {
1238 if (Real->getFastMathFlags() != Imag->getFastMathFlags()) {
1239 LLVM_DEBUG(dbgs() << "The flags in Real and Imaginary instructions are "
1240 "not identical\n");
1241 return nullptr;
1242 }
1243
1244 Flags = Real->getFastMathFlags();
1245 if (!Flags->allowReassoc()) {
1246 LLVM_DEBUG(
1247 dbgs()
1248 << "the 'Reassoc' attribute is missing in the FastMath flags\n");
1249 return nullptr;
1250 }
1251 }
1252
1253 // Collect multiplications and addend instructions from the given instruction
1254 // while traversing it operands. Additionally, verify that all instructions
1255 // have the same fast math flags.
1256 auto Collect = [&Flags](Instruction *Insn, SmallVectorImpl<Product> &Muls,
1257 AddendList &Addends) -> bool {
1258 SmallVector<PointerIntPair<Value *, 1, bool>> Worklist = {{Insn, true}};
1259 SmallPtrSet<Value *, 8> Visited;
1260 while (!Worklist.empty()) {
1261 auto [V, IsPositive] = Worklist.pop_back_val();
1262 if (!Visited.insert(V).second)
1263 continue;
1264
1266 if (!I) {
1267 Addends.emplace_back(V, IsPositive);
1268 continue;
1269 }
1270
1271 // If an instruction has more than one user, it indicates that it either
1272 // has an external user, which will be later checked by the checkNodes
1273 // function, or it is a subexpression utilized by multiple expressions. In
1274 // the latter case, we will attempt to separately identify the complex
1275 // operation from here in order to create a shared
1276 // ComplexDeinterleavingCompositeNode.
1277 if (I != Insn && I->hasNUsesOrMore(2)) {
1278 LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n");
1279 Addends.emplace_back(I, IsPositive);
1280 continue;
1281 }
1282 switch (I->getOpcode()) {
1283 case Instruction::FAdd:
1284 case Instruction::Add:
1285 Worklist.emplace_back(I->getOperand(1), IsPositive);
1286 Worklist.emplace_back(I->getOperand(0), IsPositive);
1287 break;
1288 case Instruction::FSub:
1289 Worklist.emplace_back(I->getOperand(1), !IsPositive);
1290 Worklist.emplace_back(I->getOperand(0), IsPositive);
1291 break;
1292 case Instruction::Sub:
1293 if (isNeg(I)) {
1294 Worklist.emplace_back(getNegOperand(I), !IsPositive);
1295 } else {
1296 Worklist.emplace_back(I->getOperand(1), !IsPositive);
1297 Worklist.emplace_back(I->getOperand(0), IsPositive);
1298 }
1299 break;
1300 case Instruction::FMul:
1301 case Instruction::Mul: {
1302 Value *A, *B;
1303 if (isNeg(I->getOperand(0))) {
1304 A = getNegOperand(I->getOperand(0));
1305 IsPositive = !IsPositive;
1306 } else {
1307 A = I->getOperand(0);
1308 }
1309
1310 if (isNeg(I->getOperand(1))) {
1311 B = getNegOperand(I->getOperand(1));
1312 IsPositive = !IsPositive;
1313 } else {
1314 B = I->getOperand(1);
1315 }
1316 Muls.push_back(Product{A, B, IsPositive});
1317 break;
1318 }
1319 case Instruction::FNeg:
1320 Worklist.emplace_back(I->getOperand(0), !IsPositive);
1321 break;
1322 default:
1323 Addends.emplace_back(I, IsPositive);
1324 continue;
1325 }
1326
1327 if (Flags && I->getFastMathFlags() != *Flags) {
1328 LLVM_DEBUG(dbgs() << "The instruction's fast math flags are "
1329 "inconsistent with the root instructions' flags: "
1330 << *I << "\n");
1331 return false;
1332 }
1333 }
1334 return true;
1335 };
1336
1337 SmallVector<Product> RealMuls, ImagMuls;
1338 AddendList RealAddends, ImagAddends;
1339 if (!Collect(Real, RealMuls, RealAddends) ||
1340 !Collect(Imag, ImagMuls, ImagAddends))
1341 return nullptr;
1342
1343 if (RealAddends.size() != ImagAddends.size())
1344 return nullptr;
1345
1346 CompositeNode *FinalNode = nullptr;
1347 if (!RealMuls.empty() || !ImagMuls.empty()) {
1348 // If there are multiplicands, extract positive addend and use it as an
1349 // accumulator
1350 FinalNode = extractPositiveAddend(RealAddends, ImagAddends);
1351 FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode);
1352 if (!FinalNode)
1353 return nullptr;
1354 }
1355
1356 // Identify and process remaining additions
1357 if (!RealAddends.empty() || !ImagAddends.empty()) {
1358 FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode);
1359 if (!FinalNode)
1360 return nullptr;
1361 }
1362 assert(FinalNode && "FinalNode can not be nullptr here");
1363 assert(FinalNode->Vals.size() == 1);
1364 // Set the Real and Imag fields of the final node and submit it
1365 FinalNode->Vals[0].Real = Real;
1366 FinalNode->Vals[0].Imag = Imag;
1367 submitCompositeNode(FinalNode);
1368 return FinalNode;
1369}
1370
1371bool ComplexDeinterleavingGraph::collectPartialMuls(
1372 ArrayRef<Product> RealMuls, ArrayRef<Product> ImagMuls,
1373 SmallVectorImpl<PartialMulCandidate> &PartialMulCandidates) {
1374 // Helper function to extract a common operand from two products
1375 auto FindCommonInstruction = [](const Product &Real,
1376 const Product &Imag) -> Value * {
1377 if (Real.Multiplicand == Imag.Multiplicand ||
1378 Real.Multiplicand == Imag.Multiplier)
1379 return Real.Multiplicand;
1380
1381 if (Real.Multiplier == Imag.Multiplicand ||
1382 Real.Multiplier == Imag.Multiplier)
1383 return Real.Multiplier;
1384
1385 return nullptr;
1386 };
1387
1388 // Iterating over real and imaginary multiplications to find common operands
1389 // If a common operand is found, a partial multiplication candidate is created
1390 // and added to the candidates vector The function returns false if no common
1391 // operands are found for any product
1392 for (unsigned i = 0; i < RealMuls.size(); ++i) {
1393 bool FoundCommon = false;
1394 for (unsigned j = 0; j < ImagMuls.size(); ++j) {
1395 auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]);
1396 if (!Common)
1397 continue;
1398
1399 auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier
1400 : RealMuls[i].Multiplicand;
1401 auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier
1402 : ImagMuls[j].Multiplicand;
1403
1404 auto Node = identifyNode(A, B);
1405 if (Node) {
1406 FoundCommon = true;
1407 PartialMulCandidates.push_back({Common, Node, i, j, false});
1408 }
1409
1410 Node = identifyNode(B, A);
1411 if (Node) {
1412 FoundCommon = true;
1413 PartialMulCandidates.push_back({Common, Node, i, j, true});
1414 }
1415 }
1416 if (!FoundCommon)
1417 return false;
1418 }
1419 return true;
1420}
1421
1422ComplexDeinterleavingGraph::CompositeNode *
1423ComplexDeinterleavingGraph::identifyMultiplications(
1424 SmallVectorImpl<Product> &RealMuls, SmallVectorImpl<Product> &ImagMuls,
1425 CompositeNode *Accumulator = nullptr) {
1426 if (RealMuls.size() != ImagMuls.size())
1427 return nullptr;
1428
1430 if (!collectPartialMuls(RealMuls, ImagMuls, Info))
1431 return nullptr;
1432
1433 // Map to store common instruction to node pointers
1434 DenseMap<Value *, CompositeNode *> CommonToNode;
1435 SmallVector<bool> Processed(Info.size(), false);
1436 for (unsigned I = 0; I < Info.size(); ++I) {
1437 if (Processed[I])
1438 continue;
1439
1440 PartialMulCandidate &InfoA = Info[I];
1441 for (unsigned J = I + 1; J < Info.size(); ++J) {
1442 if (Processed[J])
1443 continue;
1444
1445 PartialMulCandidate &InfoB = Info[J];
1446 auto *InfoReal = &InfoA;
1447 auto *InfoImag = &InfoB;
1448
1449 auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1450 if (!NodeFromCommon) {
1451 std::swap(InfoReal, InfoImag);
1452 NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1453 }
1454 if (!NodeFromCommon)
1455 continue;
1456
1457 CommonToNode[InfoReal->Common] = NodeFromCommon;
1458 CommonToNode[InfoImag->Common] = NodeFromCommon;
1459 Processed[I] = true;
1460 Processed[J] = true;
1461 }
1462 }
1463
1464 SmallVector<bool> ProcessedReal(RealMuls.size(), false);
1465 SmallVector<bool> ProcessedImag(ImagMuls.size(), false);
1466 CompositeNode *Result = Accumulator;
1467 for (auto &PMI : Info) {
1468 if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx])
1469 continue;
1470
1471 auto It = CommonToNode.find(PMI.Common);
1472 // TODO: Process independent complex multiplications. Cases like this:
1473 // A.real() * B where both A and B are complex numbers.
1474 if (It == CommonToNode.end()) {
1475 LLVM_DEBUG({
1476 dbgs() << "Unprocessed independent partial multiplication:\n";
1477 for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]})
1478 dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier
1479 << " multiplied by " << *Mul->Multiplicand << "\n";
1480 });
1481 return nullptr;
1482 }
1483
1484 auto &RealMul = RealMuls[PMI.RealIdx];
1485 auto &ImagMul = ImagMuls[PMI.ImagIdx];
1486
1487 auto NodeA = It->second;
1488 auto NodeB = PMI.Node;
1489 auto IsMultiplicandReal = PMI.Common == NodeA->Vals[0].Real;
1490 // The following table illustrates the relationship between multiplications
1491 // and rotations. If we consider the multiplication (X + iY) * (U + iV), we
1492 // can see:
1493 //
1494 // Rotation | Real | Imag |
1495 // ---------+--------+--------+
1496 // 0 | x * u | x * v |
1497 // 90 | -y * v | y * u |
1498 // 180 | -x * u | -x * v |
1499 // 270 | y * v | -y * u |
1500 //
1501 // Check if the candidate can indeed be represented by partial
1502 // multiplication
1503 // TODO: Add support for multiplication by complex one
1504 if ((IsMultiplicandReal && PMI.IsNodeInverted) ||
1505 (!IsMultiplicandReal && !PMI.IsNodeInverted))
1506 continue;
1507
1508 // Determine the rotation based on the multiplications
1510 if (IsMultiplicandReal) {
1511 // Detect 0 and 180 degrees rotation
1512 if (RealMul.IsPositive && ImagMul.IsPositive)
1514 else if (!RealMul.IsPositive && !ImagMul.IsPositive)
1516 else
1517 continue;
1518
1519 } else {
1520 // Detect 90 and 270 degrees rotation
1521 if (!RealMul.IsPositive && ImagMul.IsPositive)
1523 else if (RealMul.IsPositive && !ImagMul.IsPositive)
1525 else
1526 continue;
1527 }
1528
1529 LLVM_DEBUG({
1530 dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n";
1531 dbgs().indent(4) << "X: " << *NodeA->Vals[0].Real << "\n";
1532 dbgs().indent(4) << "Y: " << *NodeA->Vals[0].Imag << "\n";
1533 dbgs().indent(4) << "U: " << *NodeB->Vals[0].Real << "\n";
1534 dbgs().indent(4) << "V: " << *NodeB->Vals[0].Imag << "\n";
1535 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1536 });
1537
1538 CompositeNode *NodeMul = prepareCompositeNode(
1539 ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr);
1540 NodeMul->Rotation = Rotation;
1541 NodeMul->addOperand(NodeA);
1542 NodeMul->addOperand(NodeB);
1543 if (Result)
1544 NodeMul->addOperand(Result);
1545 submitCompositeNode(NodeMul);
1546 Result = NodeMul;
1547 ProcessedReal[PMI.RealIdx] = true;
1548 ProcessedImag[PMI.ImagIdx] = true;
1549 }
1550
1551 // Ensure all products have been processed, if not return nullptr.
1552 if (!all_of(ProcessedReal, [](bool V) { return V; }) ||
1553 !all_of(ProcessedImag, [](bool V) { return V; })) {
1554
1555 // Dump debug information about which partial multiplications are not
1556 // processed.
1557 LLVM_DEBUG({
1558 dbgs() << "Unprocessed products (Real):\n";
1559 for (size_t i = 0; i < ProcessedReal.size(); ++i) {
1560 if (!ProcessedReal[i])
1561 dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-")
1562 << *RealMuls[i].Multiplier << " multiplied by "
1563 << *RealMuls[i].Multiplicand << "\n";
1564 }
1565 dbgs() << "Unprocessed products (Imag):\n";
1566 for (size_t i = 0; i < ProcessedImag.size(); ++i) {
1567 if (!ProcessedImag[i])
1568 dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-")
1569 << *ImagMuls[i].Multiplier << " multiplied by "
1570 << *ImagMuls[i].Multiplicand << "\n";
1571 }
1572 });
1573 return nullptr;
1574 }
1575
1576 return Result;
1577}
1578
1579ComplexDeinterleavingGraph::CompositeNode *
1580ComplexDeinterleavingGraph::identifyAdditions(
1581 AddendList &RealAddends, AddendList &ImagAddends,
1582 std::optional<FastMathFlags> Flags, CompositeNode *Accumulator = nullptr) {
1583 if (RealAddends.size() != ImagAddends.size())
1584 return nullptr;
1585
1586 CompositeNode *Result = nullptr;
1587 // If we have accumulator use it as first addend
1588 if (Accumulator)
1590 // Otherwise find an element with both positive real and imaginary parts.
1591 else
1592 Result = extractPositiveAddend(RealAddends, ImagAddends);
1593
1594 if (!Result)
1595 return nullptr;
1596
1597 while (!RealAddends.empty()) {
1598 auto ItR = RealAddends.begin();
1599 auto [R, IsPositiveR] = *ItR;
1600
1601 bool FoundImag = false;
1602 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1603 auto [I, IsPositiveI] = *ItI;
1605 if (IsPositiveR && IsPositiveI)
1606 Rotation = ComplexDeinterleavingRotation::Rotation_0;
1607 else if (!IsPositiveR && IsPositiveI)
1608 Rotation = ComplexDeinterleavingRotation::Rotation_90;
1609 else if (!IsPositiveR && !IsPositiveI)
1610 Rotation = ComplexDeinterleavingRotation::Rotation_180;
1611 else
1612 Rotation = ComplexDeinterleavingRotation::Rotation_270;
1613
1614 CompositeNode *AddNode = nullptr;
1615 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
1616 Rotation == ComplexDeinterleavingRotation::Rotation_180) {
1617 AddNode = identifyNode(R, I);
1618 } else {
1619 AddNode = identifyNode(I, R);
1620 }
1621 if (AddNode) {
1622 LLVM_DEBUG({
1623 dbgs() << "Identified addition:\n";
1624 dbgs().indent(4) << "X: " << *R << "\n";
1625 dbgs().indent(4) << "Y: " << *I << "\n";
1626 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1627 });
1628
1629 CompositeNode *TmpNode = nullptr;
1631 TmpNode = prepareCompositeNode(
1632 ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1633 if (Flags) {
1634 TmpNode->Opcode = Instruction::FAdd;
1635 TmpNode->Flags = *Flags;
1636 } else {
1637 TmpNode->Opcode = Instruction::Add;
1638 }
1639 } else if (Rotation ==
1641 TmpNode = prepareCompositeNode(
1642 ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1643 if (Flags) {
1644 TmpNode->Opcode = Instruction::FSub;
1645 TmpNode->Flags = *Flags;
1646 } else {
1647 TmpNode->Opcode = Instruction::Sub;
1648 }
1649 } else {
1650 TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd,
1651 nullptr, nullptr);
1652 TmpNode->Rotation = Rotation;
1653 }
1654
1655 TmpNode->addOperand(Result);
1656 TmpNode->addOperand(AddNode);
1657 submitCompositeNode(TmpNode);
1658 Result = TmpNode;
1659 RealAddends.erase(ItR);
1660 ImagAddends.erase(ItI);
1661 FoundImag = true;
1662 break;
1663 }
1664 }
1665 if (!FoundImag)
1666 return nullptr;
1667 }
1668 return Result;
1669}
1670
1671ComplexDeinterleavingGraph::CompositeNode *
1672ComplexDeinterleavingGraph::extractPositiveAddend(AddendList &RealAddends,
1673 AddendList &ImagAddends) {
1674 for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) {
1675 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1676 auto [R, IsPositiveR] = *ItR;
1677 auto [I, IsPositiveI] = *ItI;
1678 if (IsPositiveR && IsPositiveI) {
1679 auto Result = identifyNode(R, I);
1680 if (Result) {
1681 RealAddends.erase(ItR);
1682 ImagAddends.erase(ItI);
1683 return Result;
1684 }
1685 }
1686 }
1687 }
1688 return nullptr;
1689}
1690
1691bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
1692 // This potential root instruction might already have been recognized as
1693 // reduction. Because RootToNode maps both Real and Imaginary parts to
1694 // CompositeNode we should choose only one either Real or Imag instruction to
1695 // use as an anchor for generating complex instruction.
1696 auto It = RootToNode.find(RootI);
1697 if (It != RootToNode.end()) {
1698 auto RootNode = It->second;
1699 assert(RootNode->Operation ==
1700 ComplexDeinterleavingOperation::ReductionOperation ||
1701 RootNode->Operation ==
1702 ComplexDeinterleavingOperation::ReductionSingle);
1703 assert(RootNode->Vals.size() == 1 &&
1704 "Cannot handle reductions involving multiple complex values");
1705 // Find out which part, Real or Imag, comes later, and only if we come to
1706 // the latest part, add it to OrderedRoots.
1707 auto *R = cast<Instruction>(RootNode->Vals[0].Real);
1708 auto *I = RootNode->Vals[0].Imag ? cast<Instruction>(RootNode->Vals[0].Imag)
1709 : nullptr;
1710
1711 Instruction *ReplacementAnchor;
1712 if (I)
1713 ReplacementAnchor = R->comesBefore(I) ? I : R;
1714 else
1715 ReplacementAnchor = R;
1716
1717 if (ReplacementAnchor != RootI)
1718 return false;
1719 OrderedRoots.push_back(RootI);
1720 return true;
1721 }
1722
1723 auto RootNode = identifyRoot(RootI);
1724 if (!RootNode)
1725 return false;
1726
1727 LLVM_DEBUG({
1728 Function *F = RootI->getFunction();
1729 BasicBlock *B = RootI->getParent();
1730 dbgs() << "Complex deinterleaving graph for " << F->getName()
1731 << "::" << B->getName() << ".\n";
1732 dump(dbgs());
1733 dbgs() << "\n";
1734 });
1735 RootToNode[RootI] = RootNode;
1736 OrderedRoots.push_back(RootI);
1737 return true;
1738}
1739
1740bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) {
1741 bool FoundPotentialReduction = false;
1742 if (Factor != 2)
1743 return false;
1744
1745 auto *Br = dyn_cast<BranchInst>(B->getTerminator());
1746 if (!Br || Br->getNumSuccessors() != 2)
1747 return false;
1748
1749 // Identify simple one-block loop
1750 if (Br->getSuccessor(0) != B && Br->getSuccessor(1) != B)
1751 return false;
1752
1753 for (auto &PHI : B->phis()) {
1754 if (PHI.getNumIncomingValues() != 2)
1755 continue;
1756
1757 if (!PHI.getType()->isVectorTy())
1758 continue;
1759
1760 auto *ReductionOp = dyn_cast<Instruction>(PHI.getIncomingValueForBlock(B));
1761 if (!ReductionOp)
1762 continue;
1763
1764 // Check if final instruction is reduced outside of current block
1765 Instruction *FinalReduction = nullptr;
1766 auto NumUsers = 0u;
1767 for (auto *U : ReductionOp->users()) {
1768 ++NumUsers;
1769 if (U == &PHI)
1770 continue;
1771 FinalReduction = dyn_cast<Instruction>(U);
1772 }
1773
1774 if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B ||
1775 isa<PHINode>(FinalReduction))
1776 continue;
1777
1778 ReductionInfo[ReductionOp] = {&PHI, FinalReduction};
1779 BackEdge = B;
1780 auto BackEdgeIdx = PHI.getBasicBlockIndex(B);
1781 auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0;
1782 Incoming = PHI.getIncomingBlock(IncomingIdx);
1783 FoundPotentialReduction = true;
1784
1785 // If the initial value of PHINode is an Instruction, consider it a leaf
1786 // value of a complex deinterleaving graph.
1787 if (auto *InitPHI =
1788 dyn_cast<Instruction>(PHI.getIncomingValueForBlock(Incoming)))
1789 FinalInstructions.insert(InitPHI);
1790 }
1791 return FoundPotentialReduction;
1792}
1793
1794void ComplexDeinterleavingGraph::identifyReductionNodes() {
1795 assert(Factor == 2 && "Cannot handle multiple complex values");
1796
1797 SmallVector<bool> Processed(ReductionInfo.size(), false);
1798 SmallVector<Instruction *> OperationInstruction;
1799 for (auto &P : ReductionInfo)
1800 OperationInstruction.push_back(P.first);
1801
1802 // Identify a complex computation by evaluating two reduction operations that
1803 // potentially could be involved
1804 for (size_t i = 0; i < OperationInstruction.size(); ++i) {
1805 if (Processed[i])
1806 continue;
1807 for (size_t j = i + 1; j < OperationInstruction.size(); ++j) {
1808 if (Processed[j])
1809 continue;
1810 auto *Real = OperationInstruction[i];
1811 auto *Imag = OperationInstruction[j];
1812 if (Real->getType() != Imag->getType())
1813 continue;
1814
1815 RealPHI = ReductionInfo[Real].first;
1816 ImagPHI = ReductionInfo[Imag].first;
1817 PHIsFound = false;
1818 auto Node = identifyNode(Real, Imag);
1819 if (!Node) {
1820 std::swap(Real, Imag);
1821 std::swap(RealPHI, ImagPHI);
1822 Node = identifyNode(Real, Imag);
1823 }
1824
1825 // If a node is identified and reduction PHINode is used in the chain of
1826 // operations, mark its operation instructions as used to prevent
1827 // re-identification and attach the node to the real part
1828 if (Node && PHIsFound) {
1829 LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: "
1830 << *Real << " / " << *Imag << "\n");
1831 Processed[i] = true;
1832 Processed[j] = true;
1833 auto RootNode = prepareCompositeNode(
1834 ComplexDeinterleavingOperation::ReductionOperation, Real, Imag);
1835 RootNode->addOperand(Node);
1836 RootToNode[Real] = RootNode;
1837 RootToNode[Imag] = RootNode;
1838 submitCompositeNode(RootNode);
1839 break;
1840 }
1841 }
1842
1843 auto *Real = OperationInstruction[i];
1844 // We want to check that we have 2 operands, but the function attributes
1845 // being counted as operands bloats this value.
1846 if (Processed[i] || Real->getNumOperands() < 2)
1847 continue;
1848
1849 // Can only combined integer reductions at the moment.
1850 if (!ReductionInfo[Real].second->getType()->isIntegerTy())
1851 continue;
1852
1853 RealPHI = ReductionInfo[Real].first;
1854 ImagPHI = nullptr;
1855 PHIsFound = false;
1856 auto Node = identifyNode(Real->getOperand(0), Real->getOperand(1));
1857 if (Node && PHIsFound) {
1858 LLVM_DEBUG(
1859 dbgs() << "Identified single reduction starting from instruction: "
1860 << *Real << "/" << *ReductionInfo[Real].second << "\n");
1861
1862 // Reducing to a single vector is not supported, only permit reducing down
1863 // to scalar values.
1864 // Doing this here will leave the prior node in the graph,
1865 // however with no uses the node will be unreachable by the replacement
1866 // process. That along with the usage outside the graph should prevent the
1867 // replacement process from kicking off at all for this graph.
1868 // TODO Add support for reducing to a single vector value
1869 if (ReductionInfo[Real].second->getType()->isVectorTy())
1870 continue;
1871
1872 Processed[i] = true;
1873 auto RootNode = prepareCompositeNode(
1874 ComplexDeinterleavingOperation::ReductionSingle, Real, nullptr);
1875 RootNode->addOperand(Node);
1876 RootToNode[Real] = RootNode;
1877 submitCompositeNode(RootNode);
1878 }
1879 }
1880
1881 RealPHI = nullptr;
1882 ImagPHI = nullptr;
1883}
1884
1885bool ComplexDeinterleavingGraph::checkNodes() {
1886 bool FoundDeinterleaveNode = false;
1887 for (CompositeNode *N : CompositeNodes) {
1888 if (!N->areOperandsValid())
1889 return false;
1890
1891 if (N->Operation == ComplexDeinterleavingOperation::Deinterleave)
1892 FoundDeinterleaveNode = true;
1893 }
1894
1895 // We need a deinterleave node in order to guarantee that we're working with
1896 // complex numbers.
1897 if (!FoundDeinterleaveNode) {
1898 LLVM_DEBUG(
1899 dbgs() << "Couldn't find a deinterleave node within the graph, cannot "
1900 "guarantee safety during graph transformation.\n");
1901 return false;
1902 }
1903
1904 // Collect all instructions from roots to leaves
1905 SmallPtrSet<Instruction *, 16> AllInstructions;
1906 SmallVector<Instruction *, 8> Worklist;
1907 for (auto &Pair : RootToNode)
1908 Worklist.push_back(Pair.first);
1909
1910 // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG
1911 // chains
1912 while (!Worklist.empty()) {
1913 auto *I = Worklist.pop_back_val();
1914
1915 if (!AllInstructions.insert(I).second)
1916 continue;
1917
1918 for (Value *Op : I->operands()) {
1919 if (auto *OpI = dyn_cast<Instruction>(Op)) {
1920 if (!FinalInstructions.count(I))
1921 Worklist.emplace_back(OpI);
1922 }
1923 }
1924 }
1925
1926 // Find instructions that have users outside of chain
1927 for (auto *I : AllInstructions) {
1928 // Skip root nodes
1929 if (RootToNode.count(I))
1930 continue;
1931
1932 for (User *U : I->users()) {
1933 if (AllInstructions.count(cast<Instruction>(U)))
1934 continue;
1935
1936 // Found an instruction that is not used by XCMLA/XCADD chain
1937 Worklist.emplace_back(I);
1938 break;
1939 }
1940 }
1941
1942 // If any instructions are found to be used outside, find and remove roots
1943 // that somehow connect to those instructions.
1944 SmallPtrSet<Instruction *, 16> Visited;
1945 while (!Worklist.empty()) {
1946 auto *I = Worklist.pop_back_val();
1947 if (!Visited.insert(I).second)
1948 continue;
1949
1950 // Found an impacted root node. Removing it from the nodes to be
1951 // deinterleaved
1952 if (RootToNode.count(I)) {
1953 LLVM_DEBUG(dbgs() << "Instruction " << *I
1954 << " could be deinterleaved but its chain of complex "
1955 "operations have an outside user\n");
1956 RootToNode.erase(I);
1957 }
1958
1959 if (!AllInstructions.count(I) || FinalInstructions.count(I))
1960 continue;
1961
1962 for (User *U : I->users())
1963 Worklist.emplace_back(cast<Instruction>(U));
1964
1965 for (Value *Op : I->operands()) {
1966 if (auto *OpI = dyn_cast<Instruction>(Op))
1967 Worklist.emplace_back(OpI);
1968 }
1969 }
1970 return !RootToNode.empty();
1971}
1972
1973ComplexDeinterleavingGraph::CompositeNode *
1974ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) {
1975 if (auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) {
1977 Intrinsic->getIntrinsicID())
1978 return nullptr;
1979
1980 ComplexValues Vals;
1981 for (unsigned I = 0; I < Factor; I += 2) {
1982 auto *Real = dyn_cast<Instruction>(Intrinsic->getOperand(I));
1983 auto *Imag = dyn_cast<Instruction>(Intrinsic->getOperand(I + 1));
1984 if (!Real || !Imag)
1985 return nullptr;
1986 Vals.push_back({Real, Imag});
1987 }
1988
1989 ComplexDeinterleavingGraph::CompositeNode *Node1 = identifyNode(Vals);
1990 if (!Node1)
1991 return nullptr;
1992 return Node1;
1993 }
1994
1995 // TODO: We could also add support for fixed-width interleave factors of 4
1996 // and above, but currently for symmetric operations the interleaves and
1997 // deinterleaves are already removed by VectorCombine. If we extend this to
1998 // permit complex multiplications, reductions, etc. then we should also add
1999 // support for fixed-width here.
2000 if (Factor != 2)
2001 return nullptr;
2002
2003 auto *SVI = dyn_cast<ShuffleVectorInst>(RootI);
2004 if (!SVI)
2005 return nullptr;
2006
2007 // Look for a shufflevector that takes separate vectors of the real and
2008 // imaginary components and recombines them into a single vector.
2009 if (!isInterleavingMask(SVI->getShuffleMask()))
2010 return nullptr;
2011
2012 Instruction *Real;
2013 Instruction *Imag;
2014 if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag))))
2015 return nullptr;
2016
2017 return identifyNode(Real, Imag);
2018}
2019
2020ComplexDeinterleavingGraph::CompositeNode *
2021ComplexDeinterleavingGraph::identifyDeinterleave(ComplexValues &Vals) {
2022 Instruction *II = nullptr;
2023
2024 // Must be at least one complex value.
2025 auto CheckExtract = [&](Value *V, unsigned ExpectedIdx,
2026 Instruction *ExpectedInsn) -> ExtractValueInst * {
2027 auto *EVI = dyn_cast<ExtractValueInst>(V);
2028 if (!EVI || EVI->getNumIndices() != 1 ||
2029 EVI->getIndices()[0] != ExpectedIdx ||
2030 !isa<Instruction>(EVI->getAggregateOperand()) ||
2031 (ExpectedInsn && ExpectedInsn != EVI->getAggregateOperand()))
2032 return nullptr;
2033 return EVI;
2034 };
2035
2036 for (unsigned Idx = 0; Idx < Vals.size(); Idx++) {
2037 ExtractValueInst *RealEVI = CheckExtract(Vals[Idx].Real, Idx * 2, II);
2038 if (RealEVI && Idx == 0)
2040 if (!RealEVI || !CheckExtract(Vals[Idx].Imag, (Idx * 2) + 1, II)) {
2041 II = nullptr;
2042 break;
2043 }
2044 }
2045
2046 if (auto *IntrinsicII = dyn_cast_or_null<IntrinsicInst>(II)) {
2047 if (IntrinsicII->getIntrinsicID() !=
2049 return nullptr;
2050
2051 // The remaining should match too.
2052 CompositeNode *PlaceholderNode = prepareCompositeNode(
2054 PlaceholderNode->ReplacementNode = II->getOperand(0);
2055 for (auto &V : Vals) {
2056 FinalInstructions.insert(cast<Instruction>(V.Real));
2057 FinalInstructions.insert(cast<Instruction>(V.Imag));
2058 }
2059 return submitCompositeNode(PlaceholderNode);
2060 }
2061
2062 if (Vals.size() != 1)
2063 return nullptr;
2064
2065 Value *Real = Vals[0].Real;
2066 Value *Imag = Vals[0].Imag;
2067 auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real);
2068 auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag);
2069 if (!RealShuffle || !ImagShuffle) {
2070 if (RealShuffle || ImagShuffle)
2071 LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n");
2072 return nullptr;
2073 }
2074
2075 Value *RealOp1 = RealShuffle->getOperand(1);
2076 if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) {
2077 LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n");
2078 return nullptr;
2079 }
2080 Value *ImagOp1 = ImagShuffle->getOperand(1);
2081 if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) {
2082 LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n");
2083 return nullptr;
2084 }
2085
2086 Value *RealOp0 = RealShuffle->getOperand(0);
2087 Value *ImagOp0 = ImagShuffle->getOperand(0);
2088
2089 if (RealOp0 != ImagOp0) {
2090 LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n");
2091 return nullptr;
2092 }
2093
2094 ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
2095 ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
2096 if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) {
2097 LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n");
2098 return nullptr;
2099 }
2100
2101 if (RealMask[0] != 0 || ImagMask[0] != 1) {
2102 LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n");
2103 return nullptr;
2104 }
2105
2106 // Type checking, the shuffle type should be a vector type of the same
2107 // scalar type, but half the size
2108 auto CheckType = [&](ShuffleVectorInst *Shuffle) {
2109 Value *Op = Shuffle->getOperand(0);
2110 auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType());
2111 auto *OpTy = cast<FixedVectorType>(Op->getType());
2112
2113 if (OpTy->getScalarType() != ShuffleTy->getScalarType())
2114 return false;
2115 if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
2116 return false;
2117
2118 return true;
2119 };
2120
2121 auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool {
2122 if (!CheckType(Shuffle))
2123 return false;
2124
2125 ArrayRef<int> Mask = Shuffle->getShuffleMask();
2126 int Last = *Mask.rbegin();
2127
2128 Value *Op = Shuffle->getOperand(0);
2129 auto *OpTy = cast<FixedVectorType>(Op->getType());
2130 int NumElements = OpTy->getNumElements();
2131
2132 // Ensure that the deinterleaving shuffle only pulls from the first
2133 // shuffle operand.
2134 return Last < NumElements;
2135 };
2136
2137 if (RealShuffle->getType() != ImagShuffle->getType()) {
2138 LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n");
2139 return nullptr;
2140 }
2141 if (!CheckDeinterleavingShuffle(RealShuffle)) {
2142 LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n");
2143 return nullptr;
2144 }
2145 if (!CheckDeinterleavingShuffle(ImagShuffle)) {
2146 LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n");
2147 return nullptr;
2148 }
2149
2150 CompositeNode *PlaceholderNode =
2152 RealShuffle, ImagShuffle);
2153 PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
2154 FinalInstructions.insert(RealShuffle);
2155 FinalInstructions.insert(ImagShuffle);
2156 return submitCompositeNode(PlaceholderNode);
2157}
2158
2159ComplexDeinterleavingGraph::CompositeNode *
2160ComplexDeinterleavingGraph::identifySplat(ComplexValues &Vals) {
2161 auto IsSplat = [](Value *V) -> bool {
2162 // Fixed-width vector with constants
2164 return true;
2165
2166 if (isa<ConstantInt>(V) || isa<ConstantFP>(V))
2167 return isa<VectorType>(V->getType());
2168
2169 VectorType *VTy;
2170 ArrayRef<int> Mask;
2171 // Splats are represented differently depending on whether the repeated
2172 // value is a constant or an Instruction
2173 if (auto *Const = dyn_cast<ConstantExpr>(V)) {
2174 if (Const->getOpcode() != Instruction::ShuffleVector)
2175 return false;
2176 VTy = cast<VectorType>(Const->getType());
2177 Mask = Const->getShuffleMask();
2178 } else if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V)) {
2179 VTy = Shuf->getType();
2180 Mask = Shuf->getShuffleMask();
2181 } else {
2182 return false;
2183 }
2184
2185 // When the data type is <1 x Type>, it's not possible to differentiate
2186 // between the ComplexDeinterleaving::Deinterleave and
2187 // ComplexDeinterleaving::Splat operations.
2188 if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1)
2189 return false;
2190
2191 return all_equal(Mask) && Mask[0] == 0;
2192 };
2193
2194 // The splats must meet the following requirements:
2195 // 1. Must either be all instructions or all values.
2196 // 2. Non-constant splats must live in the same block.
2197 if (auto *FirstValAsInstruction = dyn_cast<Instruction>(Vals[0].Real)) {
2198 BasicBlock *FirstBB = FirstValAsInstruction->getParent();
2199 for (auto &V : Vals) {
2200 if (!IsSplat(V.Real) || !IsSplat(V.Imag))
2201 return nullptr;
2202
2203 auto *Real = dyn_cast<Instruction>(V.Real);
2204 auto *Imag = dyn_cast<Instruction>(V.Imag);
2205 if (!Real || !Imag || Real->getParent() != FirstBB ||
2206 Imag->getParent() != FirstBB)
2207 return nullptr;
2208 }
2209 } else {
2210 for (auto &V : Vals) {
2211 if (!IsSplat(V.Real) || !IsSplat(V.Imag) || isa<Instruction>(V.Real) ||
2212 isa<Instruction>(V.Imag))
2213 return nullptr;
2214 }
2215 }
2216
2217 for (auto &V : Vals) {
2218 auto *Real = dyn_cast<Instruction>(V.Real);
2219 auto *Imag = dyn_cast<Instruction>(V.Imag);
2220 if (Real && Imag) {
2221 FinalInstructions.insert(Real);
2222 FinalInstructions.insert(Imag);
2223 }
2224 }
2225 CompositeNode *PlaceholderNode =
2226 prepareCompositeNode(ComplexDeinterleavingOperation::Splat, Vals);
2227 return submitCompositeNode(PlaceholderNode);
2228}
2229
2230ComplexDeinterleavingGraph::CompositeNode *
2231ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
2232 Instruction *Imag) {
2233 if (Real != RealPHI || (ImagPHI && Imag != ImagPHI))
2234 return nullptr;
2235
2236 PHIsFound = true;
2237 CompositeNode *PlaceholderNode = prepareCompositeNode(
2238 ComplexDeinterleavingOperation::ReductionPHI, Real, Imag);
2239 return submitCompositeNode(PlaceholderNode);
2240}
2241
2242ComplexDeinterleavingGraph::CompositeNode *
2243ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,
2244 Instruction *Imag) {
2245 auto *SelectReal = dyn_cast<SelectInst>(Real);
2246 auto *SelectImag = dyn_cast<SelectInst>(Imag);
2247 if (!SelectReal || !SelectImag)
2248 return nullptr;
2249
2250 Instruction *MaskA, *MaskB;
2251 Instruction *AR, *AI, *RA, *BI;
2252 if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR),
2253 m_Instruction(RA))) ||
2254 !match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI),
2255 m_Instruction(BI))))
2256 return nullptr;
2257
2258 if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB))
2259 return nullptr;
2260
2261 if (!MaskA->getType()->isVectorTy())
2262 return nullptr;
2263
2264 auto NodeA = identifyNode(AR, AI);
2265 if (!NodeA)
2266 return nullptr;
2267
2268 auto NodeB = identifyNode(RA, BI);
2269 if (!NodeB)
2270 return nullptr;
2271
2272 CompositeNode *PlaceholderNode = prepareCompositeNode(
2273 ComplexDeinterleavingOperation::ReductionSelect, Real, Imag);
2274 PlaceholderNode->addOperand(NodeA);
2275 PlaceholderNode->addOperand(NodeB);
2276 FinalInstructions.insert(MaskA);
2277 FinalInstructions.insert(MaskB);
2278 return submitCompositeNode(PlaceholderNode);
2279}
2280
2281static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,
2282 std::optional<FastMathFlags> Flags,
2283 Value *InputA, Value *InputB) {
2284 Value *I;
2285 switch (Opcode) {
2286 case Instruction::FNeg:
2287 I = B.CreateFNeg(InputA);
2288 break;
2289 case Instruction::FAdd:
2290 I = B.CreateFAdd(InputA, InputB);
2291 break;
2292 case Instruction::Add:
2293 I = B.CreateAdd(InputA, InputB);
2294 break;
2295 case Instruction::FSub:
2296 I = B.CreateFSub(InputA, InputB);
2297 break;
2298 case Instruction::Sub:
2299 I = B.CreateSub(InputA, InputB);
2300 break;
2301 case Instruction::FMul:
2302 I = B.CreateFMul(InputA, InputB);
2303 break;
2304 case Instruction::Mul:
2305 I = B.CreateMul(InputA, InputB);
2306 break;
2307 default:
2308 llvm_unreachable("Incorrect symmetric opcode");
2309 }
2310 if (Flags)
2311 cast<Instruction>(I)->setFastMathFlags(*Flags);
2312 return I;
2313}
2314
2315Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
2316 CompositeNode *Node) {
2317 if (Node->ReplacementNode)
2318 return Node->ReplacementNode;
2319
2320 auto ReplaceOperandIfExist = [&](CompositeNode *Node,
2321 unsigned Idx) -> Value * {
2322 return Node->Operands.size() > Idx
2323 ? replaceNode(Builder, Node->Operands[Idx])
2324 : nullptr;
2325 };
2326
2327 Value *ReplacementNode = nullptr;
2328 switch (Node->Operation) {
2329 case ComplexDeinterleavingOperation::CDot: {
2330 Value *Input0 = ReplaceOperandIfExist(Node, 0);
2331 Value *Input1 = ReplaceOperandIfExist(Node, 1);
2332 Value *Accumulator = ReplaceOperandIfExist(Node, 2);
2333 assert(!Input1 || (Input0->getType() == Input1->getType() &&
2334 "Node inputs need to be of the same type"));
2335 ReplacementNode = TL->createComplexDeinterleavingIR(
2336 Builder, Node->Operation, Node->Rotation, Input0, Input1, Accumulator);
2337 break;
2338 }
2339 case ComplexDeinterleavingOperation::CAdd:
2340 case ComplexDeinterleavingOperation::CMulPartial:
2341 case ComplexDeinterleavingOperation::Symmetric: {
2342 Value *Input0 = ReplaceOperandIfExist(Node, 0);
2343 Value *Input1 = ReplaceOperandIfExist(Node, 1);
2344 Value *Accumulator = ReplaceOperandIfExist(Node, 2);
2345 assert(!Input1 || (Input0->getType() == Input1->getType() &&
2346 "Node inputs need to be of the same type"));
2348 (Input0->getType() == Accumulator->getType() &&
2349 "Accumulator and input need to be of the same type"));
2350 if (Node->Operation == ComplexDeinterleavingOperation::Symmetric)
2351 ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, Node->Flags,
2352 Input0, Input1);
2353 else
2354 ReplacementNode = TL->createComplexDeinterleavingIR(
2355 Builder, Node->Operation, Node->Rotation, Input0, Input1,
2356 Accumulator);
2357 break;
2358 }
2359 case ComplexDeinterleavingOperation::Deinterleave:
2360 llvm_unreachable("Deinterleave node should already have ReplacementNode");
2361 break;
2362 case ComplexDeinterleavingOperation::Splat: {
2364 for (auto &V : Node->Vals) {
2365 Ops.push_back(V.Real);
2366 Ops.push_back(V.Imag);
2367 }
2368 auto *R = dyn_cast<Instruction>(Node->Vals[0].Real);
2369 auto *I = dyn_cast<Instruction>(Node->Vals[0].Imag);
2370 if (R && I) {
2371 // Splats that are not constant are interleaved where they are located
2372 Instruction *InsertPoint = R;
2373 for (auto V : Node->Vals) {
2374 if (InsertPoint->comesBefore(cast<Instruction>(V.Real)))
2375 InsertPoint = cast<Instruction>(V.Real);
2376 if (InsertPoint->comesBefore(cast<Instruction>(V.Imag)))
2377 InsertPoint = cast<Instruction>(V.Imag);
2378 }
2379 InsertPoint = InsertPoint->getNextNode();
2380 IRBuilder<> IRB(InsertPoint);
2381 ReplacementNode = IRB.CreateVectorInterleave(Ops);
2382 } else {
2383 ReplacementNode = Builder.CreateVectorInterleave(Ops);
2384 }
2385 break;
2386 }
2387 case ComplexDeinterleavingOperation::ReductionPHI: {
2388 // If Operation is ReductionPHI, a new empty PHINode is created.
2389 // It is filled later when the ReductionOperation is processed.
2390 auto *OldPHI = cast<PHINode>(Node->Vals[0].Real);
2391 auto *VTy = cast<VectorType>(Node->Vals[0].Real->getType());
2392 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2393 auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHIIt());
2394 OldToNewPHI[OldPHI] = NewPHI;
2395 ReplacementNode = NewPHI;
2396 break;
2397 }
2398 case ComplexDeinterleavingOperation::ReductionSingle:
2399 ReplacementNode = replaceNode(Builder, Node->Operands[0]);
2400 processReductionSingle(ReplacementNode, Node);
2401 break;
2402 case ComplexDeinterleavingOperation::ReductionOperation:
2403 ReplacementNode = replaceNode(Builder, Node->Operands[0]);
2404 processReductionOperation(ReplacementNode, Node);
2405 break;
2406 case ComplexDeinterleavingOperation::ReductionSelect: {
2407 auto *MaskReal = cast<Instruction>(Node->Vals[0].Real)->getOperand(0);
2408 auto *MaskImag = cast<Instruction>(Node->Vals[0].Imag)->getOperand(0);
2409 auto *A = replaceNode(Builder, Node->Operands[0]);
2410 auto *B = replaceNode(Builder, Node->Operands[1]);
2411 auto *NewMask = Builder.CreateVectorInterleave({MaskReal, MaskImag});
2412 ReplacementNode = Builder.CreateSelect(NewMask, A, B);
2413 break;
2414 }
2415 }
2416
2417 assert(ReplacementNode && "Target failed to create Intrinsic call.");
2418 NumComplexTransformations += 1;
2419 Node->ReplacementNode = ReplacementNode;
2420 return ReplacementNode;
2421}
2422
2423void ComplexDeinterleavingGraph::processReductionSingle(
2424 Value *OperationReplacement, CompositeNode *Node) {
2425 auto *Real = cast<Instruction>(Node->Vals[0].Real);
2426 auto *OldPHI = ReductionInfo[Real].first;
2427 auto *NewPHI = OldToNewPHI[OldPHI];
2428 auto *VTy = cast<VectorType>(Real->getType());
2429 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2430
2431 Value *Init = OldPHI->getIncomingValueForBlock(Incoming);
2432
2433 IRBuilder<> Builder(Incoming->getTerminator());
2434
2435 Value *NewInit = nullptr;
2436 if (auto *C = dyn_cast<Constant>(Init)) {
2437 if (C->isZeroValue())
2438 NewInit = Constant::getNullValue(NewVTy);
2439 }
2440
2441 if (!NewInit)
2442 NewInit =
2443 Builder.CreateVectorInterleave({Init, Constant::getNullValue(VTy)});
2444
2445 NewPHI->addIncoming(NewInit, Incoming);
2446 NewPHI->addIncoming(OperationReplacement, BackEdge);
2447
2448 auto *FinalReduction = ReductionInfo[Real].second;
2449 Builder.SetInsertPoint(&*FinalReduction->getParent()->getFirstInsertionPt());
2450
2451 auto *AddReduce = Builder.CreateAddReduce(OperationReplacement);
2452 FinalReduction->replaceAllUsesWith(AddReduce);
2453}
2454
2455void ComplexDeinterleavingGraph::processReductionOperation(
2456 Value *OperationReplacement, CompositeNode *Node) {
2457 auto *Real = cast<Instruction>(Node->Vals[0].Real);
2458 auto *Imag = cast<Instruction>(Node->Vals[0].Imag);
2459 auto *OldPHIReal = ReductionInfo[Real].first;
2460 auto *OldPHIImag = ReductionInfo[Imag].first;
2461 auto *NewPHI = OldToNewPHI[OldPHIReal];
2462
2463 // We have to interleave initial origin values coming from IncomingBlock
2464 Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming);
2465 Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming);
2466
2467 IRBuilder<> Builder(Incoming->getTerminator());
2468 auto *NewInit = Builder.CreateVectorInterleave({InitReal, InitImag});
2469
2470 NewPHI->addIncoming(NewInit, Incoming);
2471 NewPHI->addIncoming(OperationReplacement, BackEdge);
2472
2473 // Deinterleave complex vector outside of loop so that it can be finally
2474 // reduced
2475 auto *FinalReductionReal = ReductionInfo[Real].second;
2476 auto *FinalReductionImag = ReductionInfo[Imag].second;
2477
2478 Builder.SetInsertPoint(
2479 &*FinalReductionReal->getParent()->getFirstInsertionPt());
2480 auto *Deinterleave = Builder.CreateIntrinsic(Intrinsic::vector_deinterleave2,
2481 OperationReplacement->getType(),
2482 OperationReplacement);
2483
2484 auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0);
2485 FinalReductionReal->replaceUsesOfWith(Real, NewReal);
2486
2487 Builder.SetInsertPoint(FinalReductionImag);
2488 auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1);
2489 FinalReductionImag->replaceUsesOfWith(Imag, NewImag);
2490}
2491
2492void ComplexDeinterleavingGraph::replaceNodes() {
2493 SmallVector<Instruction *, 16> DeadInstrRoots;
2494 for (auto *RootInstruction : OrderedRoots) {
2495 // Check if this potential root went through check process and we can
2496 // deinterleave it
2497 if (!RootToNode.count(RootInstruction))
2498 continue;
2499
2500 IRBuilder<> Builder(RootInstruction);
2501 auto RootNode = RootToNode[RootInstruction];
2502 Value *R = replaceNode(Builder, RootNode);
2503
2504 if (RootNode->Operation ==
2505 ComplexDeinterleavingOperation::ReductionOperation) {
2506 auto *RootReal = cast<Instruction>(RootNode->Vals[0].Real);
2507 auto *RootImag = cast<Instruction>(RootNode->Vals[0].Imag);
2508 ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);
2509 ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
2510 DeadInstrRoots.push_back(RootReal);
2511 DeadInstrRoots.push_back(RootImag);
2512 } else if (RootNode->Operation ==
2513 ComplexDeinterleavingOperation::ReductionSingle) {
2514 auto *RootInst = cast<Instruction>(RootNode->Vals[0].Real);
2515 auto &Info = ReductionInfo[RootInst];
2516 Info.first->removeIncomingValue(BackEdge);
2517 DeadInstrRoots.push_back(Info.second);
2518 } else {
2519 assert(R && "Unable to find replacement for RootInstruction");
2520 DeadInstrRoots.push_back(RootInstruction);
2521 RootInstruction->replaceAllUsesWith(R);
2522 }
2523 }
2524
2525 for (auto *I : DeadInstrRoots)
2527}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
static MCDisassembler::DecodeStatus addOperand(MCInst &Inst, const MCOperand &Opnd)
Rewrite undef for PHI
This file defines the BumpPtrAllocator interface.
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
Analysis containing CSE Info
Definition CSEInfo.cpp:27
static bool isInstructionPotentiallySymmetric(Instruction *I)
static Value * getNegOperand(Value *V)
Returns the operand for negation operation.
static bool isNeg(Value *V)
Returns true if the operation is a negation of V, and it works for both integers and floats.
static cl::opt< bool > ComplexDeinterleavingEnabled("enable-complex-deinterleaving", cl::desc("Enable generation of complex instructions"), cl::init(true), cl::Hidden)
static bool isInstructionPairAdd(Instruction *A, Instruction *B)
static Value * replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode, std::optional< FastMathFlags > Flags, Value *InputA, Value *InputB)
static bool isInterleavingMask(ArrayRef< int > Mask)
Checks the given mask, and determines whether said mask is interleaving.
static bool isDeinterleavingMask(ArrayRef< int > Mask)
Checks the given mask, and determines whether said mask is deinterleaving.
SmallVector< struct ComplexValue, 2 > ComplexValues
static bool isInstructionPairMul(Instruction *A, Instruction *B)
static bool runOnFunction(Function &F, bool PostInlining)
#define DEBUG_TYPE
const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]
#define F(x, y, z)
Definition MD5.cpp:55
#define I(x, y, z)
Definition MD5.cpp:58
This file implements a map that provides insertion order iteration.
#define T
uint64_t IntrinsicInst * II
#define P(N)
PowerPC Reduce CR logical Operation
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition PassSupport.h:39
SI optimize exec mask operations pre RA
static LLVM_ATTRIBUTE_ALWAYS_INLINE bool CheckType(MVT::SimpleValueType VT, SDValue N, const TargetLowering *TLI, const DataLayout &DL)
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition Statistic.h:171
#define LLVM_DEBUG(...)
Definition Debug.h:114
This file describes how to lower LLVM code to machine code.
This pass exposes codegen information to IR-level passes.
BinaryOperator * Mul
AnalysisUsage & addRequired()
LLVM_ABI void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition Pass.cpp:270
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition ArrayRef.h:41
size_t size() const
size - Get the array size.
Definition ArrayRef.h:147
LLVM_ABI InstListType::const_iterator getFirstNonPHIIt() const
Returns an iterator to the first instruction in this block that is not a PHINode instruction.
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition BasicBlock.h:233
static LLVM_ABI Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
iterator find(const_arg_type_t< KeyT > Val)
Definition DenseMap.h:167
iterator end()
Definition DenseMap.h:81
bool allowContract() const
Definition FMF.h:69
FunctionPass class - This class is used to implement most global optimizations.
Definition Pass.h:314
Common base class shared among various IRBuilders.
Definition IRBuilder.h:114
Value * CreateExtractValue(Value *Agg, ArrayRef< unsigned > Idxs, const Twine &Name="")
Definition IRBuilder.h:2626
LLVM_ABI Value * CreateSelect(Value *C, Value *True, Value *False, const Twine &Name="", Instruction *MDFrom=nullptr)
LLVM_ABI CallInst * CreateAddReduce(Value *Src)
Create a vector int add reduction intrinsic of the source vector.
LLVM_ABI CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="")
Create a call to intrinsic ID with Args, mangled using Types.
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
Definition IRBuilder.h:207
LLVM_ABI Value * CreateVectorInterleave(ArrayRef< Value * > Ops, const Twine &Name="")
LLVM_ABI const Function * getFunction() const
Return the function this instruction belongs to.
LLVM_ABI bool comesBefore(const Instruction *Other) const
Given an instruction Other in the same basic block as this instruction, return true if this instructi...
LLVM_ABI FastMathFlags getFastMathFlags() const LLVM_READONLY
Convenience function for getting all the fast-math flags, which must be an operator which supports th...
unsigned getOpcode() const
Returns a member of one of the enums like Instruction::Add.
LLVM_ABI bool isIdenticalTo(const Instruction *I) const LLVM_READONLY
Return true if the specified instruction is exactly identical to the current one.
size_type size() const
Definition MapVector.h:56
static PHINode * Create(Type *Ty, unsigned NumReservedValues, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
Constructors - NumReservedValues is a hint for the number of incoming edges that this phi node will h...
static LLVM_ABI PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalyses & preserve()
Mark an analysis as preserved.
Definition Analysis.h:132
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
reference emplace_back(ArgTypes &&... Args)
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Analysis pass providing the TargetLibraryInfo.
virtual bool isComplexDeinterleavingOperationSupported(ComplexDeinterleavingOperation Operation, Type *Ty) const
Does this target support complex deinterleaving with the given operation and type.
virtual Value * createComplexDeinterleavingIR(IRBuilderBase &B, ComplexDeinterleavingOperation OperationType, ComplexDeinterleavingRotation Rotation, Value *InputA, Value *InputB, Value *Accumulator=nullptr) const
Create the IR node for the given complex deinterleaving operation.
virtual bool isComplexDeinterleavingSupported() const
Does this target support complex deinterleaving.
This class defines information used to lower LLVM code to legal SelectionDAG operators that the targe...
Primary interface to the complete machine description for the target machine.
virtual const TargetSubtargetInfo * getSubtargetImpl(const Function &) const
Virtual method implemented by subclasses that returns a reference to that target's TargetSubtargetInf...
virtual const TargetLowering * getTargetLowering() const
bool isVectorTy() const
True if this is an instance of VectorType.
Definition Type.h:273
Value * getOperand(unsigned i) const
Definition User.h:232
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:256
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition Value.h:439
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition Value.cpp:546
An opaque object representing a hash code.
Definition Hashing.h:76
const ParentTy * getParent() const
Definition ilist_node.h:34
NodeTy * getNextNode()
Get the next node, or nullptr for the list tail.
Definition ilist_node.h:348
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
Changed
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
@ BasicBlock
Various leaf nodes.
Definition ISDOpcodes.h:81
LLVM_ABI Intrinsic::ID getDeinterleaveIntrinsicID(unsigned Factor)
Returns the corresponding llvm.vector.deinterleaveN intrinsic for factor N.
LLVM_ABI Intrinsic::ID getInterleaveIntrinsicID(unsigned Factor)
Returns the corresponding llvm.vector.interleaveN intrinsic for factor N.
BinaryOp_match< SpecificConstantMatch, SrcTy, TargetOpcode::G_SUB > m_Neg(const SrcTy &&Src)
Matches a register negated by a G_SUB.
class_match< BinaryOperator > m_BinOp()
Match an arbitrary binary operation and ignore it.
BinaryOp_match< LHS, RHS, Instruction::FMul > m_FMul(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
bind_ty< Instruction > m_Instruction(Instruction *&I)
Match an instruction, capturing it if we match.
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
TwoOps_match< V1_t, V2_t, Instruction::ShuffleVector > m_Shuffle(const V1_t &v1, const V2_t &v2)
Matches ShuffleVectorInst independently of mask value.
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
FNeg_match< OpTy > m_FNeg(const OpTy &X)
Match 'fneg X' as 'fsub -0.0, X'.
initializer< Ty > init(const Ty &Val)
NodeAddr< PhiNode * > Phi
Definition RDFGraph.h:390
NodeAddr< NodeBase * > Node
Definition RDFGraph.h:381
friend class Instruction
Iterator for Instructions in a `BasicBlock.
Definition BasicBlock.h:73
This is an optimization pass for GlobalISel generic memory operations.
void dump(const SparseBitVector< ElementSize > &LHS, raw_ostream &out)
@ Offset
Definition DWP.cpp:477
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1725
hash_code hash_value(const FixedPointSemantics &Val)
LLVM_ABI bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
If the specified value is a trivially dead instruction, delete it.
Definition Local.cpp:533
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
InnerAnalysisManagerProxy< FunctionAnalysisManager, Module > FunctionAnalysisManagerModuleProxy
Provide the FunctionAnalysisManager to Module proxy.
bool operator==(const AddressRangeValuePair &LHS, const AddressRangeValuePair &RHS)
auto dyn_cast_or_null(const Y &Val)
Definition Casting.h:753
LLVM_ABI FunctionPass * createComplexDeinterleavingPass(const TargetMachine *TM)
This pass implements generation of target-specific intrinsics to support handling of complex number a...
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:207
LLVM_ABI void initializeComplexDeinterleavingLegacyPassPass(PassRegistry &)
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:547
@ Other
Any other memory.
Definition ModRef.h:68
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
DWARFExpression::Operation Op
ArrayRef(const T &OneElt) -> ArrayRef< T >
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
auto find_if(R &&Range, UnaryPredicate P)
Provide wrappers to std::find_if which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1758
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
Definition STLExtras.h:1897
bool all_equal(std::initializer_list< T > Values)
Returns true if all Values in the initializer lists are equal or the list.
Definition STLExtras.h:2108
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
hash_code hash_combine(const Ts &...args)
Combine values into a single hash_code.
Definition Hashing.h:592
AllocatorList< T, BumpPtrAllocator > BumpPtrList
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition BitVector.h:869
#define N
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
static bool isEqual(const ComplexValue &LHS, const ComplexValue &RHS)
static unsigned getHashValue(const ComplexValue &Val)
An information struct used to provide DenseMap with the various necessary components for a given valu...