LLVM  10.0.0svn
ARMParallelDSP.cpp
Go to the documentation of this file.
1 //===- ParallelDSP.cpp - Parallel DSP Pass --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// Armv6 introduced instructions to perform 32-bit SIMD operations. The
11 /// purpose of this pass is do some IR pattern matching to create ACLE
12 /// DSP intrinsics, which map on these 32-bit SIMD operations.
13 /// This pass runs only when unaligned accesses is supported/enabled.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "llvm/ADT/Statistic.h"
18 #include "llvm/ADT/SmallPtrSet.h"
21 #include "llvm/Analysis/LoopPass.h"
22 #include "llvm/Analysis/LoopInfo.h"
23 #include "llvm/IR/Instructions.h"
24 #include "llvm/IR/NoFolder.h"
25 #include "llvm/Transforms/Scalar.h"
28 #include "llvm/Pass.h"
29 #include "llvm/PassRegistry.h"
30 #include "llvm/PassSupport.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/IR/PatternMatch.h"
34 #include "ARM.h"
35 #include "ARMSubtarget.h"
36 
37 using namespace llvm;
38 using namespace PatternMatch;
39 
40 #define DEBUG_TYPE "arm-parallel-dsp"
41 
42 STATISTIC(NumSMLAD , "Number of smlad instructions generated");
43 
44 static cl::opt<bool>
45 DisableParallelDSP("disable-arm-parallel-dsp", cl::Hidden, cl::init(false),
46  cl::desc("Disable the ARM Parallel DSP pass"));
47 
48 namespace {
49  struct OpChain;
50  struct BinOpChain;
51  class Reduction;
52 
53  using OpChainList = SmallVector<std::unique_ptr<OpChain>, 8>;
54  using ReductionList = SmallVector<Reduction, 8>;
55  using ValueList = SmallVector<Value*, 8>;
56  using MemInstList = SmallVector<LoadInst*, 8>;
57  using PMACPair = std::pair<BinOpChain*,BinOpChain*>;
58  using PMACPairList = SmallVector<PMACPair, 8>;
59  using Instructions = SmallVector<Instruction*,16>;
60  using MemLocList = SmallVector<MemoryLocation, 4>;
61 
62  struct OpChain {
63  Instruction *Root;
64  ValueList AllValues;
65  MemInstList VecLd; // List of all load instructions.
66  MemInstList Loads;
67  bool ReadOnly = true;
68 
69  OpChain(Instruction *I, ValueList &vl) : Root(I), AllValues(vl) { }
70  virtual ~OpChain() = default;
71 
72  void PopulateLoads() {
73  for (auto *V : AllValues) {
74  if (auto *Ld = dyn_cast<LoadInst>(V))
75  Loads.push_back(Ld);
76  }
77  }
78 
79  unsigned size() const { return AllValues.size(); }
80  };
81 
82  // 'BinOpChain' holds the multiplication instructions that are candidates
83  // for parallel execution.
84  struct BinOpChain : public OpChain {
85  ValueList LHS; // List of all (narrow) left hand operands.
86  ValueList RHS; // List of all (narrow) right hand operands.
87  bool Exchange = false;
88 
89  BinOpChain(Instruction *I, ValueList &lhs, ValueList &rhs) :
90  OpChain(I, lhs), LHS(lhs), RHS(rhs) {
91  for (auto *V : RHS)
92  AllValues.push_back(V);
93  }
94 
95  bool AreSymmetrical(BinOpChain *Other);
96  };
97 
98  /// Represent a sequence of multiply-accumulate operations with the aim to
99  /// perform the multiplications in parallel.
100  class Reduction {
101  Instruction *Root = nullptr;
102  Value *Acc = nullptr;
103  OpChainList Muls;
104  PMACPairList MulPairs;
106 
107  public:
108  Reduction() = delete;
109 
110  Reduction (Instruction *Add) : Root(Add) { }
111 
112  /// Record an Add instruction that is a part of the this reduction.
113  void InsertAdd(Instruction *I) { Adds.insert(I); }
114 
115  /// Record a BinOpChain, rooted at a Mul instruction, that is a part of
116  /// this reduction.
117  void InsertMul(Instruction *I, ValueList &LHS, ValueList &RHS) {
118  Muls.push_back(make_unique<BinOpChain>(I, LHS, RHS));
119  }
120 
121  /// Add the incoming accumulator value, returns true if a value had not
122  /// already been added. Returning false signals to the user that this
123  /// reduction already has a value to initialise the accumulator.
124  bool InsertAcc(Value *V) {
125  if (Acc)
126  return false;
127  Acc = V;
128  return true;
129  }
130 
131  /// Set two BinOpChains, rooted at muls, that can be executed as a single
132  /// parallel operation.
133  void AddMulPair(BinOpChain *Mul0, BinOpChain *Mul1) {
134  MulPairs.push_back(std::make_pair(Mul0, Mul1));
135  }
136 
137  /// Return true if enough mul operations are found that can be executed in
138  /// parallel.
139  bool CreateParallelPairs();
140 
141  /// Return the add instruction which is the root of the reduction.
142  Instruction *getRoot() { return Root; }
143 
144  /// Return the incoming value to be accumulated. This maybe null.
145  Value *getAccumulator() { return Acc; }
146 
147  /// Return the set of adds that comprise the reduction.
148  SmallPtrSetImpl<Instruction*> &getAdds() { return Adds; }
149 
150  /// Return the BinOpChain, rooted at mul instruction, that comprise the
151  /// the reduction.
152  OpChainList &getMuls() { return Muls; }
153 
154  /// Return the BinOpChain, rooted at mul instructions, that have been
155  /// paired for parallel execution.
156  PMACPairList &getMulPairs() { return MulPairs; }
157 
158  /// To finalise, replace the uses of the root with the intrinsic call.
159  void UpdateRoot(Instruction *SMLAD) {
160  Root->replaceAllUsesWith(SMLAD);
161  }
162  };
163 
164  class WidenedLoad {
165  LoadInst *NewLd = nullptr;
167 
168  public:
169  WidenedLoad(SmallVectorImpl<LoadInst*> &Lds, LoadInst *Wide)
170  : NewLd(Wide) {
171  for (auto *I : Lds)
172  Loads.push_back(I);
173  }
174  LoadInst *getLoad() {
175  return NewLd;
176  }
177  };
178 
179  class ARMParallelDSP : public LoopPass {
180  ScalarEvolution *SE;
181  AliasAnalysis *AA;
182  TargetLibraryInfo *TLI;
183  DominatorTree *DT;
184  LoopInfo *LI;
185  Loop *L;
186  const DataLayout *DL;
187  Module *M;
188  std::map<LoadInst*, LoadInst*> LoadPairs;
189  SmallPtrSet<LoadInst*, 4> OffsetLoads;
190  std::map<LoadInst*, std::unique_ptr<WidenedLoad>> WideLoads;
191 
192  template<unsigned>
193  bool IsNarrowSequence(Value *V, ValueList &VL);
194 
195  bool RecordMemoryOps(BasicBlock *BB);
196  void InsertParallelMACs(Reduction &Reduction);
197  bool AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1, MemInstList &VecMem);
198  LoadInst* CreateWideLoad(SmallVectorImpl<LoadInst*> &Loads,
199  IntegerType *LoadTy);
200  bool CreateParallelPairs(Reduction &R);
201 
202  /// Try to match and generate: SMLAD, SMLADX - Signed Multiply Accumulate
203  /// Dual performs two signed 16x16-bit multiplications. It adds the
204  /// products to a 32-bit accumulate operand. Optionally, the instruction can
205  /// exchange the halfwords of the second operand before performing the
206  /// arithmetic.
207  bool MatchSMLAD(Loop *L);
208 
209  public:
210  static char ID;
211 
212  ARMParallelDSP() : LoopPass(ID) { }
213 
214  bool doInitialization(Loop *L, LPPassManager &LPM) override {
215  LoadPairs.clear();
216  WideLoads.clear();
217  return true;
218  }
219 
220  void getAnalysisUsage(AnalysisUsage &AU) const override {
230  AU.setPreservesCFG();
231  }
232 
233  bool runOnLoop(Loop *TheLoop, LPPassManager &) override {
234  if (DisableParallelDSP)
235  return false;
236  L = TheLoop;
237  SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
238  AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
239  TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
240  DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
241  LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
242  auto &TPC = getAnalysis<TargetPassConfig>();
243 
244  BasicBlock *Header = TheLoop->getHeader();
245  if (!Header)
246  return false;
247 
248  // TODO: We assume the loop header and latch to be the same block.
249  // This is not a fundamental restriction, but lifting this would just
250  // require more work to do the transformation and then patch up the CFG.
251  if (Header != TheLoop->getLoopLatch()) {
252  LLVM_DEBUG(dbgs() << "The loop header is not the loop latch: not "
253  "running pass ARMParallelDSP\n");
254  return false;
255  }
256 
257  if (!TheLoop->getLoopPreheader())
258  InsertPreheaderForLoop(L, DT, LI, nullptr, true);
259 
260  Function &F = *Header->getParent();
261  M = F.getParent();
262  DL = &M->getDataLayout();
263 
264  auto &TM = TPC.getTM<TargetMachine>();
265  auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
266 
267  if (!ST->allowsUnalignedMem()) {
268  LLVM_DEBUG(dbgs() << "Unaligned memory access not supported: not "
269  "running pass ARMParallelDSP\n");
270  return false;
271  }
272 
273  if (!ST->hasDSP()) {
274  LLVM_DEBUG(dbgs() << "DSP extension not enabled: not running pass "
275  "ARMParallelDSP\n");
276  return false;
277  }
278 
279  if (!ST->isLittle()) {
280  LLVM_DEBUG(dbgs() << "Only supporting little endian: not running pass "
281  << "ARMParallelDSP\n");
282  return false;
283  }
284 
285  LoopAccessInfo LAI(L, SE, TLI, AA, DT, LI);
286 
287  LLVM_DEBUG(dbgs() << "\n== Parallel DSP pass ==\n");
288  LLVM_DEBUG(dbgs() << " - " << F.getName() << "\n\n");
289 
290  if (!RecordMemoryOps(Header)) {
291  LLVM_DEBUG(dbgs() << " - No sequential loads found.\n");
292  return false;
293  }
294 
295  bool Changes = MatchSMLAD(L);
296  return Changes;
297  }
298  };
299 }
300 
301 template<typename MemInst>
302 static bool AreSequentialAccesses(MemInst *MemOp0, MemInst *MemOp1,
303  const DataLayout &DL, ScalarEvolution &SE) {
304  if (isConsecutiveAccess(MemOp0, MemOp1, DL, SE))
305  return true;
306  return false;
307 }
308 
309 bool ARMParallelDSP::AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1,
310  MemInstList &VecMem) {
311  if (!Ld0 || !Ld1)
312  return false;
313 
314  if (!LoadPairs.count(Ld0) || LoadPairs[Ld0] != Ld1)
315  return false;
316 
317  LLVM_DEBUG(dbgs() << "Loads are sequential and valid:\n";
318  dbgs() << "Ld0:"; Ld0->dump();
319  dbgs() << "Ld1:"; Ld1->dump();
320  );
321 
322  VecMem.clear();
323  VecMem.push_back(Ld0);
324  VecMem.push_back(Ld1);
325  return true;
326 }
327 
328 // MaxBitwidth: the maximum supported bitwidth of the elements in the DSP
329 // instructions, which is set to 16. So here we should collect all i8 and i16
330 // narrow operations.
331 // TODO: we currently only collect i16, and will support i8 later, so that's
332 // why we check that types are equal to MaxBitWidth, and not <= MaxBitWidth.
333 template<unsigned MaxBitWidth>
334 bool ARMParallelDSP::IsNarrowSequence(Value *V, ValueList &VL) {
335  ConstantInt *CInt;
336 
337  if (match(V, m_ConstantInt(CInt))) {
338  // TODO: if a constant is used, it needs to fit within the bit width.
339  return false;
340  }
341 
342  auto *I = dyn_cast<Instruction>(V);
343  if (!I)
344  return false;
345 
346  Value *Val, *LHS, *RHS;
347  if (match(V, m_Trunc(m_Value(Val)))) {
348  if (cast<TruncInst>(I)->getDestTy()->getIntegerBitWidth() == MaxBitWidth)
349  return IsNarrowSequence<MaxBitWidth>(Val, VL);
350  } else if (match(V, m_Add(m_Value(LHS), m_Value(RHS)))) {
351  // TODO: we need to implement sadd16/sadd8 for this, which enables to
352  // also do the rewrite for smlad8.ll, but it is unsupported for now.
353  return false;
354  } else if (match(V, m_ZExtOrSExt(m_Value(Val)))) {
355  if (cast<CastInst>(I)->getSrcTy()->getIntegerBitWidth() != MaxBitWidth)
356  return false;
357 
358  if (match(Val, m_Load(m_Value()))) {
359  auto *Ld = cast<LoadInst>(Val);
360 
361  // Check that these load could be paired.
362  if (!LoadPairs.count(Ld) && !OffsetLoads.count(Ld))
363  return false;
364 
365  VL.push_back(Val);
366  VL.push_back(I);
367  return true;
368  }
369  }
370  return false;
371 }
372 
373 /// Iterate through the block and record base, offset pairs of loads which can
374 /// be widened into a single load.
375 bool ARMParallelDSP::RecordMemoryOps(BasicBlock *BB) {
378 
379  // Collect loads and instruction that may write to memory. For now we only
380  // record loads which are simple, sign-extended and have a single user.
381  // TODO: Allow zero-extended loads.
382  for (auto &I : *BB) {
383  if (I.mayWriteToMemory())
384  Writes.push_back(&I);
385  auto *Ld = dyn_cast<LoadInst>(&I);
386  if (!Ld || !Ld->isSimple() ||
387  !Ld->hasOneUse() || !isa<SExtInst>(Ld->user_back()))
388  continue;
389  Loads.push_back(Ld);
390  }
391 
392  using InstSet = std::set<Instruction*>;
393  using DepMap = std::map<Instruction*, InstSet>;
394  DepMap RAWDeps;
395 
396  // Record any writes that may alias a load.
397  const auto Size = LocationSize::unknown();
398  for (auto Read : Loads) {
399  for (auto Write : Writes) {
400  MemoryLocation ReadLoc =
401  MemoryLocation(Read->getPointerOperand(), Size);
402 
403  if (!isModOrRefSet(intersectModRef(AA->getModRefInfo(Write, ReadLoc),
405  continue;
406  if (DT->dominates(Write, Read))
407  RAWDeps[Read].insert(Write);
408  }
409  }
410 
411  // Check whether there's not a write between the two loads which would
412  // prevent them from being safely merged.
413  auto SafeToPair = [&](LoadInst *Base, LoadInst *Offset) {
414  LoadInst *Dominator = DT->dominates(Base, Offset) ? Base : Offset;
415  LoadInst *Dominated = DT->dominates(Base, Offset) ? Offset : Base;
416 
417  if (RAWDeps.count(Dominated)) {
418  InstSet &WritesBefore = RAWDeps[Dominated];
419 
420  for (auto Before : WritesBefore) {
421 
422  // We can't move the second load backward, past a write, to merge
423  // with the first load.
424  if (DT->dominates(Dominator, Before))
425  return false;
426  }
427  }
428  return true;
429  };
430 
431  // Record base, offset load pairs.
432  for (auto *Base : Loads) {
433  for (auto *Offset : Loads) {
434  if (Base == Offset)
435  continue;
436 
437  if (AreSequentialAccesses<LoadInst>(Base, Offset, *DL, *SE) &&
438  SafeToPair(Base, Offset)) {
439  LoadPairs[Base] = Offset;
440  OffsetLoads.insert(Offset);
441  break;
442  }
443  }
444  }
445 
446  LLVM_DEBUG(if (!LoadPairs.empty()) {
447  dbgs() << "Consecutive load pairs:\n";
448  for (auto &MapIt : LoadPairs) {
449  LLVM_DEBUG(dbgs() << *MapIt.first << ", "
450  << *MapIt.second << "\n");
451  }
452  });
453  return LoadPairs.size() > 1;
454 }
455 
456 // Loop Pass that needs to identify integer add/sub reductions of 16-bit vector
457 // multiplications.
458 // To use SMLAD:
459 // 1) we first need to find integer add then look for this pattern:
460 //
461 // acc0 = ...
462 // ld0 = load i16
463 // sext0 = sext i16 %ld0 to i32
464 // ld1 = load i16
465 // sext1 = sext i16 %ld1 to i32
466 // mul0 = mul %sext0, %sext1
467 // ld2 = load i16
468 // sext2 = sext i16 %ld2 to i32
469 // ld3 = load i16
470 // sext3 = sext i16 %ld3 to i32
471 // mul1 = mul i32 %sext2, %sext3
472 // add0 = add i32 %mul0, %acc0
473 // acc1 = add i32 %add0, %mul1
474 //
475 // Which can be selected to:
476 //
477 // ldr r0
478 // ldr r1
479 // smlad r2, r0, r1, r2
480 //
481 // If constants are used instead of loads, these will need to be hoisted
482 // out and into a register.
483 //
484 // If loop invariants are used instead of loads, these need to be packed
485 // before the loop begins.
486 //
487 bool ARMParallelDSP::MatchSMLAD(Loop *L) {
488  // Search recursively back through the operands to find a tree of values that
489  // form a multiply-accumulate chain. The search records the Add and Mul
490  // instructions that form the reduction and allows us to find a single value
491  // to be used as the initial input to the accumlator.
492  std::function<bool(Value*, Reduction&)> Search = [&]
493  (Value *V, Reduction &R) -> bool {
494 
495  // If we find a non-instruction, try to use it as the initial accumulator
496  // value. This may have already been found during the search in which case
497  // this function will return false, signaling a search fail.
498  auto *I = dyn_cast<Instruction>(V);
499  if (!I)
500  return R.InsertAcc(V);
501 
502  switch (I->getOpcode()) {
503  default:
504  break;
505  case Instruction::PHI:
506  // Could be the accumulator value.
507  return R.InsertAcc(V);
508  case Instruction::Add: {
509  // Adds should be adding together two muls, or another add and a mul to
510  // be within the mac chain. One of the operands may also be the
511  // accumulator value at which point we should stop searching.
512  bool ValidLHS = Search(I->getOperand(0), R);
513  bool ValidRHS = Search(I->getOperand(1), R);
514  if (!ValidLHS && !ValidLHS)
515  return false;
516  else if (ValidLHS && ValidRHS) {
517  R.InsertAdd(I);
518  return true;
519  } else {
520  R.InsertAdd(I);
521  return R.InsertAcc(I);
522  }
523  }
524  case Instruction::Mul: {
525  Value *MulOp0 = I->getOperand(0);
526  Value *MulOp1 = I->getOperand(1);
527  if (isa<SExtInst>(MulOp0) && isa<SExtInst>(MulOp1)) {
528  ValueList LHS;
529  ValueList RHS;
530  if (IsNarrowSequence<16>(MulOp0, LHS) &&
531  IsNarrowSequence<16>(MulOp1, RHS)) {
532  R.InsertMul(I, LHS, RHS);
533  return true;
534  }
535  }
536  return false;
537  }
538  case Instruction::SExt:
539  return Search(I->getOperand(0), R);
540  }
541  return false;
542  };
543 
544  bool Changed = false;
546  BasicBlock *Latch = L->getLoopLatch();
547 
548  for (Instruction &I : reverse(*Latch)) {
549  if (I.getOpcode() != Instruction::Add)
550  continue;
551 
552  if (AllAdds.count(&I))
553  continue;
554 
555  const auto *Ty = I.getType();
556  if (!Ty->isIntegerTy(32) && !Ty->isIntegerTy(64))
557  continue;
558 
559  Reduction R(&I);
560  if (!Search(&I, R))
561  continue;
562 
563  if (!CreateParallelPairs(R))
564  continue;
565 
566  InsertParallelMACs(R);
567  Changed = true;
568  AllAdds.insert(R.getAdds().begin(), R.getAdds().end());
569  }
570 
571  return Changed;
572 }
573 
574 bool ARMParallelDSP::CreateParallelPairs(Reduction &R) {
575 
576  // Not enough mul operations to make a pair.
577  if (R.getMuls().size() < 2)
578  return false;
579 
580  // Check that the muls operate directly upon sign extended loads.
581  for (auto &MulChain : R.getMuls()) {
582  // A mul has 2 operands, and a narrow op consist of sext and a load; thus
583  // we expect at least 4 items in this operand value list.
584  if (MulChain->size() < 4) {
585  LLVM_DEBUG(dbgs() << "Operand list too short.\n");
586  return false;
587  }
588  MulChain->PopulateLoads();
589  ValueList &LHS = static_cast<BinOpChain*>(MulChain.get())->LHS;
590  ValueList &RHS = static_cast<BinOpChain*>(MulChain.get())->RHS;
591 
592  // Use +=2 to skip over the expected extend instructions.
593  for (unsigned i = 0, e = LHS.size(); i < e; i += 2) {
594  if (!isa<LoadInst>(LHS[i]) || !isa<LoadInst>(RHS[i]))
595  return false;
596  }
597  }
598 
599  auto CanPair = [&](Reduction &R, BinOpChain *PMul0, BinOpChain *PMul1) {
600  if (!PMul0->AreSymmetrical(PMul1))
601  return false;
602 
603  // The first elements of each vector should be loads with sexts. If we
604  // find that its two pairs of consecutive loads, then these can be
605  // transformed into two wider loads and the users can be replaced with
606  // DSP intrinsics.
607  for (unsigned x = 0; x < PMul0->LHS.size(); x += 2) {
608  auto *Ld0 = dyn_cast<LoadInst>(PMul0->LHS[x]);
609  auto *Ld1 = dyn_cast<LoadInst>(PMul1->LHS[x]);
610  auto *Ld2 = dyn_cast<LoadInst>(PMul0->RHS[x]);
611  auto *Ld3 = dyn_cast<LoadInst>(PMul1->RHS[x]);
612 
613  if (!Ld0 || !Ld1 || !Ld2 || !Ld3)
614  return false;
615 
616  LLVM_DEBUG(dbgs() << "Loads:\n"
617  << " - " << *Ld0 << "\n"
618  << " - " << *Ld1 << "\n"
619  << " - " << *Ld2 << "\n"
620  << " - " << *Ld3 << "\n");
621 
622  if (AreSequentialLoads(Ld0, Ld1, PMul0->VecLd)) {
623  if (AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
624  LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
625  R.AddMulPair(PMul0, PMul1);
626  return true;
627  } else if (AreSequentialLoads(Ld3, Ld2, PMul1->VecLd)) {
628  LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
629  LLVM_DEBUG(dbgs() << " exchanging Ld2 and Ld3\n");
630  PMul1->Exchange = true;
631  R.AddMulPair(PMul0, PMul1);
632  return true;
633  }
634  } else if (AreSequentialLoads(Ld1, Ld0, PMul0->VecLd) &&
635  AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
636  LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
637  LLVM_DEBUG(dbgs() << " exchanging Ld0 and Ld1\n");
638  LLVM_DEBUG(dbgs() << " and swapping muls\n");
639  PMul0->Exchange = true;
640  // Only the second operand can be exchanged, so swap the muls.
641  R.AddMulPair(PMul1, PMul0);
642  return true;
643  }
644  }
645  return false;
646  };
647 
648  OpChainList &Muls = R.getMuls();
649  const unsigned Elems = Muls.size();
651  for (unsigned i = 0; i < Elems; ++i) {
652  BinOpChain *PMul0 = static_cast<BinOpChain*>(Muls[i].get());
653  if (Paired.count(PMul0->Root))
654  continue;
655 
656  for (unsigned j = 0; j < Elems; ++j) {
657  if (i == j)
658  continue;
659 
660  BinOpChain *PMul1 = static_cast<BinOpChain*>(Muls[j].get());
661  if (Paired.count(PMul1->Root))
662  continue;
663 
664  const Instruction *Mul0 = PMul0->Root;
665  const Instruction *Mul1 = PMul1->Root;
666  if (Mul0 == Mul1)
667  continue;
668 
669  assert(PMul0 != PMul1 && "expected different chains");
670 
671  if (CanPair(R, PMul0, PMul1)) {
672  Paired.insert(Mul0);
673  Paired.insert(Mul1);
674  break;
675  }
676  }
677  }
678  return !R.getMulPairs().empty();
679 }
680 
681 
682 void ARMParallelDSP::InsertParallelMACs(Reduction &R) {
683 
684  auto CreateSMLADCall = [&](SmallVectorImpl<LoadInst*> &VecLd0,
686  Value *Acc, bool Exchange,
687  Instruction *InsertAfter) {
688  // Replace the reduction chain with an intrinsic call
689  IntegerType *Ty = IntegerType::get(M->getContext(), 32);
690  LoadInst *WideLd0 = WideLoads.count(VecLd0[0]) ?
691  WideLoads[VecLd0[0]]->getLoad() : CreateWideLoad(VecLd0, Ty);
692  LoadInst *WideLd1 = WideLoads.count(VecLd1[0]) ?
693  WideLoads[VecLd1[0]]->getLoad() : CreateWideLoad(VecLd1, Ty);
694 
695  Value* Args[] = { WideLd0, WideLd1, Acc };
696  Function *SMLAD = nullptr;
697  if (Exchange)
698  SMLAD = Acc->getType()->isIntegerTy(32) ?
699  Intrinsic::getDeclaration(M, Intrinsic::arm_smladx) :
700  Intrinsic::getDeclaration(M, Intrinsic::arm_smlaldx);
701  else
702  SMLAD = Acc->getType()->isIntegerTy(32) ?
703  Intrinsic::getDeclaration(M, Intrinsic::arm_smlad) :
704  Intrinsic::getDeclaration(M, Intrinsic::arm_smlald);
705 
706  IRBuilder<NoFolder> Builder(InsertAfter->getParent(),
707  ++BasicBlock::iterator(InsertAfter));
708  Instruction *Call = Builder.CreateCall(SMLAD, Args);
709  NumSMLAD++;
710  return Call;
711  };
712 
713  Instruction *InsertAfter = R.getRoot();
714  Value *Acc = R.getAccumulator();
715  if (!Acc)
716  Acc = ConstantInt::get(IntegerType::get(M->getContext(), 32), 0);
717 
718  LLVM_DEBUG(dbgs() << "Root: " << *InsertAfter << "\n"
719  << "Acc: " << *Acc << "\n");
720  for (auto &Pair : R.getMulPairs()) {
721  BinOpChain *PMul0 = Pair.first;
722  BinOpChain *PMul1 = Pair.second;
723  LLVM_DEBUG(dbgs() << "Muls:\n"
724  << "- " << *PMul0->Root << "\n"
725  << "- " << *PMul1->Root << "\n");
726 
727  Acc = CreateSMLADCall(PMul0->VecLd, PMul1->VecLd, Acc, PMul1->Exchange,
728  InsertAfter);
729  InsertAfter = cast<Instruction>(Acc);
730  }
731  R.UpdateRoot(cast<Instruction>(Acc));
732 }
733 
734 LoadInst* ARMParallelDSP::CreateWideLoad(SmallVectorImpl<LoadInst*> &Loads,
735  IntegerType *LoadTy) {
736  assert(Loads.size() == 2 && "currently only support widening two loads");
737 
738  LoadInst *Base = Loads[0];
739  LoadInst *Offset = Loads[1];
740 
741  Instruction *BaseSExt = dyn_cast<SExtInst>(Base->user_back());
742  Instruction *OffsetSExt = dyn_cast<SExtInst>(Offset->user_back());
743 
744  assert((BaseSExt && OffsetSExt)
745  && "Loads should have a single, extending, user");
746 
747  std::function<void(Value*, Value*)> MoveBefore =
748  [&](Value *A, Value *B) -> void {
749  if (!isa<Instruction>(A) || !isa<Instruction>(B))
750  return;
751 
752  auto *Source = cast<Instruction>(A);
753  auto *Sink = cast<Instruction>(B);
754 
755  if (DT->dominates(Source, Sink) ||
756  Source->getParent() != Sink->getParent() ||
757  isa<PHINode>(Source) || isa<PHINode>(Sink))
758  return;
759 
760  Source->moveBefore(Sink);
761  for (auto &U : Source->uses())
762  MoveBefore(Source, U.getUser());
763  };
764 
765  // Insert the load at the point of the original dominating load.
766  LoadInst *DomLoad = DT->dominates(Base, Offset) ? Base : Offset;
767  IRBuilder<NoFolder> IRB(DomLoad->getParent(),
768  ++BasicBlock::iterator(DomLoad));
769 
770  // Bitcast the pointer to a wider type and create the wide load, while making
771  // sure to maintain the original alignment as this prevents ldrd from being
772  // generated when it could be illegal due to memory alignment.
773  const unsigned AddrSpace = DomLoad->getPointerAddressSpace();
774  Value *VecPtr = IRB.CreateBitCast(Base->getPointerOperand(),
775  LoadTy->getPointerTo(AddrSpace));
776  LoadInst *WideLoad = IRB.CreateAlignedLoad(LoadTy, VecPtr,
777  Base->getAlignment());
778 
779  // Make sure everything is in the correct order in the basic block.
780  MoveBefore(Base->getPointerOperand(), VecPtr);
781  MoveBefore(VecPtr, WideLoad);
782 
783  // From the wide load, create two values that equal the original two loads.
784  // Loads[0] needs trunc while Loads[1] needs a lshr and trunc.
785  // TODO: Support big-endian as well.
786  Value *Bottom = IRB.CreateTrunc(WideLoad, Base->getType());
787  BaseSExt->setOperand(0, Bottom);
788 
789  IntegerType *OffsetTy = cast<IntegerType>(Offset->getType());
790  Value *ShiftVal = ConstantInt::get(LoadTy, OffsetTy->getBitWidth());
791  Value *Top = IRB.CreateLShr(WideLoad, ShiftVal);
792  Value *Trunc = IRB.CreateTrunc(Top, OffsetTy);
793  OffsetSExt->setOperand(0, Trunc);
794 
795  WideLoads.emplace(std::make_pair(Base,
796  make_unique<WidenedLoad>(Loads, WideLoad)));
797  return WideLoad;
798 }
799 
800 // Compare the value lists in Other to this chain.
801 bool BinOpChain::AreSymmetrical(BinOpChain *Other) {
802  // Element-by-element comparison of Value lists returning true if they are
803  // instructions with the same opcode or constants with the same value.
804  auto CompareValueList = [](const ValueList &VL0,
805  const ValueList &VL1) {
806  if (VL0.size() != VL1.size()) {
807  LLVM_DEBUG(dbgs() << "Muls are mismatching operand list lengths: "
808  << VL0.size() << " != " << VL1.size() << "\n");
809  return false;
810  }
811 
812  const unsigned Pairs = VL0.size();
813 
814  for (unsigned i = 0; i < Pairs; ++i) {
815  const Value *V0 = VL0[i];
816  const Value *V1 = VL1[i];
817  const auto *Inst0 = dyn_cast<Instruction>(V0);
818  const auto *Inst1 = dyn_cast<Instruction>(V1);
819 
820  if (!Inst0 || !Inst1)
821  return false;
822 
823  if (Inst0->isSameOperationAs(Inst1))
824  continue;
825 
826  const APInt *C0, *C1;
827  if (!(match(V0, m_APInt(C0)) && match(V1, m_APInt(C1)) && C0 == C1))
828  return false;
829  }
830 
831  return true;
832  };
833 
834  return CompareValueList(LHS, Other->LHS) &&
835  CompareValueList(RHS, Other->RHS);
836 }
837 
839  return new ARMParallelDSP();
840 }
841 
842 char ARMParallelDSP::ID = 0;
843 
844 INITIALIZE_PASS_BEGIN(ARMParallelDSP, "arm-parallel-dsp",
845  "Transform loops to use DSP intrinsics", false, false)
846 INITIALIZE_PASS_END(ARMParallelDSP, "arm-parallel-dsp",
847  "Transform loops to use DSP intrinsics", false, false)
The access may reference and may modify the value stored in memory.
Pass interface - Implemented by all &#39;passes&#39;.
Definition: Pass.h:80
A parsed version of the target data layout string in and methods for querying it. ...
Definition: DataLayout.h:110
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
Definition: PatternMatch.h:70
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
BlockT * getLoopLatch() const
If there is a single latch block for this loop, return it.
Definition: LoopInfoImpl.h:211
This class represents lattice values for constants.
Definition: AllocatorList.h:23
static cl::opt< bool > DisableParallelDSP("disable-arm-parallel-dsp", cl::Hidden, cl::init(false), cl::desc("Disable the ARM Parallel DSP pass"))
A Module instance is used to store all the information related to an LLVM module. ...
Definition: Module.h:65
static constexpr LocationSize unknown()
arm parallel dsp
The main scalar evolution driver.
BlockT * getLoopPreheader() const
If there is a preheader for this loop, return it.
Definition: LoopInfoImpl.h:160
static bool AreSequentialAccesses(MemInst *MemOp0, MemInst *MemOp1, const DataLayout &DL, ScalarEvolution &SE)
An immutable pass that tracks lazily created AssumptionCache objects.
bool mayWriteToMemory() const
Return true if this instruction may modify memory.
STATISTIC(NumFunctions, "Total number of functions")
F(f)
This class represents a sign extension of integer types.
An instruction for reading from memory.
Definition: Instructions.h:167
INITIALIZE_PASS_BEGIN(ARMParallelDSP, "arm-parallel-dsp", "Transform loops to use DSP intrinsics", false, false) INITIALIZE_PASS_END(ARMParallelDSP
void dump() const
Support for debugging, callable in GDB: V->dump()
Definition: AsmWriter.cpp:4424
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:47
AnalysisUsage & addRequired()
const DataLayout & getDataLayout() const
Get the data layout for the module&#39;s target platform.
Definition: Module.cpp:369
LLVMContext & getContext() const
Get the global data context.
Definition: Module.h:244
PointerType * getPointerTo(unsigned AddrSpace=0) const
Return a pointer to the current type.
Definition: Type.cpp:654
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: APFloat.h:41
CastClass_match< OpTy, Instruction::Trunc > m_Trunc(const OpTy &Op)
Matches Trunc.
bool isIntegerTy() const
True if this is an instance of IntegerType.
Definition: Type.h:196
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:779
arm parallel Transform loops to use DSP intrinsics
Target-Independent Code Generator Pass Configuration Options.
ELFYAML::ELF_STO Other
Definition: ELFYAML.cpp:877
BinaryOp_match< LHS, RHS, Instruction::Add > m_Add(const LHS &L, const RHS &R)
Definition: PatternMatch.h:681
BlockT * getHeader() const
Definition: LoopInfo.h:102
auto reverse(ContainerTy &&C, typename std::enable_if< has_rbegin< ContainerTy >::value >::type *=nullptr) -> decltype(make_range(C.rbegin(), C.rend()))
Definition: STLExtras.h:273
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
Definition: Pass.cpp:96
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:244
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
Definition: PatternMatch.h:81
unsigned getOpcode() const
Returns a member of one of the enums like Instruction::Add.
Definition: Instruction.h:125
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:429
unsigned getBitWidth() const
Get the number of bits in this IntegerType.
Definition: DerivedTypes.h:66
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree...
Definition: Dominators.h:144
Function * getDeclaration(Module *M, ID id, ArrayRef< Type *> Tys=None)
Create or insert an LLVM Function declaration for an intrinsic, and return it.
Definition: Function.cpp:1043
Value * getOperand(unsigned i) const
Definition: User.h:169
bool isConsecutiveAccess(Value *A, Value *B, const DataLayout &DL, ScalarEvolution &SE, bool CheckType=true)
Returns true if the memory operations A and B are consecutive.
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:432
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
Pass * createARMParallelDSPPass()
apint_match m_APInt(const APInt *&Res)
Match a ConstantInt or splatted ConstantVector, binding the specified pointer to the contained APInt...
Definition: PatternMatch.h:175
Move duplicate certain instructions close to their use
Definition: Localizer.cpp:27
LLVM Basic Block Representation.
Definition: BasicBlock.h:57
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:370
Represent the analysis usage information of a pass.
match_combine_or< CastClass_match< OpTy, Instruction::ZExt >, CastClass_match< OpTy, Instruction::SExt > > m_ZExtOrSExt(const OpTy &Op)
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
Definition: SmallPtrSet.h:381
Class to represent integer types.
Definition: DerivedTypes.h:40
OneOps_match< OpTy, Instruction::Load > m_Load(const OpTy &Op)
Matches LoadInst.
size_t size() const
Definition: SmallVector.h:52
INITIALIZE_PASS_END(RegBankSelect, DEBUG_TYPE, "Assign register bank of generic virtual registers", false, false) RegBankSelect
Representation for a specific memory location.
static IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition: Type.cpp:239
This is the shared class of boolean and integer constants.
Definition: Constants.h:83
auto size(R &&Range, typename std::enable_if< std::is_same< typename std::iterator_traits< decltype(Range.begin())>::iterator_category, std::random_access_iterator_tag >::value, void >::type *=nullptr) -> decltype(std::distance(Range.begin(), Range.end()))
Get the size of a range.
Definition: STLExtras.h:1173
This is a &#39;vector&#39; (really, a variable-sized array), optimized for the case when the array is small...
Definition: SmallVector.h:837
bool dominates(const Instruction *Def, const Use &U) const
Return true if Def dominates a use in User.
Definition: Dominators.cpp:248
Instruction * user_back()
Specialize the methods defined in Value, as we know that an instruction can only be used by other ins...
Definition: Instruction.h:63
Provides information about what library functions are available for the current target.
Drive the analysis of memory accesses in the loop.
static Constant * get(Type *Ty, uint64_t V, bool isSigned=false)
If Ty is a vector type, return a Constant with a splat of the given value.
Definition: Constants.cpp:643
void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition: Pass.cpp:301
void setOperand(unsigned i, Value *Val)
Definition: User.h:174
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:132
Class for arbitrary precision integers.
Definition: APInt.h:69
InstListType::iterator iterator
Instruction iterators...
Definition: BasicBlock.h:89
loop Loop Strength Reduction
Represents a single loop in the control flow graph.
Definition: LoopInfo.h:506
StringRef getName() const
Return a constant reference to the value&#39;s name.
Definition: Value.cpp:214
LLVM_NODISCARD ModRefInfo intersectModRef(const ModRefInfo MRI1, const ModRefInfo MRI2)
#define I(x, y, z)
Definition: MD5.cpp:58
LLVM_NODISCARD std::enable_if<!is_simple_type< Y >::value, typename cast_retty< X, const Y >::ret_type >::type dyn_cast(const Y &Val)
Definition: Casting.h:332
uint32_t Size
Definition: Profile.cpp:46
unsigned getPointerAddressSpace() const
Returns the address space of the pointer operand.
Definition: Instructions.h:290
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
Module * getParent()
Get the module that this global value is contained inside of...
Definition: GlobalValue.h:575
LLVM Value Representation.
Definition: Value.h:72
BasicBlock * InsertPreheaderForLoop(Loop *L, DominatorTree *DT, LoopInfo *LI, MemorySSAUpdater *MSSAU, bool PreserveLCSSA)
InsertPreheaderForLoop - Once we discover that a loop doesn&#39;t have a preheader, this method is called...
Primary interface to the complete machine description for the target machine.
Definition: TargetMachine.h:65
The legacy pass manager&#39;s analysis pass to compute loop information.
Definition: LoopInfo.h:1177
Legacy analysis pass which computes a DominatorTree.
Definition: Dominators.h:259
LLVM_NODISCARD bool isModOrRefSet(const ModRefInfo MRI)
A wrapper pass to provide the legacy pass manager access to a suitably prepared AAResults object...
ModRefInfo getModRefInfo(const CallBase *Call, const MemoryLocation &Loc)
getModRefInfo (for call sites) - Return information about whether a particular call site modifies or ...
#define LLVM_DEBUG(X)
Definition: Debug.h:122
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
PointerType * getType() const
Global values are always pointers.
Definition: GlobalValue.h:277
loops
Definition: LoopInfo.cpp:1019
const BasicBlock * getParent() const
Definition: Instruction.h:66