39 #define DEBUG_TYPE "expandmemcmp"
41 STATISTIC(NumMemCmpCalls,
"Number of memcmp calls");
42 STATISTIC(NumMemCmpNotConstant,
"Number of memcmp calls without constant size");
44 "Number of memcmp calls with size greater than max size");
45 STATISTIC(NumMemCmpInlined,
"Number of inlined memcmp calls");
49 cl::desc(
"The number of loads per basic block for inline expansion of "
50 "memcmp that is only being compared against zero."));
54 cl::desc(
"Set maximum number of loads used in expanded memcmp"));
58 cl::desc(
"Set maximum number of loads used in expanded memcmp for -Os/Oz"));
65 class MemCmpExpansion {
71 ResultBlock() =
default;
77 unsigned MaxLoadSize = 0;
79 const uint64_t NumLoadsPerBlockForZeroCmp;
80 std::vector<BasicBlock *> LoadCmpBlocks;
83 const bool IsUsedForZeroCmp;
101 LoadEntryVector LoadSequence;
103 void createLoadCmpBlocks();
104 void createResultBlock();
105 void setupResultBlockPHINodes();
106 void setupEndBlockPHINodes();
107 Value *getCompareLoadPairs(
unsigned BlockIndex,
unsigned &LoadIndex);
108 void emitLoadCompareBlock(
unsigned BlockIndex);
109 void emitLoadCompareBlockMultipleLoads(
unsigned BlockIndex,
110 unsigned &LoadIndex);
111 void emitLoadCompareByteBlock(
unsigned BlockIndex,
unsigned OffsetBytes);
112 void emitMemCmpResultBlock();
113 Value *getMemCmpExpansionZeroCase();
114 Value *getMemCmpEqZeroOneBlock();
115 Value *getMemCmpOneBlock();
117 Value *Lhs =
nullptr;
118 Value *Rhs =
nullptr;
120 LoadPair getLoadPair(
Type *LoadSizeType,
bool NeedsBSwap,
Type *CmpSizeType,
121 unsigned OffsetBytes);
123 static LoadEntryVector
125 unsigned MaxNumLoads,
unsigned &NumLoadsNonOneByte);
126 static LoadEntryVector
127 computeOverlappingLoadSequence(
uint64_t Size,
unsigned MaxLoadSize,
128 unsigned MaxNumLoads,
129 unsigned &NumLoadsNonOneByte);
134 const bool IsUsedForZeroCmp,
const DataLayout &TheDataLayout,
137 unsigned getNumBlocks();
138 uint64_t getNumLoads()
const {
return LoadSequence.size(); }
140 Value *getMemCmpExpansion();
145 const unsigned MaxNumLoads,
unsigned &NumLoadsNonOneByte) {
146 NumLoadsNonOneByte = 0;
147 LoadEntryVector LoadSequence;
149 while (Size && !LoadSizes.
empty()) {
150 const unsigned LoadSize = LoadSizes.
front();
152 if (LoadSequence.size() + NumLoadsForThisSize > MaxNumLoads) {
159 if (NumLoadsForThisSize > 0) {
160 for (
uint64_t I = 0;
I < NumLoadsForThisSize; ++
I) {
161 LoadSequence.push_back({LoadSize,
Offset});
165 ++NumLoadsNonOneByte;
174 MemCmpExpansion::computeOverlappingLoadSequence(
uint64_t Size,
175 const unsigned MaxLoadSize,
176 const unsigned MaxNumLoads,
177 unsigned &NumLoadsNonOneByte) {
179 if (Size < 2 || MaxLoadSize < 2)
184 const uint64_t NumNonOverlappingLoads =
Size / MaxLoadSize;
185 assert(NumNonOverlappingLoads &&
"there must be at least one load");
188 Size =
Size - NumNonOverlappingLoads * MaxLoadSize;
195 if ((NumNonOverlappingLoads + 1) > MaxNumLoads)
199 LoadEntryVector LoadSequence;
201 for (
uint64_t I = 0;
I < NumNonOverlappingLoads; ++
I) {
202 LoadSequence.push_back({MaxLoadSize,
Offset});
207 assert(Size > 0 && Size < MaxLoadSize &&
"broken invariant");
208 LoadSequence.push_back({MaxLoadSize,
Offset - (MaxLoadSize -
Size)});
209 NumLoadsNonOneByte = 1;
221 MemCmpExpansion::MemCmpExpansion(
224 const bool IsUsedForZeroCmp,
const DataLayout &TheDataLayout,
226 : CI(CI),
Size(
Size), NumLoadsPerBlockForZeroCmp(
Options.NumLoadsPerBlock),
227 IsUsedForZeroCmp(IsUsedForZeroCmp),
DL(TheDataLayout), DTU(DTU),
235 assert(!LoadSizes.
empty() &&
"cannot load Size bytes");
236 MaxLoadSize = LoadSizes.
front();
238 unsigned GreedyNumLoadsNonOneByte = 0;
239 LoadSequence = computeGreedyLoadSequence(
Size, LoadSizes,
Options.MaxNumLoads,
240 GreedyNumLoadsNonOneByte);
241 NumLoadsNonOneByte = GreedyNumLoadsNonOneByte;
242 assert(LoadSequence.size() <=
Options.MaxNumLoads &&
"broken invariant");
245 if (
Options.AllowOverlappingLoads &&
246 (LoadSequence.empty() || LoadSequence.size() > 2)) {
247 unsigned OverlappingNumLoadsNonOneByte = 0;
248 auto OverlappingLoads = computeOverlappingLoadSequence(
249 Size, MaxLoadSize,
Options.MaxNumLoads, OverlappingNumLoadsNonOneByte);
250 if (!OverlappingLoads.empty() &&
251 (LoadSequence.empty() ||
252 OverlappingLoads.size() < LoadSequence.size())) {
253 LoadSequence = OverlappingLoads;
254 NumLoadsNonOneByte = OverlappingNumLoadsNonOneByte;
257 assert(LoadSequence.size() <=
Options.MaxNumLoads &&
"broken invariant");
260 unsigned MemCmpExpansion::getNumBlocks() {
261 if (IsUsedForZeroCmp)
262 return getNumLoads() / NumLoadsPerBlockForZeroCmp +
263 (getNumLoads() % NumLoadsPerBlockForZeroCmp != 0 ? 1 : 0);
264 return getNumLoads();
267 void MemCmpExpansion::createLoadCmpBlocks() {
268 for (
unsigned i = 0;
i < getNumBlocks();
i++) {
271 LoadCmpBlocks.push_back(
BB);
275 void MemCmpExpansion::createResultBlock() {
276 ResBlock.BB = BasicBlock::Create(CI->
getContext(),
"res_block",
280 MemCmpExpansion::LoadPair MemCmpExpansion::getLoadPair(
Type *LoadSizeType,
283 unsigned OffsetBytes) {
289 if (OffsetBytes > 0) {
290 auto *ByteType = Type::getInt8Ty(CI->
getContext());
291 LhsSource =
Builder.CreateConstGEP1_64(
292 ByteType,
Builder.CreateBitCast(LhsSource, ByteType->getPointerTo()),
294 RhsSource =
Builder.CreateConstGEP1_64(
295 ByteType,
Builder.CreateBitCast(RhsSource, ByteType->getPointerTo()),
304 Value *Lhs =
nullptr;
305 if (
auto *
C = dyn_cast<Constant>(LhsSource))
308 Lhs =
Builder.CreateAlignedLoad(LoadSizeType, LhsSource, LhsAlign);
310 Value *Rhs =
nullptr;
311 if (
auto *
C = dyn_cast<Constant>(RhsSource))
314 Rhs =
Builder.CreateAlignedLoad(LoadSizeType, RhsSource, RhsAlign);
319 Intrinsic::bswap, LoadSizeType);
320 Lhs =
Builder.CreateCall(Bswap, Lhs);
321 Rhs =
Builder.CreateCall(Bswap, Rhs);
325 if (CmpSizeType !=
nullptr && CmpSizeType != LoadSizeType) {
326 Lhs =
Builder.CreateZExt(Lhs, CmpSizeType);
327 Rhs =
Builder.CreateZExt(Rhs, CmpSizeType);
336 void MemCmpExpansion::emitLoadCompareByteBlock(
unsigned BlockIndex,
337 unsigned OffsetBytes) {
340 const LoadPair Loads =
341 getLoadPair(Type::getInt8Ty(CI->
getContext()),
false,
342 Type::getInt32Ty(CI->
getContext()), OffsetBytes);
347 if (BlockIndex < (LoadCmpBlocks.size() - 1)) {
353 BranchInst::Create(EndBlock, LoadCmpBlocks[BlockIndex + 1], Cmp);
361 BranchInst *CmpBr = BranchInst::Create(EndBlock);
371 Value *MemCmpExpansion::getCompareLoadPairs(
unsigned BlockIndex,
372 unsigned &LoadIndex) {
373 assert(LoadIndex < getNumLoads() &&
374 "getCompareLoadPairs() called with no remaining loads");
375 std::vector<Value *> XorList, OrList;
376 Value *Diff =
nullptr;
378 const unsigned NumLoads =
379 std::min(getNumLoads() - LoadIndex, NumLoadsPerBlockForZeroCmp);
382 if (LoadCmpBlocks.empty())
385 Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);
392 NumLoads == 1 ? nullptr
394 for (
unsigned i = 0;
i < NumLoads; ++
i, ++LoadIndex) {
395 const LoadEntry &CurLoadEntry = LoadSequence[LoadIndex];
396 const LoadPair Loads = getLoadPair(
398 false, MaxLoadType, CurLoadEntry.Offset);
403 Diff =
Builder.CreateXor(Loads.Lhs, Loads.Rhs);
404 Diff =
Builder.CreateZExt(Diff, MaxLoadType);
405 XorList.push_back(Diff);
408 Cmp =
Builder.CreateICmpNE(Loads.Lhs, Loads.Rhs);
412 auto pairWiseOr = [&](std::vector<Value *> &InList) -> std::vector<Value *> {
413 std::vector<Value *> OutList;
414 for (
unsigned i = 0;
i < InList.size() - 1;
i =
i + 2) {
416 OutList.push_back(Or);
418 if (InList.size() % 2 != 0)
419 OutList.push_back(InList.back());
425 OrList = pairWiseOr(XorList);
428 while (OrList.size() != 1) {
429 OrList = pairWiseOr(OrList);
432 assert(Diff &&
"Failed to find comparison diff");
439 void MemCmpExpansion::emitLoadCompareBlockMultipleLoads(
unsigned BlockIndex,
440 unsigned &LoadIndex) {
441 Value *
Cmp = getCompareLoadPairs(BlockIndex, LoadIndex);
443 BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1))
445 : LoadCmpBlocks[BlockIndex + 1];
449 BranchInst *CmpBr = BranchInst::Create(ResBlock.BB, NextBB, Cmp);
458 if (BlockIndex == LoadCmpBlocks.size() - 1) {
460 PhiRes->
addIncoming(Zero, LoadCmpBlocks[BlockIndex]);
473 void MemCmpExpansion::emitLoadCompareBlock(
unsigned BlockIndex) {
475 const LoadEntry &CurLoadEntry = LoadSequence[BlockIndex];
477 if (CurLoadEntry.LoadSize == 1) {
478 MemCmpExpansion::emitLoadCompareByteBlock(BlockIndex, CurLoadEntry.Offset);
485 assert(CurLoadEntry.LoadSize <= MaxLoadSize &&
"Unexpected load type");
487 Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);
489 const LoadPair Loads =
490 getLoadPair(LoadSizeType,
DL.isLittleEndian(), MaxLoadType,
491 CurLoadEntry.Offset);
495 if (!IsUsedForZeroCmp) {
496 ResBlock.PhiSrc1->addIncoming(Loads.Lhs, LoadCmpBlocks[BlockIndex]);
497 ResBlock.PhiSrc2->addIncoming(Loads.Rhs, LoadCmpBlocks[BlockIndex]);
500 Value *
Cmp =
Builder.CreateICmp(ICmpInst::ICMP_EQ, Loads.Lhs, Loads.Rhs);
501 BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1))
503 : LoadCmpBlocks[BlockIndex + 1];
507 BranchInst *CmpBr = BranchInst::Create(NextBB, ResBlock.BB, Cmp);
516 if (BlockIndex == LoadCmpBlocks.size() - 1) {
518 PhiRes->
addIncoming(Zero, LoadCmpBlocks[BlockIndex]);
525 void MemCmpExpansion::emitMemCmpResultBlock() {
528 if (IsUsedForZeroCmp) {
530 Builder.SetInsertPoint(ResBlock.BB, InsertPt);
533 BranchInst *NewBr = BranchInst::Create(EndBlock);
540 Builder.SetInsertPoint(ResBlock.BB, InsertPt);
550 BranchInst *NewBr = BranchInst::Create(EndBlock);
556 void MemCmpExpansion::setupResultBlockPHINodes() {
558 Builder.SetInsertPoint(ResBlock.BB);
561 Builder.CreatePHI(MaxLoadType, NumLoadsNonOneByte,
"phi.src1");
563 Builder.CreatePHI(MaxLoadType, NumLoadsNonOneByte,
"phi.src2");
566 void MemCmpExpansion::setupEndBlockPHINodes() {
571 Value *MemCmpExpansion::getMemCmpExpansionZeroCase() {
572 unsigned LoadIndex = 0;
575 for (
unsigned I = 0;
I < getNumBlocks(); ++
I) {
576 emitLoadCompareBlockMultipleLoads(
I, LoadIndex);
579 emitMemCmpResultBlock();
586 Value *MemCmpExpansion::getMemCmpEqZeroOneBlock() {
587 unsigned LoadIndex = 0;
588 Value *
Cmp = getCompareLoadPairs(0, LoadIndex);
589 assert(LoadIndex == getNumLoads() &&
"some entries were not consumed");
595 Value *MemCmpExpansion::getMemCmpOneBlock() {
597 bool NeedsBSwap =
DL.isLittleEndian() &&
Size != 1;
602 const LoadPair Loads =
603 getLoadPair(LoadSizeType, NeedsBSwap,
Builder.getInt32Ty(),
605 return Builder.CreateSub(Loads.Lhs, Loads.Rhs);
608 const LoadPair Loads = getLoadPair(LoadSizeType, NeedsBSwap, LoadSizeType,
616 Value *CmpUGT =
Builder.CreateICmpUGT(Loads.Lhs, Loads.Rhs);
617 Value *CmpULT =
Builder.CreateICmpULT(Loads.Lhs, Loads.Rhs);
620 return Builder.CreateSub(ZextUGT, ZextULT);
625 Value *MemCmpExpansion::getMemCmpExpansion() {
627 if (getNumBlocks() != 1) {
629 EndBlock =
SplitBlock(StartBlock, CI, DTU,
nullptr,
630 nullptr,
"endblock");
631 setupEndBlockPHINodes();
638 if (!IsUsedForZeroCmp) setupResultBlockPHINodes();
641 createLoadCmpBlocks();
648 {DominatorTree::Delete, StartBlock, EndBlock}});
653 if (IsUsedForZeroCmp)
654 return getNumBlocks() == 1 ? getMemCmpEqZeroOneBlock()
655 : getMemCmpExpansionZeroCase();
657 if (getNumBlocks() == 1)
658 return getMemCmpOneBlock();
660 for (
unsigned I = 0;
I < getNumBlocks(); ++
I) {
661 emitLoadCompareBlock(
I);
664 emitMemCmpResultBlock();
754 NumMemCmpNotConstant++;
764 const bool IsUsedForZeroCmp =
782 MemCmpExpansion Expansion(CI, SizeVal,
Options, IsUsedForZeroCmp, *
DL, DTU);
785 if (Expansion.getNumLoads() == 0) {
786 NumMemCmpGreaterThanMax++;
792 Value *Res = Expansion.getMemCmpExpansion();
812 auto *TPC = getAnalysisIfAvailable<TargetPassConfig>();
817 TPC->getTM<
TargetMachine>().getSubtargetImpl(
F)->getTargetLowering();
820 &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(
F);
822 &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
F);
823 auto *PSI = &getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI();
825 &getAnalysis<LazyBlockFrequencyInfoPass>().getBFI() :
828 if (
auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>())
829 DT = &DTWP->getDomTree();
831 return !PA.areAllPreserved();
868 (Func == LibFunc_memcmp || Func == LibFunc_bcmp) &&
869 expandMemCmp(CI,
TTI, TL, &
DL, PSI,
BFI, DTU, Func == LibFunc_bcmp)) {
881 std::optional<DomTreeUpdater> DTU;
883 DTU.emplace(DT, DomTreeUpdater::UpdateStrategy::Lazy);
886 bool MadeChanges =
false;
887 for (
auto BBIt =
F.begin(); BBIt !=
F.end();) {
888 if (runOnBlock(*BBIt, TLI,
TTI, TL,
DL, PSI,
BFI, DTU ? &*DTU :
nullptr)) {
911 "Expand memcmp() to load/stores",
false,
false)
921 return new ExpandMemCmpPass();