LLVM 17.0.0git
ComplexDeinterleavingPass.cpp
Go to the documentation of this file.
1//===- ComplexDeinterleavingPass.cpp --------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Identification:
10// This step is responsible for finding the patterns that can be lowered to
11// complex instructions, and building a graph to represent the complex
12// structures. Starting from the "Converging Shuffle" (a shuffle that
13// reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the
14// operands are evaluated and identified as "Composite Nodes" (collections of
15// instructions that can potentially be lowered to a single complex
16// instruction). This is performed by checking the real and imaginary components
17// and tracking the data flow for each component while following the operand
18// pairs. Validity of each node is expected to be done upon creation, and any
19// validation errors should halt traversal and prevent further graph
20// construction.
21//
22// Replacement:
23// This step traverses the graph built up by identification, delegating to the
24// target to validate and generate the correct intrinsics, and plumbs them
25// together connecting each end of the new intrinsics graph to the existing
26// use-def chain. This step is assumed to finish successfully, as all
27// information is expected to be correct by this point.
28//
29//
30// Internal data structure:
31// ComplexDeinterleavingGraph:
32// Keeps references to all the valid CompositeNodes formed as part of the
33// transformation, and every Instruction contained within said nodes. It also
34// holds onto a reference to the root Instruction, and the root node that should
35// replace it.
36//
37// ComplexDeinterleavingCompositeNode:
38// A CompositeNode represents a single transformation point; each node should
39// transform into a single complex instruction (ignoring vector splitting, which
40// would generate more instructions per node). They are identified in a
41// depth-first manner, traversing and identifying the operands of each
42// instruction in the order they appear in the IR.
43// Each node maintains a reference to its Real and Imaginary instructions,
44// as well as any additional instructions that make up the identified operation
45// (Internal instructions should only have uses within their containing node).
46// A Node also contains the rotation and operation type that it represents.
47// Operands contains pointers to other CompositeNodes, acting as the edges in
48// the graph. ReplacementValue is the transformed Value* that has been emitted
49// to the IR.
50//
51// Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and
52// ReplacementValue fields of that Node are relevant, where the ReplacementValue
53// should be pre-populated.
54//
55//===----------------------------------------------------------------------===//
56
58#include "llvm/ADT/Statistic.h"
64#include "llvm/IR/IRBuilder.h"
68#include <algorithm>
69
70using namespace llvm;
71using namespace PatternMatch;
72
73#define DEBUG_TYPE "complex-deinterleaving"
74
75STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed");
76
78 "enable-complex-deinterleaving",
79 cl::desc("Enable generation of complex instructions"), cl::init(true),
81
82/// Checks the given mask, and determines whether said mask is interleaving.
83///
84/// To be interleaving, a mask must alternate between `i` and `i + (Length /
85/// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a
86/// 4x vector interleaving mask would be <0, 2, 1, 3>).
87static bool isInterleavingMask(ArrayRef<int> Mask);
88
89/// Checks the given mask, and determines whether said mask is deinterleaving.
90///
91/// To be deinterleaving, a mask must increment in steps of 2, and either start
92/// with 0 or 1.
93/// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or
94/// <1, 3, 5, 7>).
95static bool isDeinterleavingMask(ArrayRef<int> Mask);
96
97namespace {
98
99class ComplexDeinterleavingLegacyPass : public FunctionPass {
100public:
101 static char ID;
102
103 ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr)
104 : FunctionPass(ID), TM(TM) {
107 }
108
109 StringRef getPassName() const override {
110 return "Complex Deinterleaving Pass";
111 }
112
113 bool runOnFunction(Function &F) override;
114 void getAnalysisUsage(AnalysisUsage &AU) const override {
116 AU.setPreservesCFG();
117 }
118
119private:
120 const TargetMachine *TM;
121};
122
123class ComplexDeinterleavingGraph;
124struct ComplexDeinterleavingCompositeNode {
125
126 ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
128 : Operation(Op), Real(R), Imag(I) {}
129
130private:
131 friend class ComplexDeinterleavingGraph;
132 using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>;
133 using RawNodePtr = ComplexDeinterleavingCompositeNode *;
134
135public:
137 Instruction *Real;
138 Instruction *Imag;
139
140 // Instructions that should only exist within this node, there should be no
141 // users of these instructions outside the node. An example of these would be
142 // the multiply instructions of a partial multiply operation.
143 SmallVector<Instruction *> InternalInstructions;
146 Value *ReplacementNode = nullptr;
147
148 void addInstruction(Instruction *I) { InternalInstructions.push_back(I); }
149 void addOperand(NodePtr Node) { Operands.push_back(Node.get()); }
150
151 bool hasAllInternalUses(SmallPtrSet<Instruction *, 16> &AllInstructions);
152
153 void dump() { dump(dbgs()); }
154 void dump(raw_ostream &OS) {
155 auto PrintValue = [&](Value *V) {
156 if (V) {
157 OS << "\"";
158 V->print(OS, true);
159 OS << "\"\n";
160 } else
161 OS << "nullptr\n";
162 };
163 auto PrintNodeRef = [&](RawNodePtr Ptr) {
164 if (Ptr)
165 OS << Ptr << "\n";
166 else
167 OS << "nullptr\n";
168 };
169
170 OS << "- CompositeNode: " << this << "\n";
171 OS << " Real: ";
172 PrintValue(Real);
173 OS << " Imag: ";
174 PrintValue(Imag);
175 OS << " ReplacementNode: ";
176 PrintValue(ReplacementNode);
177 OS << " Operation: " << (int)Operation << "\n";
178 OS << " Rotation: " << ((int)Rotation * 90) << "\n";
179 OS << " Operands: \n";
180 for (const auto &Op : Operands) {
181 OS << " - ";
182 PrintNodeRef(Op);
183 }
184 OS << " InternalInstructions:\n";
185 for (const auto &I : InternalInstructions) {
186 OS << " - \"";
187 I->print(OS, true);
188 OS << "\"\n";
189 }
190 }
191};
192
193class ComplexDeinterleavingGraph {
194public:
195 using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr;
196 using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr;
197 explicit ComplexDeinterleavingGraph(const TargetLowering *tl) : TL(tl) {}
198
199private:
200 const TargetLowering *TL;
201 Instruction *RootValue;
202 NodePtr RootNode;
203 SmallVector<NodePtr> CompositeNodes;
204 SmallPtrSet<Instruction *, 16> AllInstructions;
205
206 NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation,
208 return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R,
209 I);
210 }
211
212 NodePtr submitCompositeNode(NodePtr Node) {
213 CompositeNodes.push_back(Node);
214 AllInstructions.insert(Node->Real);
215 AllInstructions.insert(Node->Imag);
216 for (auto *I : Node->InternalInstructions)
217 AllInstructions.insert(I);
218 return Node;
219 }
220
221 NodePtr getContainingComposite(Value *R, Value *I) {
222 for (const auto &CN : CompositeNodes) {
223 if (CN->Real == R && CN->Imag == I)
224 return CN;
225 }
226 return nullptr;
227 }
228
229 /// Identifies a complex partial multiply pattern and its rotation, based on
230 /// the following patterns
231 ///
232 /// 0: r: cr + ar * br
233 /// i: ci + ar * bi
234 /// 90: r: cr - ai * bi
235 /// i: ci + ai * br
236 /// 180: r: cr - ar * br
237 /// i: ci - ar * bi
238 /// 270: r: cr + ai * bi
239 /// i: ci - ai * br
240 NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag);
241
242 /// Identify the other branch of a Partial Mul, taking the CommonOperandI that
243 /// is partially known from identifyPartialMul, filling in the other half of
244 /// the complex pair.
245 NodePtr identifyNodeWithImplicitAdd(
247 std::pair<Instruction *, Instruction *> &CommonOperandI);
248
249 /// Identifies a complex add pattern and its rotation, based on the following
250 /// patterns.
251 ///
252 /// 90: r: ar - bi
253 /// i: ai + br
254 /// 270: r: ar + bi
255 /// i: ai - br
256 NodePtr identifyAdd(Instruction *Real, Instruction *Imag);
257
258 NodePtr identifyNode(Instruction *I, Instruction *J);
259
260 Value *replaceNode(RawNodePtr Node);
261
262public:
263 void dump() { dump(dbgs()); }
264 void dump(raw_ostream &OS) {
265 for (const auto &Node : CompositeNodes)
266 Node->dump(OS);
267 }
268
269 /// Returns false if the deinterleaving operation should be cancelled for the
270 /// current graph.
271 bool identifyNodes(Instruction *RootI);
272
273 /// Perform the actual replacement of the underlying instruction graph.
274 /// Returns false if the deinterleaving operation should be cancelled for the
275 /// current graph.
276 void replaceNodes();
277};
278
279class ComplexDeinterleaving {
280public:
281 ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli)
282 : TL(tl), TLI(tli) {}
283 bool runOnFunction(Function &F);
284
285private:
286 bool evaluateBasicBlock(BasicBlock *B);
287
288 const TargetLowering *TL = nullptr;
289 const TargetLibraryInfo *TLI = nullptr;
290};
291
292} // namespace
293
294char ComplexDeinterleavingLegacyPass::ID = 0;
295
296INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
297 "Complex Deinterleaving", false, false)
298INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
300
303 const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering();
304 auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F);
305 if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F))
306 return PreservedAnalyses::all();
307
310 return PA;
311}
312
314 return new ComplexDeinterleavingLegacyPass(TM);
315}
316
317bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) {
318 const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering();
319 auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
320 return ComplexDeinterleaving(TL, &TLI).runOnFunction(F);
321}
322
323bool ComplexDeinterleaving::runOnFunction(Function &F) {
326 dbgs() << "Complex deinterleaving has been explicitly disabled.\n");
327 return false;
328 }
329
332 dbgs() << "Complex deinterleaving has been disabled, target does "
333 "not support lowering of complex number operations.\n");
334 return false;
335 }
336
337 bool Changed = false;
338 for (auto &B : F)
339 Changed |= evaluateBasicBlock(&B);
340
341 return Changed;
342}
343
345 // If the size is not even, it's not an interleaving mask
346 if ((Mask.size() & 1))
347 return false;
348
349 int HalfNumElements = Mask.size() / 2;
350 for (int Idx = 0; Idx < HalfNumElements; ++Idx) {
351 int MaskIdx = Idx * 2;
352 if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))
353 return false;
354 }
355
356 return true;
357}
358
360 int Offset = Mask[0];
361 int HalfNumElements = Mask.size() / 2;
362
363 for (int Idx = 1; Idx < HalfNumElements; ++Idx) {
364 if (Mask[Idx] != (Idx * 2) + Offset)
365 return false;
366 }
367
368 return true;
369}
370
371bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
372 bool Changed = false;
373
374 SmallVector<Instruction *> DeadInstrRoots;
375
376 for (auto &I : *B) {
377 auto *SVI = dyn_cast<ShuffleVectorInst>(&I);
378 if (!SVI)
379 continue;
380
381 // Look for a shufflevector that takes separate vectors of the real and
382 // imaginary components and recombines them into a single vector.
383 if (!isInterleavingMask(SVI->getShuffleMask()))
384 continue;
385
386 ComplexDeinterleavingGraph Graph(TL);
387 if (!Graph.identifyNodes(SVI))
388 continue;
389
390 Graph.replaceNodes();
391 DeadInstrRoots.push_back(SVI);
392 Changed = true;
393 }
394
395 for (const auto &I : DeadInstrRoots) {
396 if (!I || I->getParent() == nullptr)
397 continue;
399 }
400
401 return Changed;
402}
403
404ComplexDeinterleavingGraph::NodePtr
405ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
406 Instruction *Real, Instruction *Imag,
407 std::pair<Instruction *, Instruction *> &PartialMatch) {
408 LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag
409 << "\n");
410
411 if (!Real->hasOneUse() || !Imag->hasOneUse()) {
412 LLVM_DEBUG(dbgs() << " - Mul operand has multiple uses.\n");
413 return nullptr;
414 }
415
416 if (Real->getOpcode() != Instruction::FMul ||
417 Imag->getOpcode() != Instruction::FMul) {
418 LLVM_DEBUG(dbgs() << " - Real or imaginary instruction is not fmul\n");
419 return nullptr;
420 }
421
422 Instruction *R0 = dyn_cast<Instruction>(Real->getOperand(0));
423 Instruction *R1 = dyn_cast<Instruction>(Real->getOperand(1));
424 Instruction *I0 = dyn_cast<Instruction>(Imag->getOperand(0));
425 Instruction *I1 = dyn_cast<Instruction>(Imag->getOperand(1));
426 if (!R0 || !R1 || !I0 || !I1) {
427 LLVM_DEBUG(dbgs() << " - Mul operand not Instruction\n");
428 return nullptr;
429 }
430
431 // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the
432 // rotations and use the operand.
433 unsigned Negs = 0;
435 if (R0->getOpcode() == Instruction::FNeg ||
436 R1->getOpcode() == Instruction::FNeg) {
437 Negs |= 1;
438 if (R0->getOpcode() == Instruction::FNeg) {
439 FNegs.push_back(R0);
440 R0 = dyn_cast<Instruction>(R0->getOperand(0));
441 } else {
442 FNegs.push_back(R1);
443 R1 = dyn_cast<Instruction>(R1->getOperand(0));
444 }
445 if (!R0 || !R1)
446 return nullptr;
447 }
448 if (I0->getOpcode() == Instruction::FNeg ||
449 I1->getOpcode() == Instruction::FNeg) {
450 Negs |= 2;
451 Negs ^= 1;
452 if (I0->getOpcode() == Instruction::FNeg) {
453 FNegs.push_back(I0);
454 I0 = dyn_cast<Instruction>(I0->getOperand(0));
455 } else {
456 FNegs.push_back(I1);
457 I1 = dyn_cast<Instruction>(I1->getOperand(0));
458 }
459 if (!I0 || !I1)
460 return nullptr;
461 }
462
464
465 Instruction *CommonOperand;
466 Instruction *UncommonRealOp;
467 Instruction *UncommonImagOp;
468
469 if (R0 == I0 || R0 == I1) {
470 CommonOperand = R0;
471 UncommonRealOp = R1;
472 } else if (R1 == I0 || R1 == I1) {
473 CommonOperand = R1;
474 UncommonRealOp = R0;
475 } else {
476 LLVM_DEBUG(dbgs() << " - No equal operand\n");
477 return nullptr;
478 }
479
480 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
481 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
482 Rotation == ComplexDeinterleavingRotation::Rotation_270)
483 std::swap(UncommonRealOp, UncommonImagOp);
484
485 // Between identifyPartialMul and here we need to have found a complete valid
486 // pair from the CommonOperand of each part.
487 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
488 Rotation == ComplexDeinterleavingRotation::Rotation_180)
489 PartialMatch.first = CommonOperand;
490 else
491 PartialMatch.second = CommonOperand;
492
493 if (!PartialMatch.first || !PartialMatch.second) {
494 LLVM_DEBUG(dbgs() << " - Incomplete partial match\n");
495 return nullptr;
496 }
497
498 NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second);
499 if (!CommonNode) {
500 LLVM_DEBUG(dbgs() << " - No CommonNode identified\n");
501 return nullptr;
502 }
503
504 NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
505 if (!UncommonNode) {
506 LLVM_DEBUG(dbgs() << " - No UncommonNode identified\n");
507 return nullptr;
508 }
509
510 NodePtr Node = prepareCompositeNode(
511 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
512 Node->Rotation = Rotation;
513 Node->addOperand(CommonNode);
514 Node->addOperand(UncommonNode);
515 Node->InternalInstructions.append(FNegs);
516 return submitCompositeNode(Node);
517}
518
519ComplexDeinterleavingGraph::NodePtr
520ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
521 Instruction *Imag) {
522 LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag
523 << "\n");
524 // Determine rotation
526 if (Real->getOpcode() == Instruction::FAdd &&
527 Imag->getOpcode() == Instruction::FAdd)
528 Rotation = ComplexDeinterleavingRotation::Rotation_0;
529 else if (Real->getOpcode() == Instruction::FSub &&
530 Imag->getOpcode() == Instruction::FAdd)
531 Rotation = ComplexDeinterleavingRotation::Rotation_90;
532 else if (Real->getOpcode() == Instruction::FSub &&
533 Imag->getOpcode() == Instruction::FSub)
534 Rotation = ComplexDeinterleavingRotation::Rotation_180;
535 else if (Real->getOpcode() == Instruction::FAdd &&
536 Imag->getOpcode() == Instruction::FSub)
537 Rotation = ComplexDeinterleavingRotation::Rotation_270;
538 else {
539 LLVM_DEBUG(dbgs() << " - Unhandled rotation.\n");
540 return nullptr;
541 }
542
543 if (!Real->getFastMathFlags().allowContract() ||
544 !Imag->getFastMathFlags().allowContract()) {
545 LLVM_DEBUG(dbgs() << " - Contract is missing from the FastMath flags.\n");
546 return nullptr;
547 }
548
549 Value *CR = Real->getOperand(0);
550 Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1));
551 if (!RealMulI)
552 return nullptr;
553 Value *CI = Imag->getOperand(0);
554 Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1));
555 if (!ImagMulI)
556 return nullptr;
557
558 if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) {
559 LLVM_DEBUG(dbgs() << " - Mul instruction has multiple uses\n");
560 return nullptr;
561 }
562
563 Instruction *R0 = dyn_cast<Instruction>(RealMulI->getOperand(0));
564 Instruction *R1 = dyn_cast<Instruction>(RealMulI->getOperand(1));
565 Instruction *I0 = dyn_cast<Instruction>(ImagMulI->getOperand(0));
566 Instruction *I1 = dyn_cast<Instruction>(ImagMulI->getOperand(1));
567 if (!R0 || !R1 || !I0 || !I1) {
568 LLVM_DEBUG(dbgs() << " - Mul operand not Instruction\n");
569 return nullptr;
570 }
571
572 Instruction *CommonOperand;
573 Instruction *UncommonRealOp;
574 Instruction *UncommonImagOp;
575
576 if (R0 == I0 || R0 == I1) {
577 CommonOperand = R0;
578 UncommonRealOp = R1;
579 } else if (R1 == I0 || R1 == I1) {
580 CommonOperand = R1;
581 UncommonRealOp = R0;
582 } else {
583 LLVM_DEBUG(dbgs() << " - No equal operand\n");
584 return nullptr;
585 }
586
587 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
588 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
589 Rotation == ComplexDeinterleavingRotation::Rotation_270)
590 std::swap(UncommonRealOp, UncommonImagOp);
591
592 std::pair<Instruction *, Instruction *> PartialMatch(
593 (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
594 Rotation == ComplexDeinterleavingRotation::Rotation_180)
595 ? CommonOperand
596 : nullptr,
597 (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
598 Rotation == ComplexDeinterleavingRotation::Rotation_270)
599 ? CommonOperand
600 : nullptr);
601 NodePtr CNode = identifyNodeWithImplicitAdd(
602 cast<Instruction>(CR), cast<Instruction>(CI), PartialMatch);
603 if (!CNode) {
604 LLVM_DEBUG(dbgs() << " - No cnode identified\n");
605 return nullptr;
606 }
607
608 NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
609 if (!UncommonRes) {
610 LLVM_DEBUG(dbgs() << " - No UncommonRes identified\n");
611 return nullptr;
612 }
613
614 assert(PartialMatch.first && PartialMatch.second);
615 NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second);
616 if (!CommonRes) {
617 LLVM_DEBUG(dbgs() << " - No CommonRes identified\n");
618 return nullptr;
619 }
620
621 NodePtr Node = prepareCompositeNode(
622 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
623 Node->addInstruction(RealMulI);
624 Node->addInstruction(ImagMulI);
625 Node->Rotation = Rotation;
626 Node->addOperand(CommonRes);
627 Node->addOperand(UncommonRes);
628 Node->addOperand(CNode);
629 return submitCompositeNode(Node);
630}
631
632ComplexDeinterleavingGraph::NodePtr
633ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
634 LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n");
635
636 // Determine rotation
638 if ((Real->getOpcode() == Instruction::FSub &&
639 Imag->getOpcode() == Instruction::FAdd) ||
640 (Real->getOpcode() == Instruction::Sub &&
641 Imag->getOpcode() == Instruction::Add))
642 Rotation = ComplexDeinterleavingRotation::Rotation_90;
643 else if ((Real->getOpcode() == Instruction::FAdd &&
644 Imag->getOpcode() == Instruction::FSub) ||
645 (Real->getOpcode() == Instruction::Add &&
646 Imag->getOpcode() == Instruction::Sub))
647 Rotation = ComplexDeinterleavingRotation::Rotation_270;
648 else {
649 LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n");
650 return nullptr;
651 }
652
653 auto *AR = dyn_cast<Instruction>(Real->getOperand(0));
654 auto *BI = dyn_cast<Instruction>(Real->getOperand(1));
655 auto *AI = dyn_cast<Instruction>(Imag->getOperand(0));
656 auto *BR = dyn_cast<Instruction>(Imag->getOperand(1));
657
658 if (!AR || !AI || !BR || !BI) {
659 LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n");
660 return nullptr;
661 }
662
663 NodePtr ResA = identifyNode(AR, AI);
664 if (!ResA) {
665 LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
666 return nullptr;
667 }
668 NodePtr ResB = identifyNode(BR, BI);
669 if (!ResB) {
670 LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");
671 return nullptr;
672 }
673
674 NodePtr Node =
675 prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
676 Node->Rotation = Rotation;
677 Node->addOperand(ResA);
678 Node->addOperand(ResB);
679 return submitCompositeNode(Node);
680}
681
683 unsigned OpcA = A->getOpcode();
684 unsigned OpcB = B->getOpcode();
685
686 return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
687 (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
688 (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
689 (OpcA == Instruction::Add && OpcB == Instruction::Sub);
690}
691
693 auto Pattern =
695
696 return match(A, Pattern) && match(B, Pattern);
697}
698
699ComplexDeinterleavingGraph::NodePtr
700ComplexDeinterleavingGraph::identifyNode(Instruction *Real, Instruction *Imag) {
701 LLVM_DEBUG(dbgs() << "identifyNode on " << *Real << " / " << *Imag << "\n");
702 if (NodePtr CN = getContainingComposite(Real, Imag)) {
703 LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
704 return CN;
705 }
706
707 auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real);
708 auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag);
709 if (RealShuffle && ImagShuffle) {
710 Value *RealOp1 = RealShuffle->getOperand(1);
711 if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) {
712 LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n");
713 return nullptr;
714 }
715 Value *ImagOp1 = ImagShuffle->getOperand(1);
716 if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) {
717 LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n");
718 return nullptr;
719 }
720
721 Value *RealOp0 = RealShuffle->getOperand(0);
722 Value *ImagOp0 = ImagShuffle->getOperand(0);
723
724 if (RealOp0 != ImagOp0) {
725 LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n");
726 return nullptr;
727 }
728
729 ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
730 ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
731 if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) {
732 LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n");
733 return nullptr;
734 }
735
736 if (RealMask[0] != 0 || ImagMask[0] != 1) {
737 LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n");
738 return nullptr;
739 }
740
741 // Type checking, the shuffle type should be a vector type of the same
742 // scalar type, but half the size
743 auto CheckType = [&](ShuffleVectorInst *Shuffle) {
744 Value *Op = Shuffle->getOperand(0);
745 auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType());
746 auto *OpTy = cast<FixedVectorType>(Op->getType());
747
748 if (OpTy->getScalarType() != ShuffleTy->getScalarType())
749 return false;
750 if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
751 return false;
752
753 return true;
754 };
755
756 auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool {
757 if (!CheckType(Shuffle))
758 return false;
759
760 ArrayRef<int> Mask = Shuffle->getShuffleMask();
761 int Last = *Mask.rbegin();
762
763 Value *Op = Shuffle->getOperand(0);
764 auto *OpTy = cast<FixedVectorType>(Op->getType());
765 int NumElements = OpTy->getNumElements();
766
767 // Ensure that the deinterleaving shuffle only pulls from the first
768 // shuffle operand.
769 return Last < NumElements;
770 };
771
772 if (RealShuffle->getType() != ImagShuffle->getType()) {
773 LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n");
774 return nullptr;
775 }
776 if (!CheckDeinterleavingShuffle(RealShuffle)) {
777 LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n");
778 return nullptr;
779 }
780 if (!CheckDeinterleavingShuffle(ImagShuffle)) {
781 LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n");
782 return nullptr;
783 }
784
785 NodePtr PlaceholderNode =
787 RealShuffle, ImagShuffle);
788 PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
789 return submitCompositeNode(PlaceholderNode);
790 }
791 if (RealShuffle || ImagShuffle)
792 return nullptr;
793
794 auto *VTy = cast<FixedVectorType>(Real->getType());
795 auto *NewVTy =
796 FixedVectorType::get(VTy->getScalarType(), VTy->getNumElements() * 2);
797
799 ComplexDeinterleavingOperation::CMulPartial, NewVTy) &&
800 isInstructionPairMul(Real, Imag)) {
801 return identifyPartialMul(Real, Imag);
802 }
803
805 ComplexDeinterleavingOperation::CAdd, NewVTy) &&
806 isInstructionPairAdd(Real, Imag)) {
807 return identifyAdd(Real, Imag);
808 }
809
810 return nullptr;
811}
812
813bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
814 Instruction *Real;
815 Instruction *Imag;
816 if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag))))
817 return false;
818
819 RootValue = RootI;
820 AllInstructions.insert(RootI);
821 RootNode = identifyNode(Real, Imag);
822
823 LLVM_DEBUG({
824 Function *F = RootI->getFunction();
825 BasicBlock *B = RootI->getParent();
826 dbgs() << "Complex deinterleaving graph for " << F->getName()
827 << "::" << B->getName() << ".\n";
828 dump(dbgs());
829 dbgs() << "\n";
830 });
831
832 // Check all instructions have internal uses
833 for (const auto &Node : CompositeNodes) {
834 if (!Node->hasAllInternalUses(AllInstructions)) {
835 LLVM_DEBUG(dbgs() << " - Invalid internal uses\n");
836 return false;
837 }
838 }
839 return RootNode != nullptr;
840}
841
842Value *ComplexDeinterleavingGraph::replaceNode(
843 ComplexDeinterleavingGraph::RawNodePtr Node) {
844 if (Node->ReplacementNode)
845 return Node->ReplacementNode;
846
847 Value *Input0 = replaceNode(Node->Operands[0]);
848 Value *Input1 = replaceNode(Node->Operands[1]);
850 Node->Operands.size() > 2 ? replaceNode(Node->Operands[2]) : nullptr;
851
852 assert(Input0->getType() == Input1->getType() &&
853 "Node inputs need to be of the same type");
854
855 Node->ReplacementNode = TL->createComplexDeinterleavingIR(
856 Node->Real, Node->Operation, Node->Rotation, Input0, Input1, Accumulator);
857
858 assert(Node->ReplacementNode && "Target failed to create Intrinsic call.");
859 NumComplexTransformations += 1;
860 return Node->ReplacementNode;
861}
862
863void ComplexDeinterleavingGraph::replaceNodes() {
864 Value *R = replaceNode(RootNode.get());
865 assert(R && "Unable to find replacement for RootValue");
866 RootValue->replaceAllUsesWith(R);
867}
868
869bool ComplexDeinterleavingCompositeNode::hasAllInternalUses(
870 SmallPtrSet<Instruction *, 16> &AllInstructions) {
871 if (Operation == ComplexDeinterleavingOperation::Shuffle)
872 return true;
873
874 for (auto *User : Real->users()) {
875 if (!AllInstructions.contains(cast<Instruction>(User)))
876 return false;
877 }
878 for (auto *User : Imag->users()) {
879 if (!AllInstructions.contains(cast<Instruction>(User)))
880 return false;
881 }
882 for (auto *I : InternalInstructions) {
883 for (auto *User : I->users()) {
884 if (!AllInstructions.contains(cast<Instruction>(User)))
885 return false;
886 }
887 }
888 return true;
889}
static MCDisassembler::DecodeStatus addOperand(MCInst &Inst, const MCOperand &Opnd)
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static cl::opt< bool > ComplexDeinterleavingEnabled("enable-complex-deinterleaving", cl::desc("Enable generation of complex instructions"), cl::init(true), cl::Hidden)
static bool isInstructionPairAdd(Instruction *A, Instruction *B)
Complex Deinterleaving
static bool isInterleavingMask(ArrayRef< int > Mask)
Checks the given mask, and determines whether said mask is interleaving.
static bool isDeinterleavingMask(ArrayRef< int > Mask)
Checks the given mask, and determines whether said mask is deinterleaving.
static bool isInstructionPairMul(Instruction *A, Instruction *B)
#define DEBUG_TYPE
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
#define LLVM_DEBUG(X)
Definition: Debug.h:101
static bool runOnFunction(Function &F, bool PostInlining)
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
mir Rename Register Operands
PowerPC Reduce CR logical Operation
const char LLVMTargetMachineRef TM
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:59
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:52
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
static LLVM_ATTRIBUTE_ALWAYS_INLINE bool CheckType(const unsigned char *MatcherTable, unsigned &MatcherIndex, SDValue N, const TargetLowering *TLI, const DataLayout &DL)
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition: Statistic.h:167
This file describes how to lower LLVM code to machine code.
Target-Independent Code Generator Pass Configuration Options pass.
This pass exposes codegen information to IR-level passes.
DEMANGLE_DUMP_METHOD void dump() const
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:620
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition: Pass.cpp:265
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: ArrayRef.h:41
LLVM Basic Block Representation.
Definition: BasicBlock.h:56
bool allowContract() const
Definition: FMF.h:71
static FixedVectorType * get(Type *ElementType, unsigned NumElts)
Definition: Type.cpp:698
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:308
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
An analysis over an "outer" IR unit that provides access to an analysis manager over an "inner" IR un...
Definition: PassManager.h:933
const BasicBlock * getParent() const
Definition: Instruction.h:90
const Function * getFunction() const
Return the function this instruction belongs to.
Definition: Instruction.cpp:74
FastMathFlags getFastMathFlags() const LLVM_READONLY
Convenience function for getting all the fast-math flags, which must be an operator which supports th...
unsigned getOpcode() const
Returns a member of one of the enums like Instruction::Add.
Definition: Instruction.h:168
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
Definition: Pass.cpp:98
virtual StringRef getPassName() const
getPassName - Return a nice clean name for a pass.
Definition: Pass.cpp:81
A set of analyses that are preserved following a run of a transformation pass.
Definition: PassManager.h:152
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: PassManager.h:158
void preserve()
Mark an analysis as preserved.
Definition: PassManager.h:173
This instruction constructs a fixed permutation of two input vectors.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:365
bool contains(ConstPtrType Ptr) const
Definition: SmallPtrSet.h:389
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
Definition: SmallPtrSet.h:450
void push_back(const T &Elt)
Definition: SmallVector.h:416
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1200
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:50
Analysis pass providing the TargetLibraryInfo.
Provides information about what library functions are available for the current target.
virtual bool isComplexDeinterleavingOperationSupported(ComplexDeinterleavingOperation Operation, Type *Ty) const
Does this target support complex deinterleaving with the given operation and type.
virtual Value * createComplexDeinterleavingIR(Instruction *I, ComplexDeinterleavingOperation OperationType, ComplexDeinterleavingRotation Rotation, Value *InputA, Value *InputB, Value *Accumulator=nullptr) const
Create the IR node for the given complex deinterleaving operation.
virtual bool isComplexDeinterleavingSupported() const
Does this target support complex deinterleaving.
This class defines information used to lower LLVM code to legal SelectionDAG operators that the targe...
Primary interface to the complete machine description for the target machine.
Definition: TargetMachine.h:78
Value * getOperand(unsigned i) const
Definition: User.h:169
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition: Value.h:434
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:532
iterator_range< user_iterator > users()
Definition: Value.h:421
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition: raw_ostream.h:52
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
Definition: BitmaskEnum.h:80
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
@ BR
Control flow instructions. These all have token chains.
Definition: ISDOpcodes.h:981
class_match< BinaryOperator > m_BinOp()
Match an arbitrary binary operation and ignore it.
Definition: PatternMatch.h:84
BinaryOp_match< LHS, RHS, Instruction::FMul > m_FMul(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
bind_ty< Instruction > m_Instruction(Instruction *&I)
Match an instruction, capturing it if we match.
Definition: PatternMatch.h:716
TwoOps_match< V1_t, V2_t, Instruction::ShuffleVector > m_Shuffle(const V1_t &v1, const V2_t &v2)
Matches ShuffleVectorInst independently of mask value.
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
Definition: PatternMatch.h:76
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:445
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
void dump(const SparseBitVector< ElementSize > &LHS, raw_ostream &out)
@ Offset
Definition: DWP.cpp:406
bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
If the specified value is a trivially dead instruction, delete it.
Definition: Local.cpp:537
void initializeComplexDeinterleavingLegacyPassPass(PassRegistry &)
FunctionPass * createComplexDeinterleavingPass(const TargetMachine *TM)
This pass implements generation of target-specific intrinsics to support handling of complex number a...
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition: BitVector.h:853