28 #include "llvm/IR/IntrinsicsARM.h"
38 using namespace PatternMatch;
40 #define DEBUG_TYPE "arm-parallel-dsp"
42 STATISTIC(NumSMLAD ,
"Number of smlad instructions generated");
46 cl::desc(
"Disable the ARM Parallel DSP pass"));
50 cl::desc(
"Limit the number of loads analysed"));
66 bool Exchange =
false;
74 bool HasTwoLoadInputs()
const {
75 return isa<LoadInst>(
LHS) && isa<LoadInst>(
RHS);
104 if (
auto *SExt = dyn_cast<SExtInst>(V)) {
105 if (
auto *
I = dyn_cast<Instruction>(SExt->getOperand(0)))
108 }
else if (
auto *
I = dyn_cast<Instruction>(V)) {
116 Value *
LHS = cast<Instruction>(
I->getOperand(0))->getOperand(0);
117 Value *
RHS = cast<Instruction>(
I->getOperand(1))->getOperand(0);
118 Muls.push_back(std::make_unique<MulCandidate>(
I,
LHS,
RHS));
121 for (
auto *Add : Adds) {
124 if (
auto *Mul = GetMulOperand(
Add->getOperand(0)))
126 if (
auto *Mul = GetMulOperand(
Add->getOperand(1)))
134 bool InsertAcc(
Value *V) {
143 void AddMulPair(MulCandidate *Mul0, MulCandidate *Mul1,
144 bool Exchange =
false) {
146 << *Mul0->Root <<
"\n"
147 << *Mul1->Root <<
"\n");
151 Mul1->Exchange =
true;
152 MulPairs.push_back(std::make_pair(Mul0, Mul1));
157 bool CreateParallelPairs();
167 Value *getAccumulator() {
return Acc; }
174 MulCandList &getMuls() {
return Muls; }
178 MulPairList &getMulPairs() {
return MulPairs; }
187 for (
auto *Add : Adds)
189 for (
auto &Mul : Muls)
191 <<
" " << *
Mul->LHS <<
"\n"
192 <<
" " << *
Mul->RHS <<
"\n");
219 std::map<LoadInst*, LoadInst*> LoadPairs;
221 std::map<LoadInst*, std::unique_ptr<WidenedLoad>> WideLoads;
224 bool IsNarrowSequence(
Value *V);
263 SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
264 AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
265 TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(
F);
266 DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
267 auto &TPC = getAnalysis<TargetPassConfig>();
270 DL = &
M->getDataLayout();
275 if (!
ST->allowsUnalignedMem()) {
276 LLVM_DEBUG(
dbgs() <<
"Unaligned memory access not supported: not "
277 "running pass ARMParallelDSP\n");
282 LLVM_DEBUG(
dbgs() <<
"DSP extension not enabled: not running pass "
287 if (!
ST->isLittle()) {
288 LLVM_DEBUG(
dbgs() <<
"Only supporting little endian: not running pass "
289 <<
"ARMParallelDSP\n");
296 bool Changes = MatchSMLAD(
F);
303 MemInstList &VecMem) {
307 if (!LoadPairs.count(Ld0) || LoadPairs[Ld0] != Ld1)
316 VecMem.push_back(Ld0);
317 VecMem.push_back(Ld1);
326 template<
unsigned MaxBitW
idth>
327 bool ARMParallelDSP::IsNarrowSequence(
Value *V) {
328 if (
auto *SExt = dyn_cast<SExtInst>(V)) {
329 if (SExt->getSrcTy()->getIntegerBitWidth() != MaxBitWidth)
332 if (
auto *Ld = dyn_cast<LoadInst>(SExt->getOperand(0))) {
334 return LoadPairs.count(Ld) || OffsetLoads.
count(Ld);
351 for (
auto &
I : *
BB) {
352 if (
I.mayWriteToMemory())
353 Writes.push_back(&
I);
354 auto *Ld = dyn_cast<LoadInst>(&
I);
355 if (!Ld || !Ld->isSimple() ||
356 !Ld->hasOneUse() || !isa<SExtInst>(Ld->user_back()))
364 using InstSet = std::set<Instruction*>;
365 using DepMap = std::map<Instruction*, InstSet>;
370 for (
auto *Write : Writes) {
371 for (
auto *Read : Loads) {
377 if (
Write->comesBefore(Read))
378 RAWDeps[
Read].insert(Write);
385 bool BaseFirst =
Base->comesBefore(Offset);
389 if (RAWDeps.count(Dominated)) {
390 InstSet &WritesBefore = RAWDeps[Dominated];
392 for (
auto *Before : WritesBefore) {
403 for (
auto *Base : Loads) {
404 for (
auto *Offset : Loads) {
405 if (Base == Offset || OffsetLoads.
count(Offset))
409 SafeToPair(Base, Offset)) {
411 OffsetLoads.
insert(Offset);
418 dbgs() <<
"Consecutive load pairs:\n";
419 for (auto &MapIt : LoadPairs) {
420 LLVM_DEBUG(dbgs() << *MapIt.first <<
", "
421 << *MapIt.second <<
"\n");
424 return LoadPairs.size() > 1;
435 auto *
I = dyn_cast<Instruction>(V);
437 return R.InsertAcc(V);
439 if (
I->getParent() !=
BB)
442 switch (
I->getOpcode()) {
447 return R.InsertAcc(V);
455 bool ValidLHS = Search(
LHS,
BB, R);
456 bool ValidRHS = Search(
RHS,
BB, R);
458 if (ValidLHS && ValidRHS)
462 if (
R.getRoot() ==
I)
465 return R.InsertAcc(
I);
468 Value *MulOp0 =
I->getOperand(0);
469 Value *MulOp1 =
I->getOperand(1);
470 return IsNarrowSequence<16>(MulOp0) && IsNarrowSequence<16>(MulOp1);
472 case Instruction::SExt:
473 return Search(
I->getOperand(0),
BB, R);
509 bool ARMParallelDSP::MatchSMLAD(
Function &
F) {
510 bool Changed =
false;
514 if (!RecordMemoryOps(&
BB))
524 const auto *Ty =
I.getType();
525 if (!Ty->isIntegerTy(32) && !Ty->isIntegerTy(64))
529 if (!Search(&
I, &
BB, R))
535 if (!CreateParallelPairs(R))
538 InsertParallelMACs(R);
540 AllAdds.
insert(
R.getAdds().begin(),
R.getAdds().end());
548 bool ARMParallelDSP::CreateParallelPairs(
Reduction &R) {
551 if (
R.getMuls().size() < 2)
555 for (
auto &MulCand :
R.getMuls()) {
556 if (!MulCand->HasTwoLoadInputs())
560 auto CanPair = [&](
Reduction &
R, MulCandidate *PMul0, MulCandidate *PMul1) {
565 auto Ld0 =
static_cast<LoadInst*
>(PMul0->LHS);
566 auto Ld1 =
static_cast<LoadInst*
>(PMul1->LHS);
567 auto Ld2 =
static_cast<LoadInst*
>(PMul0->RHS);
568 auto Ld3 =
static_cast<LoadInst*
>(PMul1->RHS);
571 if (Ld0 == Ld2 || Ld1 == Ld3)
574 if (AreSequentialLoads(Ld0, Ld1, PMul0->VecLd)) {
575 if (AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
577 R.AddMulPair(PMul0, PMul1);
579 }
else if (AreSequentialLoads(Ld3, Ld2, PMul1->VecLd)) {
582 R.AddMulPair(PMul0, PMul1,
true);
585 }
else if (AreSequentialLoads(Ld1, Ld0, PMul0->VecLd) &&
586 AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
591 R.AddMulPair(PMul1, PMul0,
true);
597 MulCandList &Muls =
R.getMuls();
598 const unsigned Elems = Muls.size();
599 for (
unsigned i = 0;
i < Elems; ++
i) {
600 MulCandidate *PMul0 =
static_cast<MulCandidate*
>(Muls[
i].get());
604 for (
unsigned j = 0;
j < Elems; ++
j) {
608 MulCandidate *PMul1 =
static_cast<MulCandidate*
>(Muls[
j].get());
617 assert(PMul0 != PMul1 &&
"expected different chains");
619 if (CanPair(R, PMul0, PMul1))
623 return !
R.getMulPairs().empty();
626 void ARMParallelDSP::InsertParallelMACs(
Reduction &R) {
629 Value *Acc,
bool Exchange,
633 Value*
Args[] = { WideLd0, WideLd1, Acc };
653 assert((isa<Instruction>(A) || isa<Instruction>(
B)) &&
654 "expected at least one instruction");
657 if (!isa<Instruction>(A))
659 else if (!isa<Instruction>(
B))
662 V = DT->
dominates(cast<Instruction>(A), cast<Instruction>(
B)) ?
B :
A;
667 Value *Acc =
R.getAccumulator();
672 MulCandList &MulCands =
R.getMuls();
673 for (
auto &MulCand : MulCands) {
678 LLVM_DEBUG(
dbgs() <<
"Accumulating unpaired mul: " << *Mul <<
"\n");
681 assert(
R.is64Bit() &&
"expected 64-bit result");
683 Mul = cast<Instruction>(
Builder.CreateSExt(Mul,
R.getRoot()->getType()));
694 Builder.SetInsertPoint(GetInsertPoint(Mul, Acc));
695 Acc =
Builder.CreateAdd(Mul, Acc);
702 }
else if (Acc->
getType() !=
R.getType()) {
703 Builder.SetInsertPoint(
R.getRoot());
704 Acc =
Builder.CreateSExt(Acc,
R.getType());
708 llvm::sort(
R.getMulPairs(), [](
auto &PairA,
auto &PairB) {
709 const Instruction *A = PairA.first->Root;
710 const Instruction *B = PairB.first->Root;
711 return A->comesBefore(B);
715 for (
auto &Pair :
R.getMulPairs()) {
716 MulCandidate *LHSMul = Pair.first;
717 MulCandidate *RHSMul = Pair.second;
718 LoadInst *BaseLHS = LHSMul->getBaseLoad();
719 LoadInst *BaseRHS = RHSMul->getBaseLoad();
720 LoadInst *WideLHS = WideLoads.count(BaseLHS) ?
721 WideLoads[BaseLHS]->getLoad() : CreateWideLoad(LHSMul->VecLd, Ty);
722 LoadInst *WideRHS = WideLoads.count(BaseRHS) ?
723 WideLoads[BaseRHS]->getLoad() : CreateWideLoad(RHSMul->VecLd, Ty);
725 Instruction *InsertAfter = GetInsertPoint(WideLHS, WideRHS);
726 InsertAfter = GetInsertPoint(InsertAfter, Acc);
727 Acc = CreateSMLAD(WideLHS, WideRHS, Acc, RHSMul->Exchange, InsertAfter);
729 R.UpdateRoot(cast<Instruction>(Acc));
732 LoadInst* ARMParallelDSP::CreateWideLoad(MemInstList &Loads,
734 assert(Loads.size() == 2 &&
"currently only support widening two loads");
742 assert((BaseSExt && OffsetSExt)
743 &&
"Loads should have a single, extending, user");
747 if (!isa<Instruction>(A) || !isa<Instruction>(
B))
750 auto *
Source = cast<Instruction>(A);
751 auto *
Sink = cast<Instruction>(
B);
772 Value *VecPtr = IRB.CreateBitCast(
Base->getPointerOperand(),
774 LoadInst *WideLoad = IRB.CreateAlignedLoad(LoadTy, VecPtr,
Base->getAlign());
777 MoveBefore(
Base->getPointerOperand(), VecPtr);
778 MoveBefore(VecPtr, WideLoad);
783 Value *Bottom = IRB.CreateTrunc(WideLoad,
Base->getType());
784 Value *NewBaseSExt = IRB.CreateSExt(Bottom, BaseSExt->
getType());
789 Value *Top = IRB.CreateLShr(WideLoad, ShiftVal);
790 Value *Trunc = IRB.CreateTrunc(Top, OffsetTy);
791 Value *NewOffsetSExt = IRB.CreateSExt(Trunc, OffsetSExt->
getType());
795 << *Base <<
"\n" << *
Offset <<
"\n"
796 <<
"Created Wide Load:\n"
799 << *NewBaseSExt <<
"\n"
802 << *NewOffsetSExt <<
"\n");
803 WideLoads.emplace(std::make_pair(Base,
804 std::make_unique<WidenedLoad>(Loads, WideLoad)));
809 return new ARMParallelDSP();
815 "Transform functions to use DSP intrinsics",
false,
false)