53 #define DEBUG_TYPE "scalarizer"
57 cl::desc(
"Allow the scalarizer pass to scalarize "
58 "insertelement/extractelement with variable index"));
65 cl::desc(
"Allow the scalarizer pass to scalarize loads and store"));
75 using ScatterMap = std::map<Value *, ValueVector>;
85 Scatterer() =
default;
91 ValueVector *cachePtr =
nullptr);
94 Value *operator[](
unsigned I);
97 unsigned size()
const {
return Size; }
103 ValueVector *CachePtr;
111 struct FCmpSplitter {
112 FCmpSplitter(
FCmpInst &fci) : FCI(fci) {}
116 return Builder.CreateFCmp(FCI.getPredicate(), Op0, Op1,
Name);
124 struct ICmpSplitter {
125 ICmpSplitter(
ICmpInst &ici) : ICI(ici) {}
129 return Builder.CreateICmp(ICI.getPredicate(), Op0, Op1,
Name);
137 struct UnarySplitter {
149 struct BinarySplitter {
154 return Builder.CreateBinOp(BO.getOpcode(), Op0, Op1,
Name);
165 Align getElemAlign(
unsigned I) {
173 Type *ElemTy =
nullptr;
179 uint64_t ElemSize = 0;
182 class ScalarizerVisitor :
public InstVisitor<ScalarizerVisitor, bool> {
184 ScalarizerVisitor(
unsigned ParallelLoopAccessMDKind,
DominatorTree *DT)
185 : ParallelLoopAccessMDKind(ParallelLoopAccessMDKind), DT(DT) {
204 bool visitPHINode(
PHINode &PHI);
212 bool canTransferMetadata(
unsigned Kind);
213 void transferMetadataAndIRFlags(
Instruction *
Op,
const ValueVector &CV);
218 template<
typename T>
bool splitUnary(
Instruction &,
const T &);
219 template<
typename T>
bool splitBinary(
Instruction &,
const T &);
223 ScatterMap Scattered;
228 unsigned ParallelLoopAccessMDKind;
253 "Scalarize vector operations",
false,
false)
259 ValueVector *cachePtr)
260 :
BB(
bb), BBI(bbi), V(v), CachePtr(cachePtr) {
261 Type *Ty = V->getType();
262 PtrTy = dyn_cast<PointerType>(Ty);
264 Ty = PtrTy->getElementType();
265 Size = cast<FixedVectorType>(Ty)->getNumElements();
267 Tmp.resize(
Size,
nullptr);
268 else if (CachePtr->empty())
269 CachePtr->resize(
Size,
nullptr);
271 assert(
Size == CachePtr->size() &&
"Inconsistent vector sizes");
275 Value *Scatterer::operator[](
unsigned I) {
276 ValueVector &CV = (CachePtr ? *CachePtr : Tmp);
282 Type *ElTy = cast<VectorType>(PtrTy->getElementType())->getElementType();
285 CV[0] =
Builder.CreateBitCast(V, NewPtrTy, V->getName() +
".i0");
288 CV[
I] =
Builder.CreateConstGEP1_32(ElTy, CV[0],
I,
289 V->getName() +
".i" +
Twine(
I));
302 V =
Insert->getOperand(0);
304 CV[J] =
Insert->getOperand(1);
310 CV[J] =
Insert->getOperand(1);
314 V->getName() +
".i" +
Twine(
I));
324 unsigned ParallelLoopAccessMDKind =
325 M.getContext().getMDKindID(
"llvm.mem.parallel_loop_access");
326 DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
327 ScalarizerVisitor Impl(ParallelLoopAccessMDKind, DT);
328 return Impl.visit(
F);
332 return new ScalarizerLegacyPass();
335 bool ScalarizerVisitor::visit(
Function &
F) {
336 assert(Gathered.empty() && Scattered.empty());
346 if (Done &&
I->getType()->isVoidTy())
347 I->eraseFromParent();
356 if (
Argument *VArg = dyn_cast<Argument>(V)) {
361 return Scatterer(
BB,
BB->begin(), V, &Scattered[V]);
388 void ScalarizerVisitor::gather(
Instruction *
Op,
const ValueVector &CV) {
389 transferMetadataAndIRFlags(
Op, CV);
393 ValueVector &SV = Scattered[
Op];
395 for (
unsigned I = 0,
E = SV.size();
I !=
E; ++
I) {
397 if (V ==
nullptr || SV[
I] == CV[
I])
401 if (isa<Instruction>(CV[
I]))
402 CV[
I]->takeName(Old);
404 PotentiallyDeadInstrs.emplace_back(Old);
408 Gathered.push_back(GatherList::value_type(
Op, &SV));
413 bool ScalarizerVisitor::canTransferMetadata(
unsigned Tag) {
414 return (Tag == LLVMContext::MD_tbaa
415 || Tag == LLVMContext::MD_fpmath
416 || Tag == LLVMContext::MD_tbaa_struct
417 || Tag == LLVMContext::MD_invariant_load
418 || Tag == LLVMContext::MD_alias_scope
419 || Tag == LLVMContext::MD_noalias
420 || Tag == ParallelLoopAccessMDKind
421 || Tag == LLVMContext::MD_access_group);
426 void ScalarizerVisitor::transferMetadataAndIRFlags(
Instruction *
Op,
427 const ValueVector &CV) {
429 Op->getAllMetadataOtherThanDebugLoc(MDs);
430 for (
unsigned I = 0,
E = CV.size();
I !=
E; ++
I) {
432 for (
const auto &MD : MDs)
433 if (canTransferMetadata(MD.first))
434 New->setMetadata(MD.first, MD.second);
435 New->copyIRFlags(
Op);
436 if (
Op->getDebugLoc() && !
New->getDebugLoc())
437 New->setDebugLoc(
Op->getDebugLoc());
445 ScalarizerVisitor::getVectorLayout(
Type *Ty,
Align Alignment,
449 Layout.VecTy = dyn_cast<VectorType>(Ty);
453 Layout.ElemTy = Layout.VecTy->getElementType();
454 if (!
DL.typeSizeEqualsStoreSize(Layout.ElemTy))
456 Layout.VecAlign = Alignment;
457 Layout.ElemSize =
DL.getTypeStoreSize(Layout.ElemTy);
463 template<
typename Splitter>
464 bool ScalarizerVisitor::splitUnary(
Instruction &
I,
const Splitter &Split) {
465 VectorType *VT = dyn_cast<VectorType>(
I.getType());
469 unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();
471 Scatterer
Op = scatter(&
I,
I.getOperand(0));
472 assert(
Op.size() == NumElems &&
"Mismatched unary operation");
474 Res.resize(NumElems);
475 for (
unsigned Elem = 0; Elem < NumElems; ++Elem)
476 Res[Elem] = Split(
Builder,
Op[Elem],
I.getName() +
".i" +
Twine(Elem));
483 template<
typename Splitter>
484 bool ScalarizerVisitor::splitBinary(
Instruction &
I,
const Splitter &Split) {
485 VectorType *VT = dyn_cast<VectorType>(
I.getType());
489 unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();
491 Scatterer VOp0 = scatter(&
I,
I.getOperand(0));
492 Scatterer VOp1 = scatter(&
I,
I.getOperand(1));
493 assert(VOp0.size() == NumElems &&
"Mismatched binary operation");
494 assert(VOp1.size() == NumElems &&
"Mismatched binary operation");
496 Res.resize(NumElems);
497 for (
unsigned Elem = 0; Elem < NumElems; ++Elem) {
498 Value *Op0 = VOp0[Elem];
499 Value *Op1 = VOp1[Elem];
500 Res[Elem] = Split(
Builder, Op0, Op1,
I.getName() +
".i" +
Twine(Elem));
519 bool ScalarizerVisitor::splitCall(
CallInst &CI) {
532 unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();
535 ValueVector ScalarOperands(NumArgs);
538 Scattered.resize(NumArgs);
542 for (
unsigned I = 0;
I != NumArgs; ++
I) {
545 Scattered[
I] = scatter(&CI, OpI);
546 assert(Scattered[
I].
size() == NumElems &&
"mismatched call operands");
548 ScalarOperands[
I] = OpI;
552 ValueVector Res(NumElems);
553 ValueVector ScalarCallOps(NumArgs);
559 for (
unsigned Elem = 0; Elem < NumElems; ++Elem) {
560 ScalarCallOps.clear();
562 for (
unsigned J = 0; J != NumArgs; ++J) {
564 ScalarCallOps.push_back(ScalarOperands[J]);
566 ScalarCallOps.push_back(Scattered[J][Elem]);
569 Res[Elem] =
Builder.CreateCall(NewIntrin, ScalarCallOps,
577 bool ScalarizerVisitor::visitSelectInst(
SelectInst &
SI) {
582 unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();
584 Scatterer VOp1 = scatter(&
SI,
SI.getOperand(1));
585 Scatterer VOp2 = scatter(&
SI,
SI.getOperand(2));
586 assert(VOp1.size() == NumElems &&
"Mismatched select");
587 assert(VOp2.size() == NumElems &&
"Mismatched select");
589 Res.resize(NumElems);
591 if (
SI.getOperand(0)->getType()->isVectorTy()) {
592 Scatterer VOp0 = scatter(&
SI,
SI.getOperand(0));
593 assert(VOp0.size() == NumElems &&
"Mismatched select");
594 for (
unsigned I = 0;
I < NumElems; ++
I) {
598 Res[
I] =
Builder.CreateSelect(Op0, Op1, Op2,
603 for (
unsigned I = 0;
I < NumElems; ++
I) {
606 Res[
I] =
Builder.CreateSelect(Op0, Op1, Op2,
614 bool ScalarizerVisitor::visitICmpInst(
ICmpInst &ICI) {
615 return splitBinary(ICI, ICmpSplitter(ICI));
618 bool ScalarizerVisitor::visitFCmpInst(
FCmpInst &FCI) {
619 return splitBinary(FCI, FCmpSplitter(FCI));
622 bool ScalarizerVisitor::visitUnaryOperator(
UnaryOperator &UO) {
623 return splitUnary(UO, UnarySplitter(UO));
627 return splitBinary(BO, BinarySplitter(BO));
636 unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();
643 Op0 =
Builder.CreateVectorSplat(NumElems, Op0);
644 Scatterer Base = scatter(&GEPI, Op0);
648 for (
unsigned I = 0;
I < NumIndices; ++
I) {
653 if (!
Op->getType()->isVectorTy())
656 Ops[
I] = scatter(&GEPI,
Op);
660 Res.resize(NumElems);
661 for (
unsigned I = 0;
I < NumElems; ++
I) {
663 Indices.
resize(NumIndices);
664 for (
unsigned J = 0; J < NumIndices; ++J)
665 Indices[J] = Ops[J][
I];
670 NewGEPI->setIsInBounds();
676 bool ScalarizerVisitor::visitCastInst(
CastInst &CI) {
681 unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();
683 Scatterer Op0 = scatter(&CI, CI.
getOperand(0));
684 assert(Op0.size() == NumElems &&
"Mismatched cast");
686 Res.resize(NumElems);
687 for (
unsigned I = 0;
I < NumElems; ++
I)
694 bool ScalarizerVisitor::visitBitCastInst(
BitCastInst &BCI) {
697 if (!DstVT || !SrcVT)
700 unsigned DstNumElems = cast<FixedVectorType>(DstVT)->getNumElements();
701 unsigned SrcNumElems = cast<FixedVectorType>(SrcVT)->getNumElements();
703 Scatterer Op0 = scatter(&BCI, BCI.
getOperand(0));
705 Res.resize(DstNumElems);
707 if (DstNumElems == SrcNumElems) {
708 for (
unsigned I = 0;
I < DstNumElems; ++
I)
711 }
else if (DstNumElems > SrcNumElems) {
714 unsigned FanOut = DstNumElems / SrcNumElems;
717 for (
unsigned Op0I = 0; Op0I < SrcNumElems; ++Op0I) {
718 Value *V = Op0[Op0I];
722 while ((
VI = dyn_cast<Instruction>(V)) &&
723 VI->getOpcode() == Instruction::BitCast)
724 V =
VI->getOperand(0);
726 Scatterer Mid = scatter(&BCI, V);
727 for (
unsigned MidI = 0; MidI < FanOut; ++MidI)
728 Res[ResI++] = Mid[MidI];
732 unsigned FanIn = SrcNumElems / DstNumElems;
735 for (
unsigned ResI = 0; ResI < DstNumElems; ++ResI) {
737 for (
unsigned MidI = 0; MidI < FanIn; ++MidI)
738 V =
Builder.CreateInsertElement(V, Op0[Op0I++],
Builder.getInt32(MidI),
740 +
".upto" +
Twine(MidI));
754 unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();
756 Scatterer Op0 = scatter(&IEI, IEI.
getOperand(0));
761 Res.resize(NumElems);
763 if (
auto *CI = dyn_cast<ConstantInt>(InsIdx)) {
764 for (
unsigned I = 0;
I < NumElems; ++
I)
765 Res[
I] = CI->getValue().getZExtValue() ==
I ? NewElt : Op0[
I];
770 for (
unsigned I = 0;
I < NumElems; ++
I) {
771 Value *ShouldReplace =
775 Res[
I] =
Builder.CreateSelect(ShouldReplace, NewElt, OldElt,
789 unsigned NumSrcElems = cast<FixedVectorType>(VT)->getNumElements();
791 Scatterer Op0 = scatter(&EEI, EEI.
getOperand(0));
794 if (
auto *CI = dyn_cast<ConstantInt>(ExtIdx)) {
795 Value *Res = Op0[CI->getValue().getZExtValue()];
804 for (
unsigned I = 0;
I < NumSrcElems; ++
I) {
805 Value *ShouldExtract =
809 Res =
Builder.CreateSelect(ShouldExtract, Elt, Res,
821 unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();
822 Scatterer Op0 = scatter(&SVI, SVI.
getOperand(0));
823 Scatterer Op1 = scatter(&SVI, SVI.
getOperand(1));
825 Res.resize(NumElems);
827 for (
unsigned I = 0;
I < NumElems; ++
I) {
831 else if (
unsigned(Selector) < Op0.size())
832 Res[
I] = Op0[Selector];
834 Res[
I] = Op1[Selector - Op0.size()];
840 bool ScalarizerVisitor::visitPHINode(
PHINode &PHI) {
845 unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();
848 Res.resize(NumElems);
851 for (
unsigned I = 0;
I < NumElems; ++
I)
855 for (
unsigned I = 0;
I < NumOps; ++
I) {
858 for (
unsigned J = 0; J < NumElems; ++J)
859 cast<PHINode>(Res[J])->addIncoming(
Op[J], IncomingBlock);
865 bool ScalarizerVisitor::visitLoadInst(
LoadInst &LI) {
876 unsigned NumElems = cast<FixedVectorType>(Layout->VecTy)->getNumElements();
880 Res.resize(NumElems);
882 for (
unsigned I = 0;
I < NumElems; ++
I)
883 Res[
I] =
Builder.CreateAlignedLoad(Layout->VecTy->getElementType(), Ptr[
I],
884 Align(Layout->getElemAlign(
I)),
890 bool ScalarizerVisitor::visitStoreInst(
StoreInst &
SI) {
896 Value *FullValue =
SI.getValueOperand();
898 FullValue->
getType(),
SI.getAlign(),
SI.getModule()->getDataLayout());
902 unsigned NumElems = cast<FixedVectorType>(Layout->VecTy)->getNumElements();
904 Scatterer VPtr = scatter(&
SI,
SI.getPointerOperand());
905 Scatterer VVal = scatter(&
SI, FullValue);
908 Stores.resize(NumElems);
909 for (
unsigned I = 0;
I < NumElems; ++
I) {
912 Stores[
I] =
Builder.CreateAlignedStore(Val, Ptr, Layout->getElemAlign(
I));
914 transferMetadataAndIRFlags(&
SI, Stores);
918 bool ScalarizerVisitor::visitCallInst(
CallInst &CI) {
919 return splitCall(CI);
924 bool ScalarizerVisitor::finish() {
927 if (Gathered.empty() && Scattered.empty())
929 for (
const auto &GMI : Gathered) {
931 ValueVector &CV = *GMI.second;
932 if (!
Op->use_empty()) {
936 if (
auto *Ty = dyn_cast<VectorType>(
Op->getType())) {
938 unsigned Count = cast<FixedVectorType>(Ty)->getNumElements();
940 if (isa<PHINode>(
Op))
941 Builder.SetInsertPoint(
BB,
BB->getFirstInsertionPt());
942 for (
unsigned I = 0;
I < Count; ++
I)
944 Op->getName() +
".upto" +
Twine(
I));
947 assert(CV.size() == 1 &&
Op->getType() == CV[0]->getType());
952 Op->replaceAllUsesWith(Res);
954 PotentiallyDeadInstrs.emplace_back(
Op);
966 unsigned ParallelLoopAccessMDKind =
967 M.getContext().getMDKindID(
"llvm.mem.parallel_loop_access");
969 ScalarizerVisitor Impl(ParallelLoopAccessMDKind, DT);
970 bool Changed = Impl.visit(
F);