82#define DEBUG_TYPE "complex-deinterleaving"
84STATISTIC(NumComplexTransformations,
"Amount of complex patterns transformed");
87 "enable-complex-deinterleaving",
115 Value *Real =
nullptr;
116 Value *Imag =
nullptr;
119 return Real ==
Other.Real && Imag ==
Other.Imag;
142 static bool isEqual(
const ComplexValue &LHS,
const ComplexValue &RHS) {
143 return LHS.Real == RHS.Real && LHS.Imag == RHS.Imag;
148template <
typename T,
typename IterT>
149std::optional<T> findCommonBetweenCollections(IterT
A, IterT
B) {
151 if (Common !=
A.end())
152 return std::make_optional(*Common);
156class ComplexDeinterleavingLegacyPass :
public FunctionPass {
160 ComplexDeinterleavingLegacyPass(
const TargetMachine *TM =
nullptr)
161 : FunctionPass(ID), TM(TM) {
166 StringRef getPassName()
const override {
167 return "Complex Deinterleaving Pass";
171 void getAnalysisUsage(AnalysisUsage &AU)
const override {
177 const TargetMachine *TM;
180class ComplexDeinterleavingGraph;
181struct ComplexDeinterleavingCompositeNode {
186 Vals.push_back({
R,
I});
191 : Operation(
Op), Vals(
Other) {}
194 friend class ComplexDeinterleavingGraph;
195 using CompositeNode = ComplexDeinterleavingCompositeNode;
196 bool OperandsValid =
true;
205 std::optional<FastMathFlags> Flags;
208 ComplexDeinterleavingRotation::Rotation_0;
210 Value *ReplacementNode =
nullptr;
214 OperandsValid =
false;
215 Operands.push_back(Node);
219 void dump(raw_ostream &OS) {
220 auto PrintValue = [&](
Value *
V) {
228 auto PrintNodeRef = [&](CompositeNode *
Ptr) {
235 OS <<
"- CompositeNode: " <<
this <<
"\n";
236 for (
unsigned I = 0;
I < Vals.size();
I++) {
237 OS <<
" Real(" <<
I <<
") : ";
238 PrintValue(Vals[
I].Real);
239 OS <<
" Imag(" <<
I <<
") : ";
240 PrintValue(Vals[
I].Imag);
242 OS <<
" ReplacementNode: ";
243 PrintValue(ReplacementNode);
244 OS <<
" Operation: " << (int)Operation <<
"\n";
245 OS <<
" Rotation: " << ((int)Rotation * 90) <<
"\n";
246 OS <<
" Operands: \n";
247 for (
const auto &
Op : Operands) {
253 bool areOperandsValid() {
return OperandsValid; }
256class ComplexDeinterleavingGraph {
264 using Addend = std::pair<Value *, bool>;
266 using CompositeNode = ComplexDeinterleavingCompositeNode::CompositeNode;
270 struct PartialMulCandidate {
278 explicit ComplexDeinterleavingGraph(
const TargetLowering *TL,
279 const TargetLibraryInfo *TLI,
281 : TL(TL), TLI(TLI), Factor(Factor) {}
284 const TargetLowering *TL =
nullptr;
285 const TargetLibraryInfo *TLI =
nullptr;
288 DenseMap<ComplexValues, CompositeNode *> CachedResult;
289 SpecificBumpPtrAllocator<ComplexDeinterleavingCompositeNode> Allocator;
291 SmallPtrSet<Instruction *, 16> FinalInstructions;
294 DenseMap<Instruction *, CompositeNode *> RootToNode;
321 MapVector<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo;
329 PHINode *RealPHI =
nullptr;
330 PHINode *ImagPHI =
nullptr;
334 bool PHIsFound =
false;
342 DenseMap<PHINode *, PHINode *> OldToNewPHI;
347 Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
349 "Reduction related nodes must have Real and Imaginary parts");
350 return new (Allocator.Allocate())
351 ComplexDeinterleavingCompositeNode(
Operation, R,
I);
357 for (
auto &V : Vals) {
359 ((
Operation != ComplexDeinterleavingOperation::ReductionPHI &&
360 Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
361 (
V.Real &&
V.Imag)) &&
362 "Reduction related nodes must have Real and Imaginary parts");
365 return new (Allocator.Allocate())
366 ComplexDeinterleavingCompositeNode(
Operation, Vals);
369 CompositeNode *submitCompositeNode(CompositeNode *Node) {
370 CompositeNodes.push_back(Node);
371 if (
Node->Vals[0].Real)
387 CompositeNode *identifyPartialMul(Instruction *Real, Instruction *Imag);
393 identifyNodeWithImplicitAdd(Instruction *
I, Instruction *J,
394 std::pair<Value *, Value *> &CommonOperandI);
403 CompositeNode *identifyAdd(Instruction *Real, Instruction *Imag);
404 CompositeNode *identifySymmetricOperation(
ComplexValues &Vals);
405 CompositeNode *identifyPartialReduction(
Value *R,
Value *
I);
406 CompositeNode *identifyDotProduct(
Value *Inst);
413 return identifyNode(Vals);
420 CompositeNode *identifyAdditions(AddendList &RealAddends,
421 AddendList &ImagAddends,
422 std::optional<FastMathFlags> Flags,
426 CompositeNode *extractPositiveAddend(AddendList &RealAddends,
427 AddendList &ImagAddends);
432 CompositeNode *identifyMultiplications(SmallVectorImpl<Product> &RealMuls,
433 SmallVectorImpl<Product> &ImagMuls,
441 SmallVectorImpl<PartialMulCandidate> &Candidates);
449 CompositeNode *identifyReassocNodes(Instruction *
I, Instruction *J);
451 CompositeNode *identifyRoot(Instruction *
I);
469 CompositeNode *identifyPHINode(Instruction *Real, Instruction *Imag);
473 CompositeNode *identifySelectNode(Instruction *Real, Instruction *Imag);
475 Value *replaceNode(IRBuilderBase &Builder, CompositeNode *Node);
482 void processReductionOperation(
Value *OperationReplacement,
483 CompositeNode *Node);
484 void processReductionSingle(
Value *OperationReplacement, CompositeNode *Node);
488 void dump(raw_ostream &OS) {
489 for (
const auto &Node : CompositeNodes)
495 bool identifyNodes(Instruction *RootI);
500 bool collectPotentialReductions(BasicBlock *
B);
502 void identifyReductionNodes();
512class ComplexDeinterleaving {
514 ComplexDeinterleaving(
const TargetLowering *tl,
const TargetLibraryInfo *tli)
515 : TL(tl), TLI(tli) {}
519 bool evaluateBasicBlock(BasicBlock *
B,
unsigned Factor);
521 const TargetLowering *TL =
nullptr;
522 const TargetLibraryInfo *TLI =
nullptr;
527char ComplexDeinterleavingLegacyPass::ID = 0;
530 "Complex Deinterleaving",
false,
false)
536 const TargetLowering *TL = TM->getSubtargetImpl(
F)->getTargetLowering();
547 return new ComplexDeinterleavingLegacyPass(TM);
550bool ComplexDeinterleavingLegacyPass::runOnFunction(
Function &
F) {
552 auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(
F);
553 return ComplexDeinterleaving(TL, &TLI).runOnFunction(
F);
556bool ComplexDeinterleaving::runOnFunction(Function &
F) {
559 dbgs() <<
"Complex deinterleaving has been explicitly disabled.\n");
565 dbgs() <<
"Complex deinterleaving has been disabled, target does "
566 "not support lowering of complex number operations.\n");
572 Changed |= evaluateBasicBlock(&
B, 2);
577 Changed |= evaluateBasicBlock(&
B, 4);
587 if ((Mask.size() & 1))
590 int HalfNumElements = Mask.size() / 2;
591 for (
int Idx = 0; Idx < HalfNumElements; ++Idx) {
592 int MaskIdx = Idx * 2;
593 if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))
602 int HalfNumElements = Mask.size() / 2;
604 for (
int Idx = 1; Idx < HalfNumElements; ++Idx) {
605 if (Mask[Idx] != (Idx * 2) +
Offset)
619 if (
I->getOpcode() == Instruction::FNeg)
620 return I->getOperand(0);
622 return I->getOperand(1);
625bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *
B,
unsigned Factor) {
626 ComplexDeinterleavingGraph Graph(TL, TLI, Factor);
627 if (Graph.collectPotentialReductions(
B))
628 Graph.identifyReductionNodes();
631 Graph.identifyNodes(&
I);
633 if (Graph.checkNodes()) {
634 Graph.replaceNodes();
641ComplexDeinterleavingGraph::CompositeNode *
642ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
643 Instruction *Real, Instruction *Imag,
644 std::pair<Value *, Value *> &PartialMatch) {
645 LLVM_DEBUG(
dbgs() <<
"identifyNodeWithImplicitAdd " << *Real <<
" / " << *Imag
653 if ((Real->
getOpcode() != Instruction::FMul &&
654 Real->
getOpcode() != Instruction::Mul) ||
655 (Imag->
getOpcode() != Instruction::FMul &&
656 Imag->
getOpcode() != Instruction::Mul)) {
658 dbgs() <<
" - Real or imaginary instruction is not fmul or mul\n");
691 Value *CommonOperand;
692 Value *UncommonRealOp;
693 Value *UncommonImagOp;
695 if (R0 == I0 || R0 == I1) {
698 }
else if (R1 == I0 || R1 == I1) {
706 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
707 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
708 Rotation == ComplexDeinterleavingRotation::Rotation_270)
709 std::swap(UncommonRealOp, UncommonImagOp);
713 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
714 Rotation == ComplexDeinterleavingRotation::Rotation_180)
715 PartialMatch.first = CommonOperand;
717 PartialMatch.second = CommonOperand;
719 if (!PartialMatch.first || !PartialMatch.second) {
724 CompositeNode *CommonNode =
725 identifyNode(PartialMatch.first, PartialMatch.second);
731 CompositeNode *UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
737 CompositeNode *
Node = prepareCompositeNode(
738 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
739 Node->Rotation = Rotation;
740 Node->addOperand(CommonNode);
741 Node->addOperand(UncommonNode);
742 return submitCompositeNode(Node);
745ComplexDeinterleavingGraph::CompositeNode *
746ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
748 LLVM_DEBUG(
dbgs() <<
"identifyPartialMul " << *Real <<
" / " << *Imag
752 auto IsAdd = [](
unsigned Op) {
753 return Op == Instruction::FAdd ||
Op == Instruction::Add;
755 auto IsSub = [](
unsigned Op) {
756 return Op == Instruction::FSub ||
Op == Instruction::Sub;
760 Rotation = ComplexDeinterleavingRotation::Rotation_0;
762 Rotation = ComplexDeinterleavingRotation::Rotation_90;
764 Rotation = ComplexDeinterleavingRotation::Rotation_180;
766 Rotation = ComplexDeinterleavingRotation::Rotation_270;
775 LLVM_DEBUG(
dbgs() <<
" - Contract is missing from the FastMath flags.\n");
798 Value *CommonOperand;
799 Value *UncommonRealOp;
800 Value *UncommonImagOp;
802 if (R0 == I0 || R0 == I1) {
805 }
else if (R1 == I0 || R1 == I1) {
813 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
814 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
815 Rotation == ComplexDeinterleavingRotation::Rotation_270)
816 std::swap(UncommonRealOp, UncommonImagOp);
818 std::pair<Value *, Value *> PartialMatch(
819 (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
820 Rotation == ComplexDeinterleavingRotation::Rotation_180)
823 (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
824 Rotation == ComplexDeinterleavingRotation::Rotation_270)
831 if (!CRInst || !CIInst) {
832 LLVM_DEBUG(
dbgs() <<
" - Common operands are not instructions.\n");
836 CompositeNode *CNode =
837 identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch);
843 CompositeNode *UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
849 assert(PartialMatch.first && PartialMatch.second);
850 CompositeNode *CommonRes =
851 identifyNode(PartialMatch.first, PartialMatch.second);
857 CompositeNode *
Node = prepareCompositeNode(
858 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
859 Node->Rotation = Rotation;
860 Node->addOperand(CommonRes);
861 Node->addOperand(UncommonRes);
862 Node->addOperand(CNode);
863 return submitCompositeNode(Node);
866ComplexDeinterleavingGraph::CompositeNode *
867ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
868 LLVM_DEBUG(
dbgs() <<
"identifyAdd " << *Real <<
" / " << *Imag <<
"\n");
872 if ((Real->
getOpcode() == Instruction::FSub &&
873 Imag->
getOpcode() == Instruction::FAdd) ||
874 (Real->
getOpcode() == Instruction::Sub &&
876 Rotation = ComplexDeinterleavingRotation::Rotation_90;
877 else if ((Real->
getOpcode() == Instruction::FAdd &&
878 Imag->
getOpcode() == Instruction::FSub) ||
879 (Real->
getOpcode() == Instruction::Add &&
881 Rotation = ComplexDeinterleavingRotation::Rotation_270;
883 LLVM_DEBUG(
dbgs() <<
" - Unhandled case, rotation is not assigned.\n");
892 if (!AR || !AI || !BR || !BI) {
897 CompositeNode *ResA = identifyNode(AR, AI);
899 LLVM_DEBUG(
dbgs() <<
" - AR/AI is not identified as a composite node.\n");
902 CompositeNode *ResB = identifyNode(BR, BI);
904 LLVM_DEBUG(
dbgs() <<
" - BR/BI is not identified as a composite node.\n");
908 CompositeNode *
Node =
909 prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
910 Node->Rotation = Rotation;
911 Node->addOperand(ResA);
912 Node->addOperand(ResB);
913 return submitCompositeNode(Node);
917 unsigned OpcA =
A->getOpcode();
918 unsigned OpcB =
B->getOpcode();
920 return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
921 (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
922 (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
923 (OpcA == Instruction::Add && OpcB == Instruction::Sub);
934 switch (
I->getOpcode()) {
935 case Instruction::FAdd:
936 case Instruction::FSub:
937 case Instruction::FMul:
938 case Instruction::FNeg:
939 case Instruction::Add:
940 case Instruction::Sub:
941 case Instruction::Mul:
948ComplexDeinterleavingGraph::CompositeNode *
949ComplexDeinterleavingGraph::identifySymmetricOperation(
ComplexValues &Vals) {
951 unsigned FirstOpc = FirstReal->getOpcode();
952 for (
auto &V : Vals) {
969 for (
auto &V : Vals) {
975 CompositeNode *Op0 = identifyNode(OpVals);
976 CompositeNode *Op1 =
nullptr;
980 if (FirstReal->isBinaryOp()) {
982 for (
auto &V : Vals) {
987 Op1 = identifyNode(OpVals);
993 prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric, Vals);
994 Node->Opcode = FirstReal->getOpcode();
996 Node->Flags = FirstReal->getFastMathFlags();
998 Node->addOperand(Op0);
999 if (FirstReal->isBinaryOp())
1000 Node->addOperand(Op1);
1002 return submitCompositeNode(Node);
1005ComplexDeinterleavingGraph::CompositeNode *
1006ComplexDeinterleavingGraph::identifyDotProduct(
Value *V) {
1008 ComplexDeinterleavingOperation::CDot,
V->getType())) {
1009 LLVM_DEBUG(
dbgs() <<
"Target doesn't support complex deinterleaving "
1010 "operation CDot with the type "
1011 << *
V->getType() <<
"\n");
1019 prepareCompositeNode(ComplexDeinterleavingOperation::CDot, Inst,
nullptr);
1021 CompositeNode *ANode =
nullptr;
1023 const Intrinsic::ID PartialReduceInt = Intrinsic::vector_partial_reduce_add;
1025 Value *AReal =
nullptr;
1026 Value *AImag =
nullptr;
1027 Value *BReal =
nullptr;
1028 Value *BImag =
nullptr;
1033 return CI->getOperand(0);
1047 if (
match(Inst, PatternRot0)) {
1048 CN->Rotation = ComplexDeinterleavingRotation::Rotation_0;
1049 }
else if (
match(Inst, PatternRot270)) {
1050 CN->Rotation = ComplexDeinterleavingRotation::Rotation_270;
1061 if (!
match(Inst, PatternRot90Rot180))
1064 A0 = UnwrapCast(A0);
1065 A1 = UnwrapCast(A1);
1068 ANode = identifyNode(A0, A1);
1071 ANode = identifyNode(A1, A0);
1075 CN->Rotation = ComplexDeinterleavingRotation::Rotation_90;
1081 CN->Rotation = ComplexDeinterleavingRotation::Rotation_180;
1085 AReal = UnwrapCast(AReal);
1086 AImag = UnwrapCast(AImag);
1087 BReal = UnwrapCast(BReal);
1088 BImag = UnwrapCast(BImag);
1091 Type *ExpectedOperandTy = VectorType::getSubdividedVectorType(VTy, 2);
1092 if (AReal->
getType() != ExpectedOperandTy)
1094 if (AImag->
getType() != ExpectedOperandTy)
1096 if (BReal->
getType() != ExpectedOperandTy)
1098 if (BImag->
getType() != ExpectedOperandTy)
1101 if (
Phi->getType() != VTy && RealUser->getType() != VTy)
1104 CompositeNode *
Node = identifyNode(AReal, AImag);
1109 if (ANode && Node != ANode) {
1112 <<
"Identified node is different from previously identified node. "
1113 "Unable to confidently generate a complex operation node\n");
1117 CN->addOperand(Node);
1118 CN->addOperand(identifyNode(BReal, BImag));
1119 CN->addOperand(identifyNode(Phi, RealUser));
1121 return submitCompositeNode(CN);
1124ComplexDeinterleavingGraph::CompositeNode *
1125ComplexDeinterleavingGraph::identifyPartialReduction(
Value *R,
Value *
I) {
1130 if (!
R->hasUseList() || !
I->hasUseList())
1134 findCommonBetweenCollections<Value *>(
R->users(),
I->users());
1139 if (!IInst || IInst->getIntrinsicID() != Intrinsic::vector_partial_reduce_add)
1142 if (CompositeNode *CN = identifyDotProduct(IInst))
1148ComplexDeinterleavingGraph::CompositeNode *
1149ComplexDeinterleavingGraph::identifyNode(
ComplexValues &Vals) {
1150 auto It = CachedResult.
find(Vals);
1151 if (It != CachedResult.
end()) {
1156 if (Vals.
size() == 1) {
1157 assert(Factor == 2 &&
"Can only handle interleave factors of 2");
1160 if (CompositeNode *CN = identifyPartialReduction(R,
I))
1162 bool IsReduction = RealPHI ==
R && (!ImagPHI || ImagPHI ==
I);
1163 if (!IsReduction &&
R->getType() !=
I->getType())
1167 if (CompositeNode *CN = identifySplat(Vals))
1170 for (
auto &V : Vals) {
1177 if (CompositeNode *CN = identifyDeinterleave(Vals))
1180 if (Vals.size() == 1) {
1181 assert(Factor == 2 &&
"Can only handle interleave factors of 2");
1184 if (CompositeNode *CN = identifyPHINode(Real, Imag))
1187 if (CompositeNode *CN = identifySelectNode(Real, Imag))
1191 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
1194 ComplexDeinterleavingOperation::CMulPartial, NewVTy);
1196 ComplexDeinterleavingOperation::CAdd, NewVTy);
1199 if (CompositeNode *CN = identifyPartialMul(Real, Imag))
1204 if (CompositeNode *CN = identifyAdd(Real, Imag))
1208 if (HasCMulSupport && HasCAddSupport) {
1209 if (CompositeNode *CN = identifyReassocNodes(Real, Imag)) {
1215 if (CompositeNode *CN = identifySymmetricOperation(Vals))
1219 CachedResult[Vals] =
nullptr;
1223ComplexDeinterleavingGraph::CompositeNode *
1224ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
1225 Instruction *Imag) {
1226 auto IsOperationSupported = [](
unsigned Opcode) ->
bool {
1227 return Opcode == Instruction::FAdd || Opcode == Instruction::FSub ||
1228 Opcode == Instruction::FNeg || Opcode == Instruction::Add ||
1229 Opcode == Instruction::Sub;
1232 if (!IsOperationSupported(Real->
getOpcode()) ||
1233 !IsOperationSupported(Imag->
getOpcode()))
1236 std::optional<FastMathFlags>
Flags;
1239 LLVM_DEBUG(
dbgs() <<
"The flags in Real and Imaginary instructions are "
1245 if (!
Flags->allowReassoc()) {
1248 <<
"the 'Reassoc' attribute is missing in the FastMath flags\n");
1257 AddendList &Addends) ->
bool {
1259 SmallPtrSet<Value *, 8> Visited;
1260 while (!Worklist.
empty()) {
1262 if (!Visited.
insert(V).second)
1267 Addends.emplace_back(V, IsPositive);
1277 if (
I != Insn &&
I->hasNUsesOrMore(2)) {
1278 LLVM_DEBUG(
dbgs() <<
"Found potential sub-expression: " << *
I <<
"\n");
1279 Addends.emplace_back(
I, IsPositive);
1282 switch (
I->getOpcode()) {
1283 case Instruction::FAdd:
1284 case Instruction::Add:
1288 case Instruction::FSub:
1292 case Instruction::Sub:
1300 case Instruction::FMul:
1301 case Instruction::Mul: {
1303 if (
isNeg(
I->getOperand(0))) {
1305 IsPositive = !IsPositive;
1307 A =
I->getOperand(0);
1310 if (
isNeg(
I->getOperand(1))) {
1312 IsPositive = !IsPositive;
1314 B =
I->getOperand(1);
1316 Muls.push_back(Product{
A,
B, IsPositive});
1319 case Instruction::FNeg:
1323 Addends.emplace_back(
I, IsPositive);
1327 if (Flags &&
I->getFastMathFlags() != *Flags) {
1329 "inconsistent with the root instructions' flags: "
1338 AddendList RealAddends, ImagAddends;
1339 if (!Collect(Real, RealMuls, RealAddends) ||
1340 !Collect(Imag, ImagMuls, ImagAddends))
1343 if (RealAddends.size() != ImagAddends.size())
1346 CompositeNode *FinalNode =
nullptr;
1347 if (!RealMuls.
empty() || !ImagMuls.
empty()) {
1350 FinalNode = extractPositiveAddend(RealAddends, ImagAddends);
1351 FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode);
1357 if (!RealAddends.empty() || !ImagAddends.empty()) {
1358 FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode);
1362 assert(FinalNode &&
"FinalNode can not be nullptr here");
1363 assert(FinalNode->Vals.size() == 1);
1365 FinalNode->Vals[0].Real = Real;
1366 FinalNode->Vals[0].Imag = Imag;
1367 submitCompositeNode(FinalNode);
1371bool ComplexDeinterleavingGraph::collectPartialMuls(
1373 SmallVectorImpl<PartialMulCandidate> &PartialMulCandidates) {
1375 auto FindCommonInstruction = [](
const Product &Real,
1376 const Product &Imag) ->
Value * {
1377 if (Real.Multiplicand == Imag.Multiplicand ||
1378 Real.Multiplicand == Imag.Multiplier)
1379 return Real.Multiplicand;
1381 if (Real.Multiplier == Imag.Multiplicand ||
1382 Real.Multiplier == Imag.Multiplier)
1383 return Real.Multiplier;
1392 for (
unsigned i = 0; i < RealMuls.
size(); ++i) {
1393 bool FoundCommon =
false;
1394 for (
unsigned j = 0;
j < ImagMuls.
size(); ++
j) {
1395 auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]);
1399 auto *
A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier
1400 : RealMuls[i].Multiplicand;
1401 auto *
B = ImagMuls[
j].Multiplicand == Common ? ImagMuls[
j].Multiplier
1402 : ImagMuls[
j].Multiplicand;
1404 auto Node = identifyNode(
A,
B);
1410 Node = identifyNode(
B,
A);
1422ComplexDeinterleavingGraph::CompositeNode *
1423ComplexDeinterleavingGraph::identifyMultiplications(
1424 SmallVectorImpl<Product> &RealMuls, SmallVectorImpl<Product> &ImagMuls,
1426 if (RealMuls.
size() != ImagMuls.
size())
1430 if (!collectPartialMuls(RealMuls, ImagMuls,
Info))
1434 DenseMap<Value *, CompositeNode *> CommonToNode;
1435 SmallVector<bool> Processed(
Info.size(),
false);
1436 for (
unsigned I = 0;
I <
Info.size(); ++
I) {
1440 PartialMulCandidate &InfoA =
Info[
I];
1441 for (
unsigned J =
I + 1; J <
Info.size(); ++J) {
1445 PartialMulCandidate &InfoB =
Info[J];
1446 auto *InfoReal = &InfoA;
1447 auto *InfoImag = &InfoB;
1449 auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1450 if (!NodeFromCommon) {
1452 NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1454 if (!NodeFromCommon)
1457 CommonToNode[InfoReal->Common] = NodeFromCommon;
1458 CommonToNode[InfoImag->Common] = NodeFromCommon;
1459 Processed[
I] =
true;
1460 Processed[J] =
true;
1464 SmallVector<bool> ProcessedReal(RealMuls.
size(),
false);
1465 SmallVector<bool> ProcessedImag(ImagMuls.
size(),
false);
1467 for (
auto &PMI :
Info) {
1468 if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx])
1471 auto It = CommonToNode.
find(PMI.Common);
1474 if (It == CommonToNode.
end()) {
1476 dbgs() <<
"Unprocessed independent partial multiplication:\n";
1477 for (
auto *
Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]})
1479 <<
" multiplied by " << *
Mul->Multiplicand <<
"\n";
1484 auto &RealMul = RealMuls[PMI.RealIdx];
1485 auto &ImagMul = ImagMuls[PMI.ImagIdx];
1487 auto NodeA = It->second;
1488 auto NodeB = PMI.Node;
1489 auto IsMultiplicandReal = PMI.Common == NodeA->Vals[0].Real;
1504 if ((IsMultiplicandReal && PMI.IsNodeInverted) ||
1505 (!IsMultiplicandReal && !PMI.IsNodeInverted))
1510 if (IsMultiplicandReal) {
1512 if (RealMul.IsPositive && ImagMul.IsPositive)
1514 else if (!RealMul.IsPositive && !ImagMul.IsPositive)
1521 if (!RealMul.IsPositive && ImagMul.IsPositive)
1523 else if (RealMul.IsPositive && !ImagMul.IsPositive)
1530 dbgs() <<
"Identified partial multiplication (X, Y) * (U, V):\n";
1531 dbgs().
indent(4) <<
"X: " << *NodeA->Vals[0].Real <<
"\n";
1532 dbgs().
indent(4) <<
"Y: " << *NodeA->Vals[0].Imag <<
"\n";
1533 dbgs().
indent(4) <<
"U: " << *NodeB->Vals[0].Real <<
"\n";
1534 dbgs().
indent(4) <<
"V: " << *NodeB->Vals[0].Imag <<
"\n";
1535 dbgs().
indent(4) <<
"Rotation - " << (int)Rotation * 90 <<
"\n";
1538 CompositeNode *NodeMul = prepareCompositeNode(
1539 ComplexDeinterleavingOperation::CMulPartial,
nullptr,
nullptr);
1540 NodeMul->Rotation = Rotation;
1541 NodeMul->addOperand(NodeA);
1542 NodeMul->addOperand(NodeB);
1544 NodeMul->addOperand(Result);
1545 submitCompositeNode(NodeMul);
1547 ProcessedReal[PMI.RealIdx] =
true;
1548 ProcessedImag[PMI.ImagIdx] =
true;
1552 if (!
all_of(ProcessedReal, [](
bool V) {
return V; }) ||
1553 !
all_of(ProcessedImag, [](
bool V) {
return V; })) {
1558 dbgs() <<
"Unprocessed products (Real):\n";
1559 for (
size_t i = 0; i < ProcessedReal.size(); ++i) {
1560 if (!ProcessedReal[i])
1561 dbgs().
indent(4) << (RealMuls[i].IsPositive ?
"+" :
"-")
1562 << *RealMuls[i].Multiplier <<
" multiplied by "
1563 << *RealMuls[i].Multiplicand <<
"\n";
1565 dbgs() <<
"Unprocessed products (Imag):\n";
1566 for (
size_t i = 0; i < ProcessedImag.size(); ++i) {
1567 if (!ProcessedImag[i])
1568 dbgs().
indent(4) << (ImagMuls[i].IsPositive ?
"+" :
"-")
1569 << *ImagMuls[i].Multiplier <<
" multiplied by "
1570 << *ImagMuls[i].Multiplicand <<
"\n";
1579ComplexDeinterleavingGraph::CompositeNode *
1580ComplexDeinterleavingGraph::identifyAdditions(
1581 AddendList &RealAddends, AddendList &ImagAddends,
1582 std::optional<FastMathFlags> Flags, CompositeNode *
Accumulator =
nullptr) {
1583 if (RealAddends.size() != ImagAddends.size())
1586 CompositeNode *
Result =
nullptr;
1592 Result = extractPositiveAddend(RealAddends, ImagAddends);
1597 while (!RealAddends.empty()) {
1598 auto ItR = RealAddends.begin();
1599 auto [
R, IsPositiveR] = *ItR;
1601 bool FoundImag =
false;
1602 for (
auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1603 auto [
I, IsPositiveI] = *ItI;
1605 if (IsPositiveR && IsPositiveI)
1606 Rotation = ComplexDeinterleavingRotation::Rotation_0;
1607 else if (!IsPositiveR && IsPositiveI)
1608 Rotation = ComplexDeinterleavingRotation::Rotation_90;
1609 else if (!IsPositiveR && !IsPositiveI)
1610 Rotation = ComplexDeinterleavingRotation::Rotation_180;
1612 Rotation = ComplexDeinterleavingRotation::Rotation_270;
1614 CompositeNode *AddNode =
nullptr;
1615 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
1616 Rotation == ComplexDeinterleavingRotation::Rotation_180) {
1617 AddNode = identifyNode(R,
I);
1619 AddNode = identifyNode(
I, R);
1623 dbgs() <<
"Identified addition:\n";
1626 dbgs().
indent(4) <<
"Rotation - " << (int)Rotation * 90 <<
"\n";
1629 CompositeNode *TmpNode =
nullptr;
1631 TmpNode = prepareCompositeNode(
1632 ComplexDeinterleavingOperation::Symmetric,
nullptr,
nullptr);
1634 TmpNode->Opcode = Instruction::FAdd;
1635 TmpNode->Flags = *
Flags;
1637 TmpNode->Opcode = Instruction::Add;
1639 }
else if (Rotation ==
1641 TmpNode = prepareCompositeNode(
1642 ComplexDeinterleavingOperation::Symmetric,
nullptr,
nullptr);
1644 TmpNode->Opcode = Instruction::FSub;
1645 TmpNode->Flags = *
Flags;
1647 TmpNode->Opcode = Instruction::Sub;
1650 TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd,
1652 TmpNode->Rotation = Rotation;
1655 TmpNode->addOperand(Result);
1656 TmpNode->addOperand(AddNode);
1657 submitCompositeNode(TmpNode);
1659 RealAddends.erase(ItR);
1660 ImagAddends.erase(ItI);
1671ComplexDeinterleavingGraph::CompositeNode *
1672ComplexDeinterleavingGraph::extractPositiveAddend(AddendList &RealAddends,
1673 AddendList &ImagAddends) {
1674 for (
auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) {
1675 for (
auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1676 auto [
R, IsPositiveR] = *ItR;
1677 auto [
I, IsPositiveI] = *ItI;
1678 if (IsPositiveR && IsPositiveI) {
1679 auto Result = identifyNode(R,
I);
1681 RealAddends.erase(ItR);
1682 ImagAddends.erase(ItI);
1691bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
1696 auto It = RootToNode.
find(RootI);
1697 if (It != RootToNode.
end()) {
1698 auto RootNode = It->second;
1699 assert(RootNode->Operation ==
1700 ComplexDeinterleavingOperation::ReductionOperation ||
1701 RootNode->Operation ==
1702 ComplexDeinterleavingOperation::ReductionSingle);
1703 assert(RootNode->Vals.size() == 1 &&
1704 "Cannot handle reductions involving multiple complex values");
1713 ReplacementAnchor =
R->comesBefore(
I) ?
I :
R;
1715 ReplacementAnchor =
R;
1717 if (ReplacementAnchor != RootI)
1723 auto RootNode = identifyRoot(RootI);
1730 dbgs() <<
"Complex deinterleaving graph for " <<
F->getName()
1731 <<
"::" <<
B->getName() <<
".\n";
1735 RootToNode[RootI] = RootNode;
1740bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *
B) {
1741 bool FoundPotentialReduction =
false;
1746 if (!Br || Br->getNumSuccessors() != 2)
1750 if (Br->getSuccessor(0) !=
B && Br->getSuccessor(1) !=
B)
1753 for (
auto &
PHI :
B->phis()) {
1754 if (
PHI.getNumIncomingValues() != 2)
1757 if (!
PHI.getType()->isVectorTy())
1767 for (
auto *U : ReductionOp->users()) {
1774 if (NumUsers != 2 || !FinalReduction || FinalReduction->
getParent() ==
B ||
1778 ReductionInfo[ReductionOp] = {&
PHI, FinalReduction};
1780 auto BackEdgeIdx =
PHI.getBasicBlockIndex(
B);
1781 auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0;
1782 Incoming =
PHI.getIncomingBlock(IncomingIdx);
1783 FoundPotentialReduction =
true;
1789 FinalInstructions.
insert(InitPHI);
1791 return FoundPotentialReduction;
1794void ComplexDeinterleavingGraph::identifyReductionNodes() {
1795 assert(Factor == 2 &&
"Cannot handle multiple complex values");
1797 SmallVector<bool> Processed(ReductionInfo.
size(),
false);
1799 for (
auto &
P : ReductionInfo)
1804 for (
size_t i = 0; i < OperationInstruction.
size(); ++i) {
1807 for (
size_t j = i + 1;
j < OperationInstruction.
size(); ++
j) {
1810 auto *Real = OperationInstruction[i];
1811 auto *Imag = OperationInstruction[
j];
1812 if (Real->getType() != Imag->
getType())
1815 RealPHI = ReductionInfo[Real].first;
1816 ImagPHI = ReductionInfo[Imag].first;
1818 auto Node = identifyNode(Real, Imag);
1822 Node = identifyNode(Real, Imag);
1828 if (Node && PHIsFound) {
1829 LLVM_DEBUG(
dbgs() <<
"Identified reduction starting from instructions: "
1830 << *Real <<
" / " << *Imag <<
"\n");
1831 Processed[i] =
true;
1832 Processed[
j] =
true;
1833 auto RootNode = prepareCompositeNode(
1834 ComplexDeinterleavingOperation::ReductionOperation, Real, Imag);
1835 RootNode->addOperand(Node);
1836 RootToNode[Real] = RootNode;
1837 RootToNode[Imag] = RootNode;
1838 submitCompositeNode(RootNode);
1843 auto *Real = OperationInstruction[i];
1846 if (Processed[i] || Real->getNumOperands() < 2)
1850 if (!ReductionInfo[Real].second->getType()->isIntegerTy())
1853 RealPHI = ReductionInfo[Real].first;
1856 auto Node = identifyNode(Real->getOperand(0), Real->getOperand(1));
1857 if (Node && PHIsFound) {
1859 dbgs() <<
"Identified single reduction starting from instruction: "
1860 << *Real <<
"/" << *ReductionInfo[Real].second <<
"\n");
1869 if (ReductionInfo[Real].second->getType()->isVectorTy())
1872 Processed[i] =
true;
1873 auto RootNode = prepareCompositeNode(
1874 ComplexDeinterleavingOperation::ReductionSingle, Real,
nullptr);
1875 RootNode->addOperand(Node);
1876 RootToNode[Real] = RootNode;
1877 submitCompositeNode(RootNode);
1885bool ComplexDeinterleavingGraph::checkNodes() {
1886 bool FoundDeinterleaveNode =
false;
1887 for (CompositeNode *
N : CompositeNodes) {
1888 if (!
N->areOperandsValid())
1891 if (
N->Operation == ComplexDeinterleavingOperation::Deinterleave)
1892 FoundDeinterleaveNode =
true;
1897 if (!FoundDeinterleaveNode) {
1899 dbgs() <<
"Couldn't find a deinterleave node within the graph, cannot "
1900 "guarantee safety during graph transformation.\n");
1905 SmallPtrSet<Instruction *, 16> AllInstructions;
1906 SmallVector<Instruction *, 8> Worklist;
1907 for (
auto &Pair : RootToNode)
1912 while (!Worklist.
empty()) {
1915 if (!AllInstructions.
insert(
I).second)
1920 if (!FinalInstructions.
count(
I))
1927 for (
auto *
I : AllInstructions) {
1929 if (RootToNode.count(
I))
1932 for (User *U :
I->users()) {
1944 SmallPtrSet<Instruction *, 16> Visited;
1945 while (!Worklist.
empty()) {
1947 if (!Visited.
insert(
I).second)
1952 if (RootToNode.count(
I)) {
1954 <<
" could be deinterleaved but its chain of complex "
1955 "operations have an outside user\n");
1956 RootToNode.erase(
I);
1959 if (!AllInstructions.count(
I) || FinalInstructions.
count(
I))
1962 for (User *U :
I->users())
1970 return !RootToNode.
empty();
1973ComplexDeinterleavingGraph::CompositeNode *
1974ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) {
1981 for (
unsigned I = 0;
I < Factor;
I += 2) {
1989 ComplexDeinterleavingGraph::CompositeNode *Node1 = identifyNode(Vals);
2017 return identifyNode(Real, Imag);
2020ComplexDeinterleavingGraph::CompositeNode *
2021ComplexDeinterleavingGraph::identifyDeinterleave(
ComplexValues &Vals) {
2025 auto CheckExtract = [&](
Value *
V,
unsigned ExpectedIdx,
2026 Instruction *ExpectedInsn) -> ExtractValueInst * {
2028 if (!EVI || EVI->getNumIndices() != 1 ||
2029 EVI->getIndices()[0] != ExpectedIdx ||
2031 (ExpectedInsn && ExpectedInsn != EVI->getAggregateOperand()))
2036 for (
unsigned Idx = 0; Idx < Vals.
size(); Idx++) {
2037 ExtractValueInst *RealEVI = CheckExtract(Vals[Idx].Real, Idx * 2,
II);
2038 if (RealEVI && Idx == 0)
2040 if (!RealEVI || !CheckExtract(Vals[Idx].Imag, (Idx * 2) + 1,
II)) {
2047 if (IntrinsicII->getIntrinsicID() !=
2052 CompositeNode *PlaceholderNode = prepareCompositeNode(
2054 PlaceholderNode->ReplacementNode =
II->getOperand(0);
2055 for (
auto &V : Vals) {
2059 return submitCompositeNode(PlaceholderNode);
2062 if (Vals.size() != 1)
2065 Value *Real = Vals[0].Real;
2066 Value *Imag = Vals[0].Imag;
2069 if (!RealShuffle || !ImagShuffle) {
2070 if (RealShuffle || ImagShuffle)
2071 LLVM_DEBUG(
dbgs() <<
" - There's a shuffle where there shouldn't be.\n");
2075 Value *RealOp1 = RealShuffle->getOperand(1);
2080 Value *ImagOp1 = ImagShuffle->getOperand(1);
2086 Value *RealOp0 = RealShuffle->getOperand(0);
2087 Value *ImagOp0 = ImagShuffle->getOperand(0);
2089 if (RealOp0 != ImagOp0) {
2094 ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
2095 ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
2101 if (RealMask[0] != 0 || ImagMask[0] != 1) {
2102 LLVM_DEBUG(
dbgs() <<
" - Masks do not have the correct initial value.\n");
2108 auto CheckType = [&](ShuffleVectorInst *Shuffle) {
2109 Value *
Op = Shuffle->getOperand(0);
2113 if (OpTy->getScalarType() != ShuffleTy->getScalarType())
2115 if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
2121 auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) ->
bool {
2125 ArrayRef<int>
Mask = Shuffle->getShuffleMask();
2128 Value *
Op = Shuffle->getOperand(0);
2130 int NumElements = OpTy->getNumElements();
2134 return Last < NumElements;
2137 if (RealShuffle->getType() != ImagShuffle->getType()) {
2141 if (!CheckDeinterleavingShuffle(RealShuffle)) {
2145 if (!CheckDeinterleavingShuffle(ImagShuffle)) {
2150 CompositeNode *PlaceholderNode =
2152 RealShuffle, ImagShuffle);
2153 PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
2154 FinalInstructions.
insert(RealShuffle);
2155 FinalInstructions.
insert(ImagShuffle);
2156 return submitCompositeNode(PlaceholderNode);
2159ComplexDeinterleavingGraph::CompositeNode *
2160ComplexDeinterleavingGraph::identifySplat(
ComplexValues &Vals) {
2161 auto IsSplat = [](
Value *
V) ->
bool {
2174 if (
Const->getOpcode() != Instruction::ShuffleVector)
2179 VTy = Shuf->getType();
2180 Mask = Shuf->getShuffleMask();
2188 if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1)
2198 BasicBlock *FirstBB = FirstValAsInstruction->getParent();
2199 for (
auto &V : Vals) {
2200 if (!IsSplat(
V.Real) || !IsSplat(
V.Imag))
2205 if (!Real || !Imag || Real->getParent() != FirstBB ||
2206 Imag->getParent() != FirstBB)
2210 for (
auto &V : Vals) {
2217 for (
auto &V : Vals) {
2221 FinalInstructions.
insert(Real);
2222 FinalInstructions.
insert(Imag);
2225 CompositeNode *PlaceholderNode =
2226 prepareCompositeNode(ComplexDeinterleavingOperation::Splat, Vals);
2227 return submitCompositeNode(PlaceholderNode);
2230ComplexDeinterleavingGraph::CompositeNode *
2231ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
2232 Instruction *Imag) {
2233 if (Real != RealPHI || (ImagPHI && Imag != ImagPHI))
2237 CompositeNode *PlaceholderNode = prepareCompositeNode(
2238 ComplexDeinterleavingOperation::ReductionPHI, Real, Imag);
2239 return submitCompositeNode(PlaceholderNode);
2242ComplexDeinterleavingGraph::CompositeNode *
2243ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,
2244 Instruction *Imag) {
2247 if (!SelectReal || !SelectImag)
2264 auto NodeA = identifyNode(AR, AI);
2268 auto NodeB = identifyNode(
RA, BI);
2272 CompositeNode *PlaceholderNode = prepareCompositeNode(
2273 ComplexDeinterleavingOperation::ReductionSelect, Real, Imag);
2274 PlaceholderNode->addOperand(NodeA);
2275 PlaceholderNode->addOperand(NodeB);
2276 FinalInstructions.
insert(MaskA);
2277 FinalInstructions.
insert(MaskB);
2278 return submitCompositeNode(PlaceholderNode);
2282 std::optional<FastMathFlags> Flags,
2286 case Instruction::FNeg:
2287 I =
B.CreateFNeg(InputA);
2289 case Instruction::FAdd:
2290 I =
B.CreateFAdd(InputA, InputB);
2292 case Instruction::Add:
2293 I =
B.CreateAdd(InputA, InputB);
2295 case Instruction::FSub:
2296 I =
B.CreateFSub(InputA, InputB);
2298 case Instruction::Sub:
2299 I =
B.CreateSub(InputA, InputB);
2301 case Instruction::FMul:
2302 I =
B.CreateFMul(InputA, InputB);
2304 case Instruction::Mul:
2305 I =
B.CreateMul(InputA, InputB);
2315Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
2316 CompositeNode *Node) {
2317 if (
Node->ReplacementNode)
2318 return Node->ReplacementNode;
2320 auto ReplaceOperandIfExist = [&](CompositeNode *
Node,
2321 unsigned Idx) ->
Value * {
2322 return Node->Operands.size() > Idx
2323 ? replaceNode(Builder,
Node->Operands[Idx])
2327 Value *ReplacementNode =
nullptr;
2328 switch (
Node->Operation) {
2329 case ComplexDeinterleavingOperation::CDot: {
2330 Value *Input0 = ReplaceOperandIfExist(Node, 0);
2331 Value *Input1 = ReplaceOperandIfExist(Node, 1);
2334 "Node inputs need to be of the same type"));
2339 case ComplexDeinterleavingOperation::CAdd:
2340 case ComplexDeinterleavingOperation::CMulPartial:
2341 case ComplexDeinterleavingOperation::Symmetric: {
2342 Value *Input0 = ReplaceOperandIfExist(Node, 0);
2343 Value *Input1 = ReplaceOperandIfExist(Node, 1);
2346 "Node inputs need to be of the same type"));
2349 "Accumulator and input need to be of the same type"));
2350 if (
Node->Operation == ComplexDeinterleavingOperation::Symmetric)
2355 Builder,
Node->Operation,
Node->Rotation, Input0, Input1,
2359 case ComplexDeinterleavingOperation::Deinterleave:
2362 case ComplexDeinterleavingOperation::Splat: {
2364 for (
auto &V :
Node->Vals) {
2365 Ops.push_back(
V.Real);
2366 Ops.push_back(
V.Imag);
2373 for (
auto V :
Node->Vals) {
2381 ReplacementNode = IRB.CreateVectorInterleave(
Ops);
2387 case ComplexDeinterleavingOperation::ReductionPHI: {
2392 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2394 OldToNewPHI[OldPHI] = NewPHI;
2395 ReplacementNode = NewPHI;
2398 case ComplexDeinterleavingOperation::ReductionSingle:
2399 ReplacementNode = replaceNode(Builder,
Node->Operands[0]);
2400 processReductionSingle(ReplacementNode, Node);
2402 case ComplexDeinterleavingOperation::ReductionOperation:
2403 ReplacementNode = replaceNode(Builder,
Node->Operands[0]);
2404 processReductionOperation(ReplacementNode, Node);
2406 case ComplexDeinterleavingOperation::ReductionSelect: {
2409 auto *
A = replaceNode(Builder,
Node->Operands[0]);
2410 auto *
B = replaceNode(Builder,
Node->Operands[1]);
2417 assert(ReplacementNode &&
"Target failed to create Intrinsic call.");
2418 NumComplexTransformations += 1;
2419 Node->ReplacementNode = ReplacementNode;
2420 return ReplacementNode;
2423void ComplexDeinterleavingGraph::processReductionSingle(
2424 Value *OperationReplacement, CompositeNode *Node) {
2426 auto *OldPHI = ReductionInfo[Real].first;
2427 auto *NewPHI = OldToNewPHI[OldPHI];
2429 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2431 Value *Init = OldPHI->getIncomingValueForBlock(Incoming);
2435 Value *NewInit =
nullptr;
2437 if (
C->isZeroValue())
2445 NewPHI->addIncoming(NewInit, Incoming);
2446 NewPHI->addIncoming(OperationReplacement, BackEdge);
2448 auto *FinalReduction = ReductionInfo[Real].second;
2455void ComplexDeinterleavingGraph::processReductionOperation(
2456 Value *OperationReplacement, CompositeNode *Node) {
2459 auto *OldPHIReal = ReductionInfo[Real].first;
2460 auto *OldPHIImag = ReductionInfo[Imag].first;
2461 auto *NewPHI = OldToNewPHI[OldPHIReal];
2464 Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming);
2465 Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming);
2470 NewPHI->addIncoming(NewInit, Incoming);
2471 NewPHI->addIncoming(OperationReplacement, BackEdge);
2475 auto *FinalReductionReal = ReductionInfo[Real].second;
2476 auto *FinalReductionImag = ReductionInfo[Imag].second;
2479 &*FinalReductionReal->getParent()->getFirstInsertionPt());
2481 OperationReplacement->
getType(),
2482 OperationReplacement);
2485 FinalReductionReal->replaceUsesOfWith(Real, NewReal);
2489 FinalReductionImag->replaceUsesOfWith(Imag, NewImag);
2492void ComplexDeinterleavingGraph::replaceNodes() {
2493 SmallVector<Instruction *, 16> DeadInstrRoots;
2494 for (
auto *RootInstruction : OrderedRoots) {
2497 if (!RootToNode.count(RootInstruction))
2501 auto RootNode = RootToNode[RootInstruction];
2502 Value *
R = replaceNode(Builder, RootNode);
2504 if (RootNode->Operation ==
2505 ComplexDeinterleavingOperation::ReductionOperation) {
2508 ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);
2509 ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
2512 }
else if (RootNode->Operation ==
2513 ComplexDeinterleavingOperation::ReductionSingle) {
2515 auto &
Info = ReductionInfo[RootInst];
2516 Info.first->removeIncomingValue(BackEdge);
2519 assert(R &&
"Unable to find replacement for RootInstruction");
2520 DeadInstrRoots.
push_back(RootInstruction);
2521 RootInstruction->replaceAllUsesWith(R);
2525 for (
auto *
I : DeadInstrRoots)
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
static MCDisassembler::DecodeStatus addOperand(MCInst &Inst, const MCOperand &Opnd)
This file defines the BumpPtrAllocator interface.
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
Analysis containing CSE Info
static bool isInstructionPotentiallySymmetric(Instruction *I)
static Value * getNegOperand(Value *V)
Returns the operand for negation operation.
static bool isNeg(Value *V)
Returns true if the operation is a negation of V, and it works for both integers and floats.
static cl::opt< bool > ComplexDeinterleavingEnabled("enable-complex-deinterleaving", cl::desc("Enable generation of complex instructions"), cl::init(true), cl::Hidden)
static bool isInstructionPairAdd(Instruction *A, Instruction *B)
static Value * replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode, std::optional< FastMathFlags > Flags, Value *InputA, Value *InputB)
static bool isInterleavingMask(ArrayRef< int > Mask)
Checks the given mask, and determines whether said mask is interleaving.
static bool isDeinterleavingMask(ArrayRef< int > Mask)
Checks the given mask, and determines whether said mask is deinterleaving.
SmallVector< struct ComplexValue, 2 > ComplexValues
static bool isInstructionPairMul(Instruction *A, Instruction *B)
static bool runOnFunction(Function &F, bool PostInlining)
const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]
This file implements a map that provides insertion order iteration.
uint64_t IntrinsicInst * II
PowerPC Reduce CR logical Operation
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
SI optimize exec mask operations pre RA
static LLVM_ATTRIBUTE_ALWAYS_INLINE bool CheckType(MVT::SimpleValueType VT, SDValue N, const TargetLowering *TLI, const DataLayout &DL)
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
This file describes how to lower LLVM code to machine code.
AnalysisUsage & addRequired()
LLVM_ABI void setPreservesCFG()
This function should be called by the pass, iff they do not:
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
size_t size() const
size - Get the array size.
LLVM_ABI InstListType::const_iterator getFirstNonPHIIt() const
Returns an iterator to the first instruction in this block that is not a PHINode instruction.
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
static LLVM_ABI Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
iterator find(const_arg_type_t< KeyT > Val)
bool allowContract() const
FunctionPass class - This class is used to implement most global optimizations.
Common base class shared among various IRBuilders.
Value * CreateExtractValue(Value *Agg, ArrayRef< unsigned > Idxs, const Twine &Name="")
LLVM_ABI Value * CreateSelect(Value *C, Value *True, Value *False, const Twine &Name="", Instruction *MDFrom=nullptr)
LLVM_ABI CallInst * CreateAddReduce(Value *Src)
Create a vector int add reduction intrinsic of the source vector.
LLVM_ABI CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="")
Create a call to intrinsic ID with Args, mangled using Types.
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
LLVM_ABI Value * CreateVectorInterleave(ArrayRef< Value * > Ops, const Twine &Name="")
LLVM_ABI const Function * getFunction() const
Return the function this instruction belongs to.
LLVM_ABI bool comesBefore(const Instruction *Other) const
Given an instruction Other in the same basic block as this instruction, return true if this instructi...
LLVM_ABI FastMathFlags getFastMathFlags() const LLVM_READONLY
Convenience function for getting all the fast-math flags, which must be an operator which supports th...
unsigned getOpcode() const
Returns a member of one of the enums like Instruction::Add.
LLVM_ABI bool isIdenticalTo(const Instruction *I) const LLVM_READONLY
Return true if the specified instruction is exactly identical to the current one.
static PHINode * Create(Type *Ty, unsigned NumReservedValues, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
Constructors - NumReservedValues is a hint for the number of incoming edges that this phi node will h...
static LLVM_ABI PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
PreservedAnalyses & preserve()
Mark an analysis as preserved.
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
reference emplace_back(ArgTypes &&... Args)
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Analysis pass providing the TargetLibraryInfo.
virtual bool isComplexDeinterleavingOperationSupported(ComplexDeinterleavingOperation Operation, Type *Ty) const
Does this target support complex deinterleaving with the given operation and type.
virtual Value * createComplexDeinterleavingIR(IRBuilderBase &B, ComplexDeinterleavingOperation OperationType, ComplexDeinterleavingRotation Rotation, Value *InputA, Value *InputB, Value *Accumulator=nullptr) const
Create the IR node for the given complex deinterleaving operation.
virtual bool isComplexDeinterleavingSupported() const
Does this target support complex deinterleaving.
This class defines information used to lower LLVM code to legal SelectionDAG operators that the targe...
Primary interface to the complete machine description for the target machine.
virtual const TargetSubtargetInfo * getSubtargetImpl(const Function &) const
Virtual method implemented by subclasses that returns a reference to that target's TargetSubtargetInf...
virtual const TargetLowering * getTargetLowering() const
bool isVectorTy() const
True if this is an instance of VectorType.
Value * getOperand(unsigned i) const
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
bool hasOneUse() const
Return true if there is exactly one use of this value.
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
An opaque object representing a hash code.
const ParentTy * getParent() const
NodeTy * getNextNode()
Get the next node, or nullptr for the list tail.
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
@ C
The default llvm calling convention, compatible with C.
@ BasicBlock
Various leaf nodes.
LLVM_ABI Intrinsic::ID getDeinterleaveIntrinsicID(unsigned Factor)
Returns the corresponding llvm.vector.deinterleaveN intrinsic for factor N.
LLVM_ABI Intrinsic::ID getInterleaveIntrinsicID(unsigned Factor)
Returns the corresponding llvm.vector.interleaveN intrinsic for factor N.
BinaryOp_match< SpecificConstantMatch, SrcTy, TargetOpcode::G_SUB > m_Neg(const SrcTy &&Src)
Matches a register negated by a G_SUB.
class_match< BinaryOperator > m_BinOp()
Match an arbitrary binary operation and ignore it.
BinaryOp_match< LHS, RHS, Instruction::FMul > m_FMul(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
bind_ty< Instruction > m_Instruction(Instruction *&I)
Match an instruction, capturing it if we match.
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
TwoOps_match< V1_t, V2_t, Instruction::ShuffleVector > m_Shuffle(const V1_t &v1, const V2_t &v2)
Matches ShuffleVectorInst independently of mask value.
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
FNeg_match< OpTy > m_FNeg(const OpTy &X)
Match 'fneg X' as 'fsub -0.0, X'.
initializer< Ty > init(const Ty &Val)
NodeAddr< PhiNode * > Phi
NodeAddr< NodeBase * > Node
friend class Instruction
Iterator for Instructions in a `BasicBlock.
This is an optimization pass for GlobalISel generic memory operations.
void dump(const SparseBitVector< ElementSize > &LHS, raw_ostream &out)
FunctionAddr VTableAddr Value
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
hash_code hash_value(const FixedPointSemantics &Val)
LLVM_ABI bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
If the specified value is a trivially dead instruction, delete it.
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
InnerAnalysisManagerProxy< FunctionAnalysisManager, Module > FunctionAnalysisManagerModuleProxy
Provide the FunctionAnalysisManager to Module proxy.
bool operator==(const AddressRangeValuePair &LHS, const AddressRangeValuePair &RHS)
auto dyn_cast_or_null(const Y &Val)
ComplexDeinterleavingOperation
LLVM_ABI FunctionPass * createComplexDeinterleavingPass(const TargetMachine *TM)
This pass implements generation of target-specific intrinsics to support handling of complex number a...
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
LLVM_ABI void initializeComplexDeinterleavingLegacyPassPass(PassRegistry &)
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
ComplexDeinterleavingRotation
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
DWARFExpression::Operation Op
ArrayRef(const T &OneElt) -> ArrayRef< T >
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
auto find_if(R &&Range, UnaryPredicate P)
Provide wrappers to std::find_if which take ranges instead of having to pass begin/end explicitly.
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
bool all_equal(std::initializer_list< T > Values)
Returns true if all Values in the initializer lists are equal or the list.
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
hash_code hash_combine(const Ts &...args)
Combine values into a single hash_code.
AllocatorList< T, BumpPtrAllocator > BumpPtrList
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
ComplexDeinterleavingPass(TargetMachine *TM)
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
static bool isEqual(const ComplexValue &LHS, const ComplexValue &RHS)
static ComplexValue getEmptyKey()
static unsigned getHashValue(const ComplexValue &Val)
static ComplexValue getTombstoneKey()
An information struct used to provide DenseMap with the various necessary components for a given valu...