49 #define DEBUG_TYPE "interleaved-load-combine"
54 STATISTIC(NumInterleavedLoadCombine,
"Number of combined loads");
59 cl::desc(
"Disable combining of interleaved loads"));
63 struct InterleavedLoadCombineImpl {
67 :
F(
F), DT(DT), MSSA(MSSA),
68 TLI(*
TM.getSubtargetImpl(
F)->getTargetLowering()),
69 TTI(
TM.getTargetTransformInfo(
F)) {}
93 LoadInst *findFirstLoad(
const std::set<LoadInst *> &LIs);
98 bool combine(std::list<VectorInfo> &InterleavedLoad,
103 bool findPattern(std::list<VectorInfo> &Candidates,
104 std::list<VectorInfo> &InterleavedLoad,
unsigned Factor,
186 Polynomial(
Value *V) : ErrorMSBs((
unsigned)-1), V(V) {
195 Polynomial(
const APInt &A,
unsigned ErrorMSBs = 0)
196 : ErrorMSBs(ErrorMSBs), V(
nullptr), A(A) {}
199 : ErrorMSBs(ErrorMSBs), V(
nullptr), A(
BitWidth, A) {}
201 Polynomial() : ErrorMSBs((
unsigned)-1), V(
nullptr) {}
204 void incErrorMSBs(
unsigned amt) {
205 if (ErrorMSBs == (
unsigned)-1)
209 if (ErrorMSBs > A.getBitWidth())
210 ErrorMSBs = A.getBitWidth();
214 void decErrorMSBs(
unsigned amt) {
215 if (ErrorMSBs == (
unsigned)-1)
242 if (
C.getBitWidth() != A.getBitWidth()) {
243 ErrorMSBs = (unsigned)-1;
252 Polynomial &mul(
const APInt &
C) {
303 if (
C.getBitWidth() != A.getBitWidth()) {
304 ErrorMSBs = (unsigned)-1;
321 decErrorMSBs(
C.countTrailingZeros());
324 pushBOperation(
Mul,
C);
460 if (
C.getBitWidth() != A.getBitWidth()) {
461 ErrorMSBs = (unsigned)-1;
469 unsigned shiftAmt =
C.getZExtValue();
470 if (shiftAmt >=
C.getBitWidth())
471 return mul(
APInt(
C.getBitWidth(), 0));
478 if (A.countTrailingZeros() < shiftAmt)
479 ErrorMSBs = A.getBitWidth();
481 incErrorMSBs(shiftAmt);
484 pushBOperation(LShr,
C);
485 A = A.lshr(shiftAmt);
491 Polynomial &sextOrTrunc(
unsigned n) {
492 if (
n < A.getBitWidth()) {
495 decErrorMSBs(A.getBitWidth() -
n);
497 pushBOperation(Trunc,
APInt(
sizeof(
n) * 8,
n));
499 if (
n > A.getBitWidth()) {
502 incErrorMSBs(
n - A.getBitWidth());
504 pushBOperation(SExt,
APInt(
sizeof(
n) * 8,
n));
511 bool isFirstOrder()
const {
return V !=
nullptr; }
514 bool isCompatibleTo(
const Polynomial &o)
const {
516 if (A.getBitWidth() != o.A.getBitWidth())
520 if (!isFirstOrder() && !o.isFirstOrder())
528 if (
B.size() != o.B.size())
531 auto ob = o.B.begin();
543 Polynomial
operator-(
const Polynomial &o)
const {
545 if (!isCompatibleTo(o))
551 return Polynomial(A - o.A,
std::max(ErrorMSBs, o.ErrorMSBs));
556 Polynomial Result(*
this);
563 Polynomial Result(*
this);
569 bool isProvenEqualTo(
const Polynomial &o) {
571 Polynomial r = *
this - o;
572 return (r.ErrorMSBs == 0) && (!r.isFirstOrder()) && (r.A.isZero());
577 OS <<
"[{#ErrBits:" << ErrorMSBs <<
"} ";
582 OS <<
"(" << *V <<
") ";
600 OS <<
b.second <<
") ";
604 OS <<
"+ " << A <<
"]";
613 void pushBOperation(
const BOps
Op,
const APInt &
C) {
614 if (isFirstOrder()) {
615 B.push_back(std::make_pair(
Op,
C));
637 VectorInfo(
const VectorInfo &
c) : VTy(
c.VTy) {
639 "Copying VectorInfo is neither implemented nor necessary,");
652 ElementInfo(Polynomial Offset = Polynomial(),
LoadInst *LI =
nullptr)
653 : Ofs(Offset), LI(LI) {}
663 std::set<LoadInst *> LIs;
666 std::set<Instruction *> Is;
681 virtual ~VectorInfo() {
delete[] EI; }
692 bool isInterleaved(
unsigned Factor,
const DataLayout &
DL)
const {
694 for (
unsigned i = 1;
i < getDimension();
i++) {
695 if (!EI[
i].Ofs.isProvenEqualTo(EI[0].Ofs +
i * Factor * Size)) {
713 return computeFromSVI(SVI, Result,
DL);
714 LoadInst *LI = dyn_cast<LoadInst>(V);
716 return computeFromLI(LI, Result,
DL);
719 return computeFromBCI(BCI, Result,
DL);
729 static bool computeFromBCI(
BitCastInst *BCI, VectorInfo &Result,
744 unsigned Factor = Result.VTy->getNumElements() / VTy->
getNumElements();
745 unsigned NewSize =
DL.getTypeAllocSize(Result.VTy->getElementType());
748 if (NewSize * Factor != OldSize)
752 if (!compute(
Op, Old,
DL))
755 for (
unsigned i = 0;
i < Result.VTy->getNumElements();
i += Factor) {
756 for (
unsigned j = 0;
j < Factor;
j++) {
758 ElementInfo(Old.EI[
i / Factor].Ofs +
j * NewSize,
759 j == 0 ? Old.EI[
i / Factor].LI :
nullptr);
765 Result.LIs.insert(Old.LIs.begin(), Old.LIs.end());
766 Result.Is.insert(Old.Is.begin(), Old.Is.end());
767 Result.Is.insert(BCI);
768 Result.SVI =
nullptr;
790 VectorInfo
LHS(ArgTy);
795 VectorInfo
RHS(ArgTy);
824 Result.LIs.insert(
LHS.LIs.begin(),
LHS.LIs.end());
825 Result.Is.insert(
LHS.Is.begin(),
LHS.Is.end());
828 Result.LIs.insert(
RHS.LIs.begin(),
RHS.LIs.end());
829 Result.Is.insert(
RHS.Is.begin(),
RHS.Is.end());
831 Result.Is.insert(SVI);
837 "Invalid ShuffleVectorInst (index out of bounds)");
840 Result.EI[
j] = ElementInfo();
843 Result.EI[
j] =
LHS.EI[
i];
845 Result.EI[
j] = ElementInfo();
850 Result.EI[
j] = ElementInfo();
866 static bool computeFromLI(
LoadInst *LI, VectorInfo &Result,
882 Result.LIs.insert(LI);
883 Result.Is.insert(LI);
885 for (
unsigned i = 0;
i < Result.getDimension();
i++) {
890 int64_t Ofs =
DL.getIndexedOffsetInType(Result.VTy,
makeArrayRef(Idx, 2));
891 Result.EI[
i] = ElementInfo(Offset + Ofs,
i == 0 ? LI :
nullptr);
901 static void computePolynomialBinOp(
BinaryOperator &BO, Polynomial &Result) {
908 C = dyn_cast<ConstantInt>(
LHS);
918 computePolynomial(*
LHS, Result);
919 Result.add(
C->getValue());
922 case Instruction::LShr:
926 computePolynomial(*
LHS, Result);
927 Result.lshr(
C->getValue());
934 Result = Polynomial(&BO);
941 static void computePolynomial(
Value &V, Polynomial &Result) {
942 if (
auto *BO = dyn_cast<BinaryOperator>(&V))
943 computePolynomialBinOp(*BO, Result);
945 Result = Polynomial(&V);
954 static void computePolynomialFromPointer(
Value &Ptr, Polynomial &Result,
960 Result = Polynomial();
964 unsigned PointerBits =
968 if (isa<CastInst>(&Ptr)) {
969 CastInst &CI = *cast<CastInst>(&Ptr);
971 case Instruction::BitCast:
972 computePolynomialFromPointer(*CI.
getOperand(0), Result, BasePtr,
DL);
976 Polynomial(PointerBits, 0);
981 else if (isa<GetElementPtrInst>(&Ptr)) {
984 APInt BaseOffset(PointerBits, 0);
987 if (
GEP.accumulateConstantOffset(
DL, BaseOffset)) {
988 Result = Polynomial(BaseOffset);
989 BasePtr =
GEP.getPointerOperand();
994 unsigned idxOperand,
e;
996 for (idxOperand = 1,
e =
GEP.getNumOperands(); idxOperand <
e;
998 ConstantInt *IDX = dyn_cast<ConstantInt>(
GEP.getOperand(idxOperand));
1001 Indices.push_back(IDX);
1005 if (idxOperand + 1 !=
e) {
1006 Result = Polynomial();
1012 computePolynomial(*
GEP.getOperand(idxOperand), Result);
1017 DL.getIndexedOffsetInType(
GEP.getSourceElementType(), Indices);
1020 unsigned ResultSize =
DL.getTypeAllocSize(
GEP.getResultElementType());
1021 Result.sextOrTrunc(PointerBits);
1022 Result.mul(
APInt(PointerBits, ResultSize));
1023 Result.add(BaseOffset);
1024 BasePtr =
GEP.getPointerOperand();
1042 for (
unsigned i = 0;
i < getDimension();
i++)
1043 OS << ((
i == 0) ?
"[" :
", ") << EI[
i].Ofs;
1051 bool InterleavedLoadCombineImpl::findPattern(
1052 std::list<VectorInfo> &Candidates, std::list<VectorInfo> &InterleavedLoad,
1054 for (
auto C0 = Candidates.begin(), E0 = Candidates.end(); C0 != E0; ++C0) {
1057 unsigned Size =
DL.getTypeAllocSize(C0->VTy->getElementType());
1060 std::vector<std::list<VectorInfo>::iterator> Res(Factor, Candidates.end());
1062 for (
auto C = Candidates.begin(),
E = Candidates.end();
C !=
E;
C++) {
1063 if (
C->VTy != C0->VTy)
1065 if (
C->BB != C0->BB)
1067 if (
C->PV != C0->PV)
1071 for (
i = 1;
i < Factor;
i++) {
1072 if (
C->EI[0].Ofs.isProvenEqualTo(C0->EI[0].Ofs +
i * Size)) {
1077 for (
i = 1;
i < Factor;
i++) {
1078 if (Res[
i] == Candidates.end())
1087 if (Res[0] != Candidates.end()) {
1089 for (
unsigned i = 0;
i < Factor;
i++) {
1090 InterleavedLoad.splice(InterleavedLoad.end(), Candidates, Res[
i]);
1100 InterleavedLoadCombineImpl::findFirstLoad(
const std::set<LoadInst *> &LIs) {
1101 assert(!LIs.empty() &&
"No load instructions given.");
1109 return cast<LoadInst>(FLI);
1119 LoadInst *InsertionPoint = InterleavedLoad.front().EI[0].LI;
1122 if (!InsertionPoint)
1125 std::set<LoadInst *> LIs;
1126 std::set<Instruction *> Is;
1127 std::set<Instruction *> SVIs;
1134 unsigned Factor = InterleavedLoad.size();
1137 for (
auto &
VI : InterleavedLoad) {
1139 LIs.insert(
VI.LIs.begin(),
VI.LIs.end());
1144 Is.insert(
VI.Is.begin(),
VI.Is.end());
1147 SVIs.insert(
VI.SVI);
1157 for (
auto &
I : Is) {
1162 if (SVIs.find(
I) != SVIs.end())
1167 for (
auto *U :
I->users()) {
1168 if (Is.find(dyn_cast<Instruction>(U)) == Is.end())
1184 auto FMA = MSSA.getMemoryAccess(First);
1185 for (
auto LI : LIs) {
1186 auto MADef = MSSA.getMemoryAccess(LI)->getDefiningAccess();
1187 if (!MSSA.dominates(MADef,
FMA))
1190 assert(!LIs.empty() &&
"There are no LoadInst to combine");
1193 for (
auto &
VI : InterleavedLoad) {
1194 if (!DT.dominates(InsertionPoint,
VI.SVI))
1201 Type *ETy = InterleavedLoad.front().SVI->getType()->getElementType();
1202 unsigned ElementsPerSVI =
1203 cast<FixedVectorType>(InterleavedLoad.front().SVI->getType())
1207 auto Indices = llvm::to_vector<4>(llvm::seq<unsigned>(0, Factor));
1219 "interleaved.wide.ptrcast");
1222 auto LI =
Builder.CreateAlignedLoad(ILTy, CI, InsertionPoint->
getAlign(),
1223 "interleaved.wide.load");
1225 MemoryUse *MSSALoad = cast<MemoryUse>(MSSAU.createMemoryAccessBefore(
1226 LI,
nullptr, MSSA.getMemoryAccess(InsertionPoint)));
1227 MSSAU.insertUse(MSSALoad);
1231 for (
auto &
VI : InterleavedLoad) {
1233 for (
unsigned j = 0;
j < ElementsPerSVI;
j++)
1234 Mask.push_back(
i +
j * Factor);
1237 auto SVI =
Builder.CreateShuffleVector(LI,
Mask,
"interleaved.shuffle");
1238 VI.SVI->replaceAllUsesWith(SVI);
1242 NumInterleavedLoadCombine++;
1245 <<
"Load interleaved combined with factor "
1254 bool changed =
false;
1255 unsigned MaxFactor = TLI.getMaxSupportedInterleaveFactor();
1257 auto &
DL =
F.getParent()->getDataLayout();
1260 for (
unsigned Factor = MaxFactor; Factor >= 2; Factor--) {
1261 std::list<VectorInfo> Candidates;
1265 if (
auto SVI = dyn_cast<ShuffleVectorInst>(&
I)) {
1267 if (isa<ScalableVectorType>(SVI->getType()))
1270 Candidates.emplace_back(cast<FixedVectorType>(SVI->getType()));
1272 if (!VectorInfo::computeFromSVI(SVI, Candidates.back(),
DL)) {
1273 Candidates.pop_back();
1277 if (!Candidates.back().isInterleaved(Factor,
DL)) {
1278 Candidates.pop_back();
1284 std::list<VectorInfo> InterleavedLoad;
1285 while (findPattern(Candidates, InterleavedLoad, Factor,
DL)) {
1286 if (
combine(InterleavedLoad, ORE)) {
1291 Candidates.splice(Candidates.begin(), InterleavedLoad,
1292 std::next(InterleavedLoad.begin()),
1293 InterleavedLoad.end());
1295 InterleavedLoad.clear();
1312 StringRef getPassName()
const override {
1313 return "Interleaved Load Combine Pass";
1317 if (DisableInterleavedLoadCombine)
1320 auto *TPC = getAnalysisIfAvailable<TargetPassConfig>();
1327 return InterleavedLoadCombineImpl(
1328 F, getAnalysis<DominatorTreeWrapperPass>().getDomTree(),
1329 getAnalysis<MemorySSAWrapperPass>().getMSSA(),
1348 "Combine interleaved loads into wide loads and shufflevector instructions",
1359 auto P =
new InterleavedLoadCombine();