LLVM 17.0.0git
ComplexDeinterleavingPass.cpp
Go to the documentation of this file.
1//===- ComplexDeinterleavingPass.cpp --------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Identification:
10// This step is responsible for finding the patterns that can be lowered to
11// complex instructions, and building a graph to represent the complex
12// structures. Starting from the "Converging Shuffle" (a shuffle that
13// reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the
14// operands are evaluated and identified as "Composite Nodes" (collections of
15// instructions that can potentially be lowered to a single complex
16// instruction). This is performed by checking the real and imaginary components
17// and tracking the data flow for each component while following the operand
18// pairs. Validity of each node is expected to be done upon creation, and any
19// validation errors should halt traversal and prevent further graph
20// construction.
21// Instead of relying on Shuffle operations, vector interleaving and
22// deinterleaving can be represented by vector.interleave2 and
23// vector.deinterleave2 intrinsics. Scalable vectors can be represented only by
24// these intrinsics, whereas, fixed-width vectors are recognized for both
25// shufflevector instruction and intrinsics.
26//
27// Replacement:
28// This step traverses the graph built up by identification, delegating to the
29// target to validate and generate the correct intrinsics, and plumbs them
30// together connecting each end of the new intrinsics graph to the existing
31// use-def chain. This step is assumed to finish successfully, as all
32// information is expected to be correct by this point.
33//
34//
35// Internal data structure:
36// ComplexDeinterleavingGraph:
37// Keeps references to all the valid CompositeNodes formed as part of the
38// transformation, and every Instruction contained within said nodes. It also
39// holds onto a reference to the root Instruction, and the root node that should
40// replace it.
41//
42// ComplexDeinterleavingCompositeNode:
43// A CompositeNode represents a single transformation point; each node should
44// transform into a single complex instruction (ignoring vector splitting, which
45// would generate more instructions per node). They are identified in a
46// depth-first manner, traversing and identifying the operands of each
47// instruction in the order they appear in the IR.
48// Each node maintains a reference to its Real and Imaginary instructions,
49// as well as any additional instructions that make up the identified operation
50// (Internal instructions should only have uses within their containing node).
51// A Node also contains the rotation and operation type that it represents.
52// Operands contains pointers to other CompositeNodes, acting as the edges in
53// the graph. ReplacementValue is the transformed Value* that has been emitted
54// to the IR.
55//
56// Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and
57// ReplacementValue fields of that Node are relevant, where the ReplacementValue
58// should be pre-populated.
59//
60//===----------------------------------------------------------------------===//
61
63#include "llvm/ADT/Statistic.h"
69#include "llvm/IR/IRBuilder.h"
74#include <algorithm>
75
76using namespace llvm;
77using namespace PatternMatch;
78
79#define DEBUG_TYPE "complex-deinterleaving"
80
81STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed");
82
84 "enable-complex-deinterleaving",
85 cl::desc("Enable generation of complex instructions"), cl::init(true),
87
88/// Checks the given mask, and determines whether said mask is interleaving.
89///
90/// To be interleaving, a mask must alternate between `i` and `i + (Length /
91/// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a
92/// 4x vector interleaving mask would be <0, 2, 1, 3>).
93static bool isInterleavingMask(ArrayRef<int> Mask);
94
95/// Checks the given mask, and determines whether said mask is deinterleaving.
96///
97/// To be deinterleaving, a mask must increment in steps of 2, and either start
98/// with 0 or 1.
99/// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or
100/// <1, 3, 5, 7>).
101static bool isDeinterleavingMask(ArrayRef<int> Mask);
102
103namespace {
104
105class ComplexDeinterleavingLegacyPass : public FunctionPass {
106public:
107 static char ID;
108
109 ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr)
110 : FunctionPass(ID), TM(TM) {
113 }
114
115 StringRef getPassName() const override {
116 return "Complex Deinterleaving Pass";
117 }
118
119 bool runOnFunction(Function &F) override;
120 void getAnalysisUsage(AnalysisUsage &AU) const override {
122 AU.setPreservesCFG();
123 }
124
125private:
126 const TargetMachine *TM;
127};
128
129class ComplexDeinterleavingGraph;
130struct ComplexDeinterleavingCompositeNode {
131
132 ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
134 : Operation(Op), Real(R), Imag(I) {}
135
136private:
137 friend class ComplexDeinterleavingGraph;
138 using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>;
139 using RawNodePtr = ComplexDeinterleavingCompositeNode *;
140
141public:
143 Instruction *Real;
144 Instruction *Imag;
145
146 // This two members are required exclusively for generating
147 // ComplexDeinterleavingOperation::Symmetric operations.
148 unsigned Opcode;
150
152 ComplexDeinterleavingRotation::Rotation_0;
154 Value *ReplacementNode = nullptr;
155
156 void addOperand(NodePtr Node) { Operands.push_back(Node.get()); }
157
158 void dump() { dump(dbgs()); }
159 void dump(raw_ostream &OS) {
160 auto PrintValue = [&](Value *V) {
161 if (V) {
162 OS << "\"";
163 V->print(OS, true);
164 OS << "\"\n";
165 } else
166 OS << "nullptr\n";
167 };
168 auto PrintNodeRef = [&](RawNodePtr Ptr) {
169 if (Ptr)
170 OS << Ptr << "\n";
171 else
172 OS << "nullptr\n";
173 };
174
175 OS << "- CompositeNode: " << this << "\n";
176 OS << " Real: ";
177 PrintValue(Real);
178 OS << " Imag: ";
179 PrintValue(Imag);
180 OS << " ReplacementNode: ";
181 PrintValue(ReplacementNode);
182 OS << " Operation: " << (int)Operation << "\n";
183 OS << " Rotation: " << ((int)Rotation * 90) << "\n";
184 OS << " Operands: \n";
185 for (const auto &Op : Operands) {
186 OS << " - ";
187 PrintNodeRef(Op);
188 }
189 }
190};
191
192class ComplexDeinterleavingGraph {
193public:
194 struct Product {
195 Instruction *Multiplier;
196 Instruction *Multiplicand;
197 bool IsPositive;
198 };
199
200 using Addend = std::pair<Instruction *, bool>;
201 using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr;
202 using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr;
203
204 // Helper struct for holding info about potential partial multiplication
205 // candidates
206 struct PartialMulCandidate {
207 Instruction *Common;
208 NodePtr Node;
209 unsigned RealIdx;
210 unsigned ImagIdx;
211 bool IsNodeInverted;
212 };
213
214 explicit ComplexDeinterleavingGraph(const TargetLowering *TL,
215 const TargetLibraryInfo *TLI)
216 : TL(TL), TLI(TLI) {}
217
218private:
219 const TargetLowering *TL = nullptr;
220 const TargetLibraryInfo *TLI = nullptr;
221 SmallVector<NodePtr> CompositeNodes;
222
223 SmallPtrSet<Instruction *, 16> FinalInstructions;
224
225 /// Root instructions are instructions from which complex computation starts
226 std::map<Instruction *, NodePtr> RootToNode;
227
228 /// Topologically sorted root instructions
230
231 NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation,
233 return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R,
234 I);
235 }
236
237 NodePtr submitCompositeNode(NodePtr Node) {
238 CompositeNodes.push_back(Node);
239 return Node;
240 }
241
242 NodePtr getContainingComposite(Value *R, Value *I) {
243 for (const auto &CN : CompositeNodes) {
244 if (CN->Real == R && CN->Imag == I)
245 return CN;
246 }
247 return nullptr;
248 }
249
250 /// Identifies a complex partial multiply pattern and its rotation, based on
251 /// the following patterns
252 ///
253 /// 0: r: cr + ar * br
254 /// i: ci + ar * bi
255 /// 90: r: cr - ai * bi
256 /// i: ci + ai * br
257 /// 180: r: cr - ar * br
258 /// i: ci - ar * bi
259 /// 270: r: cr + ai * bi
260 /// i: ci - ai * br
261 NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag);
262
263 /// Identify the other branch of a Partial Mul, taking the CommonOperandI that
264 /// is partially known from identifyPartialMul, filling in the other half of
265 /// the complex pair.
266 NodePtr identifyNodeWithImplicitAdd(
268 std::pair<Instruction *, Instruction *> &CommonOperandI);
269
270 /// Identifies a complex add pattern and its rotation, based on the following
271 /// patterns.
272 ///
273 /// 90: r: ar - bi
274 /// i: ai + br
275 /// 270: r: ar + bi
276 /// i: ai - br
277 NodePtr identifyAdd(Instruction *Real, Instruction *Imag);
278 NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag);
279
280 NodePtr identifyNode(Instruction *I, Instruction *J);
281
282 /// Determine if a sum of complex numbers can be formed from \p RealAddends
283 /// and \p ImagAddens. If \p Accumulator is not null, add the result to it.
284 /// Return nullptr if it is not possible to construct a complex number.
285 /// \p Flags are needed to generate symmetric Add and Sub operations.
286 NodePtr identifyAdditions(std::list<Addend> &RealAddends,
287 std::list<Addend> &ImagAddends, FastMathFlags Flags,
288 NodePtr Accumulator);
289
290 /// Extract one addend that have both real and imaginary parts positive.
291 NodePtr extractPositiveAddend(std::list<Addend> &RealAddends,
292 std::list<Addend> &ImagAddends);
293
294 /// Determine if sum of multiplications of complex numbers can be formed from
295 /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result
296 /// to it. Return nullptr if it is not possible to construct a complex number.
297 NodePtr identifyMultiplications(std::vector<Product> &RealMuls,
298 std::vector<Product> &ImagMuls,
299 NodePtr Accumulator);
300
301 /// Go through pairs of multiplication (one Real and one Imag) and find all
302 /// possible candidates for partial multiplication and put them into \p
303 /// Candidates. Returns true if all Product has pair with common operand
304 bool collectPartialMuls(const std::vector<Product> &RealMuls,
305 const std::vector<Product> &ImagMuls,
306 std::vector<PartialMulCandidate> &Candidates);
307
308 /// If the code is compiled with -Ofast or expressions have `reassoc` flag,
309 /// the order of complex computation operations may be significantly altered,
310 /// and the real and imaginary parts may not be executed in parallel. This
311 /// function takes this into consideration and employs a more general approach
312 /// to identify complex computations. Initially, it gathers all the addends
313 /// and multiplicands and then constructs a complex expression from them.
314 NodePtr identifyReassocNodes(Instruction *I, Instruction *J);
315
316 NodePtr identifyRoot(Instruction *I);
317
318 /// Identifies the Deinterleave operation applied to a vector containing
319 /// complex numbers. There are two ways to represent the Deinterleave
320 /// operation:
321 /// * Using two shufflevectors with even indices for /pReal instruction and
322 /// odd indices for /pImag instructions (only for fixed-width vectors)
323 /// * Using two extractvalue instructions applied to `vector.deinterleave2`
324 /// intrinsic (for both fixed and scalable vectors)
325 NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag);
326
327 Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node);
328
329public:
330 void dump() { dump(dbgs()); }
331 void dump(raw_ostream &OS) {
332 for (const auto &Node : CompositeNodes)
333 Node->dump(OS);
334 }
335
336 /// Returns false if the deinterleaving operation should be cancelled for the
337 /// current graph.
338 bool identifyNodes(Instruction *RootI);
339
340 /// Check that every instruction, from the roots to the leaves, has internal
341 /// uses.
342 bool checkNodes();
343
344 /// Perform the actual replacement of the underlying instruction graph.
345 void replaceNodes();
346};
347
348class ComplexDeinterleaving {
349public:
350 ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli)
351 : TL(tl), TLI(tli) {}
352 bool runOnFunction(Function &F);
353
354private:
355 bool evaluateBasicBlock(BasicBlock *B);
356
357 const TargetLowering *TL = nullptr;
358 const TargetLibraryInfo *TLI = nullptr;
359};
360
361} // namespace
362
363char ComplexDeinterleavingLegacyPass::ID = 0;
364
365INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
366 "Complex Deinterleaving", false, false)
367INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
369
372 const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering();
373 auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F);
374 if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F))
375 return PreservedAnalyses::all();
376
379 return PA;
380}
381
383 return new ComplexDeinterleavingLegacyPass(TM);
384}
385
386bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) {
387 const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering();
388 auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
389 return ComplexDeinterleaving(TL, &TLI).runOnFunction(F);
390}
391
392bool ComplexDeinterleaving::runOnFunction(Function &F) {
395 dbgs() << "Complex deinterleaving has been explicitly disabled.\n");
396 return false;
397 }
398
401 dbgs() << "Complex deinterleaving has been disabled, target does "
402 "not support lowering of complex number operations.\n");
403 return false;
404 }
405
406 bool Changed = false;
407 for (auto &B : F)
408 Changed |= evaluateBasicBlock(&B);
409
410 return Changed;
411}
412
414 // If the size is not even, it's not an interleaving mask
415 if ((Mask.size() & 1))
416 return false;
417
418 int HalfNumElements = Mask.size() / 2;
419 for (int Idx = 0; Idx < HalfNumElements; ++Idx) {
420 int MaskIdx = Idx * 2;
421 if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))
422 return false;
423 }
424
425 return true;
426}
427
429 int Offset = Mask[0];
430 int HalfNumElements = Mask.size() / 2;
431
432 for (int Idx = 1; Idx < HalfNumElements; ++Idx) {
433 if (Mask[Idx] != (Idx * 2) + Offset)
434 return false;
435 }
436
437 return true;
438}
439
440bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
441 ComplexDeinterleavingGraph Graph(TL, TLI);
442 for (auto &I : *B)
443 Graph.identifyNodes(&I);
444
445 if (Graph.checkNodes()) {
446 Graph.replaceNodes();
447 return true;
448 }
449
450 return false;
451}
452
453ComplexDeinterleavingGraph::NodePtr
454ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
455 Instruction *Real, Instruction *Imag,
456 std::pair<Instruction *, Instruction *> &PartialMatch) {
457 LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag
458 << "\n");
459
460 if (!Real->hasOneUse() || !Imag->hasOneUse()) {
461 LLVM_DEBUG(dbgs() << " - Mul operand has multiple uses.\n");
462 return nullptr;
463 }
464
465 if (Real->getOpcode() != Instruction::FMul ||
466 Imag->getOpcode() != Instruction::FMul) {
467 LLVM_DEBUG(dbgs() << " - Real or imaginary instruction is not fmul\n");
468 return nullptr;
469 }
470
471 Instruction *R0 = dyn_cast<Instruction>(Real->getOperand(0));
472 Instruction *R1 = dyn_cast<Instruction>(Real->getOperand(1));
473 Instruction *I0 = dyn_cast<Instruction>(Imag->getOperand(0));
474 Instruction *I1 = dyn_cast<Instruction>(Imag->getOperand(1));
475 if (!R0 || !R1 || !I0 || !I1) {
476 LLVM_DEBUG(dbgs() << " - Mul operand not Instruction\n");
477 return nullptr;
478 }
479
480 // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the
481 // rotations and use the operand.
482 unsigned Negs = 0;
484 if (R0->getOpcode() == Instruction::FNeg ||
485 R1->getOpcode() == Instruction::FNeg) {
486 Negs |= 1;
487 if (R0->getOpcode() == Instruction::FNeg) {
488 FNegs.push_back(R0);
489 R0 = dyn_cast<Instruction>(R0->getOperand(0));
490 } else {
491 FNegs.push_back(R1);
492 R1 = dyn_cast<Instruction>(R1->getOperand(0));
493 }
494 if (!R0 || !R1)
495 return nullptr;
496 }
497 if (I0->getOpcode() == Instruction::FNeg ||
498 I1->getOpcode() == Instruction::FNeg) {
499 Negs |= 2;
500 Negs ^= 1;
501 if (I0->getOpcode() == Instruction::FNeg) {
502 FNegs.push_back(I0);
503 I0 = dyn_cast<Instruction>(I0->getOperand(0));
504 } else {
505 FNegs.push_back(I1);
506 I1 = dyn_cast<Instruction>(I1->getOperand(0));
507 }
508 if (!I0 || !I1)
509 return nullptr;
510 }
511
513
514 Instruction *CommonOperand;
515 Instruction *UncommonRealOp;
516 Instruction *UncommonImagOp;
517
518 if (R0 == I0 || R0 == I1) {
519 CommonOperand = R0;
520 UncommonRealOp = R1;
521 } else if (R1 == I0 || R1 == I1) {
522 CommonOperand = R1;
523 UncommonRealOp = R0;
524 } else {
525 LLVM_DEBUG(dbgs() << " - No equal operand\n");
526 return nullptr;
527 }
528
529 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
530 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
531 Rotation == ComplexDeinterleavingRotation::Rotation_270)
532 std::swap(UncommonRealOp, UncommonImagOp);
533
534 // Between identifyPartialMul and here we need to have found a complete valid
535 // pair from the CommonOperand of each part.
536 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
537 Rotation == ComplexDeinterleavingRotation::Rotation_180)
538 PartialMatch.first = CommonOperand;
539 else
540 PartialMatch.second = CommonOperand;
541
542 if (!PartialMatch.first || !PartialMatch.second) {
543 LLVM_DEBUG(dbgs() << " - Incomplete partial match\n");
544 return nullptr;
545 }
546
547 NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second);
548 if (!CommonNode) {
549 LLVM_DEBUG(dbgs() << " - No CommonNode identified\n");
550 return nullptr;
551 }
552
553 NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
554 if (!UncommonNode) {
555 LLVM_DEBUG(dbgs() << " - No UncommonNode identified\n");
556 return nullptr;
557 }
558
559 NodePtr Node = prepareCompositeNode(
560 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
561 Node->Rotation = Rotation;
562 Node->addOperand(CommonNode);
563 Node->addOperand(UncommonNode);
564 return submitCompositeNode(Node);
565}
566
567ComplexDeinterleavingGraph::NodePtr
568ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
569 Instruction *Imag) {
570 LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag
571 << "\n");
572 // Determine rotation
574 if (Real->getOpcode() == Instruction::FAdd &&
575 Imag->getOpcode() == Instruction::FAdd)
576 Rotation = ComplexDeinterleavingRotation::Rotation_0;
577 else if (Real->getOpcode() == Instruction::FSub &&
578 Imag->getOpcode() == Instruction::FAdd)
579 Rotation = ComplexDeinterleavingRotation::Rotation_90;
580 else if (Real->getOpcode() == Instruction::FSub &&
581 Imag->getOpcode() == Instruction::FSub)
582 Rotation = ComplexDeinterleavingRotation::Rotation_180;
583 else if (Real->getOpcode() == Instruction::FAdd &&
584 Imag->getOpcode() == Instruction::FSub)
585 Rotation = ComplexDeinterleavingRotation::Rotation_270;
586 else {
587 LLVM_DEBUG(dbgs() << " - Unhandled rotation.\n");
588 return nullptr;
589 }
590
591 if (!Real->getFastMathFlags().allowContract() ||
592 !Imag->getFastMathFlags().allowContract()) {
593 LLVM_DEBUG(dbgs() << " - Contract is missing from the FastMath flags.\n");
594 return nullptr;
595 }
596
597 Value *CR = Real->getOperand(0);
598 Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1));
599 if (!RealMulI)
600 return nullptr;
601 Value *CI = Imag->getOperand(0);
602 Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1));
603 if (!ImagMulI)
604 return nullptr;
605
606 if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) {
607 LLVM_DEBUG(dbgs() << " - Mul instruction has multiple uses\n");
608 return nullptr;
609 }
610
611 Instruction *R0 = dyn_cast<Instruction>(RealMulI->getOperand(0));
612 Instruction *R1 = dyn_cast<Instruction>(RealMulI->getOperand(1));
613 Instruction *I0 = dyn_cast<Instruction>(ImagMulI->getOperand(0));
614 Instruction *I1 = dyn_cast<Instruction>(ImagMulI->getOperand(1));
615 if (!R0 || !R1 || !I0 || !I1) {
616 LLVM_DEBUG(dbgs() << " - Mul operand not Instruction\n");
617 return nullptr;
618 }
619
620 Instruction *CommonOperand;
621 Instruction *UncommonRealOp;
622 Instruction *UncommonImagOp;
623
624 if (R0 == I0 || R0 == I1) {
625 CommonOperand = R0;
626 UncommonRealOp = R1;
627 } else if (R1 == I0 || R1 == I1) {
628 CommonOperand = R1;
629 UncommonRealOp = R0;
630 } else {
631 LLVM_DEBUG(dbgs() << " - No equal operand\n");
632 return nullptr;
633 }
634
635 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
636 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
637 Rotation == ComplexDeinterleavingRotation::Rotation_270)
638 std::swap(UncommonRealOp, UncommonImagOp);
639
640 std::pair<Instruction *, Instruction *> PartialMatch(
641 (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
642 Rotation == ComplexDeinterleavingRotation::Rotation_180)
643 ? CommonOperand
644 : nullptr,
645 (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
646 Rotation == ComplexDeinterleavingRotation::Rotation_270)
647 ? CommonOperand
648 : nullptr);
649
650 auto *CRInst = dyn_cast<Instruction>(CR);
651 auto *CIInst = dyn_cast<Instruction>(CI);
652
653 if (!CRInst || !CIInst) {
654 LLVM_DEBUG(dbgs() << " - Common operands are not instructions.\n");
655 return nullptr;
656 }
657
658 NodePtr CNode = identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch);
659 if (!CNode) {
660 LLVM_DEBUG(dbgs() << " - No cnode identified\n");
661 return nullptr;
662 }
663
664 NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
665 if (!UncommonRes) {
666 LLVM_DEBUG(dbgs() << " - No UncommonRes identified\n");
667 return nullptr;
668 }
669
670 assert(PartialMatch.first && PartialMatch.second);
671 NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second);
672 if (!CommonRes) {
673 LLVM_DEBUG(dbgs() << " - No CommonRes identified\n");
674 return nullptr;
675 }
676
677 NodePtr Node = prepareCompositeNode(
678 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
679 Node->Rotation = Rotation;
680 Node->addOperand(CommonRes);
681 Node->addOperand(UncommonRes);
682 Node->addOperand(CNode);
683 return submitCompositeNode(Node);
684}
685
686ComplexDeinterleavingGraph::NodePtr
687ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
688 LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n");
689
690 // Determine rotation
692 if ((Real->getOpcode() == Instruction::FSub &&
693 Imag->getOpcode() == Instruction::FAdd) ||
694 (Real->getOpcode() == Instruction::Sub &&
695 Imag->getOpcode() == Instruction::Add))
696 Rotation = ComplexDeinterleavingRotation::Rotation_90;
697 else if ((Real->getOpcode() == Instruction::FAdd &&
698 Imag->getOpcode() == Instruction::FSub) ||
699 (Real->getOpcode() == Instruction::Add &&
700 Imag->getOpcode() == Instruction::Sub))
701 Rotation = ComplexDeinterleavingRotation::Rotation_270;
702 else {
703 LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n");
704 return nullptr;
705 }
706
707 auto *AR = dyn_cast<Instruction>(Real->getOperand(0));
708 auto *BI = dyn_cast<Instruction>(Real->getOperand(1));
709 auto *AI = dyn_cast<Instruction>(Imag->getOperand(0));
710 auto *BR = dyn_cast<Instruction>(Imag->getOperand(1));
711
712 if (!AR || !AI || !BR || !BI) {
713 LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n");
714 return nullptr;
715 }
716
717 NodePtr ResA = identifyNode(AR, AI);
718 if (!ResA) {
719 LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
720 return nullptr;
721 }
722 NodePtr ResB = identifyNode(BR, BI);
723 if (!ResB) {
724 LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");
725 return nullptr;
726 }
727
728 NodePtr Node =
729 prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
730 Node->Rotation = Rotation;
731 Node->addOperand(ResA);
732 Node->addOperand(ResB);
733 return submitCompositeNode(Node);
734}
735
737 unsigned OpcA = A->getOpcode();
738 unsigned OpcB = B->getOpcode();
739
740 return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
741 (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
742 (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
743 (OpcA == Instruction::Add && OpcB == Instruction::Sub);
744}
745
747 auto Pattern =
749
750 return match(A, Pattern) && match(B, Pattern);
751}
752
754 switch (I->getOpcode()) {
755 case Instruction::FAdd:
756 case Instruction::FSub:
757 case Instruction::FMul:
758 case Instruction::FNeg:
759 return true;
760 default:
761 return false;
762 }
763}
764
765ComplexDeinterleavingGraph::NodePtr
766ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real,
767 Instruction *Imag) {
768 if (Real->getOpcode() != Imag->getOpcode())
769 return nullptr;
770
773 return nullptr;
774
775 auto *R0 = dyn_cast<Instruction>(Real->getOperand(0));
776 auto *I0 = dyn_cast<Instruction>(Imag->getOperand(0));
777
778 if (!R0 || !I0)
779 return nullptr;
780
781 NodePtr Op0 = identifyNode(R0, I0);
782 NodePtr Op1 = nullptr;
783 if (Op0 == nullptr)
784 return nullptr;
785
786 if (Real->isBinaryOp()) {
787 auto *R1 = dyn_cast<Instruction>(Real->getOperand(1));
788 auto *I1 = dyn_cast<Instruction>(Imag->getOperand(1));
789 if (!R1 || !I1)
790 return nullptr;
791
792 Op1 = identifyNode(R1, I1);
793 if (Op1 == nullptr)
794 return nullptr;
795 }
796
797 if (isa<FPMathOperator>(Real) &&
798 Real->getFastMathFlags() != Imag->getFastMathFlags())
799 return nullptr;
800
801 auto Node = prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric,
802 Real, Imag);
803 Node->Opcode = Real->getOpcode();
804 if (isa<FPMathOperator>(Real))
805 Node->Flags = Real->getFastMathFlags();
806
807 Node->addOperand(Op0);
808 if (Real->isBinaryOp())
809 Node->addOperand(Op1);
810
811 return submitCompositeNode(Node);
812}
813
814ComplexDeinterleavingGraph::NodePtr
815ComplexDeinterleavingGraph::identifyNode(Instruction *Real, Instruction *Imag) {
816 LLVM_DEBUG(dbgs() << "identifyNode on " << *Real << " / " << *Imag << "\n");
817 if (NodePtr CN = getContainingComposite(Real, Imag)) {
818 LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
819 return CN;
820 }
821
822 if (NodePtr CN = identifyDeinterleave(Real, Imag))
823 return CN;
824
825 auto *VTy = cast<VectorType>(Real->getType());
826 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
827
828 bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported(
829 ComplexDeinterleavingOperation::CMulPartial, NewVTy);
830 bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported(
831 ComplexDeinterleavingOperation::CAdd, NewVTy);
832
833 if (HasCMulSupport && isInstructionPairMul(Real, Imag)) {
834 if (NodePtr CN = identifyPartialMul(Real, Imag))
835 return CN;
836 }
837
838 if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) {
839 if (NodePtr CN = identifyAdd(Real, Imag))
840 return CN;
841 }
842
843 if (HasCMulSupport && HasCAddSupport) {
844 if (NodePtr CN = identifyReassocNodes(Real, Imag))
845 return CN;
846 }
847
848 if (NodePtr CN = identifySymmetricOperation(Real, Imag))
849 return CN;
850
851 LLVM_DEBUG(dbgs() << " - Not recognised as a valid pattern.\n");
852 return nullptr;
853}
854
855ComplexDeinterleavingGraph::NodePtr
856ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
857 Instruction *Imag) {
858 if ((Real->getOpcode() != Instruction::FAdd &&
859 Real->getOpcode() != Instruction::FSub &&
860 Real->getOpcode() != Instruction::FNeg) ||
861 (Imag->getOpcode() != Instruction::FAdd &&
862 Imag->getOpcode() != Instruction::FSub &&
863 Imag->getOpcode() != Instruction::FNeg))
864 return nullptr;
865
866 if (Real->getFastMathFlags() != Imag->getFastMathFlags()) {
868 dbgs()
869 << "The flags in Real and Imaginary instructions are not identical\n");
870 return nullptr;
871 }
872
874 if (!Flags.allowReassoc()) {
876 dbgs() << "the 'Reassoc' attribute is missing in the FastMath flags\n");
877 return nullptr;
878 }
879
880 // Collect multiplications and addend instructions from the given instruction
881 // while traversing it operands. Additionally, verify that all instructions
882 // have the same fast math flags.
883 auto Collect = [&Flags](Instruction *Insn, std::vector<Product> &Muls,
884 std::list<Addend> &Addends) -> bool {
887 while (!Worklist.empty()) {
888 auto [V, IsPositive] = Worklist.back();
889 Worklist.pop_back();
890 if (!Visited.insert(V).second)
891 continue;
892
893 Instruction *I = dyn_cast<Instruction>(V);
894 if (!I)
895 return false;
896
897 // If an instruction has more than one user, it indicates that it either
898 // has an external user, which will be later checked by the checkNodes
899 // function, or it is a subexpression utilized by multiple expressions. In
900 // the latter case, we will attempt to separately identify the complex
901 // operation from here in order to create a shared
902 // ComplexDeinterleavingCompositeNode.
903 if (I != Insn && I->getNumUses() > 1) {
904 LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n");
905 Addends.emplace_back(I, IsPositive);
906 continue;
907 }
908
909 if (I->getOpcode() == Instruction::FAdd) {
910 Worklist.emplace_back(I->getOperand(1), IsPositive);
911 Worklist.emplace_back(I->getOperand(0), IsPositive);
912 } else if (I->getOpcode() == Instruction::FSub) {
913 Worklist.emplace_back(I->getOperand(1), !IsPositive);
914 Worklist.emplace_back(I->getOperand(0), IsPositive);
915 } else if (I->getOpcode() == Instruction::FMul) {
916 auto *A = dyn_cast<Instruction>(I->getOperand(0));
917 if (A && A->getOpcode() == Instruction::FNeg) {
918 A = dyn_cast<Instruction>(A->getOperand(0));
919 IsPositive = !IsPositive;
920 }
921 if (!A)
922 return false;
923 auto *B = dyn_cast<Instruction>(I->getOperand(1));
924 if (B && B->getOpcode() == Instruction::FNeg) {
925 B = dyn_cast<Instruction>(B->getOperand(0));
926 IsPositive = !IsPositive;
927 }
928 if (!B)
929 return false;
930 Muls.push_back(Product{A, B, IsPositive});
931 } else if (I->getOpcode() == Instruction::FNeg) {
932 Worklist.emplace_back(I->getOperand(0), !IsPositive);
933 } else {
934 Addends.emplace_back(I, IsPositive);
935 continue;
936 }
937
938 if (I->getFastMathFlags() != Flags) {
939 LLVM_DEBUG(dbgs() << "The instruction's fast math flags are "
940 "inconsistent with the root instructions' flags: "
941 << *I << "\n");
942 return false;
943 }
944 }
945 return true;
946 };
947
948 std::vector<Product> RealMuls, ImagMuls;
949 std::list<Addend> RealAddends, ImagAddends;
950 if (!Collect(Real, RealMuls, RealAddends) ||
951 !Collect(Imag, ImagMuls, ImagAddends))
952 return nullptr;
953
954 if (RealAddends.size() != ImagAddends.size())
955 return nullptr;
956
957 NodePtr FinalNode;
958 if (!RealMuls.empty() || !ImagMuls.empty()) {
959 // If there are multiplicands, extract positive addend and use it as an
960 // accumulator
961 FinalNode = extractPositiveAddend(RealAddends, ImagAddends);
962 FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode);
963 if (!FinalNode)
964 return nullptr;
965 }
966
967 // Identify and process remaining additions
968 if (!RealAddends.empty() || !ImagAddends.empty()) {
969 FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode);
970 if (!FinalNode)
971 return nullptr;
972 }
973
974 // Set the Real and Imag fields of the final node and submit it
975 FinalNode->Real = Real;
976 FinalNode->Imag = Imag;
977 submitCompositeNode(FinalNode);
978 return FinalNode;
979}
980
981bool ComplexDeinterleavingGraph::collectPartialMuls(
982 const std::vector<Product> &RealMuls, const std::vector<Product> &ImagMuls,
983 std::vector<PartialMulCandidate> &PartialMulCandidates) {
984 // Helper function to extract a common operand from two products
985 auto FindCommonInstruction = [](const Product &Real,
986 const Product &Imag) -> Instruction * {
987 if (Real.Multiplicand == Imag.Multiplicand ||
988 Real.Multiplicand == Imag.Multiplier)
989 return Real.Multiplicand;
990
991 if (Real.Multiplier == Imag.Multiplicand ||
992 Real.Multiplier == Imag.Multiplier)
993 return Real.Multiplier;
994
995 return nullptr;
996 };
997
998 // Iterating over real and imaginary multiplications to find common operands
999 // If a common operand is found, a partial multiplication candidate is created
1000 // and added to the candidates vector The function returns false if no common
1001 // operands are found for any product
1002 for (unsigned i = 0; i < RealMuls.size(); ++i) {
1003 bool FoundCommon = false;
1004 for (unsigned j = 0; j < ImagMuls.size(); ++j) {
1005 auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]);
1006 if (!Common)
1007 continue;
1008
1009 auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier
1010 : RealMuls[i].Multiplicand;
1011 auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier
1012 : ImagMuls[j].Multiplicand;
1013
1014 bool Inverted = false;
1015 auto Node = identifyNode(A, B);
1016 if (!Node) {
1017 std::swap(A, B);
1018 Inverted = true;
1019 Node = identifyNode(A, B);
1020 }
1021 if (!Node)
1022 continue;
1023
1024 FoundCommon = true;
1025 PartialMulCandidates.push_back({Common, Node, i, j, Inverted});
1026 }
1027 if (!FoundCommon)
1028 return false;
1029 }
1030 return true;
1031}
1032
1033ComplexDeinterleavingGraph::NodePtr
1034ComplexDeinterleavingGraph::identifyMultiplications(
1035 std::vector<Product> &RealMuls, std::vector<Product> &ImagMuls,
1036 NodePtr Accumulator = nullptr) {
1037 if (RealMuls.size() != ImagMuls.size())
1038 return nullptr;
1039
1040 std::vector<PartialMulCandidate> Info;
1041 if (!collectPartialMuls(RealMuls, ImagMuls, Info))
1042 return nullptr;
1043
1044 // Map to store common instruction to node pointers
1045 std::map<Instruction *, NodePtr> CommonToNode;
1046 std::vector<bool> Processed(Info.size(), false);
1047 for (unsigned I = 0; I < Info.size(); ++I) {
1048 if (Processed[I])
1049 continue;
1050
1051 PartialMulCandidate &InfoA = Info[I];
1052 for (unsigned J = I + 1; J < Info.size(); ++J) {
1053 if (Processed[J])
1054 continue;
1055
1056 PartialMulCandidate &InfoB = Info[J];
1057 auto *InfoReal = &InfoA;
1058 auto *InfoImag = &InfoB;
1059
1060 auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1061 if (!NodeFromCommon) {
1062 std::swap(InfoReal, InfoImag);
1063 NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1064 }
1065 if (!NodeFromCommon)
1066 continue;
1067
1068 CommonToNode[InfoReal->Common] = NodeFromCommon;
1069 CommonToNode[InfoImag->Common] = NodeFromCommon;
1070 Processed[I] = true;
1071 Processed[J] = true;
1072 }
1073 }
1074
1075 std::vector<bool> ProcessedReal(RealMuls.size(), false);
1076 std::vector<bool> ProcessedImag(ImagMuls.size(), false);
1077 NodePtr Result = Accumulator;
1078 for (auto &PMI : Info) {
1079 if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx])
1080 continue;
1081
1082 auto It = CommonToNode.find(PMI.Common);
1083 // TODO: Process independent complex multiplications. Cases like this:
1084 // A.real() * B where both A and B are complex numbers.
1085 if (It == CommonToNode.end()) {
1086 LLVM_DEBUG({
1087 dbgs() << "Unprocessed independent partial multiplication:\n";
1088 for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]})
1089 dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier
1090 << " multiplied by " << *Mul->Multiplicand << "\n";
1091 });
1092 return nullptr;
1093 }
1094
1095 auto &RealMul = RealMuls[PMI.RealIdx];
1096 auto &ImagMul = ImagMuls[PMI.ImagIdx];
1097
1098 auto NodeA = It->second;
1099 auto NodeB = PMI.Node;
1100 auto IsMultiplicandReal = PMI.Common == NodeA->Real;
1101 // The following table illustrates the relationship between multiplications
1102 // and rotations. If we consider the multiplication (X + iY) * (U + iV), we
1103 // can see:
1104 //
1105 // Rotation | Real | Imag |
1106 // ---------+--------+--------+
1107 // 0 | x * u | x * v |
1108 // 90 | -y * v | y * u |
1109 // 180 | -x * u | -x * v |
1110 // 270 | y * v | -y * u |
1111 //
1112 // Check if the candidate can indeed be represented by partial
1113 // multiplication
1114 // TODO: Add support for multiplication by complex one
1115 if ((IsMultiplicandReal && PMI.IsNodeInverted) ||
1116 (!IsMultiplicandReal && !PMI.IsNodeInverted))
1117 continue;
1118
1119 // Determine the rotation based on the multiplications
1121 if (IsMultiplicandReal) {
1122 // Detect 0 and 180 degrees rotation
1123 if (RealMul.IsPositive && ImagMul.IsPositive)
1125 else if (!RealMul.IsPositive && !ImagMul.IsPositive)
1127 else
1128 continue;
1129
1130 } else {
1131 // Detect 90 and 270 degrees rotation
1132 if (!RealMul.IsPositive && ImagMul.IsPositive)
1134 else if (RealMul.IsPositive && !ImagMul.IsPositive)
1136 else
1137 continue;
1138 }
1139
1140 LLVM_DEBUG({
1141 dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n";
1142 dbgs().indent(4) << "X: " << *NodeA->Real << "\n";
1143 dbgs().indent(4) << "Y: " << *NodeA->Imag << "\n";
1144 dbgs().indent(4) << "U: " << *NodeB->Real << "\n";
1145 dbgs().indent(4) << "V: " << *NodeB->Imag << "\n";
1146 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1147 });
1148
1149 NodePtr NodeMul = prepareCompositeNode(
1150 ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr);
1151 NodeMul->Rotation = Rotation;
1152 NodeMul->addOperand(NodeA);
1153 NodeMul->addOperand(NodeB);
1154 if (Result)
1155 NodeMul->addOperand(Result);
1156 submitCompositeNode(NodeMul);
1157 Result = NodeMul;
1158 ProcessedReal[PMI.RealIdx] = true;
1159 ProcessedImag[PMI.ImagIdx] = true;
1160 }
1161
1162 // Ensure all products have been processed, if not return nullptr.
1163 if (!all_of(ProcessedReal, [](bool V) { return V; }) ||
1164 !all_of(ProcessedImag, [](bool V) { return V; })) {
1165
1166 // Dump debug information about which partial multiplications are not
1167 // processed.
1168 LLVM_DEBUG({
1169 dbgs() << "Unprocessed products (Real):\n";
1170 for (size_t i = 0; i < ProcessedReal.size(); ++i) {
1171 if (!ProcessedReal[i])
1172 dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-")
1173 << *RealMuls[i].Multiplier << " multiplied by "
1174 << *RealMuls[i].Multiplicand << "\n";
1175 }
1176 dbgs() << "Unprocessed products (Imag):\n";
1177 for (size_t i = 0; i < ProcessedImag.size(); ++i) {
1178 if (!ProcessedImag[i])
1179 dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-")
1180 << *ImagMuls[i].Multiplier << " multiplied by "
1181 << *ImagMuls[i].Multiplicand << "\n";
1182 }
1183 });
1184 return nullptr;
1185 }
1186
1187 return Result;
1188}
1189
1190ComplexDeinterleavingGraph::NodePtr
1191ComplexDeinterleavingGraph::identifyAdditions(std::list<Addend> &RealAddends,
1192 std::list<Addend> &ImagAddends,
1193 FastMathFlags Flags,
1194 NodePtr Accumulator = nullptr) {
1195 if (RealAddends.size() != ImagAddends.size())
1196 return nullptr;
1197
1198 NodePtr Result;
1199 // If we have accumulator use it as first addend
1200 if (Accumulator)
1202 // Otherwise find an element with both positive real and imaginary parts.
1203 else
1204 Result = extractPositiveAddend(RealAddends, ImagAddends);
1205
1206 if (!Result)
1207 return nullptr;
1208
1209 while (!RealAddends.empty()) {
1210 auto ItR = RealAddends.begin();
1211 auto [R, IsPositiveR] = *ItR;
1212
1213 bool FoundImag = false;
1214 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1215 auto [I, IsPositiveI] = *ItI;
1217 if (IsPositiveR && IsPositiveI)
1218 Rotation = ComplexDeinterleavingRotation::Rotation_0;
1219 else if (!IsPositiveR && IsPositiveI)
1220 Rotation = ComplexDeinterleavingRotation::Rotation_90;
1221 else if (!IsPositiveR && !IsPositiveI)
1222 Rotation = ComplexDeinterleavingRotation::Rotation_180;
1223 else
1224 Rotation = ComplexDeinterleavingRotation::Rotation_270;
1225
1226 NodePtr AddNode;
1227 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
1228 Rotation == ComplexDeinterleavingRotation::Rotation_180) {
1229 AddNode = identifyNode(R, I);
1230 } else {
1231 AddNode = identifyNode(I, R);
1232 }
1233 if (AddNode) {
1234 LLVM_DEBUG({
1235 dbgs() << "Identified addition:\n";
1236 dbgs().indent(4) << "X: " << *R << "\n";
1237 dbgs().indent(4) << "Y: " << *I << "\n";
1238 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1239 });
1240
1241 NodePtr TmpNode;
1243 TmpNode = prepareCompositeNode(
1244 ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1245 TmpNode->Opcode = Instruction::FAdd;
1246 TmpNode->Flags = Flags;
1247 } else if (Rotation ==
1249 TmpNode = prepareCompositeNode(
1250 ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1251 TmpNode->Opcode = Instruction::FSub;
1252 TmpNode->Flags = Flags;
1253 } else {
1254 TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd,
1255 nullptr, nullptr);
1256 TmpNode->Rotation = Rotation;
1257 }
1258
1259 TmpNode->addOperand(Result);
1260 TmpNode->addOperand(AddNode);
1261 submitCompositeNode(TmpNode);
1262 Result = TmpNode;
1263 RealAddends.erase(ItR);
1264 ImagAddends.erase(ItI);
1265 FoundImag = true;
1266 break;
1267 }
1268 }
1269 if (!FoundImag)
1270 return nullptr;
1271 }
1272 return Result;
1273}
1274
1275ComplexDeinterleavingGraph::NodePtr
1276ComplexDeinterleavingGraph::extractPositiveAddend(
1277 std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends) {
1278 for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) {
1279 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1280 auto [R, IsPositiveR] = *ItR;
1281 auto [I, IsPositiveI] = *ItI;
1282 if (IsPositiveR && IsPositiveI) {
1283 auto Result = identifyNode(R, I);
1284 if (Result) {
1285 RealAddends.erase(ItR);
1286 ImagAddends.erase(ItI);
1287 return Result;
1288 }
1289 }
1290 }
1291 }
1292 return nullptr;
1293}
1294
1295bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
1296 auto RootNode = identifyRoot(RootI);
1297 if (!RootNode)
1298 return false;
1299
1300 LLVM_DEBUG({
1301 Function *F = RootI->getFunction();
1302 BasicBlock *B = RootI->getParent();
1303 dbgs() << "Complex deinterleaving graph for " << F->getName()
1304 << "::" << B->getName() << ".\n";
1305 dump(dbgs());
1306 dbgs() << "\n";
1307 });
1308 RootToNode[RootI] = RootNode;
1309 OrderedRoots.push_back(RootI);
1310 return true;
1311}
1312
1313bool ComplexDeinterleavingGraph::checkNodes() {
1314 // Collect all instructions from roots to leaves
1315 SmallPtrSet<Instruction *, 16> AllInstructions;
1317 for (auto *I : OrderedRoots)
1318 Worklist.push_back(I);
1319
1320 // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG
1321 // chains
1322 while (!Worklist.empty()) {
1323 auto *I = Worklist.back();
1324 Worklist.pop_back();
1325
1326 if (!AllInstructions.insert(I).second)
1327 continue;
1328
1329 for (Value *Op : I->operands()) {
1330 if (auto *OpI = dyn_cast<Instruction>(Op)) {
1331 if (!FinalInstructions.count(I))
1332 Worklist.emplace_back(OpI);
1333 }
1334 }
1335 }
1336
1337 // Find instructions that have users outside of chain
1338 SmallVector<Instruction *, 2> OuterInstructions;
1339 for (auto *I : AllInstructions) {
1340 // Skip root nodes
1341 if (RootToNode.count(I))
1342 continue;
1343
1344 for (User *U : I->users()) {
1345 if (AllInstructions.count(cast<Instruction>(U)))
1346 continue;
1347
1348 // Found an instruction that is not used by XCMLA/XCADD chain
1349 Worklist.emplace_back(I);
1350 break;
1351 }
1352 }
1353
1354 // If any instructions are found to be used outside, find and remove roots
1355 // that somehow connect to those instructions.
1357 while (!Worklist.empty()) {
1358 auto *I = Worklist.back();
1359 Worklist.pop_back();
1360 if (!Visited.insert(I).second)
1361 continue;
1362
1363 // Found an impacted root node. Removing it from the nodes to be
1364 // deinterleaved
1365 if (RootToNode.count(I)) {
1366 LLVM_DEBUG(dbgs() << "Instruction " << *I
1367 << " could be deinterleaved but its chain of complex "
1368 "operations have an outside user\n");
1369 RootToNode.erase(I);
1370 }
1371
1372 if (!AllInstructions.count(I) || FinalInstructions.count(I))
1373 continue;
1374
1375 for (User *U : I->users())
1376 Worklist.emplace_back(cast<Instruction>(U));
1377
1378 for (Value *Op : I->operands()) {
1379 if (auto *OpI = dyn_cast<Instruction>(Op))
1380 Worklist.emplace_back(OpI);
1381 }
1382 }
1383 return !RootToNode.empty();
1384}
1385
1386ComplexDeinterleavingGraph::NodePtr
1387ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) {
1388 if (auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) {
1389 if (Intrinsic->getIntrinsicID() !=
1390 Intrinsic::experimental_vector_interleave2)
1391 return nullptr;
1392
1393 auto *Real = dyn_cast<Instruction>(Intrinsic->getOperand(0));
1394 auto *Imag = dyn_cast<Instruction>(Intrinsic->getOperand(1));
1395 if (!Real || !Imag)
1396 return nullptr;
1397
1398 return identifyNode(Real, Imag);
1399 }
1400
1401 auto *SVI = dyn_cast<ShuffleVectorInst>(RootI);
1402 if (!SVI)
1403 return nullptr;
1404
1405 // Look for a shufflevector that takes separate vectors of the real and
1406 // imaginary components and recombines them into a single vector.
1407 if (!isInterleavingMask(SVI->getShuffleMask()))
1408 return nullptr;
1409
1410 Instruction *Real;
1411 Instruction *Imag;
1412 if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag))))
1413 return nullptr;
1414
1415 return identifyNode(Real, Imag);
1416}
1417
1418ComplexDeinterleavingGraph::NodePtr
1419ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real,
1420 Instruction *Imag) {
1421 Instruction *I = nullptr;
1422 Value *FinalValue = nullptr;
1423 if (match(Real, m_ExtractValue<0>(m_Instruction(I))) &&
1424 match(Imag, m_ExtractValue<1>(m_Specific(I))) &&
1425 match(I, m_Intrinsic<Intrinsic::experimental_vector_deinterleave2>(
1426 m_Value(FinalValue)))) {
1427 NodePtr PlaceholderNode = prepareCompositeNode(
1429 PlaceholderNode->ReplacementNode = FinalValue;
1430 FinalInstructions.insert(Real);
1431 FinalInstructions.insert(Imag);
1432 return submitCompositeNode(PlaceholderNode);
1433 }
1434
1435 auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real);
1436 auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag);
1437 if (!RealShuffle || !ImagShuffle) {
1438 if (RealShuffle || ImagShuffle)
1439 LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n");
1440 return nullptr;
1441 }
1442
1443 Value *RealOp1 = RealShuffle->getOperand(1);
1444 if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) {
1445 LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n");
1446 return nullptr;
1447 }
1448 Value *ImagOp1 = ImagShuffle->getOperand(1);
1449 if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) {
1450 LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n");
1451 return nullptr;
1452 }
1453
1454 Value *RealOp0 = RealShuffle->getOperand(0);
1455 Value *ImagOp0 = ImagShuffle->getOperand(0);
1456
1457 if (RealOp0 != ImagOp0) {
1458 LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n");
1459 return nullptr;
1460 }
1461
1462 ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
1463 ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
1464 if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) {
1465 LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n");
1466 return nullptr;
1467 }
1468
1469 if (RealMask[0] != 0 || ImagMask[0] != 1) {
1470 LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n");
1471 return nullptr;
1472 }
1473
1474 // Type checking, the shuffle type should be a vector type of the same
1475 // scalar type, but half the size
1476 auto CheckType = [&](ShuffleVectorInst *Shuffle) {
1477 Value *Op = Shuffle->getOperand(0);
1478 auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType());
1479 auto *OpTy = cast<FixedVectorType>(Op->getType());
1480
1481 if (OpTy->getScalarType() != ShuffleTy->getScalarType())
1482 return false;
1483 if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
1484 return false;
1485
1486 return true;
1487 };
1488
1489 auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool {
1490 if (!CheckType(Shuffle))
1491 return false;
1492
1493 ArrayRef<int> Mask = Shuffle->getShuffleMask();
1494 int Last = *Mask.rbegin();
1495
1496 Value *Op = Shuffle->getOperand(0);
1497 auto *OpTy = cast<FixedVectorType>(Op->getType());
1498 int NumElements = OpTy->getNumElements();
1499
1500 // Ensure that the deinterleaving shuffle only pulls from the first
1501 // shuffle operand.
1502 return Last < NumElements;
1503 };
1504
1505 if (RealShuffle->getType() != ImagShuffle->getType()) {
1506 LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n");
1507 return nullptr;
1508 }
1509 if (!CheckDeinterleavingShuffle(RealShuffle)) {
1510 LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n");
1511 return nullptr;
1512 }
1513 if (!CheckDeinterleavingShuffle(ImagShuffle)) {
1514 LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n");
1515 return nullptr;
1516 }
1517
1518 NodePtr PlaceholderNode =
1520 RealShuffle, ImagShuffle);
1521 PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
1522 FinalInstructions.insert(RealShuffle);
1523 FinalInstructions.insert(ImagShuffle);
1524 return submitCompositeNode(PlaceholderNode);
1525}
1526
1527static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,
1528 FastMathFlags Flags, Value *InputA,
1529 Value *InputB) {
1530 Value *I;
1531 switch (Opcode) {
1532 case Instruction::FNeg:
1533 I = B.CreateFNeg(InputA);
1534 break;
1535 case Instruction::FAdd:
1536 I = B.CreateFAdd(InputA, InputB);
1537 break;
1538 case Instruction::FSub:
1539 I = B.CreateFSub(InputA, InputB);
1540 break;
1541 case Instruction::FMul:
1542 I = B.CreateFMul(InputA, InputB);
1543 break;
1544 default:
1545 llvm_unreachable("Incorrect symmetric opcode");
1546 }
1547 cast<Instruction>(I)->setFastMathFlags(Flags);
1548 return I;
1549}
1550
1551Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
1552 RawNodePtr Node) {
1553 if (Node->ReplacementNode)
1554 return Node->ReplacementNode;
1555
1556 Value *Input0 = replaceNode(Builder, Node->Operands[0]);
1557 Value *Input1 = Node->Operands.size() > 1
1558 ? replaceNode(Builder, Node->Operands[1])
1559 : nullptr;
1560 Value *Accumulator = Node->Operands.size() > 2
1561 ? replaceNode(Builder, Node->Operands[2])
1562 : nullptr;
1563 if (Input1)
1564 assert(Input0->getType() == Input1->getType() &&
1565 "Node inputs need to be of the same type");
1566
1567 if (Node->Operation == ComplexDeinterleavingOperation::Symmetric)
1568 Node->ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode,
1569 Node->Flags, Input0, Input1);
1570 else
1571 Node->ReplacementNode = TL->createComplexDeinterleavingIR(
1572 Builder, Node->Operation, Node->Rotation, Input0, Input1, Accumulator);
1573
1574 assert(Node->ReplacementNode && "Target failed to create Intrinsic call.");
1575 NumComplexTransformations += 1;
1576 return Node->ReplacementNode;
1577}
1578
1579void ComplexDeinterleavingGraph::replaceNodes() {
1580 SmallVector<Instruction *, 16> DeadInstrRoots;
1581 for (auto *RootInstruction : OrderedRoots) {
1582 // Check if this potential root went through check process and we can
1583 // deinterleave it
1584 if (!RootToNode.count(RootInstruction))
1585 continue;
1586
1587 IRBuilder<> Builder(RootInstruction);
1588 auto RootNode = RootToNode[RootInstruction];
1589 Value *R = replaceNode(Builder, RootNode.get());
1590 assert(R && "Unable to find replacement for RootInstruction");
1591 DeadInstrRoots.push_back(RootInstruction);
1592 RootInstruction->replaceAllUsesWith(R);
1593 }
1594
1595 for (auto *I : DeadInstrRoots)
1597}
SmallVector< AArch64_IMM::ImmInsnModel, 4 > Insn
static MCDisassembler::DecodeStatus addOperand(MCInst &Inst, const MCOperand &Opnd)
assume Assume Builder
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
Analysis containing CSE Info
Definition: CSEInfo.cpp:27
static bool isInstructionPotentiallySymmetric(Instruction *I)
static Value * replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode, FastMathFlags Flags, Value *InputA, Value *InputB)
static cl::opt< bool > ComplexDeinterleavingEnabled("enable-complex-deinterleaving", cl::desc("Enable generation of complex instructions"), cl::init(true), cl::Hidden)
static bool isInstructionPairAdd(Instruction *A, Instruction *B)
Complex Deinterleaving
static bool isInterleavingMask(ArrayRef< int > Mask)
Checks the given mask, and determines whether said mask is interleaving.
static bool isDeinterleavingMask(ArrayRef< int > Mask)
Checks the given mask, and determines whether said mask is deinterleaving.
static bool isInstructionPairMul(Instruction *A, Instruction *B)
#define DEBUG_TYPE
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
#define LLVM_DEBUG(X)
Definition: Debug.h:101
static bool runOnFunction(Function &F, bool PostInlining)
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
mir Rename Register Operands
PowerPC Reduce CR logical Operation
const char LLVMTargetMachineRef TM
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:59
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:52
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
raw_pwrite_stream & OS
static LLVM_ATTRIBUTE_ALWAYS_INLINE bool CheckType(const unsigned char *MatcherTable, unsigned &MatcherIndex, SDValue N, const TargetLowering *TLI, const DataLayout &DL)
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition: Statistic.h:167
This file describes how to lower LLVM code to machine code.
Target-Independent Code Generator Pass Configuration Options pass.
This pass exposes codegen information to IR-level passes.
@ Flags
Definition: TextStubV5.cpp:93
BinaryOperator * Mul
DEMANGLE_DUMP_METHOD void dump() const
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:620
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition: Pass.cpp:265
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: ArrayRef.h:41
LLVM Basic Block Representation.
Definition: BasicBlock.h:56
Convenience struct for specifying and reasoning about fast-math flags.
Definition: FMF.h:20
bool allowContract() const
Definition: FMF.h:70
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:311
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
Common base class shared among various IRBuilders.
Definition: IRBuilder.h:94
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2570
An analysis over an "outer" IR unit that provides access to an analysis manager over an "inner" IR un...
Definition: PassManager.h:933
bool isBinaryOp() const
Definition: Instruction.h:173
const BasicBlock * getParent() const
Definition: Instruction.h:90
const Function * getFunction() const
Return the function this instruction belongs to.
Definition: Instruction.cpp:74
FastMathFlags getFastMathFlags() const LLVM_READONLY
Convenience function for getting all the fast-math flags, which must be an operator which supports th...
unsigned getOpcode() const
Returns a member of one of the enums like Instruction::Add.
Definition: Instruction.h:168
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
Definition: Pass.cpp:98
virtual StringRef getPassName() const
getPassName - Return a nice clean name for a pass.
Definition: Pass.cpp:81
A set of analyses that are preserved following a run of a transformation pass.
Definition: PassManager.h:152
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: PassManager.h:158
void preserve()
Mark an analysis as preserved.
Definition: PassManager.h:173
This instruction constructs a fixed permutation of two input vectors.
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
Definition: SmallPtrSet.h:383
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:365
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
Definition: SmallPtrSet.h:450
bool empty() const
Definition: SmallVector.h:94
reference emplace_back(ArgTypes &&... Args)
Definition: SmallVector.h:941
void push_back(const T &Elt)
Definition: SmallVector.h:416
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1200
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:50
Analysis pass providing the TargetLibraryInfo.
Provides information about what library functions are available for the current target.
virtual bool isComplexDeinterleavingOperationSupported(ComplexDeinterleavingOperation Operation, Type *Ty) const
Does this target support complex deinterleaving with the given operation and type.
virtual Value * createComplexDeinterleavingIR(IRBuilderBase &B, ComplexDeinterleavingOperation OperationType, ComplexDeinterleavingRotation Rotation, Value *InputA, Value *InputB, Value *Accumulator=nullptr) const
Create the IR node for the given complex deinterleaving operation.
virtual bool isComplexDeinterleavingSupported() const
Does this target support complex deinterleaving.
This class defines information used to lower LLVM code to legal SelectionDAG operators that the targe...
Primary interface to the complete machine description for the target machine.
Definition: TargetMachine.h:78
Value * getOperand(unsigned i) const
Definition: User.h:169
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition: Value.h:434
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition: raw_ostream.h:52
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
Definition: BitmaskEnum.h:119
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
@ BR
Control flow instructions. These all have token chains.
Definition: ISDOpcodes.h:1021
class_match< BinaryOperator > m_BinOp()
Match an arbitrary binary operation and ignore it.
Definition: PatternMatch.h:84
BinaryOp_match< LHS, RHS, Instruction::FMul > m_FMul(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
bind_ty< Instruction > m_Instruction(Instruction *&I)
Match an instruction, capturing it if we match.
Definition: PatternMatch.h:716
specificval_ty m_Specific(const Value *V)
Match if we have a specific specified value.
Definition: PatternMatch.h:772
TwoOps_match< V1_t, V2_t, Instruction::ShuffleVector > m_Shuffle(const V1_t &v1, const V2_t &v2)
Matches ShuffleVectorInst independently of mask value.
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
Definition: PatternMatch.h:76
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:445
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
void dump(const SparseBitVector< ElementSize > &LHS, raw_ostream &out)
@ Offset
Definition: DWP.cpp:440
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1819
bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
If the specified value is a trivially dead instruction, delete it.
Definition: Local.cpp:529
void initializeComplexDeinterleavingLegacyPassPass(PassRegistry &)
FunctionPass * createComplexDeinterleavingPass(const TargetMachine *TM)
This pass implements generation of target-specific intrinsics to support handling of complex number a...
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition: BitVector.h:860