77using namespace PatternMatch;
79#define DEBUG_TYPE "complex-deinterleaving"
81STATISTIC(NumComplexTransformations,
"Amount of complex patterns transformed");
84 "enable-complex-deinterleaving",
105class ComplexDeinterleavingLegacyPass :
public FunctionPass {
109 ComplexDeinterleavingLegacyPass(
const TargetMachine *TM =
nullptr)
116 return "Complex Deinterleaving Pass";
129class ComplexDeinterleavingGraph;
130struct ComplexDeinterleavingCompositeNode {
137 friend class ComplexDeinterleavingGraph;
138 using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>;
139 using RawNodePtr = ComplexDeinterleavingCompositeNode *;
152 ComplexDeinterleavingRotation::Rotation_0;
154 Value *ReplacementNode =
nullptr;
160 auto PrintValue = [&](
Value *
V) {
168 auto PrintNodeRef = [&](RawNodePtr
Ptr) {
175 OS <<
"- CompositeNode: " <<
this <<
"\n";
180 OS <<
" ReplacementNode: ";
181 PrintValue(ReplacementNode);
183 OS <<
" Rotation: " << ((int)Rotation * 90) <<
"\n";
184 OS <<
" Operands: \n";
192class ComplexDeinterleavingGraph {
200 using Addend = std::pair<Instruction *, bool>;
201 using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr;
202 using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr;
206 struct PartialMulCandidate {
216 : TL(TL), TLI(TLI) {}
226 std::map<Instruction *, NodePtr> RootToNode;
233 return std::make_shared<ComplexDeinterleavingCompositeNode>(
Operation, R,
237 NodePtr submitCompositeNode(NodePtr
Node) {
242 NodePtr getContainingComposite(
Value *R,
Value *
I) {
243 for (
const auto &CN : CompositeNodes) {
244 if (CN->Real == R && CN->Imag ==
I)
266 NodePtr identifyNodeWithImplicitAdd(
268 std::pair<Instruction *, Instruction *> &CommonOperandI);
286 NodePtr identifyAdditions(std::list<Addend> &RealAddends,
291 NodePtr extractPositiveAddend(std::list<Addend> &RealAddends,
292 std::list<Addend> &ImagAddends);
297 NodePtr identifyMultiplications(std::vector<Product> &RealMuls,
298 std::vector<Product> &ImagMuls,
304 bool collectPartialMuls(
const std::vector<Product> &RealMuls,
305 const std::vector<Product> &ImagMuls,
306 std::vector<PartialMulCandidate> &Candidates);
332 for (
const auto &
Node : CompositeNodes)
348class ComplexDeinterleaving {
351 : TL(tl), TLI(tli) {}
363char ComplexDeinterleavingLegacyPass::ID = 0;
366 "Complex Deinterleaving",
false,
false)
374 if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(
F))
383 return new ComplexDeinterleavingLegacyPass(
TM);
386bool ComplexDeinterleavingLegacyPass::runOnFunction(
Function &
F) {
387 const auto *TL =
TM->getSubtargetImpl(
F)->getTargetLowering();
388 auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(
F);
389 return ComplexDeinterleaving(TL, &TLI).runOnFunction(
F);
392bool ComplexDeinterleaving::runOnFunction(
Function &
F) {
395 dbgs() <<
"Complex deinterleaving has been explicitly disabled.\n");
401 dbgs() <<
"Complex deinterleaving has been disabled, target does "
402 "not support lowering of complex number operations.\n");
406 bool Changed =
false;
408 Changed |= evaluateBasicBlock(&
B);
415 if ((Mask.size() & 1))
418 int HalfNumElements = Mask.size() / 2;
419 for (
int Idx = 0;
Idx < HalfNumElements; ++
Idx) {
420 int MaskIdx =
Idx * 2;
421 if (Mask[MaskIdx] !=
Idx || Mask[MaskIdx + 1] != (
Idx + HalfNumElements))
430 int HalfNumElements = Mask.size() / 2;
432 for (
int Idx = 1;
Idx < HalfNumElements; ++
Idx) {
440bool ComplexDeinterleaving::evaluateBasicBlock(
BasicBlock *
B) {
441 ComplexDeinterleavingGraph Graph(TL, TLI);
443 Graph.identifyNodes(&
I);
445 if (Graph.checkNodes()) {
446 Graph.replaceNodes();
453ComplexDeinterleavingGraph::NodePtr
454ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
456 std::pair<Instruction *, Instruction *> &PartialMatch) {
457 LLVM_DEBUG(
dbgs() <<
"identifyNodeWithImplicitAdd " << *Real <<
" / " << *Imag
465 if (Real->
getOpcode() != Instruction::FMul ||
466 Imag->
getOpcode() != Instruction::FMul) {
467 LLVM_DEBUG(
dbgs() <<
" - Real or imaginary instruction is not fmul\n");
475 if (!R0 || !R1 || !I0 || !I1) {
484 if (R0->
getOpcode() == Instruction::FNeg ||
487 if (R0->
getOpcode() == Instruction::FNeg) {
489 R0 = dyn_cast<Instruction>(R0->
getOperand(0));
492 R1 = dyn_cast<Instruction>(R1->
getOperand(0));
497 if (I0->
getOpcode() == Instruction::FNeg ||
498 I1->getOpcode() == Instruction::FNeg) {
501 if (I0->
getOpcode() == Instruction::FNeg) {
503 I0 = dyn_cast<Instruction>(I0->
getOperand(0));
506 I1 = dyn_cast<Instruction>(
I1->getOperand(0));
518 if (R0 == I0 || R0 == I1) {
521 }
else if (R1 == I0 || R1 == I1) {
529 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
530 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
531 Rotation == ComplexDeinterleavingRotation::Rotation_270)
532 std::swap(UncommonRealOp, UncommonImagOp);
536 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
537 Rotation == ComplexDeinterleavingRotation::Rotation_180)
538 PartialMatch.first = CommonOperand;
540 PartialMatch.second = CommonOperand;
542 if (!PartialMatch.first || !PartialMatch.second) {
547 NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second);
553 NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
559 NodePtr
Node = prepareCompositeNode(
560 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
561 Node->Rotation = Rotation;
562 Node->addOperand(CommonNode);
563 Node->addOperand(UncommonNode);
564 return submitCompositeNode(
Node);
567ComplexDeinterleavingGraph::NodePtr
568ComplexDeinterleavingGraph::identifyPartialMul(
Instruction *Real,
570 LLVM_DEBUG(
dbgs() <<
"identifyPartialMul " << *Real <<
" / " << *Imag
574 if (Real->
getOpcode() == Instruction::FAdd &&
576 Rotation = ComplexDeinterleavingRotation::Rotation_0;
577 else if (Real->
getOpcode() == Instruction::FSub &&
579 Rotation = ComplexDeinterleavingRotation::Rotation_90;
580 else if (Real->
getOpcode() == Instruction::FSub &&
582 Rotation = ComplexDeinterleavingRotation::Rotation_180;
583 else if (Real->
getOpcode() == Instruction::FAdd &&
585 Rotation = ComplexDeinterleavingRotation::Rotation_270;
593 LLVM_DEBUG(
dbgs() <<
" - Contract is missing from the FastMath flags.\n");
615 if (!R0 || !R1 || !I0 || !I1) {
624 if (R0 == I0 || R0 == I1) {
627 }
else if (R1 == I0 || R1 == I1) {
635 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
636 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
637 Rotation == ComplexDeinterleavingRotation::Rotation_270)
638 std::swap(UncommonRealOp, UncommonImagOp);
640 std::pair<Instruction *, Instruction *> PartialMatch(
641 (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
642 Rotation == ComplexDeinterleavingRotation::Rotation_180)
645 (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
646 Rotation == ComplexDeinterleavingRotation::Rotation_270)
650 auto *CRInst = dyn_cast<Instruction>(CR);
651 auto *CIInst = dyn_cast<Instruction>(CI);
653 if (!CRInst || !CIInst) {
654 LLVM_DEBUG(
dbgs() <<
" - Common operands are not instructions.\n");
658 NodePtr CNode = identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch);
664 NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
670 assert(PartialMatch.first && PartialMatch.second);
671 NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second);
677 NodePtr
Node = prepareCompositeNode(
678 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
679 Node->Rotation = Rotation;
680 Node->addOperand(CommonRes);
681 Node->addOperand(UncommonRes);
682 Node->addOperand(CNode);
683 return submitCompositeNode(
Node);
686ComplexDeinterleavingGraph::NodePtr
688 LLVM_DEBUG(
dbgs() <<
"identifyAdd " << *Real <<
" / " << *Imag <<
"\n");
692 if ((Real->
getOpcode() == Instruction::FSub &&
693 Imag->
getOpcode() == Instruction::FAdd) ||
694 (Real->
getOpcode() == Instruction::Sub &&
696 Rotation = ComplexDeinterleavingRotation::Rotation_90;
697 else if ((Real->
getOpcode() == Instruction::FAdd &&
698 Imag->
getOpcode() == Instruction::FSub) ||
699 (Real->
getOpcode() == Instruction::Add &&
701 Rotation = ComplexDeinterleavingRotation::Rotation_270;
703 LLVM_DEBUG(
dbgs() <<
" - Unhandled case, rotation is not assigned.\n");
707 auto *AR = dyn_cast<Instruction>(Real->
getOperand(0));
708 auto *BI = dyn_cast<Instruction>(Real->
getOperand(1));
709 auto *AI = dyn_cast<Instruction>(Imag->
getOperand(0));
712 if (!AR || !AI || !BR || !BI) {
717 NodePtr ResA = identifyNode(AR, AI);
719 LLVM_DEBUG(
dbgs() <<
" - AR/AI is not identified as a composite node.\n");
722 NodePtr ResB = identifyNode(BR, BI);
724 LLVM_DEBUG(
dbgs() <<
" - BR/BI is not identified as a composite node.\n");
729 prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
730 Node->Rotation = Rotation;
731 Node->addOperand(ResA);
732 Node->addOperand(ResB);
733 return submitCompositeNode(
Node);
737 unsigned OpcA =
A->getOpcode();
738 unsigned OpcB =
B->getOpcode();
740 return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
741 (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
742 (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
743 (OpcA == Instruction::Add && OpcB == Instruction::Sub);
754 switch (
I->getOpcode()) {
755 case Instruction::FAdd:
756 case Instruction::FSub:
757 case Instruction::FMul:
758 case Instruction::FNeg:
765ComplexDeinterleavingGraph::NodePtr
766ComplexDeinterleavingGraph::identifySymmetricOperation(
Instruction *Real,
775 auto *R0 = dyn_cast<Instruction>(Real->
getOperand(0));
776 auto *I0 = dyn_cast<Instruction>(Imag->
getOperand(0));
781 NodePtr Op0 = identifyNode(R0, I0);
782 NodePtr Op1 =
nullptr;
787 auto *R1 = dyn_cast<Instruction>(Real->
getOperand(1));
792 Op1 = identifyNode(R1, I1);
797 if (isa<FPMathOperator>(Real) &&
801 auto Node = prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric,
804 if (isa<FPMathOperator>(Real))
807 Node->addOperand(Op0);
809 Node->addOperand(Op1);
811 return submitCompositeNode(
Node);
814ComplexDeinterleavingGraph::NodePtr
816 LLVM_DEBUG(
dbgs() <<
"identifyNode on " << *Real <<
" / " << *Imag <<
"\n");
817 if (NodePtr CN = getContainingComposite(Real, Imag)) {
822 if (NodePtr CN = identifyDeinterleave(Real, Imag))
825 auto *VTy = cast<VectorType>(Real->
getType());
826 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
829 ComplexDeinterleavingOperation::CMulPartial, NewVTy);
831 ComplexDeinterleavingOperation::CAdd, NewVTy);
834 if (NodePtr CN = identifyPartialMul(Real, Imag))
839 if (NodePtr CN = identifyAdd(Real, Imag))
843 if (HasCMulSupport && HasCAddSupport) {
844 if (NodePtr CN = identifyReassocNodes(Real, Imag))
848 if (NodePtr CN = identifySymmetricOperation(Real, Imag))
855ComplexDeinterleavingGraph::NodePtr
856ComplexDeinterleavingGraph::identifyReassocNodes(
Instruction *Real,
858 if ((Real->
getOpcode() != Instruction::FAdd &&
859 Real->
getOpcode() != Instruction::FSub &&
860 Real->
getOpcode() != Instruction::FNeg) ||
861 (Imag->
getOpcode() != Instruction::FAdd &&
862 Imag->
getOpcode() != Instruction::FSub &&
869 <<
"The flags in Real and Imaginary instructions are not identical\n");
874 if (!
Flags.allowReassoc()) {
876 dbgs() <<
"the 'Reassoc' attribute is missing in the FastMath flags\n");
884 std::list<Addend> &Addends) ->
bool {
887 while (!Worklist.
empty()) {
888 auto [
V, IsPositive] = Worklist.
back();
890 if (!Visited.
insert(V).second)
903 if (
I !=
Insn &&
I->getNumUses() > 1) {
904 LLVM_DEBUG(
dbgs() <<
"Found potential sub-expression: " << *
I <<
"\n");
905 Addends.emplace_back(
I, IsPositive);
909 if (
I->getOpcode() == Instruction::FAdd) {
912 }
else if (
I->getOpcode() == Instruction::FSub) {
915 }
else if (
I->getOpcode() == Instruction::FMul) {
916 auto *
A = dyn_cast<Instruction>(
I->getOperand(0));
917 if (
A &&
A->getOpcode() == Instruction::FNeg) {
918 A = dyn_cast<Instruction>(
A->getOperand(0));
919 IsPositive = !IsPositive;
923 auto *
B = dyn_cast<Instruction>(
I->getOperand(1));
924 if (
B &&
B->getOpcode() == Instruction::FNeg) {
925 B = dyn_cast<Instruction>(
B->getOperand(0));
926 IsPositive = !IsPositive;
930 Muls.push_back(Product{
A,
B, IsPositive});
931 }
else if (
I->getOpcode() == Instruction::FNeg) {
934 Addends.emplace_back(
I, IsPositive);
938 if (
I->getFastMathFlags() != Flags) {
940 "inconsistent with the root instructions' flags: "
948 std::vector<Product> RealMuls, ImagMuls;
949 std::list<Addend> RealAddends, ImagAddends;
950 if (!Collect(Real, RealMuls, RealAddends) ||
951 !Collect(Imag, ImagMuls, ImagAddends))
954 if (RealAddends.size() != ImagAddends.size())
958 if (!RealMuls.empty() || !ImagMuls.empty()) {
961 FinalNode = extractPositiveAddend(RealAddends, ImagAddends);
962 FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode);
968 if (!RealAddends.empty() || !ImagAddends.empty()) {
969 FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode);
975 FinalNode->Real = Real;
976 FinalNode->Imag = Imag;
977 submitCompositeNode(FinalNode);
981bool ComplexDeinterleavingGraph::collectPartialMuls(
982 const std::vector<Product> &RealMuls,
const std::vector<Product> &ImagMuls,
983 std::vector<PartialMulCandidate> &PartialMulCandidates) {
985 auto FindCommonInstruction = [](
const Product &Real,
987 if (Real.Multiplicand == Imag.Multiplicand ||
988 Real.Multiplicand == Imag.Multiplier)
989 return Real.Multiplicand;
991 if (Real.Multiplier == Imag.Multiplicand ||
992 Real.Multiplier == Imag.Multiplier)
993 return Real.Multiplier;
1002 for (
unsigned i = 0; i < RealMuls.size(); ++i) {
1003 bool FoundCommon =
false;
1004 for (
unsigned j = 0;
j < ImagMuls.size(); ++
j) {
1005 auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]);
1009 auto *
A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier
1010 : RealMuls[i].Multiplicand;
1011 auto *
B = ImagMuls[
j].Multiplicand == Common ? ImagMuls[
j].Multiplier
1012 : ImagMuls[
j].Multiplicand;
1014 bool Inverted =
false;
1015 auto Node = identifyNode(
A,
B);
1019 Node = identifyNode(
A,
B);
1025 PartialMulCandidates.push_back({Common,
Node, i,
j, Inverted});
1033ComplexDeinterleavingGraph::NodePtr
1034ComplexDeinterleavingGraph::identifyMultiplications(
1035 std::vector<Product> &RealMuls, std::vector<Product> &ImagMuls,
1037 if (RealMuls.size() != ImagMuls.size())
1040 std::vector<PartialMulCandidate>
Info;
1041 if (!collectPartialMuls(RealMuls, ImagMuls, Info))
1045 std::map<Instruction *, NodePtr> CommonToNode;
1046 std::vector<bool> Processed(
Info.size(),
false);
1047 for (
unsigned I = 0;
I <
Info.size(); ++
I) {
1051 PartialMulCandidate &InfoA =
Info[
I];
1052 for (
unsigned J =
I + 1; J <
Info.size(); ++J) {
1056 PartialMulCandidate &InfoB =
Info[J];
1057 auto *InfoReal = &InfoA;
1058 auto *InfoImag = &InfoB;
1060 auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1061 if (!NodeFromCommon) {
1063 NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1065 if (!NodeFromCommon)
1068 CommonToNode[InfoReal->Common] = NodeFromCommon;
1069 CommonToNode[InfoImag->Common] = NodeFromCommon;
1070 Processed[
I] =
true;
1071 Processed[J] =
true;
1075 std::vector<bool> ProcessedReal(RealMuls.size(),
false);
1076 std::vector<bool> ProcessedImag(ImagMuls.size(),
false);
1078 for (
auto &PMI : Info) {
1079 if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx])
1082 auto It = CommonToNode.find(PMI.Common);
1085 if (It == CommonToNode.end()) {
1087 dbgs() <<
"Unprocessed independent partial multiplication:\n";
1088 for (
auto *
Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]})
1090 <<
" multiplied by " << *
Mul->Multiplicand <<
"\n";
1095 auto &RealMul = RealMuls[PMI.RealIdx];
1096 auto &ImagMul = ImagMuls[PMI.ImagIdx];
1098 auto NodeA = It->second;
1099 auto NodeB = PMI.Node;
1100 auto IsMultiplicandReal = PMI.Common == NodeA->Real;
1115 if ((IsMultiplicandReal && PMI.IsNodeInverted) ||
1116 (!IsMultiplicandReal && !PMI.IsNodeInverted))
1121 if (IsMultiplicandReal) {
1123 if (RealMul.IsPositive && ImagMul.IsPositive)
1125 else if (!RealMul.IsPositive && !ImagMul.IsPositive)
1132 if (!RealMul.IsPositive && ImagMul.IsPositive)
1134 else if (RealMul.IsPositive && !ImagMul.IsPositive)
1141 dbgs() <<
"Identified partial multiplication (X, Y) * (U, V):\n";
1142 dbgs().
indent(4) <<
"X: " << *NodeA->Real <<
"\n";
1143 dbgs().
indent(4) <<
"Y: " << *NodeA->Imag <<
"\n";
1144 dbgs().
indent(4) <<
"U: " << *NodeB->Real <<
"\n";
1145 dbgs().
indent(4) <<
"V: " << *NodeB->Imag <<
"\n";
1146 dbgs().
indent(4) <<
"Rotation - " << (int)Rotation * 90 <<
"\n";
1149 NodePtr NodeMul = prepareCompositeNode(
1150 ComplexDeinterleavingOperation::CMulPartial,
nullptr,
nullptr);
1151 NodeMul->Rotation = Rotation;
1152 NodeMul->addOperand(NodeA);
1153 NodeMul->addOperand(NodeB);
1155 NodeMul->addOperand(Result);
1156 submitCompositeNode(NodeMul);
1158 ProcessedReal[PMI.RealIdx] =
true;
1159 ProcessedImag[PMI.ImagIdx] =
true;
1163 if (!
all_of(ProcessedReal, [](
bool V) {
return V; }) ||
1164 !
all_of(ProcessedImag, [](
bool V) {
return V; })) {
1169 dbgs() <<
"Unprocessed products (Real):\n";
1170 for (
size_t i = 0; i < ProcessedReal.size(); ++i) {
1171 if (!ProcessedReal[i])
1172 dbgs().
indent(4) << (RealMuls[i].IsPositive ?
"+" :
"-")
1173 << *RealMuls[i].Multiplier <<
" multiplied by "
1174 << *RealMuls[i].Multiplicand <<
"\n";
1176 dbgs() <<
"Unprocessed products (Imag):\n";
1177 for (
size_t i = 0; i < ProcessedImag.size(); ++i) {
1178 if (!ProcessedImag[i])
1179 dbgs().
indent(4) << (ImagMuls[i].IsPositive ?
"+" :
"-")
1180 << *ImagMuls[i].Multiplier <<
" multiplied by "
1181 << *ImagMuls[i].Multiplicand <<
"\n";
1190ComplexDeinterleavingGraph::NodePtr
1191ComplexDeinterleavingGraph::identifyAdditions(std::list<Addend> &RealAddends,
1192 std::list<Addend> &ImagAddends,
1195 if (RealAddends.size() != ImagAddends.size())
1204 Result = extractPositiveAddend(RealAddends, ImagAddends);
1209 while (!RealAddends.empty()) {
1210 auto ItR = RealAddends.begin();
1211 auto [
R, IsPositiveR] = *ItR;
1213 bool FoundImag =
false;
1214 for (
auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1215 auto [
I, IsPositiveI] = *ItI;
1217 if (IsPositiveR && IsPositiveI)
1218 Rotation = ComplexDeinterleavingRotation::Rotation_0;
1219 else if (!IsPositiveR && IsPositiveI)
1220 Rotation = ComplexDeinterleavingRotation::Rotation_90;
1221 else if (!IsPositiveR && !IsPositiveI)
1222 Rotation = ComplexDeinterleavingRotation::Rotation_180;
1224 Rotation = ComplexDeinterleavingRotation::Rotation_270;
1227 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
1228 Rotation == ComplexDeinterleavingRotation::Rotation_180) {
1229 AddNode = identifyNode(R,
I);
1231 AddNode = identifyNode(
I, R);
1235 dbgs() <<
"Identified addition:\n";
1238 dbgs().
indent(4) <<
"Rotation - " << (int)Rotation * 90 <<
"\n";
1243 TmpNode = prepareCompositeNode(
1244 ComplexDeinterleavingOperation::Symmetric,
nullptr,
nullptr);
1245 TmpNode->Opcode = Instruction::FAdd;
1246 TmpNode->Flags =
Flags;
1247 }
else if (Rotation ==
1249 TmpNode = prepareCompositeNode(
1250 ComplexDeinterleavingOperation::Symmetric,
nullptr,
nullptr);
1251 TmpNode->Opcode = Instruction::FSub;
1252 TmpNode->Flags =
Flags;
1254 TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd,
1256 TmpNode->Rotation = Rotation;
1259 TmpNode->addOperand(Result);
1260 TmpNode->addOperand(AddNode);
1261 submitCompositeNode(TmpNode);
1263 RealAddends.erase(ItR);
1264 ImagAddends.erase(ItI);
1275ComplexDeinterleavingGraph::NodePtr
1276ComplexDeinterleavingGraph::extractPositiveAddend(
1277 std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends) {
1278 for (
auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) {
1279 for (
auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1280 auto [
R, IsPositiveR] = *ItR;
1281 auto [
I, IsPositiveI] = *ItI;
1282 if (IsPositiveR && IsPositiveI) {
1283 auto Result = identifyNode(R,
I);
1285 RealAddends.erase(ItR);
1286 ImagAddends.erase(ItI);
1295bool ComplexDeinterleavingGraph::identifyNodes(
Instruction *RootI) {
1296 auto RootNode = identifyRoot(RootI);
1303 dbgs() <<
"Complex deinterleaving graph for " <<
F->getName()
1304 <<
"::" <<
B->getName() <<
".\n";
1308 RootToNode[RootI] = RootNode;
1313bool ComplexDeinterleavingGraph::checkNodes() {
1317 for (
auto *
I : OrderedRoots)
1322 while (!Worklist.
empty()) {
1323 auto *
I = Worklist.
back();
1326 if (!AllInstructions.
insert(
I).second)
1329 for (
Value *Op :
I->operands()) {
1330 if (
auto *OpI = dyn_cast<Instruction>(Op)) {
1331 if (!FinalInstructions.
count(
I))
1339 for (
auto *
I : AllInstructions) {
1341 if (RootToNode.count(
I))
1344 for (
User *U :
I->users()) {
1345 if (AllInstructions.count(cast<Instruction>(U)))
1357 while (!Worklist.
empty()) {
1358 auto *
I = Worklist.
back();
1360 if (!Visited.
insert(
I).second)
1365 if (RootToNode.count(
I)) {
1367 <<
" could be deinterleaved but its chain of complex "
1368 "operations have an outside user\n");
1369 RootToNode.erase(
I);
1372 if (!AllInstructions.count(
I) || FinalInstructions.
count(
I))
1375 for (
User *U :
I->users())
1378 for (
Value *Op :
I->operands()) {
1379 if (
auto *OpI = dyn_cast<Instruction>(Op))
1383 return !RootToNode.empty();
1386ComplexDeinterleavingGraph::NodePtr
1387ComplexDeinterleavingGraph::identifyRoot(
Instruction *RootI) {
1388 if (
auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) {
1390 Intrinsic::experimental_vector_interleave2)
1393 auto *Real = dyn_cast<Instruction>(
Intrinsic->getOperand(0));
1394 auto *Imag = dyn_cast<Instruction>(
Intrinsic->getOperand(1));
1398 return identifyNode(Real, Imag);
1401 auto *SVI = dyn_cast<ShuffleVectorInst>(RootI);
1415 return identifyNode(Real, Imag);
1418ComplexDeinterleavingGraph::NodePtr
1419ComplexDeinterleavingGraph::identifyDeinterleave(
Instruction *Real,
1422 Value *FinalValue =
nullptr;
1425 match(
I, m_Intrinsic<Intrinsic::experimental_vector_deinterleave2>(
1427 NodePtr PlaceholderNode = prepareCompositeNode(
1429 PlaceholderNode->ReplacementNode = FinalValue;
1430 FinalInstructions.
insert(Real);
1431 FinalInstructions.
insert(Imag);
1432 return submitCompositeNode(PlaceholderNode);
1435 auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real);
1436 auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag);
1437 if (!RealShuffle || !ImagShuffle) {
1438 if (RealShuffle || ImagShuffle)
1439 LLVM_DEBUG(
dbgs() <<
" - There's a shuffle where there shouldn't be.\n");
1443 Value *RealOp1 = RealShuffle->getOperand(1);
1444 if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) {
1448 Value *ImagOp1 = ImagShuffle->getOperand(1);
1449 if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) {
1454 Value *RealOp0 = RealShuffle->getOperand(0);
1455 Value *ImagOp0 = ImagShuffle->getOperand(0);
1457 if (RealOp0 != ImagOp0) {
1469 if (RealMask[0] != 0 || ImagMask[0] != 1) {
1470 LLVM_DEBUG(
dbgs() <<
" - Masks do not have the correct initial value.\n");
1477 Value *
Op = Shuffle->getOperand(0);
1478 auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType());
1479 auto *OpTy = cast<FixedVectorType>(
Op->getType());
1481 if (OpTy->getScalarType() != ShuffleTy->getScalarType())
1483 if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
1496 Value *
Op = Shuffle->getOperand(0);
1497 auto *OpTy = cast<FixedVectorType>(
Op->getType());
1498 int NumElements = OpTy->getNumElements();
1502 return Last < NumElements;
1505 if (RealShuffle->getType() != ImagShuffle->getType()) {
1509 if (!CheckDeinterleavingShuffle(RealShuffle)) {
1513 if (!CheckDeinterleavingShuffle(ImagShuffle)) {
1518 NodePtr PlaceholderNode =
1520 RealShuffle, ImagShuffle);
1521 PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
1522 FinalInstructions.
insert(RealShuffle);
1523 FinalInstructions.
insert(ImagShuffle);
1524 return submitCompositeNode(PlaceholderNode);
1532 case Instruction::FNeg:
1533 I =
B.CreateFNeg(InputA);
1535 case Instruction::FAdd:
1536 I =
B.CreateFAdd(InputA, InputB);
1538 case Instruction::FSub:
1539 I =
B.CreateFSub(InputA, InputB);
1541 case Instruction::FMul:
1542 I =
B.CreateFMul(InputA, InputB);
1547 cast<Instruction>(
I)->setFastMathFlags(
Flags);
1553 if (
Node->ReplacementNode)
1554 return Node->ReplacementNode;
1556 Value *Input0 = replaceNode(Builder,
Node->Operands[0]);
1557 Value *Input1 =
Node->Operands.size() > 1
1558 ? replaceNode(Builder,
Node->Operands[1])
1561 ? replaceNode(Builder,
Node->Operands[2])
1565 "Node inputs need to be of the same type");
1567 if (
Node->Operation == ComplexDeinterleavingOperation::Symmetric)
1569 Node->Flags, Input0, Input1);
1574 assert(
Node->ReplacementNode &&
"Target failed to create Intrinsic call.");
1575 NumComplexTransformations += 1;
1576 return Node->ReplacementNode;
1579void ComplexDeinterleavingGraph::replaceNodes() {
1581 for (
auto *RootInstruction : OrderedRoots) {
1584 if (!RootToNode.count(RootInstruction))
1588 auto RootNode = RootToNode[RootInstruction];
1589 Value *
R = replaceNode(Builder, RootNode.get());
1590 assert(R &&
"Unable to find replacement for RootInstruction");
1591 DeadInstrRoots.
push_back(RootInstruction);
1592 RootInstruction->replaceAllUsesWith(R);
1595 for (
auto *
I : DeadInstrRoots)
SmallVector< AArch64_IMM::ImmInsnModel, 4 > Insn
static MCDisassembler::DecodeStatus addOperand(MCInst &Inst, const MCOperand &Opnd)
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
Analysis containing CSE Info
static bool isInstructionPotentiallySymmetric(Instruction *I)
static Value * replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode, FastMathFlags Flags, Value *InputA, Value *InputB)
static cl::opt< bool > ComplexDeinterleavingEnabled("enable-complex-deinterleaving", cl::desc("Enable generation of complex instructions"), cl::init(true), cl::Hidden)
static bool isInstructionPairAdd(Instruction *A, Instruction *B)
static bool isInterleavingMask(ArrayRef< int > Mask)
Checks the given mask, and determines whether said mask is interleaving.
static bool isDeinterleavingMask(ArrayRef< int > Mask)
Checks the given mask, and determines whether said mask is deinterleaving.
static bool isInstructionPairMul(Instruction *A, Instruction *B)
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
static bool runOnFunction(Function &F, bool PostInlining)
mir Rename Register Operands
PowerPC Reduce CR logical Operation
const char LLVMTargetMachineRef TM
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
static LLVM_ATTRIBUTE_ALWAYS_INLINE bool CheckType(const unsigned char *MatcherTable, unsigned &MatcherIndex, SDValue N, const TargetLowering *TLI, const DataLayout &DL)
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
This file describes how to lower LLVM code to machine code.
Target-Independent Code Generator Pass Configuration Options pass.
DEMANGLE_DUMP_METHOD void dump() const
A container for analyses that lazily runs them and caches their results.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
void setPreservesCFG()
This function should be called by the pass, iff they do not:
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
LLVM Basic Block Representation.
Convenience struct for specifying and reasoning about fast-math flags.
bool allowContract() const
FunctionPass class - This class is used to implement most global optimizations.
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
Common base class shared among various IRBuilders.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
An analysis over an "outer" IR unit that provides access to an analysis manager over an "inner" IR un...
const BasicBlock * getParent() const
const Function * getFunction() const
Return the function this instruction belongs to.
FastMathFlags getFastMathFlags() const LLVM_READONLY
Convenience function for getting all the fast-math flags, which must be an operator which supports th...
unsigned getOpcode() const
Returns a member of one of the enums like Instruction::Add.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
virtual StringRef getPassName() const
getPassName - Return a nice clean name for a pass.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
void preserve()
Mark an analysis as preserved.
This instruction constructs a fixed permutation of two input vectors.
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
reference emplace_back(ArgTypes &&... Args)
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
StringRef - Represent a constant reference to a string, i.e.
Analysis pass providing the TargetLibraryInfo.
Provides information about what library functions are available for the current target.
virtual bool isComplexDeinterleavingOperationSupported(ComplexDeinterleavingOperation Operation, Type *Ty) const
Does this target support complex deinterleaving with the given operation and type.
virtual Value * createComplexDeinterleavingIR(IRBuilderBase &B, ComplexDeinterleavingOperation OperationType, ComplexDeinterleavingRotation Rotation, Value *InputA, Value *InputB, Value *Accumulator=nullptr) const
Create the IR node for the given complex deinterleaving operation.
virtual bool isComplexDeinterleavingSupported() const
Does this target support complex deinterleaving.
This class defines information used to lower LLVM code to legal SelectionDAG operators that the targe...
Primary interface to the complete machine description for the target machine.
Value * getOperand(unsigned i) const
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
bool hasOneUse() const
Return true if there is exactly one use of this value.
This class implements an extremely fast bulk output stream that can only output to a stream.
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
@ BR
Control flow instructions. These all have token chains.
class_match< BinaryOperator > m_BinOp()
Match an arbitrary binary operation and ignore it.
BinaryOp_match< LHS, RHS, Instruction::FMul > m_FMul(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
bind_ty< Instruction > m_Instruction(Instruction *&I)
Match an instruction, capturing it if we match.
specificval_ty m_Specific(const Value *V)
Match if we have a specific specified value.
TwoOps_match< V1_t, V2_t, Instruction::ShuffleVector > m_Shuffle(const V1_t &v1, const V2_t &v2)
Matches ShuffleVectorInst independently of mask value.
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
initializer< Ty > init(const Ty &Val)
This is an optimization pass for GlobalISel generic memory operations.
void dump(const SparseBitVector< ElementSize > &LHS, raw_ostream &out)
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
If the specified value is a trivially dead instruction, delete it.
void initializeComplexDeinterleavingLegacyPassPass(PassRegistry &)
ComplexDeinterleavingOperation
FunctionPass * createComplexDeinterleavingPass(const TargetMachine *TM)
This pass implements generation of target-specific intrinsics to support handling of complex number a...
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
ComplexDeinterleavingRotation
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.