90 #define DEBUG_TYPE "load-store-vectorizer"
92 STATISTIC(NumVectorInstructions,
"Number of vector accesses generated");
93 STATISTIC(NumScalarsVectorized,
"Number of scalar accesses vectorized");
107 using ChainID =
const Value *;
130 unsigned getPointerAddressSpace(
Value *
I);
136 unsigned Depth = 0)
const;
137 bool lookThroughComplexAddresses(
Value *PtrA,
Value *PtrB,
APInt PtrDelta,
138 unsigned Depth)
const;
139 bool lookThroughSelects(
Value *PtrA,
Value *PtrB,
const APInt &PtrDelta,
140 unsigned Depth)
const;
147 std::pair<BasicBlock::iterator, BasicBlock::iterator>
169 std::pair<InstrListMap, InstrListMap> collectInstructions(
BasicBlock *
BB);
173 bool vectorizeChains(InstrListMap &Map);
189 bool accessIsMisaligned(
unsigned SzInBytes,
unsigned AddressSpace,
193 class LoadStoreVectorizerLegacyPass :
public FunctionPass {
204 return "GPU Load and Store Vectorizer";
222 "Vectorize load and Store instructions",
false,
false)
233 return new LoadStoreVectorizerLegacyPass();
238 if (skipFunction(
F) ||
F.hasFnAttribute(Attribute::NoImplicitFloat))
241 AliasAnalysis &
AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
242 DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
243 ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE();
245 getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
F);
248 getAnalysis<AssumptionCacheTracker>().getAssumptionCache(
F);
250 Vectorizer V(
F,
AA, AC, DT, SE,
TTI);
256 if (
F.hasFnAttribute(Attribute::NoImplicitFloat))
265 Vectorizer V(
F,
AA, AC, DT, SE,
TTI);
266 bool Changed = V.run();
281 bool Changed =
false;
285 InstrListMap LoadRefs, StoreRefs;
286 std::tie(LoadRefs, StoreRefs) = collectInstructions(
BB);
287 Changed |= vectorizeChains(LoadRefs);
288 Changed |= vectorizeChains(StoreRefs);
294 unsigned Vectorizer::getPointerAddressSpace(
Value *
I) {
296 return L->getPointerAddressSpace();
298 return S->getPointerAddressSpace();
306 unsigned ASA = getPointerAddressSpace(A);
307 unsigned ASB = getPointerAddressSpace(
B);
310 if (!PtrA || !PtrB || (ASA != ASB))
318 DL.getTypeStoreSize(PtrATy) !=
DL.getTypeStoreSize(PtrBTy) ||
323 unsigned PtrBitWidth =
DL.getPointerSizeInBits(ASA);
324 APInt Size(PtrBitWidth,
DL.getTypeStoreSize(PtrATy));
326 return areConsecutivePointers(PtrA, PtrB, Size);
329 bool Vectorizer::areConsecutivePointers(
Value *PtrA,
Value *PtrB,
331 unsigned PtrBitWidth =
DL.getPointerTypeSizeInBits(PtrA->
getType());
332 APInt OffsetA(PtrBitWidth, 0);
333 APInt OffsetB(PtrBitWidth, 0);
337 unsigned NewPtrBitWidth =
DL.getTypeStoreSizeInBits(PtrA->
getType());
339 if (NewPtrBitWidth !=
DL.getTypeStoreSizeInBits(PtrB->
getType()))
346 assert(OffsetA.getMinSignedBits() <= NewPtrBitWidth &&
347 OffsetB.getMinSignedBits() <= NewPtrBitWidth);
349 OffsetA = OffsetA.sextOrTrunc(NewPtrBitWidth);
350 OffsetB = OffsetB.sextOrTrunc(NewPtrBitWidth);
353 APInt OffsetDelta = OffsetB - OffsetA;
358 return OffsetDelta == PtrDelta;
362 APInt BaseDelta = PtrDelta - OffsetDelta;
383 return lookThroughComplexAddresses(PtrA, PtrB, BaseDelta,
Depth);
394 unsigned MatchingOpIdxB,
bool Signed) {
416 Value *OtherOperandA = AddOpA->
getOperand(MatchingOpIdxA == 1 ? 0 : 1);
417 Value *OtherOperandB = AddOpB->
getOperand(MatchingOpIdxB == 1 ? 0 : 1);
418 Instruction *OtherInstrA = dyn_cast<Instruction>(OtherOperandA);
419 Instruction *OtherInstrB = dyn_cast<Instruction>(OtherOperandB);
423 isa<ConstantInt>(OtherInstrB->
getOperand(1))) {
425 cast<ConstantInt>(OtherInstrB->
getOperand(1))->getSExtValue();
426 if (OtherInstrB->
getOperand(0) == OtherOperandA &&
433 isa<ConstantInt>(OtherInstrA->
getOperand(1))) {
435 cast<ConstantInt>(OtherInstrA->
getOperand(1))->getSExtValue();
436 if (OtherInstrA->
getOperand(0) == OtherOperandB &&
442 if (OtherInstrA && OtherInstrB &&
447 isa<ConstantInt>(OtherInstrA->
getOperand(1)) &&
448 isa<ConstantInt>(OtherInstrB->
getOperand(1))) {
450 cast<ConstantInt>(OtherInstrA->
getOperand(1))->getSExtValue();
452 cast<ConstantInt>(OtherInstrB->
getOperand(1))->getSExtValue();
461 bool Vectorizer::lookThroughComplexAddresses(
Value *PtrA,
Value *PtrB,
463 unsigned Depth)
const {
464 auto *GEPA = dyn_cast<GetElementPtrInst>(PtrA);
465 auto *GEPB = dyn_cast<GetElementPtrInst>(PtrB);
467 return lookThroughSelects(PtrA, PtrB, PtrDelta,
Depth);
471 if (GEPA->getNumOperands() != GEPB->getNumOperands() ||
472 GEPA->getPointerOperand() != GEPB->getPointerOperand())
476 for (
unsigned I = 0,
E = GEPA->getNumIndices() - 1;
I <
E; ++
I) {
496 if (PtrDelta.
urem(Stride) != 0)
502 if (!isa<SExtInst>(OpA) && !isa<ZExtInst>(OpA))
505 bool Signed = isa<SExtInst>(OpA);
509 OpB = dyn_cast<Instruction>(OpB->
getOperand(0));
520 IdxDiff.
sle(cast<ConstantInt>(OpB->
getOperand(1))->getSExtValue()) &&
526 OpA = dyn_cast<Instruction>(ValA);
534 for (
unsigned MatchingOpIdxA : {0, 1})
535 for (
unsigned MatchingOpIdxB : {0, 1})
553 if (BitsAllowedToBeSet.
ult(IdxDiff))
561 return X == OffsetSCEVB;
564 bool Vectorizer::lookThroughSelects(
Value *PtrA,
Value *PtrB,
565 const APInt &PtrDelta,
566 unsigned Depth)
const {
570 if (
auto *SelectA = dyn_cast<SelectInst>(PtrA)) {
571 if (
auto *SelectB = dyn_cast<SelectInst>(PtrB)) {
572 return SelectA->getCondition() == SelectB->getCondition() &&
573 areConsecutivePointers(SelectA->getTrueValue(),
574 SelectB->getTrueValue(), PtrDelta,
Depth) &&
575 areConsecutivePointers(SelectA->getFalseValue(),
576 SelectB->getFalseValue(), PtrDelta,
Depth);
586 Worklist.push_back(
I);
587 while (!Worklist.empty()) {
590 for (
int i = 0;
i < NumOperands;
i++) {
592 if (!IM || IM->
getOpcode() == Instruction::PHI)
601 InstructionsToMove.
insert(IM);
602 Worklist.push_back(IM);
608 for (
auto BBI =
I->getIterator(),
E =
I->getParent()->end(); BBI !=
E;
610 if (!InstructionsToMove.
count(&*BBI))
619 std::pair<BasicBlock::iterator, BasicBlock::iterator>
626 unsigned NumFound = 0;
633 FirstInstr =
I.getIterator();
635 if (NumFound == Chain.
size()) {
636 LastInstr =
I.getIterator();
642 return std::make_pair(FirstInstr, ++LastInstr);
649 assert(PtrOperand &&
"Instruction must have a pointer operand.");
652 Instrs.push_back(
GEP);
658 I->eraseFromParent();
663 unsigned ElementSizeBits) {
664 unsigned ElementSizeBytes = ElementSizeBits / 8;
665 unsigned SizeBytes = ElementSizeBytes * Chain.
size();
666 unsigned NumLeft = (SizeBytes - (SizeBytes % 4)) / ElementSizeBytes;
667 if (NumLeft == Chain.
size()) {
668 if ((NumLeft & 1) == 0)
672 }
else if (NumLeft == 0)
674 return std::make_pair(Chain.
slice(0, NumLeft), Chain.
slice(NumLeft));
683 bool IsLoadChain = isa<LoadInst>(Chain[0]);
688 "All elements of Chain must be loads, or all must be stores.");
691 "All elements of Chain must be loads, or all must be stores.");
696 if ((isa<LoadInst>(
I) || isa<StoreInst>(
I)) &&
is_contained(Chain, &
I)) {
697 ChainInstrs.push_back(&
I);
701 LLVM_DEBUG(
dbgs() <<
"LSV: Found instruction may not transfer execution: "
705 if (
I.mayReadOrWriteMemory())
706 MemoryInstrs.push_back(&
I);
710 unsigned ChainInstrIdx = 0;
713 for (
unsigned E = ChainInstrs.size(); ChainInstrIdx <
E; ++ChainInstrIdx) {
714 Instruction *ChainInstr = ChainInstrs[ChainInstrIdx];
718 if (BarrierMemoryInstr && BarrierMemoryInstr->
comesBefore(ChainInstr))
725 if (BarrierMemoryInstr && BarrierMemoryInstr->
comesBefore(MemInstr))
728 auto *MemLoad = dyn_cast<LoadInst>(MemInstr);
729 auto *ChainLoad = dyn_cast<LoadInst>(ChainInstr);
730 if (MemLoad && ChainLoad)
735 auto IsInvariantLoad = [](
const LoadInst *LI) ->
bool {
736 return LI->hasMetadata(LLVMContext::MD_invariant_load);
745 (ChainLoad && IsInvariantLoad(ChainLoad)))
749 if (MemInstr->comesBefore(ChainInstr) ||
750 (MemLoad && IsInvariantLoad(MemLoad)))
758 dbgs() <<
"LSV: Found alias:\n"
759 " Aliasing instruction:\n"
760 <<
" " << *MemInstr <<
'\n'
761 <<
" Aliased instruction and pointer:\n"
762 <<
" " << *ChainInstr <<
'\n'
767 BarrierMemoryInstr = MemInstr;
775 if (IsLoadChain && BarrierMemoryInstr) {
787 ChainInstrs.begin(), ChainInstrs.begin() + ChainInstrIdx);
788 unsigned ChainIdx = 0;
789 for (
unsigned ChainLen = Chain.size(); ChainIdx < ChainLen; ++ChainIdx) {
790 if (!VectorizableChainInstrs.count(Chain[ChainIdx]))
793 return Chain.slice(0, ChainIdx);
798 if (
const auto *Sel = dyn_cast<SelectInst>(ObjPtr)) {
805 return Sel->getCondition();
810 std::pair<InstrListMap, InstrListMap>
812 InstrListMap LoadRefs;
813 InstrListMap StoreRefs;
816 if (!
I.mayReadOrWriteMemory())
819 if (
LoadInst *LI = dyn_cast<LoadInst>(&
I)) {
827 Type *Ty = LI->getType();
833 unsigned TySize =
DL.getTypeSizeInBits(Ty);
834 if ((TySize % 8) != 0)
844 Value *Ptr = LI->getPointerOperand();
848 unsigned VF = VecRegSize / TySize;
852 if (TySize > VecRegSize / 2 ||
858 LoadRefs[
ID].push_back(LI);
867 Type *Ty =
SI->getValueOperand()->getType();
880 unsigned TySize =
DL.getTypeSizeInBits(Ty);
881 if ((TySize % 8) != 0)
884 Value *Ptr =
SI->getPointerOperand();
888 unsigned VF = VecRegSize / TySize;
892 if (TySize > VecRegSize / 2 ||
898 StoreRefs[
ID].push_back(
SI);
902 return {LoadRefs, StoreRefs};
905 bool Vectorizer::vectorizeChains(InstrListMap &Map) {
906 bool Changed =
false;
908 for (
const std::pair<ChainID, InstrList> &Chain : Map) {
909 unsigned Size = Chain.second.size();
913 LLVM_DEBUG(
dbgs() <<
"LSV: Analyzing a chain of length " << Size <<
".\n");
916 for (
unsigned CI = 0, CE = Size; CI <
CE; CI += 64) {
917 unsigned Len = std::min<unsigned>(CE - CI, 64);
919 Changed |= vectorizeInstructions(Chunk);
928 <<
" instructions.\n");
930 int ConsecutiveChain[64];
934 for (
int i = 0,
e = Instrs.
size();
i <
e; ++
i) {
935 ConsecutiveChain[
i] = -1;
936 for (
int j =
e - 1;
j >= 0; --
j) {
941 if (ConsecutiveChain[
i] != -1) {
942 int CurDistance =
std::abs(ConsecutiveChain[
i] -
i);
943 int NewDistance =
std::abs(ConsecutiveChain[
i] -
j);
944 if (j < i || NewDistance > CurDistance)
950 ConsecutiveChain[
i] =
j;
955 bool Changed =
false;
958 for (
int Head : Heads) {
959 if (InstructionsProcessed.
count(Instrs[Head]))
961 bool LongerChainExists =
false;
962 for (
unsigned TIt = 0; TIt < Tails.size(); TIt++)
963 if (Head == Tails[TIt] &&
964 !InstructionsProcessed.
count(Instrs[Heads[TIt]])) {
965 LongerChainExists =
true;
968 if (LongerChainExists)
976 if (InstructionsProcessed.
count(Instrs[
I]))
980 I = ConsecutiveChain[
I];
983 bool Vectorized =
false;
984 if (isa<LoadInst>(*
Operands.begin()))
985 Vectorized = vectorizeLoadChain(
Operands, &InstructionsProcessed);
987 Vectorized = vectorizeStoreChain(
Operands, &InstructionsProcessed);
989 Changed |= Vectorized;
995 bool Vectorizer::vectorizeStoreChain(
998 StoreInst *S0 = cast<StoreInst>(Chain[0]);
1001 Type *StoreTy =
nullptr;
1003 StoreTy = cast<StoreInst>(
I)->getValueOperand()->getType();
1009 DL.getTypeSizeInBits(StoreTy));
1013 assert(StoreTy &&
"Failed to find store type");
1015 unsigned Sz =
DL.getTypeSizeInBits(StoreTy);
1018 unsigned VF = VecRegSize / Sz;
1019 unsigned ChainSize = Chain.size();
1023 InstructionsProcessed->
insert(Chain.begin(), Chain.end());
1028 if (NewChain.
empty()) {
1030 InstructionsProcessed->
insert(Chain.begin(), Chain.end());
1033 if (NewChain.
size() == 1) {
1041 ChainSize = Chain.
size();
1045 unsigned EltSzInBytes = Sz / 8;
1046 unsigned SzInBytes = EltSzInBytes * ChainSize;
1049 auto *VecStoreTy = dyn_cast<FixedVectorType>(StoreTy);
1052 Chain.size() * VecStoreTy->getNumElements());
1059 if (ChainSize > VF || (VF != TargetVF && TargetVF < ChainSize)) {
1060 LLVM_DEBUG(
dbgs() <<
"LSV: Chain doesn't match with the vector factor."
1061 " Creating two separate arrays.\n");
1062 bool Vectorized =
false;
1064 vectorizeStoreChain(Chain.slice(0, TargetVF), InstructionsProcessed);
1066 vectorizeStoreChain(Chain.slice(TargetVF), InstructionsProcessed);
1071 dbgs() <<
"LSV: Stores to vectorize:\n";
1073 dbgs() <<
" " << *
I <<
"\n";
1078 InstructionsProcessed->
insert(Chain.begin(), Chain.end());
1081 if (accessIsMisaligned(SzInBytes, AS, Alignment)) {
1083 auto Chains = splitOddVectorElts(Chain, Sz);
1084 bool Vectorized =
false;
1085 Vectorized |= vectorizeStoreChain(Chains.first, InstructionsProcessed);
1086 Vectorized |= vectorizeStoreChain(Chains.second, InstructionsProcessed);
1092 DL, S0,
nullptr, &DT);
1093 if (NewAlign >= Alignment)
1100 auto Chains = splitOddVectorElts(Chain, Sz);
1101 bool Vectorized =
false;
1102 Vectorized |= vectorizeStoreChain(Chains.first, InstructionsProcessed);
1103 Vectorized |= vectorizeStoreChain(Chains.second, InstructionsProcessed);
1108 std::tie(First, Last) = getBoundaryInstrs(Chain);
1109 Builder.SetInsertPoint(&*Last);
1114 unsigned VecWidth = VecStoreTy->getNumElements();
1115 for (
unsigned I = 0,
E = Chain.size();
I !=
E; ++
I) {
1117 for (
unsigned J = 0,
NE = VecStoreTy->getNumElements(); J !=
NE; ++J) {
1118 unsigned NewIdx = J +
I * VecWidth;
1125 Builder.CreateInsertElement(Vec, Extract,
Builder.getInt32(NewIdx));
1130 for (
unsigned I = 0,
E = Chain.size();
I !=
E; ++
I) {
1149 eraseInstructions(Chain);
1150 ++NumVectorInstructions;
1151 NumScalarsVectorized += Chain.size();
1155 bool Vectorizer::vectorizeLoadChain(
1158 LoadInst *L0 = cast<LoadInst>(Chain[0]);
1161 Type *LoadTy =
nullptr;
1162 for (
const auto &V : Chain) {
1163 LoadTy = cast<LoadInst>(V)->getType();
1169 DL.getTypeSizeInBits(LoadTy));
1173 assert(LoadTy &&
"Can't determine LoadInst type from chain");
1175 unsigned Sz =
DL.getTypeSizeInBits(LoadTy);
1178 unsigned VF = VecRegSize / Sz;
1179 unsigned ChainSize = Chain.size();
1183 InstructionsProcessed->
insert(Chain.begin(), Chain.end());
1188 if (NewChain.
empty()) {
1190 InstructionsProcessed->
insert(Chain.begin(), Chain.end());
1193 if (NewChain.
size() == 1) {
1201 ChainSize = Chain.
size();
1205 unsigned EltSzInBytes = Sz / 8;
1206 unsigned SzInBytes = EltSzInBytes * ChainSize;
1208 auto *VecLoadTy = dyn_cast<FixedVectorType>(LoadTy);
1211 Chain.size() * VecLoadTy->getNumElements());
1218 if (ChainSize > VF || (VF != TargetVF && TargetVF < ChainSize)) {
1219 LLVM_DEBUG(
dbgs() <<
"LSV: Chain doesn't match with the vector factor."
1220 " Creating two separate arrays.\n");
1221 bool Vectorized =
false;
1223 vectorizeLoadChain(Chain.slice(0, TargetVF), InstructionsProcessed);
1225 vectorizeLoadChain(Chain.slice(TargetVF), InstructionsProcessed);
1231 InstructionsProcessed->
insert(Chain.begin(), Chain.end());
1234 if (accessIsMisaligned(SzInBytes, AS, Alignment)) {
1236 auto Chains = splitOddVectorElts(Chain, Sz);
1237 bool Vectorized =
false;
1238 Vectorized |= vectorizeLoadChain(Chains.first, InstructionsProcessed);
1239 Vectorized |= vectorizeLoadChain(Chains.second, InstructionsProcessed);
1245 DL, L0,
nullptr, &DT);
1246 if (NewAlign >= Alignment)
1253 auto Chains = splitOddVectorElts(Chain, Sz);
1254 bool Vectorized =
false;
1255 Vectorized |= vectorizeLoadChain(Chains.first, InstructionsProcessed);
1256 Vectorized |= vectorizeLoadChain(Chains.second, InstructionsProcessed);
1261 dbgs() <<
"LSV: Loads to vectorize:\n";
1270 std::tie(First, Last) = getBoundaryInstrs(Chain);
1271 Builder.SetInsertPoint(&*First);
1279 for (
unsigned I = 0,
E = Chain.size();
I !=
E; ++
I) {
1284 unsigned VecWidth = VecLoadTy->getNumElements();
1286 llvm::to_vector<8>(llvm::seq<int>(
I * VecWidth, (
I + 1) * VecWidth));
1311 eraseInstructions(Chain);
1313 ++NumVectorInstructions;
1314 NumScalarsVectorized += Chain.size();
1318 bool Vectorizer::accessIsMisaligned(
unsigned SzInBytes,
unsigned AddressSpace,
1327 LLVM_DEBUG(
dbgs() <<
"LSV: Target said misaligned is allowed? " << Allows
1328 <<
" and fast? " << Fast <<
"\n";);
1329 return !Allows || !
Fast;