31#define DEBUG_TYPE "riscv-gather-scatter-lowering" 
   62    return "RISC-V gather/scatter lowering";
 
   68  std::pair<Value *, Value *> determineBaseAndStride(
Instruction *
Ptr,
 
   78char RISCVGatherScatterLowering::ID = 0;
 
   81                "RISC-V gather/scatter lowering pass", 
false, 
false)
 
   84  return new RISCVGatherScatterLowering();
 
 
   90    return std::make_pair(
nullptr, 
nullptr);
 
   98    return std::make_pair(
nullptr, 
nullptr);
 
   99  APInt StrideVal(StartVal->getValue().getBitWidth(), 0);
 
  101  for (
unsigned i = 1; i != NumElts; ++i) {
 
  104      return std::make_pair(
nullptr, 
nullptr);
 
  108      StrideVal = LocalStride;
 
  109    else if (StrideVal != LocalStride)
 
  110      return std::make_pair(
nullptr, 
nullptr);
 
  115  Value *Stride = ConstantInt::get(StartVal->getType(), StrideVal);
 
  117  return std::make_pair(StartVal, Stride);
 
 
  129    auto *Ty = Start->getType()->getScalarType();
 
  130    return std::make_pair(ConstantInt::get(Ty, 0), ConstantInt::get(Ty, 1));
 
  136  if (!BO || (BO->getOpcode() != Instruction::Add &&
 
  137              BO->getOpcode() != Instruction::Or &&
 
  138              BO->getOpcode() != Instruction::Shl &&
 
  139              BO->getOpcode() != Instruction::Mul))
 
  140    return std::make_pair(
nullptr, 
nullptr);
 
  142  if (BO->getOpcode() == Instruction::Or &&
 
  144    return std::make_pair(
nullptr, 
nullptr);
 
  147  unsigned OtherIndex = 0;
 
  154    return std::make_pair(
nullptr, 
nullptr);
 
  160    return std::make_pair(
nullptr, 
nullptr);
 
  162  Builder.SetInsertPoint(BO);
 
  163  Builder.SetCurrentDebugLocation(
DebugLoc());
 
  166  switch (BO->getOpcode()) {
 
  169  case Instruction::Or:
 
  170    Start = Builder.CreateOr(Start, 
Splat, 
"", 
true);
 
  172  case Instruction::Add:
 
  173    Start = Builder.CreateAdd(Start, 
Splat);
 
  175  case Instruction::Mul:
 
  176    Start = Builder.CreateMul(Start, 
Splat);
 
  177    Stride = Builder.CreateMul(Stride, 
Splat);
 
  179  case Instruction::Shl:
 
  180    Start = Builder.CreateShl(Start, 
Splat);
 
  181    Stride = Builder.CreateShl(Stride, 
Splat);
 
  185  return std::make_pair(Start, Stride);
 
 
  192bool RISCVGatherScatterLowering::matchStridedRecurrence(
Value *Index, 
Loop *L,
 
  201    if (
Phi->getParent() != 
L->getHeader())
 
  208    assert(
Phi->getNumIncomingValues() == 2 && 
"Expected 2 operand phi.");
 
  209    unsigned IncrementingBlock = 
Phi->getIncomingValue(0) == Inc ? 0 : 1;
 
  210    assert(
Phi->getIncomingValue(IncrementingBlock) == Inc &&
 
  211           "Expected one operand of phi to be Inc");
 
  221    assert(Stride != 
nullptr);
 
  226    Inc = BinaryOperator::CreateAdd(BasePtr, Step, Inc->
getName() + 
".scalar",
 
  228    BasePtr->addIncoming(Start, 
Phi->getIncomingBlock(1 - IncrementingBlock));
 
  229    BasePtr->addIncoming(Inc, 
Phi->getIncomingBlock(IncrementingBlock));
 
  232    MaybeDeadPHIs.push_back(Phi);
 
  241  switch (BO->getOpcode()) {
 
  244  case Instruction::Or:
 
  249  case Instruction::Add:
 
  251  case Instruction::Shl:
 
  253  case Instruction::Mul:
 
  262    OtherOp = BO->getOperand(1);
 
  267    OtherOp = BO->getOperand(0);
 
  273  if (!
L->isLoopInvariant(OtherOp))
 
  282  if (!matchStridedRecurrence(Index, L, Stride, BasePtr, Inc, Builder))
 
  287  unsigned StartBlock = 
BasePtr->getOperand(0) == Inc ? 1 : 0;
 
  293      BasePtr->getIncomingBlock(StartBlock)->getTerminator());
 
  297  switch (BO->getOpcode()) {
 
  300  case Instruction::Add:
 
  301  case Instruction::Or: {
 
  307  case Instruction::Mul: {
 
  309    Stride = Builder.
CreateMul(Stride, SplatOp, 
"stride");
 
  312  case Instruction::Shl: {
 
  314    Stride = Builder.
CreateShl(Stride, SplatOp, 
"stride");
 
  324  switch (BO->getOpcode()) {
 
  327  case Instruction::Mul:
 
  328    Step = Builder.
CreateMul(Step, SplatOp, 
"step");
 
  330  case Instruction::Shl:
 
  331    Step = Builder.
CreateShl(Step, SplatOp, 
"step");
 
  336  BasePtr->setIncomingValue(StartBlock, Start);
 
  340std::pair<Value *, Value *>
 
  341RISCVGatherScatterLowering::determineBaseAndStride(Instruction *
Ptr,
 
  342                                                   IRBuilderBase &Builder) {
 
  347    return std::make_pair(BasePtr, ConstantInt::get(IntPtrTy, 0));
 
  352    return std::make_pair(
nullptr, 
nullptr);
 
  354  auto I = StridedAddrs.find(
GEP);
 
  355  if (
I != StridedAddrs.end())
 
  358  SmallVector<Value *, 2> 
Ops(
GEP->operands());
 
  363      BaseInst && BaseInst->getType()->isVectorTy()) {
 
  365    auto IsScalar = [](
Value *Idx) { 
return !Idx->getType()->isVectorTy(); };
 
  367      auto [BaseBase, Stride] = determineBaseAndStride(BaseInst, Builder);
 
  372            Builder.
CreateGEP(
GEP->getSourceElementType(), BaseBase, Indices,
 
  373                              GEP->getName() + 
"offset", 
GEP->isInBounds());
 
  374        return {OffsetBase, Stride};
 
  384      return std::make_pair(
nullptr, 
nullptr);
 
  387  std::optional<unsigned> VecOperand;
 
  388  unsigned TypeScale = 0;
 
  392  for (
unsigned i = 1, e = 
GEP->getNumOperands(); i != e; ++i, ++GTI) {
 
  397      return std::make_pair(
nullptr, 
nullptr);
 
  403      return std::make_pair(
nullptr, 
nullptr);
 
  410    return std::make_pair(
nullptr, 
nullptr);
 
  418  Type *VecIntPtrTy = 
DL->getIntPtrType(
GEP->getType());
 
  419  if (VecIndex->
getType() != VecIntPtrTy) {
 
  422      return std::make_pair(
nullptr, 
nullptr);
 
  438    Type *SourceTy = 
GEP->getSourceElementType();
 
  448      Stride = Builder.
CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale));
 
  450    auto P = std::make_pair(BasePtr, Stride);
 
  451    StridedAddrs[
GEP] = 
P;
 
  457  if (!L || !
L->getLoopPreheader() || !
L->getLoopLatch())
 
  458    return std::make_pair(
nullptr, 
nullptr);
 
  462  if (!matchStridedRecurrence(VecIndex, L, Stride, BasePhi, Inc, Builder))
 
  463    return std::make_pair(
nullptr, 
nullptr);
 
  466  unsigned IncrementingBlock = BasePhi->
getOperand(0) == Inc ? 0 : 1;
 
  468         "Expected one operand of phi to be Inc");
 
  473  Ops[*VecOperand] = BasePhi;
 
  474  Type *SourceTy = 
GEP->getSourceElementType();
 
  488    Stride = Builder.
CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale));
 
  490  auto P = std::make_pair(BasePtr, Stride);
 
  491  StridedAddrs[
GEP] = 
P;
 
  495bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *
II) {
 
  499  switch (
II->getIntrinsicID()) {
 
  500  case Intrinsic::masked_gather:
 
  502    Ptr = 
II->getArgOperand(0);
 
  503    Alignment = 
II->getParamAlign(0).valueOrOne();
 
  504    Mask = 
II->getArgOperand(1);
 
  506  case Intrinsic::vp_gather:
 
  508    Ptr = 
II->getArgOperand(0);
 
  510    Alignment = 
II->getParamAlign(0).value_or(
 
  511        DL->getABITypeAlign(DataType->getElementType()));
 
  512    Mask = 
II->getArgOperand(1);
 
  513    EVL = 
II->getArgOperand(2);
 
  515  case Intrinsic::masked_scatter:
 
  517    StoreVal = 
II->getArgOperand(0);
 
  518    Ptr = 
II->getArgOperand(1);
 
  519    Alignment = 
II->getParamAlign(1).valueOrOne();
 
  520    Mask = 
II->getArgOperand(2);
 
  522  case Intrinsic::vp_scatter:
 
  524    StoreVal = 
II->getArgOperand(0);
 
  525    Ptr = 
II->getArgOperand(1);
 
  527    Alignment = 
II->getParamAlign(1).value_or(
 
  528        DL->getABITypeAlign(DataType->getElementType()));
 
  529    Mask = 
II->getArgOperand(2);
 
  530    EVL = 
II->getArgOperand(3);
 
  550  LLVMContext &Ctx = PtrI->getContext();
 
  555  std::tie(BasePtr, Stride) = determineBaseAndStride(PtrI, Builder);
 
  558  assert(Stride != 
nullptr);
 
  570        Intrinsic::experimental_vp_strided_load,
 
  575    if (
II->getIntrinsicID() == Intrinsic::masked_gather)
 
  579        Intrinsic::experimental_vp_strided_store,
 
  584  II->replaceAllUsesWith(
Call);
 
  585  II->eraseFromParent();
 
  587  if (PtrI->use_empty())
 
  593bool RISCVGatherScatterLowering::runOnFunction(Function &
F) {
 
  597  auto &TPC = getAnalysis<TargetPassConfig>();
 
  598  auto &
TM = TPC.getTM<RISCVTargetMachine>();
 
  599  ST = &
TM.getSubtarget<RISCVSubtarget>(
F);
 
  600  if (!
ST->hasVInstructions() || !
ST->useRVVForFixedLengthVectors())
 
  603  TLI = 
ST->getTargetLowering();
 
  604  DL = &
F.getDataLayout();
 
  605  LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
 
  607  StridedAddrs.clear();
 
  613  for (BasicBlock &BB : 
F) {
 
  614    for (Instruction &
I : BB) {
 
  618      switch (
II->getIntrinsicID()) {
 
  619      case Intrinsic::masked_gather:
 
  620      case Intrinsic::masked_scatter:
 
  621      case Intrinsic::vp_gather:
 
  622      case Intrinsic::vp_scatter:
 
  632  for (
auto *
II : Worklist)
 
  633    Changed |= tryCreateStridedLoadStore(
II);
 
  636  while (!MaybeDeadPHIs.empty()) {
 
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
 
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
 
static bool runOnFunction(Function &F, bool PostInlining)
 
const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]
 
uint64_t IntrinsicInst * II
 
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
 
static std::pair< Value *, Value * > matchStridedStart(Value *Start, IRBuilderBase &Builder)
 
static std::pair< Value *, Value * > matchStridedConstant(Constant *StartC)
 
static SymbolRef::Type getType(const Symbol *Sym)
 
Target-Independent Code Generator Pass Configuration Options pass.
 
Class for arbitrary precision integers.
 
Represent the analysis usage information of a pass.
 
AnalysisUsage & addRequired()
 
LLVM_ABI void setPreservesCFG()
This function should be called by the pass, iff they do not:
 
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
 
BinaryOps getOpcode() const
 
This is the shared class of boolean and integer constants.
 
const APInt & getValue() const
Return the constant as an APInt value reference.
 
This is an important base class in LLVM.
 
LLVM_ABI Constant * getAggregateElement(unsigned Elt) const
For aggregates (struct/array/vector) return the constant that corresponds to the specified element if...
 
A parsed version of the target data layout string in and methods for querying it.
 
FunctionPass class - This class is used to implement most global optimizations.
 
Common base class shared among various IRBuilders.
 
LLVM_ABI Value * CreateSelect(Value *C, Value *True, Value *False, const Twine &Name="", Instruction *MDFrom=nullptr)
 
IntegerType * getInt32Ty()
Fetch the type representing a 32-bit integer.
 
void SetCurrentDebugLocation(DebugLoc L)
Set location information used by debugging information.
 
Value * CreateGEP(Type *Ty, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &Name="", GEPNoWrapFlags NW=GEPNoWrapFlags::none())
 
LLVM_ABI CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="")
Create a call to intrinsic ID with Args, mangled using Types.
 
Value * CreateShl(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
 
Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
 
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
 
Value * CreateMul(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
 
LLVM_ABI Value * CreateElementCount(Type *Ty, ElementCount EC)
Create an expression which evaluates to the number of elements in EC at runtime.
 
LLVM_ABI bool isCommutative() const LLVM_READONLY
Return true if the instruction is commutative:
 
A wrapper class for inspecting calls to intrinsic functions.
 
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
 
The legacy pass manager's analysis pass to compute loop information.
 
Represents a single loop in the control flow graph.
 
BasicBlock * getIncomingBlock(unsigned i) const
Return incoming basic block number i.
 
Value * getIncomingValue(unsigned i) const
Return incoming value number x.
 
unsigned getNumIncomingValues() const
Return the number of incoming edges.
 
static PHINode * Create(Type *Ty, unsigned NumReservedValues, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
Constructors - NumReservedValues is a hint for the number of incoming edges that this phi node will h...
 
bool isLegalStridedLoadStore(EVT DataType, Align Alignment) const
Return true if a stride load store of the given result type and alignment is legal.
 
void push_back(const T &Elt)
 
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
 
StringRef - Represent a constant reference to a string, i.e.
 
EVT getValueType(const DataLayout &DL, Type *Ty, bool AllowUnknown=false) const
Return the EVT corresponding to this LLVM type.
 
bool isTypeLegal(EVT VT) const
Return true if the target has native support for the specified value type.
 
Target-Independent Code Generator Pass Configuration Options.
 
bool isVectorTy() const
True if this is an instance of VectorType.
 
LLVM_ABI unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
 
void setOperand(unsigned i, Value *Val)
 
Value * getOperand(unsigned i) const
 
LLVM Value Representation.
 
Type * getType() const
All values are typed, get the type of this value.
 
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
 
LLVM_ABI void takeName(Value *V)
Transfer the name from V to this value.
 
constexpr ScalarTy getFixedValue() const
 
constexpr bool isScalable() const
Returns whether the quantity is scaled by a runtime quantity (vscale).
 
TypeSize getSequentialElementStride(const DataLayout &DL) const
 
self_iterator getIterator()
 
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
 
constexpr char Align[]
Key for Kernel::Arg::Metadata::mAlign.
 
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
 
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
 
@ C
The default llvm calling convention, compatible with C.
 
bool match(Val *V, const Pattern &P)
 
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
 
NodeAddr< PhiNode * > Phi
 
This is an optimization pass for GlobalISel generic memory operations.
 
FunctionAddr VTableAddr Value
 
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
 
LLVM_ABI bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
If the specified value is a trivially dead instruction, delete it.
 
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
 
LLVM_ABI Value * getSplatValue(const Value *V)
Get splat value if the input is a splat vector or return nullptr.
 
FunctionPass * createRISCVGatherScatterLoweringPass()
 
LLVM_ABI bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO, Value *&Start, Value *&Step)
Attempt to match a simple first order recurrence cycle of the form: iv = phi Ty [Start,...
 
auto dyn_cast_or_null(const Y &Val)
 
generic_gep_type_iterator<> gep_type_iterator
 
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
 
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
 
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
 
ArrayRef(const T &OneElt) -> ArrayRef< T >
 
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
 
gep_type_iterator gep_type_begin(const User *GEP)
 
LLVM_ABI bool RecursivelyDeleteDeadPHINode(PHINode *PN, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr)
If the specified value is an effectively dead PHI node, due to being a def-use chain of single-use no...
 
LLVM_ABI Constant * ConstantFoldCastInstruction(unsigned opcode, Constant *V, Type *DestTy)