LLVM  13.0.0git
MVEGatherScatterLowering.cpp
Go to the documentation of this file.
1 //===- MVEGatherScatterLowering.cpp - Gather/Scatter lowering -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// This pass custom lowers llvm.gather and llvm.scatter instructions to
10 /// arm.mve.gather and arm.mve.scatter intrinsics, optimising the code to
11 /// produce a better final result as we go.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "ARM.h"
16 #include "ARMBaseInstrInfo.h"
17 #include "ARMSubtarget.h"
18 #include "llvm/Analysis/LoopInfo.h"
23 #include "llvm/InitializePasses.h"
24 #include "llvm/IR/BasicBlock.h"
25 #include "llvm/IR/Constant.h"
26 #include "llvm/IR/Constants.h"
27 #include "llvm/IR/DerivedTypes.h"
28 #include "llvm/IR/Function.h"
29 #include "llvm/IR/InstrTypes.h"
30 #include "llvm/IR/Instruction.h"
31 #include "llvm/IR/Instructions.h"
32 #include "llvm/IR/IntrinsicInst.h"
33 #include "llvm/IR/Intrinsics.h"
34 #include "llvm/IR/IntrinsicsARM.h"
35 #include "llvm/IR/IRBuilder.h"
36 #include "llvm/IR/PatternMatch.h"
37 #include "llvm/IR/Type.h"
38 #include "llvm/IR/Value.h"
39 #include "llvm/Pass.h"
40 #include "llvm/Support/Casting.h"
42 #include <algorithm>
43 #include <cassert>
44 
45 using namespace llvm;
46 
47 #define DEBUG_TYPE "arm-mve-gather-scatter-lowering"
48 
50  "enable-arm-maskedgatscat", cl::Hidden, cl::init(true),
51  cl::desc("Enable the generation of masked gathers and scatters"));
52 
53 namespace {
54 
55 class MVEGatherScatterLowering : public FunctionPass {
56 public:
57  static char ID; // Pass identification, replacement for typeid
58 
59  explicit MVEGatherScatterLowering() : FunctionPass(ID) {
61  }
62 
63  bool runOnFunction(Function &F) override;
64 
65  StringRef getPassName() const override {
66  return "MVE gather/scatter lowering";
67  }
68 
69  void getAnalysisUsage(AnalysisUsage &AU) const override {
70  AU.setPreservesCFG();
74  }
75 
76 private:
77  LoopInfo *LI = nullptr;
78 
79  // Check this is a valid gather with correct alignment
80  bool isLegalTypeAndAlignment(unsigned NumElements, unsigned ElemSize,
81  Align Alignment);
82  // Check whether Ptr is hidden behind a bitcast and look through it
83  void lookThroughBitcast(Value *&Ptr);
84  // Check for a getelementptr and deduce base and offsets from it, on success
85  // returning the base directly and the offsets indirectly using the Offsets
86  // argument
89  // Compute the scale of this gather/scatter instruction
90  int computeScale(unsigned GEPElemSize, unsigned MemoryElemSize);
91  // If the value is a constant, or derived from constants via additions
92  // and multilications, return its numeric value
93  Optional<int64_t> getIfConst(const Value *V);
94  // If Inst is an add instruction, check whether one summand is a
95  // constant. If so, scale this constant and return it together with
96  // the other summand.
97  std::pair<Value *, int64_t> getVarAndConst(Value *Inst, int TypeScale);
98 
99  Value *lowerGather(IntrinsicInst *I);
100  // Create a gather from a base + vector of offsets
101  Value *tryCreateMaskedGatherOffset(IntrinsicInst *I, Value *Ptr,
102  Instruction *&Root, IRBuilder<> &Builder);
103  // Create a gather from a vector of pointers
104  Value *tryCreateMaskedGatherBase(IntrinsicInst *I, Value *Ptr,
105  IRBuilder<> &Builder, int64_t Increment = 0);
106  // Create an incrementing gather from a vector of pointers
107  Value *tryCreateMaskedGatherBaseWB(IntrinsicInst *I, Value *Ptr,
109  int64_t Increment = 0);
110 
111  Value *lowerScatter(IntrinsicInst *I);
112  // Create a scatter to a base + vector of offsets
113  Value *tryCreateMaskedScatterOffset(IntrinsicInst *I, Value *Offsets,
115  // Create a scatter to a vector of pointers
116  Value *tryCreateMaskedScatterBase(IntrinsicInst *I, Value *Ptr,
118  int64_t Increment = 0);
119  // Create an incrementing scatter from a vector of pointers
120  Value *tryCreateMaskedScatterBaseWB(IntrinsicInst *I, Value *Ptr,
122  int64_t Increment = 0);
123 
124  // QI gathers and scatters can increment their offsets on their own if
125  // the increment is a constant value (digit)
126  Value *tryCreateIncrementingGatScat(IntrinsicInst *I, Value *BasePtr,
127  Value *Ptr, GetElementPtrInst *GEP,
129  // QI gathers/scatters can increment their offsets on their own if the
130  // increment is a constant value (digit) - this creates a writeback QI
131  // gather/scatter
132  Value *tryCreateIncrementingWBGatScat(IntrinsicInst *I, Value *BasePtr,
133  Value *Ptr, unsigned TypeScale,
135 
136  // Optimise the base and offsets of the given address
137  bool optimiseAddress(Value *Address, BasicBlock *BB, LoopInfo *LI);
138  // Try to fold consecutive geps together into one
140  // Check whether these offsets could be moved out of the loop they're in
141  bool optimiseOffsets(Value *Offsets, BasicBlock *BB, LoopInfo *LI);
142  // Pushes the given add out of the loop
143  void pushOutAdd(PHINode *&Phi, Value *OffsSecondOperand, unsigned StartIndex);
144  // Pushes the given mul out of the loop
145  void pushOutMul(PHINode *&Phi, Value *IncrementPerRound,
146  Value *OffsSecondOperand, unsigned LoopIncrement,
148 };
149 
150 } // end anonymous namespace
151 
153 
154 INITIALIZE_PASS(MVEGatherScatterLowering, DEBUG_TYPE,
155  "MVE gather/scattering lowering pass", false, false)
156 
158  return new MVEGatherScatterLowering();
159 }
160 
161 bool MVEGatherScatterLowering::isLegalTypeAndAlignment(unsigned NumElements,
162  unsigned ElemSize,
163  Align Alignment) {
164  if (((NumElements == 4 &&
165  (ElemSize == 32 || ElemSize == 16 || ElemSize == 8)) ||
166  (NumElements == 8 && (ElemSize == 16 || ElemSize == 8)) ||
167  (NumElements == 16 && ElemSize == 8)) &&
168  Alignment >= ElemSize / 8)
169  return true;
170  LLVM_DEBUG(dbgs() << "masked gathers/scatters: instruction does not have "
171  << "valid alignment or vector type \n");
172  return false;
173 }
174 
175 static bool checkOffsetSize(Value *Offsets, unsigned TargetElemCount) {
176  // Offsets that are not of type <N x i32> are sign extended by the
177  // getelementptr instruction, and MVE gathers/scatters treat the offset as
178  // unsigned. Thus, if the element size is smaller than 32, we can only allow
179  // positive offsets - i.e., the offsets are not allowed to be variables we
180  // can't look into.
181  // Additionally, <N x i32> offsets have to either originate from a zext of a
182  // vector with element types smaller or equal the type of the gather we're
183  // looking at, or consist of constants that we can check are small enough
184  // to fit into the gather type.
185  // Thus we check that 0 < value < 2^TargetElemSize.
186  unsigned TargetElemSize = 128 / TargetElemCount;
187  unsigned OffsetElemSize = cast<FixedVectorType>(Offsets->getType())
188  ->getElementType()
189  ->getScalarSizeInBits();
190  if (OffsetElemSize != TargetElemSize || OffsetElemSize != 32) {
191  Constant *ConstOff = dyn_cast<Constant>(Offsets);
192  if (!ConstOff)
193  return false;
194  int64_t TargetElemMaxSize = (1ULL << TargetElemSize);
195  auto CheckValueSize = [TargetElemMaxSize](Value *OffsetElem) {
196  ConstantInt *OConst = dyn_cast<ConstantInt>(OffsetElem);
197  if (!OConst)
198  return false;
199  int SExtValue = OConst->getSExtValue();
200  if (SExtValue >= TargetElemMaxSize || SExtValue < 0)
201  return false;
202  return true;
203  };
204  if (isa<FixedVectorType>(ConstOff->getType())) {
205  for (unsigned i = 0; i < TargetElemCount; i++) {
206  if (!CheckValueSize(ConstOff->getAggregateElement(i)))
207  return false;
208  }
209  } else {
210  if (!CheckValueSize(ConstOff))
211  return false;
212  }
213  }
214  return true;
215 }
216 
217 Value *MVEGatherScatterLowering::checkGEP(Value *&Offsets, FixedVectorType *Ty,
219  IRBuilder<> &Builder) {
220  if (!GEP) {
221  LLVM_DEBUG(
222  dbgs() << "masked gathers/scatters: no getelementpointer found\n");
223  return nullptr;
224  }
225  LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementpointer found."
226  << " Looking at intrinsic for base + vector of offsets\n");
227  Value *GEPPtr = GEP->getPointerOperand();
228  Offsets = GEP->getOperand(1);
229  if (GEPPtr->getType()->isVectorTy() ||
230  !isa<FixedVectorType>(Offsets->getType()))
231  return nullptr;
232 
233  if (GEP->getNumOperands() != 2) {
234  LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementptr with too many"
235  << " operands. Expanding.\n");
236  return nullptr;
237  }
238  Offsets = GEP->getOperand(1);
239  unsigned OffsetsElemCount =
240  cast<FixedVectorType>(Offsets->getType())->getNumElements();
241  // Paranoid check whether the number of parallel lanes is the same
242  assert(Ty->getNumElements() == OffsetsElemCount);
243 
244  ZExtInst *ZextOffs = dyn_cast<ZExtInst>(Offsets);
245  if (ZextOffs)
246  Offsets = ZextOffs->getOperand(0);
247  FixedVectorType *OffsetType = cast<FixedVectorType>(Offsets->getType());
248 
249  // If the offsets are already being zext-ed to <N x i32>, that relieves us of
250  // having to make sure that they won't overflow.
251  if (!ZextOffs || cast<FixedVectorType>(ZextOffs->getDestTy())
252  ->getElementType()
253  ->getScalarSizeInBits() != 32)
254  if (!checkOffsetSize(Offsets, OffsetsElemCount))
255  return nullptr;
256 
257  // The offset sizes have been checked; if any truncating or zext-ing is
258  // required to fix them, do that now
259  if (Ty != Offsets->getType()) {
260  if ((Ty->getElementType()->getScalarSizeInBits() <
261  OffsetType->getElementType()->getScalarSizeInBits())) {
262  Offsets = Builder.CreateTrunc(Offsets, Ty);
263  } else {
264  Offsets = Builder.CreateZExt(Offsets, VectorType::getInteger(Ty));
265  }
266  }
267  // If none of the checks failed, return the gep's base pointer
268  LLVM_DEBUG(dbgs() << "masked gathers/scatters: found correct offsets\n");
269  return GEPPtr;
270 }
271 
272 void MVEGatherScatterLowering::lookThroughBitcast(Value *&Ptr) {
273  // Look through bitcast instruction if #elements is the same
274  if (auto *BitCast = dyn_cast<BitCastInst>(Ptr)) {
275  auto *BCTy = cast<FixedVectorType>(BitCast->getType());
276  auto *BCSrcTy = cast<FixedVectorType>(BitCast->getOperand(0)->getType());
277  if (BCTy->getNumElements() == BCSrcTy->getNumElements()) {
278  LLVM_DEBUG(
279  dbgs() << "masked gathers/scatters: looking through bitcast\n");
280  Ptr = BitCast->getOperand(0);
281  }
282  }
283 }
284 
285 int MVEGatherScatterLowering::computeScale(unsigned GEPElemSize,
286  unsigned MemoryElemSize) {
287  // This can be a 32bit load/store scaled by 4, a 16bit load/store scaled by 2,
288  // or a 8bit, 16bit or 32bit load/store scaled by 1
289  if (GEPElemSize == 32 && MemoryElemSize == 32)
290  return 2;
291  else if (GEPElemSize == 16 && MemoryElemSize == 16)
292  return 1;
293  else if (GEPElemSize == 8)
294  return 0;
295  LLVM_DEBUG(dbgs() << "masked gathers/scatters: incorrect scale. Can't "
296  << "create intrinsic\n");
297  return -1;
298 }
299 
300 Optional<int64_t> MVEGatherScatterLowering::getIfConst(const Value *V) {
301  const Constant *C = dyn_cast<Constant>(V);
302  if (C != nullptr)
303  return Optional<int64_t>{C->getUniqueInteger().getSExtValue()};
304  if (!isa<Instruction>(V))
305  return Optional<int64_t>{};
306 
307  const Instruction *I = cast<Instruction>(V);
308  if (I->getOpcode() == Instruction::Add ||
309  I->getOpcode() == Instruction::Mul) {
310  Optional<int64_t> Op0 = getIfConst(I->getOperand(0));
311  Optional<int64_t> Op1 = getIfConst(I->getOperand(1));
312  if (!Op0 || !Op1)
313  return Optional<int64_t>{};
314  if (I->getOpcode() == Instruction::Add)
315  return Optional<int64_t>{Op0.getValue() + Op1.getValue()};
316  if (I->getOpcode() == Instruction::Mul)
317  return Optional<int64_t>{Op0.getValue() * Op1.getValue()};
318  }
319  return Optional<int64_t>{};
320 }
321 
322 std::pair<Value *, int64_t>
323 MVEGatherScatterLowering::getVarAndConst(Value *Inst, int TypeScale) {
324  std::pair<Value *, int64_t> ReturnFalse =
325  std::pair<Value *, int64_t>(nullptr, 0);
326  // At this point, the instruction we're looking at must be an add or we
327  // bail out
328  Instruction *Add = dyn_cast<Instruction>(Inst);
329  if (Add == nullptr || Add->getOpcode() != Instruction::Add)
330  return ReturnFalse;
331 
332  Value *Summand;
334  // Find out which operand the value that is increased is
335  if ((Const = getIfConst(Add->getOperand(0))))
336  Summand = Add->getOperand(1);
337  else if ((Const = getIfConst(Add->getOperand(1))))
338  Summand = Add->getOperand(0);
339  else
340  return ReturnFalse;
341 
342  // Check that the constant is small enough for an incrementing gather
343  int64_t Immediate = Const.getValue() << TypeScale;
344  if (Immediate > 512 || Immediate < -512 || Immediate % 4 != 0)
345  return ReturnFalse;
346 
347  return std::pair<Value *, int64_t>(Summand, Immediate);
348 }
349 
350 Value *MVEGatherScatterLowering::lowerGather(IntrinsicInst *I) {
351  using namespace PatternMatch;
352  LLVM_DEBUG(dbgs() << "masked gathers: checking transform preconditions\n");
353 
354  // @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0)
355  // Attempt to turn the masked gather in I into a MVE intrinsic
356  // Potentially optimising the addressing modes as we do so.
357  auto *Ty = cast<FixedVectorType>(I->getType());
358  Value *Ptr = I->getArgOperand(0);
359  Align Alignment = cast<ConstantInt>(I->getArgOperand(1))->getAlignValue();
360  Value *Mask = I->getArgOperand(2);
361  Value *PassThru = I->getArgOperand(3);
362 
363  if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(),
364  Alignment))
365  return nullptr;
366  lookThroughBitcast(Ptr);
367  assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
368 
369  IRBuilder<> Builder(I->getContext());
370  Builder.SetInsertPoint(I);
371  Builder.SetCurrentDebugLocation(I->getDebugLoc());
372 
373  Instruction *Root = I;
374  Value *Load = tryCreateMaskedGatherOffset(I, Ptr, Root, Builder);
375  if (!Load)
376  Load = tryCreateMaskedGatherBase(I, Ptr, Builder);
377  if (!Load)
378  return nullptr;
379 
380  if (!isa<UndefValue>(PassThru) && !match(PassThru, m_Zero())) {
381  LLVM_DEBUG(dbgs() << "masked gathers: found non-trivial passthru - "
382  << "creating select\n");
383  Load = Builder.CreateSelect(Mask, Load, PassThru);
384  }
385 
386  Root->replaceAllUsesWith(Load);
387  Root->eraseFromParent();
388  if (Root != I)
389  // If this was an extending gather, we need to get rid of the sext/zext
390  // sext/zext as well as of the gather itself
391  I->eraseFromParent();
392 
393  LLVM_DEBUG(dbgs() << "masked gathers: successfully built masked gather\n");
394  return Load;
395 }
396 
397 Value *MVEGatherScatterLowering::tryCreateMaskedGatherBase(IntrinsicInst *I,
398  Value *Ptr,
400  int64_t Increment) {
401  using namespace PatternMatch;
402  auto *Ty = cast<FixedVectorType>(I->getType());
403  LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers\n");
404  if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
405  // Can't build an intrinsic for this
406  return nullptr;
407  Value *Mask = I->getArgOperand(2);
408  if (match(Mask, m_One()))
409  return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base,
410  {Ty, Ptr->getType()},
411  {Ptr, Builder.getInt32(Increment)});
412  else
413  return Builder.CreateIntrinsic(
414  Intrinsic::arm_mve_vldr_gather_base_predicated,
415  {Ty, Ptr->getType(), Mask->getType()},
416  {Ptr, Builder.getInt32(Increment), Mask});
417 }
418 
419 Value *MVEGatherScatterLowering::tryCreateMaskedGatherBaseWB(
420  IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
421  using namespace PatternMatch;
422  auto *Ty = cast<FixedVectorType>(I->getType());
423  LLVM_DEBUG(
424  dbgs()
425  << "masked gathers: loading from vector of pointers with writeback\n");
426  if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
427  // Can't build an intrinsic for this
428  return nullptr;
429  Value *Mask = I->getArgOperand(2);
430  if (match(Mask, m_One()))
431  return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base_wb,
432  {Ty, Ptr->getType()},
433  {Ptr, Builder.getInt32(Increment)});
434  else
435  return Builder.CreateIntrinsic(
436  Intrinsic::arm_mve_vldr_gather_base_wb_predicated,
437  {Ty, Ptr->getType(), Mask->getType()},
438  {Ptr, Builder.getInt32(Increment), Mask});
439 }
440 
441 Value *MVEGatherScatterLowering::tryCreateMaskedGatherOffset(
442  IntrinsicInst *I, Value *Ptr, Instruction *&Root, IRBuilder<> &Builder) {
443  using namespace PatternMatch;
444 
445  Type *OriginalTy = I->getType();
446  Type *ResultTy = OriginalTy;
447 
448  unsigned Unsigned = 1;
449  // The size of the gather was already checked in isLegalTypeAndAlignment;
450  // if it was not a full vector width an appropriate extend should follow.
451  auto *Extend = Root;
452  if (OriginalTy->getPrimitiveSizeInBits() < 128) {
453  // Only transform gathers with exactly one use
454  if (!I->hasOneUse())
455  return nullptr;
456 
457  // The correct root to replace is not the CallInst itself, but the
458  // instruction which extends it
459  Extend = cast<Instruction>(*I->users().begin());
460  if (isa<SExtInst>(Extend)) {
461  Unsigned = 0;
462  } else if (!isa<ZExtInst>(Extend)) {
463  LLVM_DEBUG(dbgs() << "masked gathers: extend needed but not provided. "
464  << "Expanding\n");
465  return nullptr;
466  }
467  LLVM_DEBUG(dbgs() << "masked gathers: found an extending gather\n");
468  ResultTy = Extend->getType();
469  // The final size of the gather must be a full vector width
470  if (ResultTy->getPrimitiveSizeInBits() != 128) {
471  LLVM_DEBUG(dbgs() << "masked gathers: extending from the wrong type. "
472  << "Expanding\n");
473  return nullptr;
474  }
475  }
476 
477  GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
478  Value *Offsets;
479  Value *BasePtr =
480  checkGEP(Offsets, cast<FixedVectorType>(ResultTy), GEP, Builder);
481  if (!BasePtr)
482  return nullptr;
483  // Check whether the offset is a constant increment that could be merged into
484  // a QI gather
485  Value *Load = tryCreateIncrementingGatScat(I, BasePtr, Offsets, GEP, Builder);
486  if (Load)
487  return Load;
488 
489  int Scale = computeScale(
490  BasePtr->getType()->getPointerElementType()->getPrimitiveSizeInBits(),
491  OriginalTy->getScalarSizeInBits());
492  if (Scale == -1)
493  return nullptr;
494  Root = Extend;
495 
496  Value *Mask = I->getArgOperand(2);
497  if (!match(Mask, m_One()))
498  return Builder.CreateIntrinsic(
499  Intrinsic::arm_mve_vldr_gather_offset_predicated,
500  {ResultTy, BasePtr->getType(), Offsets->getType(), Mask->getType()},
501  {BasePtr, Offsets, Builder.getInt32(OriginalTy->getScalarSizeInBits()),
502  Builder.getInt32(Scale), Builder.getInt32(Unsigned), Mask});
503  else
504  return Builder.CreateIntrinsic(
505  Intrinsic::arm_mve_vldr_gather_offset,
506  {ResultTy, BasePtr->getType(), Offsets->getType()},
507  {BasePtr, Offsets, Builder.getInt32(OriginalTy->getScalarSizeInBits()),
508  Builder.getInt32(Scale), Builder.getInt32(Unsigned)});
509 }
510 
511 Value *MVEGatherScatterLowering::lowerScatter(IntrinsicInst *I) {
512  using namespace PatternMatch;
513  LLVM_DEBUG(dbgs() << "masked scatters: checking transform preconditions\n");
514 
515  // @llvm.masked.scatter.*(data, ptrs, alignment, mask)
516  // Attempt to turn the masked scatter in I into a MVE intrinsic
517  // Potentially optimising the addressing modes as we do so.
518  Value *Input = I->getArgOperand(0);
519  Value *Ptr = I->getArgOperand(1);
520  Align Alignment = cast<ConstantInt>(I->getArgOperand(2))->getAlignValue();
521  auto *Ty = cast<FixedVectorType>(Input->getType());
522 
523  if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(),
524  Alignment))
525  return nullptr;
526 
527  lookThroughBitcast(Ptr);
528  assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
529 
530  IRBuilder<> Builder(I->getContext());
531  Builder.SetInsertPoint(I);
532  Builder.SetCurrentDebugLocation(I->getDebugLoc());
533 
534  Value *Store = tryCreateMaskedScatterOffset(I, Ptr, Builder);
535  if (!Store)
536  Store = tryCreateMaskedScatterBase(I, Ptr, Builder);
537  if (!Store)
538  return nullptr;
539 
540  LLVM_DEBUG(dbgs() << "masked scatters: successfully built masked scatter\n");
541  I->eraseFromParent();
542  return Store;
543 }
544 
545 Value *MVEGatherScatterLowering::tryCreateMaskedScatterBase(
546  IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
547  using namespace PatternMatch;
548  Value *Input = I->getArgOperand(0);
549  auto *Ty = cast<FixedVectorType>(Input->getType());
550  // Only QR variants allow truncating
551  if (!(Ty->getNumElements() == 4 && Ty->getScalarSizeInBits() == 32)) {
552  // Can't build an intrinsic for this
553  return nullptr;
554  }
555  Value *Mask = I->getArgOperand(3);
556  // int_arm_mve_vstr_scatter_base(_predicated) addr, offset, data(, mask)
557  LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers\n");
558  if (match(Mask, m_One()))
559  return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base,
560  {Ptr->getType(), Input->getType()},
561  {Ptr, Builder.getInt32(Increment), Input});
562  else
563  return Builder.CreateIntrinsic(
564  Intrinsic::arm_mve_vstr_scatter_base_predicated,
565  {Ptr->getType(), Input->getType(), Mask->getType()},
566  {Ptr, Builder.getInt32(Increment), Input, Mask});
567 }
568 
569 Value *MVEGatherScatterLowering::tryCreateMaskedScatterBaseWB(
570  IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
571  using namespace PatternMatch;
572  Value *Input = I->getArgOperand(0);
573  auto *Ty = cast<FixedVectorType>(Input->getType());
574  LLVM_DEBUG(
575  dbgs()
576  << "masked scatters: storing to a vector of pointers with writeback\n");
577  if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
578  // Can't build an intrinsic for this
579  return nullptr;
580  Value *Mask = I->getArgOperand(3);
581  if (match(Mask, m_One()))
582  return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base_wb,
583  {Ptr->getType(), Input->getType()},
584  {Ptr, Builder.getInt32(Increment), Input});
585  else
586  return Builder.CreateIntrinsic(
587  Intrinsic::arm_mve_vstr_scatter_base_wb_predicated,
588  {Ptr->getType(), Input->getType(), Mask->getType()},
589  {Ptr, Builder.getInt32(Increment), Input, Mask});
590 }
591 
592 Value *MVEGatherScatterLowering::tryCreateMaskedScatterOffset(
594  using namespace PatternMatch;
595  Value *Input = I->getArgOperand(0);
596  Value *Mask = I->getArgOperand(3);
597  Type *InputTy = Input->getType();
598  Type *MemoryTy = InputTy;
599  LLVM_DEBUG(dbgs() << "masked scatters: getelementpointer found. Storing"
600  << " to base + vector of offsets\n");
601  // If the input has been truncated, try to integrate that trunc into the
602  // scatter instruction (we don't care about alignment here)
603  if (TruncInst *Trunc = dyn_cast<TruncInst>(Input)) {
604  Value *PreTrunc = Trunc->getOperand(0);
605  Type *PreTruncTy = PreTrunc->getType();
606  if (PreTruncTy->getPrimitiveSizeInBits() == 128) {
607  Input = PreTrunc;
608  InputTy = PreTruncTy;
609  }
610  }
611  if (InputTy->getPrimitiveSizeInBits() != 128) {
612  LLVM_DEBUG(
613  dbgs() << "masked scatters: cannot create scatters for non-standard"
614  << " input types. Expanding.\n");
615  return nullptr;
616  }
617 
618  GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
619  Value *Offsets;
620  Value *BasePtr =
621  checkGEP(Offsets, cast<FixedVectorType>(InputTy), GEP, Builder);
622  if (!BasePtr)
623  return nullptr;
624  // Check whether the offset is a constant increment that could be merged into
625  // a QI gather
626  Value *Store =
627  tryCreateIncrementingGatScat(I, BasePtr, Offsets, GEP, Builder);
628  if (Store)
629  return Store;
630  int Scale = computeScale(
631  BasePtr->getType()->getPointerElementType()->getPrimitiveSizeInBits(),
632  MemoryTy->getScalarSizeInBits());
633  if (Scale == -1)
634  return nullptr;
635 
636  if (!match(Mask, m_One()))
637  return Builder.CreateIntrinsic(
638  Intrinsic::arm_mve_vstr_scatter_offset_predicated,
639  {BasePtr->getType(), Offsets->getType(), Input->getType(),
640  Mask->getType()},
641  {BasePtr, Offsets, Input,
642  Builder.getInt32(MemoryTy->getScalarSizeInBits()),
643  Builder.getInt32(Scale), Mask});
644  else
645  return Builder.CreateIntrinsic(
646  Intrinsic::arm_mve_vstr_scatter_offset,
647  {BasePtr->getType(), Offsets->getType(), Input->getType()},
648  {BasePtr, Offsets, Input,
649  Builder.getInt32(MemoryTy->getScalarSizeInBits()),
650  Builder.getInt32(Scale)});
651 }
652 
653 Value *MVEGatherScatterLowering::tryCreateIncrementingGatScat(
655  IRBuilder<> &Builder) {
656  FixedVectorType *Ty;
657  if (I->getIntrinsicID() == Intrinsic::masked_gather)
658  Ty = cast<FixedVectorType>(I->getType());
659  else
660  Ty = cast<FixedVectorType>(I->getArgOperand(0)->getType());
661  // Incrementing gathers only exist for v4i32
662  if (Ty->getNumElements() != 4 ||
663  Ty->getScalarSizeInBits() != 32)
664  return nullptr;
665  Loop *L = LI->getLoopFor(I->getParent());
666  if (L == nullptr)
667  // Incrementing gathers are not beneficial outside of a loop
668  return nullptr;
669  LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing "
670  "wb gather/scatter\n");
671 
672  // The gep was in charge of making sure the offsets are scaled correctly
673  // - calculate that factor so it can be applied by hand
674  DataLayout DT = I->getParent()->getParent()->getParent()->getDataLayout();
675  int TypeScale =
676  computeScale(DT.getTypeSizeInBits(GEP->getOperand(0)->getType()),
677  DT.getTypeSizeInBits(GEP->getType()) /
678  cast<FixedVectorType>(GEP->getType())->getNumElements());
679  if (TypeScale == -1)
680  return nullptr;
681 
682  if (GEP->hasOneUse()) {
683  // Only in this case do we want to build a wb gather, because the wb will
684  // change the phi which does affect other users of the gep (which will still
685  // be using the phi in the old way)
686  Value *Load =
687  tryCreateIncrementingWBGatScat(I, BasePtr, Offsets, TypeScale, Builder);
688  if (Load != nullptr)
689  return Load;
690  }
691  LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing "
692  "non-wb gather/scatter\n");
693 
694  std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale);
695  if (Add.first == nullptr)
696  return nullptr;
697  Value *OffsetsIncoming = Add.first;
698  int64_t Immediate = Add.second;
699 
700  // Make sure the offsets are scaled correctly
701  Instruction *ScaledOffsets = BinaryOperator::Create(
702  Instruction::Shl, OffsetsIncoming,
703  Builder.CreateVectorSplat(Ty->getNumElements(), Builder.getInt32(TypeScale)),
704  "ScaledIndex", I);
705  // Add the base to the offsets
706  OffsetsIncoming = BinaryOperator::Create(
707  Instruction::Add, ScaledOffsets,
708  Builder.CreateVectorSplat(
709  Ty->getNumElements(),
710  Builder.CreatePtrToInt(
711  BasePtr,
712  cast<VectorType>(ScaledOffsets->getType())->getElementType())),
713  "StartIndex", I);
714 
715  if (I->getIntrinsicID() == Intrinsic::masked_gather)
716  return cast<IntrinsicInst>(
717  tryCreateMaskedGatherBase(I, OffsetsIncoming, Builder, Immediate));
718  else
719  return cast<IntrinsicInst>(
720  tryCreateMaskedScatterBase(I, OffsetsIncoming, Builder, Immediate));
721 }
722 
723 Value *MVEGatherScatterLowering::tryCreateIncrementingWBGatScat(
724  IntrinsicInst *I, Value *BasePtr, Value *Offsets, unsigned TypeScale,
725  IRBuilder<> &Builder) {
726  // Check whether this gather's offset is incremented by a constant - if so,
727  // and the load is of the right type, we can merge this into a QI gather
728  Loop *L = LI->getLoopFor(I->getParent());
729  // Offsets that are worth merging into this instruction will be incremented
730  // by a constant, thus we're looking for an add of a phi and a constant
731  PHINode *Phi = dyn_cast<PHINode>(Offsets);
732  if (Phi == nullptr || Phi->getNumIncomingValues() != 2 ||
733  Phi->getParent() != L->getHeader() || Phi->getNumUses() != 2)
734  // No phi means no IV to write back to; if there is a phi, we expect it
735  // to have exactly two incoming values; the only phis we are interested in
736  // will be loop IV's and have exactly two uses, one in their increment and
737  // one in the gather's gep
738  return nullptr;
739 
740  unsigned IncrementIndex =
741  Phi->getIncomingBlock(0) == L->getLoopLatch() ? 0 : 1;
742  // Look through the phi to the phi increment
743  Offsets = Phi->getIncomingValue(IncrementIndex);
744 
745  std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale);
746  if (Add.first == nullptr)
747  return nullptr;
748  Value *OffsetsIncoming = Add.first;
749  int64_t Immediate = Add.second;
750  if (OffsetsIncoming != Phi)
751  // Then the increment we are looking at is not an increment of the
752  // induction variable, and we don't want to do a writeback
753  return nullptr;
754 
755  Builder.SetInsertPoint(&Phi->getIncomingBlock(1 - IncrementIndex)->back());
756  unsigned NumElems =
757  cast<FixedVectorType>(OffsetsIncoming->getType())->getNumElements();
758 
759  // Make sure the offsets are scaled correctly
760  Instruction *ScaledOffsets = BinaryOperator::Create(
761  Instruction::Shl, Phi->getIncomingValue(1 - IncrementIndex),
762  Builder.CreateVectorSplat(NumElems, Builder.getInt32(TypeScale)),
763  "ScaledIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back());
764  // Add the base to the offsets
765  OffsetsIncoming = BinaryOperator::Create(
766  Instruction::Add, ScaledOffsets,
767  Builder.CreateVectorSplat(
768  NumElems,
769  Builder.CreatePtrToInt(
770  BasePtr,
771  cast<VectorType>(ScaledOffsets->getType())->getElementType())),
772  "StartIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back());
773  // The gather is pre-incrementing
774  OffsetsIncoming = BinaryOperator::Create(
775  Instruction::Sub, OffsetsIncoming,
776  Builder.CreateVectorSplat(NumElems, Builder.getInt32(Immediate)),
777  "PreIncrementStartIndex",
778  &Phi->getIncomingBlock(1 - IncrementIndex)->back());
779  Phi->setIncomingValue(1 - IncrementIndex, OffsetsIncoming);
780 
781  Builder.SetInsertPoint(I);
782 
783  Value *EndResult;
784  Value *NewInduction;
785  if (I->getIntrinsicID() == Intrinsic::masked_gather) {
786  // Build the incrementing gather
787  Value *Load = tryCreateMaskedGatherBaseWB(I, Phi, Builder, Immediate);
788  // One value to be handed to whoever uses the gather, one is the loop
789  // increment
790  EndResult = Builder.CreateExtractValue(Load, 0, "Gather");
791  NewInduction = Builder.CreateExtractValue(Load, 1, "GatherIncrement");
792  } else {
793  // Build the incrementing scatter
794  NewInduction = tryCreateMaskedScatterBaseWB(I, Phi, Builder, Immediate);
795  EndResult = NewInduction;
796  }
797  Instruction *AddInst = cast<Instruction>(Offsets);
798  AddInst->replaceAllUsesWith(NewInduction);
799  AddInst->eraseFromParent();
800  Phi->setIncomingValue(IncrementIndex, NewInduction);
801 
802  return EndResult;
803 }
804 
805 void MVEGatherScatterLowering::pushOutAdd(PHINode *&Phi,
806  Value *OffsSecondOperand,
807  unsigned StartIndex) {
808  LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising add instruction\n");
809  Instruction *InsertionPoint =
810  &cast<Instruction>(Phi->getIncomingBlock(StartIndex)->back());
811  // Initialize the phi with a vector that contains a sum of the constants
813  Instruction::Add, Phi->getIncomingValue(StartIndex), OffsSecondOperand,
814  "PushedOutAdd", InsertionPoint);
815  unsigned IncrementIndex = StartIndex == 0 ? 1 : 0;
816 
817  // Order such that start index comes first (this reduces mov's)
818  Phi->addIncoming(NewIndex, Phi->getIncomingBlock(StartIndex));
819  Phi->addIncoming(Phi->getIncomingValue(IncrementIndex),
820  Phi->getIncomingBlock(IncrementIndex));
821  Phi->removeIncomingValue(IncrementIndex);
822  Phi->removeIncomingValue(StartIndex);
823 }
824 
825 void MVEGatherScatterLowering::pushOutMul(PHINode *&Phi,
826  Value *IncrementPerRound,
827  Value *OffsSecondOperand,
828  unsigned LoopIncrement,
829  IRBuilder<> &Builder) {
830  LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising mul instruction\n");
831 
832  // Create a new scalar add outside of the loop and transform it to a splat
833  // by which loop variable can be incremented
834  Instruction *InsertionPoint = &cast<Instruction>(
835  Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)->back());
836 
837  // Create a new index
838  Value *StartIndex = BinaryOperator::Create(
839  Instruction::Mul, Phi->getIncomingValue(LoopIncrement == 1 ? 0 : 1),
840  OffsSecondOperand, "PushedOutMul", InsertionPoint);
841 
842  Instruction *Product =
843  BinaryOperator::Create(Instruction::Mul, IncrementPerRound,
844  OffsSecondOperand, "Product", InsertionPoint);
845  // Increment NewIndex by Product instead of the multiplication
846  Instruction *NewIncrement = BinaryOperator::Create(
847  Instruction::Add, Phi, Product, "IncrementPushedOutMul",
848  cast<Instruction>(Phi->getIncomingBlock(LoopIncrement)->back())
849  .getPrevNode());
850 
851  Phi->addIncoming(StartIndex,
852  Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1));
853  Phi->addIncoming(NewIncrement, Phi->getIncomingBlock(LoopIncrement));
854  Phi->removeIncomingValue((unsigned)0);
855  Phi->removeIncomingValue((unsigned)0);
856 }
857 
858 // Check whether all usages of this instruction are as offsets of
859 // gathers/scatters or simple arithmetics only used by gathers/scatters
861  if (I->hasNUses(0)) {
862  return false;
863  }
864  bool Gatscat = true;
865  for (User *U : I->users()) {
866  if (!isa<Instruction>(U))
867  return false;
868  if (isa<GetElementPtrInst>(U) ||
869  isGatherScatter(dyn_cast<IntrinsicInst>(U))) {
870  return Gatscat;
871  } else {
872  unsigned OpCode = cast<Instruction>(U)->getOpcode();
873  if ((OpCode == Instruction::Add || OpCode == Instruction::Mul) &&
874  hasAllGatScatUsers(cast<Instruction>(U))) {
875  continue;
876  }
877  return false;
878  }
879  }
880  return Gatscat;
881 }
882 
883 bool MVEGatherScatterLowering::optimiseOffsets(Value *Offsets, BasicBlock *BB,
884  LoopInfo *LI) {
885  LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to optimize\n");
886  // Optimise the addresses of gathers/scatters by moving invariant
887  // calculations out of the loop
888  if (!isa<Instruction>(Offsets))
889  return false;
890  Instruction *Offs = cast<Instruction>(Offsets);
891  if (Offs->getOpcode() != Instruction::Add &&
892  Offs->getOpcode() != Instruction::Mul)
893  return false;
894  Loop *L = LI->getLoopFor(BB);
895  if (L == nullptr)
896  return false;
897  if (!Offs->hasOneUse()) {
898  if (!hasAllGatScatUsers(Offs))
899  return false;
900  }
901 
902  // Find out which, if any, operand of the instruction
903  // is a phi node
904  PHINode *Phi;
905  int OffsSecondOp;
906  if (isa<PHINode>(Offs->getOperand(0))) {
907  Phi = cast<PHINode>(Offs->getOperand(0));
908  OffsSecondOp = 1;
909  } else if (isa<PHINode>(Offs->getOperand(1))) {
910  Phi = cast<PHINode>(Offs->getOperand(1));
911  OffsSecondOp = 0;
912  } else {
913  bool Changed = true;
914  if (isa<Instruction>(Offs->getOperand(0)) &&
915  L->contains(cast<Instruction>(Offs->getOperand(0))))
916  Changed |= optimiseOffsets(Offs->getOperand(0), BB, LI);
917  if (isa<Instruction>(Offs->getOperand(1)) &&
918  L->contains(cast<Instruction>(Offs->getOperand(1))))
919  Changed |= optimiseOffsets(Offs->getOperand(1), BB, LI);
920  if (!Changed) {
921  return false;
922  } else {
923  if (isa<PHINode>(Offs->getOperand(0))) {
924  Phi = cast<PHINode>(Offs->getOperand(0));
925  OffsSecondOp = 1;
926  } else if (isa<PHINode>(Offs->getOperand(1))) {
927  Phi = cast<PHINode>(Offs->getOperand(1));
928  OffsSecondOp = 0;
929  } else {
930  return false;
931  }
932  }
933  }
934  // A phi node we want to perform this function on should be from the
935  // loop header, and shouldn't have more than 2 incoming values
936  if (Phi->getParent() != L->getHeader() ||
937  Phi->getNumIncomingValues() != 2)
938  return false;
939 
940  // The phi must be an induction variable
941  int IncrementingBlock = -1;
942 
943  for (int i = 0; i < 2; i++)
944  if (auto *Op = dyn_cast<Instruction>(Phi->getIncomingValue(i)))
945  if (Op->getOpcode() == Instruction::Add &&
946  (Op->getOperand(0) == Phi || Op->getOperand(1) == Phi))
947  IncrementingBlock = i;
948  if (IncrementingBlock == -1)
949  return false;
950 
951  Instruction *IncInstruction =
952  cast<Instruction>(Phi->getIncomingValue(IncrementingBlock));
953 
954  // If the phi is not used by anything else, we can just adapt it when
955  // replacing the instruction; if it is, we'll have to duplicate it
956  PHINode *NewPhi;
957  Value *IncrementPerRound = IncInstruction->getOperand(
958  (IncInstruction->getOperand(0) == Phi) ? 1 : 0);
959 
960  // Get the value that is added to/multiplied with the phi
961  Value *OffsSecondOperand = Offs->getOperand(OffsSecondOp);
962 
963  if (IncrementPerRound->getType() != OffsSecondOperand->getType())
964  // Something has gone wrong, abort
965  return false;
966 
967  // Only proceed if the increment per round is a constant or an instruction
968  // which does not originate from within the loop
969  if (!isa<Constant>(IncrementPerRound) &&
970  !(isa<Instruction>(IncrementPerRound) &&
971  !L->contains(cast<Instruction>(IncrementPerRound))))
972  return false;
973 
974  if (Phi->getNumUses() == 2) {
975  // No other users -> reuse existing phi (One user is the instruction
976  // we're looking at, the other is the phi increment)
977  if (IncInstruction->getNumUses() != 1) {
978  // If the incrementing instruction does have more users than
979  // our phi, we need to copy it
980  IncInstruction = BinaryOperator::Create(
981  Instruction::BinaryOps(IncInstruction->getOpcode()), Phi,
982  IncrementPerRound, "LoopIncrement", IncInstruction);
983  Phi->setIncomingValue(IncrementingBlock, IncInstruction);
984  }
985  NewPhi = Phi;
986  } else {
987  // There are other users -> create a new phi
988  NewPhi = PHINode::Create(Phi->getType(), 0, "NewPhi", Phi);
989  std::vector<Value *> Increases;
990  // Copy the incoming values of the old phi
991  NewPhi->addIncoming(Phi->getIncomingValue(IncrementingBlock == 1 ? 0 : 1),
992  Phi->getIncomingBlock(IncrementingBlock == 1 ? 0 : 1));
993  IncInstruction = BinaryOperator::Create(
994  Instruction::BinaryOps(IncInstruction->getOpcode()), NewPhi,
995  IncrementPerRound, "LoopIncrement", IncInstruction);
996  NewPhi->addIncoming(IncInstruction,
997  Phi->getIncomingBlock(IncrementingBlock));
998  IncrementingBlock = 1;
999  }
1000 
1001  IRBuilder<> Builder(BB->getContext());
1002  Builder.SetInsertPoint(Phi);
1003  Builder.SetCurrentDebugLocation(Offs->getDebugLoc());
1004 
1005  switch (Offs->getOpcode()) {
1006  case Instruction::Add:
1007  pushOutAdd(NewPhi, OffsSecondOperand, IncrementingBlock == 1 ? 0 : 1);
1008  break;
1009  case Instruction::Mul:
1010  pushOutMul(NewPhi, IncrementPerRound, OffsSecondOperand, IncrementingBlock,
1011  Builder);
1012  break;
1013  default:
1014  return false;
1015  }
1016  LLVM_DEBUG(
1017  dbgs() << "masked gathers/scatters: simplified loop variable add/mul\n");
1018 
1019  // The instruction has now been "absorbed" into the phi value
1020  Offs->replaceAllUsesWith(NewPhi);
1021  if (Offs->hasNUses(0))
1022  Offs->eraseFromParent();
1023  // Clean up the old increment in case it's unused because we built a new
1024  // one
1025  if (IncInstruction->hasNUses(0))
1026  IncInstruction->eraseFromParent();
1027 
1028  return true;
1029 }
1030 
1032  IRBuilder<> &Builder) {
1033  // Splat the non-vector value to a vector of the given type - if the value is
1034  // a constant (and its value isn't too big), we can even use this opportunity
1035  // to scale it to the size of the vector elements
1036  auto FixSummands = [&Builder](FixedVectorType *&VT, Value *&NonVectorVal) {
1037  ConstantInt *Const;
1038  if ((Const = dyn_cast<ConstantInt>(NonVectorVal)) &&
1039  VT->getElementType() != NonVectorVal->getType()) {
1040  unsigned TargetElemSize = VT->getElementType()->getPrimitiveSizeInBits();
1041  uint64_t N = Const->getZExtValue();
1042  if (N < (unsigned)(1 << (TargetElemSize - 1))) {
1043  NonVectorVal = Builder.CreateVectorSplat(
1044  VT->getNumElements(), Builder.getIntN(TargetElemSize, N));
1045  return;
1046  }
1047  }
1048  NonVectorVal =
1049  Builder.CreateVectorSplat(VT->getNumElements(), NonVectorVal);
1050  };
1051 
1052  FixedVectorType *XElType = dyn_cast<FixedVectorType>(X->getType());
1053  FixedVectorType *YElType = dyn_cast<FixedVectorType>(Y->getType());
1054  // If one of X, Y is not a vector, we have to splat it in order
1055  // to add the two of them.
1056  if (XElType && !YElType) {
1057  FixSummands(XElType, Y);
1058  YElType = cast<FixedVectorType>(Y->getType());
1059  } else if (YElType && !XElType) {
1060  FixSummands(YElType, X);
1061  XElType = cast<FixedVectorType>(X->getType());
1062  }
1063  assert(XElType && YElType && "Unknown vector types");
1064  // Check that the summands are of compatible types
1065  if (XElType != YElType) {
1066  LLVM_DEBUG(dbgs() << "masked gathers/scatters: incompatible gep offsets\n");
1067  return nullptr;
1068  }
1069 
1070  if (XElType->getElementType()->getScalarSizeInBits() != 32) {
1071  // Check that by adding the vectors we do not accidentally
1072  // create an overflow
1073  Constant *ConstX = dyn_cast<Constant>(X);
1074  Constant *ConstY = dyn_cast<Constant>(Y);
1075  if (!ConstX || !ConstY)
1076  return nullptr;
1077  unsigned TargetElemSize = 128 / XElType->getNumElements();
1078  for (unsigned i = 0; i < XElType->getNumElements(); i++) {
1079  ConstantInt *ConstXEl =
1080  dyn_cast<ConstantInt>(ConstX->getAggregateElement(i));
1081  ConstantInt *ConstYEl =
1082  dyn_cast<ConstantInt>(ConstY->getAggregateElement(i));
1083  if (!ConstXEl || !ConstYEl ||
1084  ConstXEl->getZExtValue() + ConstYEl->getZExtValue() >=
1085  (unsigned)(1 << (TargetElemSize - 1)))
1086  return nullptr;
1087  }
1088  }
1089 
1090  Value *Add = Builder.CreateAdd(X, Y);
1091 
1092  FixedVectorType *GEPType = cast<FixedVectorType>(GEP->getType());
1093  if (checkOffsetSize(Add, GEPType->getNumElements()))
1094  return Add;
1095  else
1096  return nullptr;
1097 }
1098 
1099 Value *MVEGatherScatterLowering::foldGEP(GetElementPtrInst *GEP,
1100  Value *&Offsets,
1101  IRBuilder<> &Builder) {
1102  Value *GEPPtr = GEP->getPointerOperand();
1103  Offsets = GEP->getOperand(1);
1104  // We only merge geps with constant offsets, because only for those
1105  // we can make sure that we do not cause an overflow
1106  if (!isa<Constant>(Offsets))
1107  return nullptr;
1108  GetElementPtrInst *BaseGEP;
1109  if ((BaseGEP = dyn_cast<GetElementPtrInst>(GEPPtr))) {
1110  // Merge the two geps into one
1111  Value *BaseBasePtr = foldGEP(BaseGEP, Offsets, Builder);
1112  if (!BaseBasePtr)
1113  return nullptr;
1114  Offsets =
1115  CheckAndCreateOffsetAdd(Offsets, GEP->getOperand(1), GEP, Builder);
1116  if (Offsets == nullptr)
1117  return nullptr;
1118  return BaseBasePtr;
1119  }
1120  return GEPPtr;
1121 }
1122 
1123 bool MVEGatherScatterLowering::optimiseAddress(Value *Address, BasicBlock *BB,
1124  LoopInfo *LI) {
1125  GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Address);
1126  if (!GEP)
1127  return false;
1128  bool Changed = false;
1129  if (GEP->hasOneUse() &&
1130  dyn_cast<GetElementPtrInst>(GEP->getPointerOperand())) {
1131  IRBuilder<> Builder(GEP->getContext());
1132  Builder.SetInsertPoint(GEP);
1133  Builder.SetCurrentDebugLocation(GEP->getDebugLoc());
1134  Value *Offsets;
1135  Value *Base = foldGEP(GEP, Offsets, Builder);
1136  // We only want to merge the geps if there is a real chance that they can be
1137  // used by an MVE gather; thus the offset has to have the correct size
1138  // (always i32 if it is not of vector type) and the base has to be a
1139  // pointer.
1140  if (Offsets && Base && Base != GEP) {
1141  PointerType *BaseType = cast<PointerType>(Base->getType());
1143  BaseType->getPointerElementType(), Base, Offsets, "gep.merged", GEP);
1144  GEP->replaceAllUsesWith(NewAddress);
1145  GEP = NewAddress;
1146  Changed = true;
1147  }
1148  }
1149  Changed |= optimiseOffsets(GEP->getOperand(1), GEP->getParent(), LI);
1150  return Changed;
1151 }
1152 
1155  return false;
1156  auto &TPC = getAnalysis<TargetPassConfig>();
1157  auto &TM = TPC.getTM<TargetMachine>();
1158  auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
1159  if (!ST->hasMVEIntegerOps())
1160  return false;
1161  LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
1164 
1165  bool Changed = false;
1166 
1167  for (BasicBlock &BB : F) {
1168  for (Instruction &I : BB) {
1169  IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
1170  if (II && II->getIntrinsicID() == Intrinsic::masked_gather &&
1171  isa<FixedVectorType>(II->getType())) {
1172  Gathers.push_back(II);
1173  Changed |= optimiseAddress(II->getArgOperand(0), II->getParent(), LI);
1174  } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter &&
1175  isa<FixedVectorType>(II->getArgOperand(0)->getType())) {
1176  Scatters.push_back(II);
1177  Changed |= optimiseAddress(II->getArgOperand(1), II->getParent(), LI);
1178  }
1179  }
1180  }
1181  for (unsigned i = 0; i < Gathers.size(); i++) {
1182  IntrinsicInst *I = Gathers[i];
1183  Value *L = lowerGather(I);
1184  if (L == nullptr)
1185  continue;
1186 
1187  // Get rid of any now dead instructions
1188  SimplifyInstructionsInBlock(cast<Instruction>(L)->getParent());
1189  Changed = true;
1190  }
1191 
1192  for (unsigned i = 0; i < Scatters.size(); i++) {
1193  IntrinsicInst *I = Scatters[i];
1194  Value *S = lowerScatter(I);
1195  if (S == nullptr)
1196  continue;
1197 
1198  // Get rid of any now dead instructions
1199  SimplifyInstructionsInBlock(cast<Instruction>(S)->getParent());
1200  Changed = true;
1201  }
1202  return Changed;
1203 }
i
i
Definition: README.txt:29
ARMSubtarget.h
hasAllGatScatUsers
static bool hasAllGatScatUsers(Instruction *I)
Definition: MVEGatherScatterLowering.cpp:860
llvm
Definition: AllocatorList.h:23
llvm::SystemZISD::TM
@ TM
Definition: SystemZISelLowering.h:65
CheckAndCreateOffsetAdd
static Value * CheckAndCreateOffsetAdd(Value *X, Value *Y, Value *GEP, IRBuilder<> &Builder)
Definition: MVEGatherScatterLowering.cpp:1031
llvm::DataLayout
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:112
llvm::Value::hasOneUse
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition: Value.h:447
IntrinsicInst.h
llvm::ARMSubtarget
Definition: ARMSubtarget.h:46
llvm::Function
Definition: Function.h:61
llvm::Loop
Represents a single loop in the control flow graph.
Definition: LoopInfo.h:530
Pass.h
llvm::LoopBase::contains
bool contains(const LoopT *L) const
Return true if the specified loop is contained within in this loop.
Definition: LoopInfo.h:122
llvm::IntrinsicInst::getIntrinsicID
Intrinsic::ID getIntrinsicID() const
Return the intrinsic ID of this intrinsic.
Definition: IntrinsicInst.h:52
llvm::DataLayout::getTypeSizeInBits
TypeSize getTypeSizeInBits(Type *Ty) const
Size examples:
Definition: DataLayout.h:674
llvm::SmallVector
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1168
llvm::IRBuilder<>
llvm::PHINode::removeIncomingValue
Value * removeIncomingValue(unsigned Idx, bool DeletePHIIfEmpty=true)
Remove an incoming value.
Definition: Instructions.cpp:109
Local.h
llvm::cl::Hidden
@ Hidden
Definition: CommandLine.h:140
llvm::Type
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:46
llvm::LoopInfoWrapperPass
The legacy pass manager's analysis pass to compute loop information.
Definition: LoopInfo.h:1258
llvm::Optional< int64_t >
llvm::VectorType::getElementType
Type * getElementType() const
Definition: DerivedTypes.h:424
BaseType
llvm::FixedVectorType
Class to represent fixed width SIMD vectors.
Definition: DerivedTypes.h:527
llvm::BitmaskEnumDetail::Mask
std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
Definition: BitmaskEnum.h:80
llvm::PHINode::setIncomingValue
void setIncomingValue(unsigned i, Value *V)
Definition: Instructions.h:2669
LLVM_DEBUG
#define LLVM_DEBUG(X)
Definition: Debug.h:122
llvm::CastInst::getDestTy
Type * getDestTy() const
Return the destination type, as a convenience.
Definition: InstrTypes.h:686
F
#define F(x, y, z)
Definition: MD5.cpp:56
llvm::BasicBlock
LLVM Basic Block Representation.
Definition: BasicBlock.h:58
llvm::dbgs
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:132
Instruction.h
llvm::FixedVectorType::getNumElements
unsigned getNumElements() const
Definition: DerivedTypes.h:570
TargetLowering.h
llvm::ConstantInt
This is the shared class of boolean and integer constants.
Definition: Constants.h:77
llvm::Instruction::getOpcode
unsigned getOpcode() const
Returns a member of one of the enums like Instruction::Add.
Definition: Instruction.h:160
llvm::SimplifyInstructionsInBlock
bool SimplifyInstructionsInBlock(BasicBlock *BB, const TargetLibraryInfo *TLI=nullptr)
Scan the specified basic block and try to simplify any instructions in it and recursively delete dead...
Definition: Local.cpp:676
llvm::PassRegistry::getPassRegistry
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
Definition: PassRegistry.cpp:31
Constants.h
llvm::PatternMatch::match
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
llvm::PHINode::getIncomingValue
Value * getIncomingValue(unsigned i) const
Return incoming value number x.
Definition: Instructions.h:2666
INITIALIZE_PASS
INITIALIZE_PASS(MVEGatherScatterLowering, DEBUG_TYPE, "MVE gather/scattering lowering pass", false, false) Pass *llvm
Definition: MVEGatherScatterLowering.cpp:154
llvm::isGatherScatter
bool isGatherScatter(IntrinsicInst *IntInst)
Definition: ARMBaseInstrInfo.h:933
llvm::User
Definition: User.h:44
Intrinsics.h
C
(vector float) vec_cmpeq(*A, *B) C
Definition: README_ALTIVEC.txt:86
InstrTypes.h
llvm::VectorType::getInteger
static VectorType * getInteger(VectorType *VTy)
This static method gets a VectorType with the same number of elements as the input type,...
Definition: DerivedTypes.h:442
Y
static GCMetadataPrinterRegistry::Add< OcamlGCMetadataPrinter > Y("ocaml", "ocaml 3.10-compatible collector")
llvm::AnalysisUsage
Represent the analysis usage information of a pass.
Definition: PassAnalysisSupport.h:47
llvm::Type::isVectorTy
bool isVectorTy() const
True if this is an instance of VectorType.
Definition: Type.h:235
llvm::Instruction
Definition: Instruction.h:45
llvm::Type::getScalarSizeInBits
unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
Definition: Type.cpp:154
DEBUG_TYPE
#define DEBUG_TYPE
Definition: MVEGatherScatterLowering.cpp:47
llvm::codeview::EncodedFramePtrReg::BasePtr
@ BasePtr
PatternMatch.h
llvm::Align
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition: Alignment.h:39
llvm::PHINode::getNumIncomingValues
unsigned getNumIncomingValues() const
Return the number of incoming edges.
Definition: Instructions.h:2662
Type.h
X
static GCMetadataPrinterRegistry::Add< ErlangGCPrinter > X("erlang", "erlang-compatible garbage collector")
llvm::PatternMatch::m_One
cst_pred_ty< is_one > m_One()
Match an integer 1 or a vector with all elements equal to 1.
Definition: PatternMatch.h:513
LoopInfo.h
llvm::TargetPassConfig
Target-Independent Code Generator Pass Configuration Options.
Definition: TargetPassConfig.h:84
BasicBlock.h
llvm::cl::opt< bool >
llvm::PatternMatch::m_Zero
is_zero m_Zero()
Match any null constant or a vector with all elements equal to 0.
Definition: PatternMatch.h:535
llvm::Constant
This is an important base class in LLVM.
Definition: Constant.h:41
llvm::Instruction::eraseFromParent
SymbolTableList< Instruction >::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
Definition: Instruction.cpp:78
llvm::ARM_MB::ST
@ ST
Definition: ARMBaseInfo.h:73
llvm::TruncInst
This class represents a truncation of integer types.
Definition: Instructions.h:4691
llvm::PHINode::addIncoming
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
Definition: Instructions.h:2720
llvm::LoopInfoBase::getLoopFor
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
Definition: LoopInfo.h:964
I
#define I(x, y, z)
Definition: MD5.cpp:59
llvm::GetElementPtrInst
an instruction for type-safe pointer arithmetic to access elements of arrays and structs
Definition: Instructions.h:905
llvm::createMVEGatherScatterLoweringPass
Pass * createMVEGatherScatterLoweringPass()
llvm::cl::init
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:440
llvm::PointerType
Class to represent pointers.
Definition: DerivedTypes.h:634
TargetPassConfig.h
llvm::LoopBase::getLoopLatch
BlockT * getLoopLatch() const
If there is a single latch block for this loop, return it.
Definition: LoopInfoImpl.h:216
IRBuilder.h
assert
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
llvm::TargetMachine
Primary interface to the complete machine description for the target machine.
Definition: TargetMachine.h:77
ARMBaseInstrInfo.h
ARM.h
Builder
assume Assume Builder
Definition: AssumeBundleBuilder.cpp:649
llvm::ZExtInst
This class represents zero extension of integer types.
Definition: Instructions.h:4730
llvm::GetElementPtrInst::Create
static GetElementPtrInst * Create(Type *PointeeType, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &NameStr="", Instruction *InsertBefore=nullptr)
Definition: Instructions.h:931
llvm::LoopInfo
Definition: LoopInfo.h:1080
llvm::SPII::Load
@ Load
Definition: SparcInstrInfo.h:32
llvm::AnalysisUsage::setPreservesCFG
void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition: Pass.cpp:253
llvm::StringRef
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:57
llvm::Constant::getAggregateElement
Constant * getAggregateElement(unsigned Elt) const
For aggregates (struct/array/vector) return the constant that corresponds to the specified element if...
Definition: Constants.cpp:421
llvm::Value::getNumUses
unsigned getNumUses() const
This method computes the number of uses of this Value.
Definition: Value.cpp:245
llvm::Value::getType
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:256
llvm::Value::replaceAllUsesWith
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:527
getParent
static const Function * getParent(const Value *V)
Definition: BasicAliasAnalysis.cpp:759
llvm::initializeMVEGatherScatterLoweringPass
void initializeMVEGatherScatterLoweringPass(PassRegistry &)
TargetSubtargetInfo.h
Unsigned
@ Unsigned
Definition: NVPTXISelLowering.cpp:4544
S
add sub stmia L5 ldr r0 bl L_printf $stub Instead of a and a wouldn t it be better to do three moves *Return an aggregate type is even return S
Definition: README.txt:210
llvm::ConstantInt::getSExtValue
int64_t getSExtValue() const
Return the constant as a 64-bit integer value after it has been sign extended as appropriate for the ...
Definition: Constants.h:146
llvm::SPII::Store
@ Store
Definition: SparcInstrInfo.h:33
llvm::ConstantInt::getZExtValue
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
Definition: Constants.h:140
runOnFunction
static bool runOnFunction(Function &F, bool PostInlining)
Definition: EntryExitInstrumenter.cpp:69
Constant.h
llvm::AMDGPU::SendMsg::Op
Op
Definition: SIDefines.h:314
llvm::PHINode::Create
static PHINode * Create(Type *Ty, unsigned NumReservedValues, const Twine &NameStr="", Instruction *InsertBefore=nullptr)
Constructors - NumReservedValues is a hint for the number of incoming edges that this phi node will h...
Definition: Instructions.h:2612
Casting.h
Function.h
llvm::Value::hasNUses
bool hasNUses(unsigned N) const
Return true if this Value has exactly N uses.
Definition: Value.cpp:151
llvm::LoopBase::getHeader
BlockT * getHeader() const
Definition: LoopInfo.h:104
checkOffsetSize
static bool checkOffsetSize(Value *Offsets, unsigned TargetElemCount)
Definition: MVEGatherScatterLowering.cpp:175
llvm::MCID::Add
@ Add
Definition: MCInstrDesc.h:184
llvm::codeview::ModifierOptions::Const
@ Const
llvm::IntrinsicInst
A wrapper class for inspecting calls to intrinsic functions.
Definition: IntrinsicInst.h:45
llvm::Instruction::BinaryOps
BinaryOps
Definition: Instruction.h:768
llvm::BasicBlock::back
const Instruction & back() const
Definition: BasicBlock.h:310
llvm::Pass
Pass interface - Implemented by all 'passes'.
Definition: Pass.h:91
EnableMaskedGatherScatters
cl::opt< bool > EnableMaskedGatherScatters("enable-arm-maskedgatscat", cl::Hidden, cl::init(true), cl::desc("Enable the generation of masked gathers and scatters"))
Instructions.h
llvm::Instruction::getDebugLoc
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
Definition: Instruction.h:365
llvm::CallBase::getArgOperand
Value * getArgOperand(unsigned i) const
Definition: InstrTypes.h:1341
N
#define N
llvm::Instruction::getParent
const BasicBlock * getParent() const
Definition: Instruction.h:94
llvm::PHINode::getIncomingBlock
BasicBlock * getIncomingBlock(unsigned i) const
Return incoming basic block number i.
Definition: Instructions.h:2686
TargetTransformInfo.h
llvm::PHINode
Definition: Instructions.h:2572
llvm::Pass::getAnalysisUsage
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
Definition: Pass.cpp:93
DerivedTypes.h
llvm::FunctionPass
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:298
BB
Common register allocation spilling lr str ldr sxth r3 ldr mla r4 can lr mov lr str ldr sxth r3 mla r4 and then merge mul and lr str ldr sxth r3 mla r4 It also increase the likelihood the store may become dead bb27 Successors according to LLVM BB
Definition: README.txt:39
GEP
Hexagon Common GEP
Definition: HexagonCommonGEP.cpp:171
llvm::AnalysisUsage::addRequired
AnalysisUsage & addRequired()
Definition: PassAnalysisSupport.h:75
llvm::User::getOperand
Value * getOperand(unsigned i) const
Definition: User.h:169
llvm::cl::desc
Definition: CommandLine.h:411
llvm::SI::KernelInputOffsets::Offsets
Offsets
Offsets in bytes from the start of the input buffer.
Definition: SIInstrInfo.h:1224
llvm::BinaryOperator::Create
static BinaryOperator * Create(BinaryOps Op, Value *S1, Value *S2, const Twine &Name=Twine(), Instruction *InsertBefore=nullptr)
Construct a binary instruction, given the opcode and the two operands.
Definition: Instructions.cpp:2549
Value.h
InitializePasses.h
llvm::Value
LLVM Value Representation.
Definition: Value.h:75
llvm::Optional::getValue
constexpr const T & getValue() const LLVM_LVALUE_FUNCTION
Definition: Optional.h:280
llvm::Type::getPrimitiveSizeInBits
TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
Definition: Type.cpp:129
llvm::Intrinsic::ID
unsigned ID
Definition: TargetTransformInfo.h:38