51 #define DEBUG_TYPE "scalarizer"
55 cl::desc(
"Allow the scalarizer pass to scalarize "
56 "insertelement/extractelement with variable index"));
63 cl::desc(
"Allow the scalarizer pass to scalarize loads and store"));
69 if (isa<PHINode>(Itr))
70 Itr =
BB->getFirstInsertionPt();
82 using ScatterMap = std::map<Value *, ValueVector>;
92 Scatterer() =
default;
98 ValueVector *cachePtr =
nullptr);
101 Value *operator[](
unsigned I);
104 unsigned size()
const {
return Size; }
111 ValueVector *CachePtr;
118 struct FCmpSplitter {
119 FCmpSplitter(
FCmpInst &fci) : FCI(fci) {}
123 return Builder.CreateFCmp(FCI.getPredicate(), Op0, Op1,
Name);
131 struct ICmpSplitter {
132 ICmpSplitter(
ICmpInst &ici) : ICI(ici) {}
136 return Builder.CreateICmp(ICI.getPredicate(), Op0, Op1,
Name);
144 struct UnarySplitter {
156 struct BinarySplitter {
161 return Builder.CreateBinOp(BO.getOpcode(), Op0, Op1,
Name);
172 Align getElemAlign(
unsigned I) {
180 Type *ElemTy =
nullptr;
189 template <
typename T>
190 T getWithDefaultOverride(
const cl::opt<T> &ClOption,
196 class ScalarizerVisitor :
public InstVisitor<ScalarizerVisitor, bool> {
198 ScalarizerVisitor(
unsigned ParallelLoopAccessMDKind,
DominatorTree *DT,
200 : ParallelLoopAccessMDKind(ParallelLoopAccessMDKind), DT(DT),
201 ScalarizeVariableInsertExtract(
203 Options.ScalarizeVariableInsertExtract)),
224 bool visitPHINode(
PHINode &PHI);
232 bool canTransferMetadata(
unsigned Kind);
233 void transferMetadataAndIRFlags(
Instruction *
Op,
const ValueVector &CV);
238 template<
typename T>
bool splitUnary(
Instruction &,
const T &);
239 template<
typename T>
bool splitBinary(
Instruction &,
const T &);
243 ScatterMap Scattered;
248 unsigned ParallelLoopAccessMDKind;
252 const bool ScalarizeVariableInsertExtract;
253 const bool ScalarizeLoadStore;
276 "Scalarize vector operations",
false,
false)
282 Type *PtrElemTy, ValueVector *cachePtr)
283 :
BB(
bb), BBI(bbi), V(v), PtrElemTy(PtrElemTy), CachePtr(cachePtr) {
284 Type *Ty = V->getType();
286 assert(cast<PointerType>(Ty)->isOpaqueOrPointeeTypeMatches(PtrElemTy) &&
287 "Pointer element type mismatch");
290 Size = cast<FixedVectorType>(Ty)->getNumElements();
292 Tmp.resize(Size,
nullptr);
293 else if (CachePtr->empty())
294 CachePtr->resize(Size,
nullptr);
296 assert(Size == CachePtr->size() &&
"Inconsistent vector sizes");
300 Value *Scatterer::operator[](
unsigned I) {
301 ValueVector &CV = (CachePtr ? *CachePtr : Tmp);
307 Type *VectorElemTy = cast<VectorType>(PtrElemTy)->getElementType();
310 VectorElemTy, V->getType()->getPointerAddressSpace());
311 CV[0] =
Builder.CreateBitCast(V, NewPtrTy, V->getName() +
".i0");
314 CV[
I] =
Builder.CreateConstGEP1_32(VectorElemTy, CV[0],
I,
315 V->getName() +
".i" +
Twine(
I));
328 V =
Insert->getOperand(0);
330 CV[J] =
Insert->getOperand(1);
336 CV[J] =
Insert->getOperand(1);
340 V->getName() +
".i" +
Twine(
I));
350 unsigned ParallelLoopAccessMDKind =
351 M.getContext().getMDKindID(
"llvm.mem.parallel_loop_access");
352 DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
354 return Impl.visit(
F);
358 return new ScalarizerLegacyPass();
361 bool ScalarizerVisitor::visit(
Function &
F) {
362 assert(Gathered.empty() && Scattered.empty());
372 if (Done &&
I->getType()->isVoidTy())
373 I->eraseFromParent();
383 if (
Argument *VArg = dyn_cast<Argument>(V)) {
388 return Scatterer(
BB,
BB->begin(), V, PtrElemTy, &Scattered[V]);
405 PtrElemTy, &Scattered[V]);
416 void ScalarizerVisitor::gather(
Instruction *
Op,
const ValueVector &CV) {
417 transferMetadataAndIRFlags(
Op, CV);
421 ValueVector &SV = Scattered[
Op];
423 for (
unsigned I = 0,
E = SV.size();
I !=
E; ++
I) {
425 if (V ==
nullptr || SV[
I] == CV[
I])
429 if (isa<Instruction>(CV[
I]))
430 CV[
I]->takeName(Old);
432 PotentiallyDeadInstrs.emplace_back(Old);
436 Gathered.push_back(GatherList::value_type(
Op, &SV));
441 bool ScalarizerVisitor::canTransferMetadata(
unsigned Tag) {
442 return (Tag == LLVMContext::MD_tbaa
443 || Tag == LLVMContext::MD_fpmath
444 || Tag == LLVMContext::MD_tbaa_struct
445 || Tag == LLVMContext::MD_invariant_load
446 || Tag == LLVMContext::MD_alias_scope
447 || Tag == LLVMContext::MD_noalias
448 || Tag == ParallelLoopAccessMDKind
449 || Tag == LLVMContext::MD_access_group);
454 void ScalarizerVisitor::transferMetadataAndIRFlags(
Instruction *
Op,
455 const ValueVector &CV) {
457 Op->getAllMetadataOtherThanDebugLoc(MDs);
458 for (
unsigned I = 0,
E = CV.size();
I !=
E; ++
I) {
460 for (
const auto &MD : MDs)
461 if (canTransferMetadata(MD.first))
462 New->setMetadata(MD.first, MD.second);
463 New->copyIRFlags(
Op);
464 if (
Op->getDebugLoc() && !
New->getDebugLoc())
465 New->setDebugLoc(
Op->getDebugLoc());
473 ScalarizerVisitor::getVectorLayout(
Type *Ty,
Align Alignment,
477 Layout.VecTy = dyn_cast<VectorType>(Ty);
481 Layout.ElemTy = Layout.VecTy->getElementType();
482 if (!
DL.typeSizeEqualsStoreSize(Layout.ElemTy))
485 Layout.ElemSize =
DL.getTypeStoreSize(Layout.ElemTy);
491 template<
typename Splitter>
492 bool ScalarizerVisitor::splitUnary(
Instruction &
I,
const Splitter &Split) {
493 VectorType *VT = dyn_cast<VectorType>(
I.getType());
497 unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();
499 Scatterer
Op = scatter(&
I,
I.getOperand(0));
500 assert(
Op.size() == NumElems &&
"Mismatched unary operation");
502 Res.resize(NumElems);
503 for (
unsigned Elem = 0; Elem < NumElems; ++Elem)
504 Res[Elem] = Split(
Builder,
Op[Elem],
I.getName() +
".i" +
Twine(Elem));
511 template<
typename Splitter>
512 bool ScalarizerVisitor::splitBinary(
Instruction &
I,
const Splitter &Split) {
513 VectorType *VT = dyn_cast<VectorType>(
I.getType());
517 unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();
519 Scatterer VOp0 = scatter(&
I,
I.getOperand(0));
520 Scatterer VOp1 = scatter(&
I,
I.getOperand(1));
521 assert(VOp0.size() == NumElems &&
"Mismatched binary operation");
522 assert(VOp1.size() == NumElems &&
"Mismatched binary operation");
524 Res.resize(NumElems);
525 for (
unsigned Elem = 0; Elem < NumElems; ++Elem) {
526 Value *Op0 = VOp0[Elem];
527 Value *Op1 = VOp1[Elem];
528 Res[Elem] = Split(
Builder, Op0, Op1,
I.getName() +
".i" +
Twine(Elem));
547 bool ScalarizerVisitor::splitCall(
CallInst &CI) {
560 unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();
563 ValueVector ScalarOperands(NumArgs);
566 Scattered.resize(NumArgs);
573 for (
unsigned I = 0;
I != NumArgs; ++
I) {
576 Scattered[
I] = scatter(&CI, OpI);
577 assert(Scattered[
I].
size() == NumElems &&
"mismatched call operands");
581 ScalarOperands[
I] = OpI;
587 ValueVector Res(NumElems);
588 ValueVector ScalarCallOps(NumArgs);
594 for (
unsigned Elem = 0; Elem < NumElems; ++Elem) {
595 ScalarCallOps.clear();
597 for (
unsigned J = 0; J != NumArgs; ++J) {
599 ScalarCallOps.push_back(ScalarOperands[J]);
601 ScalarCallOps.push_back(Scattered[J][Elem]);
604 Res[Elem] =
Builder.CreateCall(NewIntrin, ScalarCallOps,
612 bool ScalarizerVisitor::visitSelectInst(
SelectInst &
SI) {
617 unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();
619 Scatterer VOp1 = scatter(&
SI,
SI.getOperand(1));
620 Scatterer VOp2 = scatter(&
SI,
SI.getOperand(2));
621 assert(VOp1.size() == NumElems &&
"Mismatched select");
622 assert(VOp2.size() == NumElems &&
"Mismatched select");
624 Res.resize(NumElems);
626 if (
SI.getOperand(0)->getType()->isVectorTy()) {
627 Scatterer VOp0 = scatter(&
SI,
SI.getOperand(0));
628 assert(VOp0.size() == NumElems &&
"Mismatched select");
629 for (
unsigned I = 0;
I < NumElems; ++
I) {
633 Res[
I] =
Builder.CreateSelect(Op0, Op1, Op2,
638 for (
unsigned I = 0;
I < NumElems; ++
I) {
641 Res[
I] =
Builder.CreateSelect(Op0, Op1, Op2,
649 bool ScalarizerVisitor::visitICmpInst(
ICmpInst &ICI) {
650 return splitBinary(ICI, ICmpSplitter(ICI));
653 bool ScalarizerVisitor::visitFCmpInst(
FCmpInst &FCI) {
654 return splitBinary(FCI, FCmpSplitter(FCI));
657 bool ScalarizerVisitor::visitUnaryOperator(
UnaryOperator &UO) {
658 return splitUnary(UO, UnarySplitter(UO));
662 return splitBinary(BO, BinarySplitter(BO));
671 unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();
678 Op0 =
Builder.CreateVectorSplat(NumElems, Op0);
679 Scatterer
Base = scatter(&GEPI, Op0);
683 for (
unsigned I = 0;
I < NumIndices; ++
I) {
688 if (!
Op->getType()->isVectorTy())
691 Ops[
I] = scatter(&GEPI,
Op);
695 Res.resize(NumElems);
696 for (
unsigned I = 0;
I < NumElems; ++
I) {
698 Indices.
resize(NumIndices);
699 for (
unsigned J = 0; J < NumIndices; ++J)
700 Indices[J] = Ops[J][
I];
705 NewGEPI->setIsInBounds();
711 bool ScalarizerVisitor::visitCastInst(
CastInst &CI) {
716 unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();
718 Scatterer Op0 = scatter(&CI, CI.
getOperand(0));
719 assert(Op0.size() == NumElems &&
"Mismatched cast");
721 Res.resize(NumElems);
722 for (
unsigned I = 0;
I < NumElems; ++
I)
729 bool ScalarizerVisitor::visitBitCastInst(
BitCastInst &BCI) {
732 if (!DstVT || !SrcVT)
735 unsigned DstNumElems = cast<FixedVectorType>(DstVT)->getNumElements();
736 unsigned SrcNumElems = cast<FixedVectorType>(SrcVT)->getNumElements();
738 Scatterer Op0 = scatter(&BCI, BCI.
getOperand(0));
740 Res.resize(DstNumElems);
742 if (DstNumElems == SrcNumElems) {
743 for (
unsigned I = 0;
I < DstNumElems; ++
I)
746 }
else if (DstNumElems > SrcNumElems) {
749 unsigned FanOut = DstNumElems / SrcNumElems;
752 for (
unsigned Op0I = 0; Op0I < SrcNumElems; ++Op0I) {
753 Value *V = Op0[Op0I];
757 while ((
VI = dyn_cast<Instruction>(V)) &&
758 VI->getOpcode() == Instruction::BitCast)
759 V =
VI->getOperand(0);
761 Scatterer Mid = scatter(&BCI, V);
762 for (
unsigned MidI = 0; MidI < FanOut; ++MidI)
763 Res[ResI++] = Mid[MidI];
767 unsigned FanIn = SrcNumElems / DstNumElems;
770 for (
unsigned ResI = 0; ResI < DstNumElems; ++ResI) {
772 for (
unsigned MidI = 0; MidI < FanIn; ++MidI)
773 V =
Builder.CreateInsertElement(V, Op0[Op0I++],
Builder.getInt32(MidI),
775 +
".upto" +
Twine(MidI));
789 unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();
791 Scatterer Op0 = scatter(&IEI, IEI.
getOperand(0));
796 Res.resize(NumElems);
798 if (
auto *CI = dyn_cast<ConstantInt>(InsIdx)) {
799 for (
unsigned I = 0;
I < NumElems; ++
I)
800 Res[
I] = CI->getValue().getZExtValue() ==
I ? NewElt : Op0[
I];
802 if (!ScalarizeVariableInsertExtract)
805 for (
unsigned I = 0;
I < NumElems; ++
I) {
806 Value *ShouldReplace =
810 Res[
I] =
Builder.CreateSelect(ShouldReplace, NewElt, OldElt,
824 unsigned NumSrcElems = cast<FixedVectorType>(VT)->getNumElements();
826 Scatterer Op0 = scatter(&EEI, EEI.
getOperand(0));
829 if (
auto *CI = dyn_cast<ConstantInt>(ExtIdx)) {
830 Value *Res = Op0[CI->getValue().getZExtValue()];
835 if (!ScalarizeVariableInsertExtract)
839 for (
unsigned I = 0;
I < NumSrcElems; ++
I) {
840 Value *ShouldExtract =
844 Res =
Builder.CreateSelect(ShouldExtract, Elt, Res,
856 unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();
857 Scatterer Op0 = scatter(&SVI, SVI.
getOperand(0));
858 Scatterer Op1 = scatter(&SVI, SVI.
getOperand(1));
860 Res.resize(NumElems);
862 for (
unsigned I = 0;
I < NumElems; ++
I) {
866 else if (
unsigned(Selector) < Op0.size())
867 Res[
I] = Op0[Selector];
869 Res[
I] = Op1[Selector - Op0.size()];
875 bool ScalarizerVisitor::visitPHINode(
PHINode &PHI) {
880 unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();
883 Res.resize(NumElems);
886 for (
unsigned I = 0;
I < NumElems; ++
I)
890 for (
unsigned I = 0;
I < NumOps; ++
I) {
893 for (
unsigned J = 0; J < NumElems; ++J)
894 cast<PHINode>(Res[J])->addIncoming(
Op[J], IncomingBlock);
900 bool ScalarizerVisitor::visitLoadInst(
LoadInst &LI) {
901 if (!ScalarizeLoadStore)
911 unsigned NumElems = cast<FixedVectorType>(Layout->VecTy)->getNumElements();
915 Res.resize(NumElems);
917 for (
unsigned I = 0;
I < NumElems; ++
I)
918 Res[
I] =
Builder.CreateAlignedLoad(Layout->VecTy->getElementType(), Ptr[
I],
919 Align(Layout->getElemAlign(
I)),
925 bool ScalarizerVisitor::visitStoreInst(
StoreInst &
SI) {
926 if (!ScalarizeLoadStore)
931 Value *FullValue =
SI.getValueOperand();
933 FullValue->
getType(),
SI.getAlign(),
SI.getModule()->getDataLayout());
937 unsigned NumElems = cast<FixedVectorType>(Layout->VecTy)->getNumElements();
939 Scatterer VPtr = scatter(&
SI,
SI.getPointerOperand(), FullValue->
getType());
940 Scatterer VVal = scatter(&
SI, FullValue);
943 Stores.resize(NumElems);
944 for (
unsigned I = 0;
I < NumElems; ++
I) {
947 Stores[
I] =
Builder.CreateAlignedStore(Val, Ptr, Layout->getElemAlign(
I));
949 transferMetadataAndIRFlags(&
SI, Stores);
953 bool ScalarizerVisitor::visitCallInst(
CallInst &CI) {
954 return splitCall(CI);
959 bool ScalarizerVisitor::finish() {
962 if (Gathered.empty() && Scattered.empty())
964 for (
const auto &GMI : Gathered) {
966 ValueVector &CV = *GMI.second;
967 if (!
Op->use_empty()) {
971 if (
auto *Ty = dyn_cast<VectorType>(
Op->getType())) {
973 unsigned Count = cast<FixedVectorType>(Ty)->getNumElements();
975 if (isa<PHINode>(
Op))
976 Builder.SetInsertPoint(
BB,
BB->getFirstInsertionPt());
977 for (
unsigned I = 0;
I < Count; ++
I)
979 Op->getName() +
".upto" +
Twine(
I));
982 assert(CV.size() == 1 &&
Op->getType() == CV[0]->getType());
987 Op->replaceAllUsesWith(Res);
989 PotentiallyDeadInstrs.emplace_back(
Op);
1001 unsigned ParallelLoopAccessMDKind =
1002 M.getContext().getMDKindID(
"llvm.mem.parallel_loop_access");
1004 ScalarizerVisitor Impl(ParallelLoopAccessMDKind, DT, Options);
1005 bool Changed = Impl.visit(
F);