32#include "llvm/IR/IntrinsicsAArch64.h"
42#define DEBUG_TYPE "aarch64-sve-intrinsic-opts"
55 bool coalescePTrueIntrinsicCalls(
BasicBlock &BB,
69void SVEIntrinsicOpts::getAnalysisUsage(
AnalysisUsage &AU)
const {
74char SVEIntrinsicOpts::ID = 0;
75static const char *
name =
"SVE intrinsics optimizations";
81 return new SVEIntrinsicOpts();
101 if (
match(
User, m_Intrinsic<Intrinsic::aarch64_sve_convert_to_svbool>())) {
107 if (ConvertToUses.
empty())
113 const auto *PTrueVTy = cast<ScalableVectorType>(PTrue->
getType());
116 auto *IntrUser = dyn_cast<IntrinsicInst>(
User);
117 if (IntrUser && IntrUser->getIntrinsicID() ==
118 Intrinsic::aarch64_sve_convert_from_svbool) {
119 const auto *IntrUserVTy = cast<ScalableVectorType>(IntrUser->getType());
122 if (IntrUserVTy->getElementCount().getKnownMinValue() >
123 PTrueVTy->getElementCount().getKnownMinValue())
135bool SVEIntrinsicOpts::coalescePTrueIntrinsicCalls(
137 if (PTrues.
size() <= 1)
141 auto *MostEncompassingPTrue = *std::max_element(
142 PTrues.
begin(), PTrues.
end(), [](
auto *PTrue1,
auto *PTrue2) {
143 auto *PTrue1VTy = cast<ScalableVectorType>(PTrue1->getType());
144 auto *PTrue2VTy = cast<ScalableVectorType>(PTrue2->getType());
145 return PTrue1VTy->getElementCount().getKnownMinValue() <
146 PTrue2VTy->getElementCount().getKnownMinValue();
151 PTrues.
remove(MostEncompassingPTrue);
161 Builder.SetInsertPoint(&BB, ++MostEncompassingPTrue->getIterator());
163 auto *MostEncompassingPTrueVTy =
164 cast<VectorType>(MostEncompassingPTrue->getType());
165 auto *ConvertToSVBool =
Builder.CreateIntrinsic(
166 Intrinsic::aarch64_sve_convert_to_svbool, {MostEncompassingPTrueVTy},
167 {MostEncompassingPTrue});
169 bool ConvertFromCreated =
false;
170 for (
auto *PTrue : PTrues) {
171 auto *PTrueVTy = cast<VectorType>(PTrue->getType());
175 if (MostEncompassingPTrueVTy != PTrueVTy) {
176 ConvertFromCreated =
true;
178 Builder.SetInsertPoint(&BB, ++ConvertToSVBool->getIterator());
179 auto *ConvertFromSVBool =
180 Builder.CreateIntrinsic(Intrinsic::aarch64_sve_convert_from_svbool,
181 {PTrueVTy}, {ConvertToSVBool});
182 PTrue->replaceAllUsesWith(ConvertFromSVBool);
184 PTrue->replaceAllUsesWith(MostEncompassingPTrue);
186 PTrue->eraseFromParent();
190 if (!ConvertFromCreated)
191 ConvertToSVBool->eraseFromParent();
244bool SVEIntrinsicOpts::optimizePTrueIntrinsicCalls(
246 bool Changed =
false;
248 for (
auto *
F : Functions) {
249 for (
auto &BB : *
F) {
258 auto *IntrI = dyn_cast<IntrinsicInst>(&
I);
259 if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
262 const auto PTruePattern =
263 cast<ConstantInt>(IntrI->getOperand(0))->getZExtValue();
265 if (PTruePattern == AArch64SVEPredPattern::all)
266 SVAllPTrues.
insert(IntrI);
267 if (PTruePattern == AArch64SVEPredPattern::pow2)
268 SVPow2PTrues.
insert(IntrI);
271 Changed |= coalescePTrueIntrinsicCalls(BB, SVAllPTrues);
272 Changed |= coalescePTrueIntrinsicCalls(BB, SVPow2PTrues);
281bool SVEIntrinsicOpts::optimizePredicateStore(
Instruction *
I) {
282 auto *
F =
I->getFunction();
283 auto Attr =
F->getFnAttribute(Attribute::VScaleRange);
287 unsigned MinVScale = Attr.getVScaleRangeMin();
288 std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax();
290 if (!MaxVScale || MinVScale != MaxVScale)
295 auto *FixedPredType =
299 auto *
Store = dyn_cast<StoreInst>(
I);
300 if (!Store || !
Store->isSimple())
304 if (
Store->getOperand(0)->getType() != FixedPredType)
308 auto *IntrI = dyn_cast<IntrinsicInst>(
Store->getOperand(0));
309 if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::vector_extract)
313 if (!cast<ConstantInt>(IntrI->getOperand(1))->isZero())
317 auto *BitCast = dyn_cast<BitCastInst>(IntrI->getOperand(0));
322 if (BitCast->getOperand(0)->getType() != PredType)
328 Builder.CreateStore(BitCast->getOperand(0),
Store->getPointerOperand());
330 Store->eraseFromParent();
331 if (IntrI->getNumUses() == 0)
332 IntrI->eraseFromParent();
333 if (BitCast->getNumUses() == 0)
334 BitCast->eraseFromParent();
341bool SVEIntrinsicOpts::optimizePredicateLoad(
Instruction *
I) {
342 auto *
F =
I->getFunction();
343 auto Attr =
F->getFnAttribute(Attribute::VScaleRange);
347 unsigned MinVScale = Attr.getVScaleRangeMin();
348 std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax();
350 if (!MaxVScale || MinVScale != MaxVScale)
355 auto *FixedPredType =
359 auto *BitCast = dyn_cast<BitCastInst>(
I);
360 if (!BitCast || BitCast->getType() != PredType)
364 auto *IntrI = dyn_cast<IntrinsicInst>(BitCast->getOperand(0));
365 if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::vector_insert)
369 if (!isa<UndefValue>(IntrI->getOperand(0)) ||
370 !cast<ConstantInt>(IntrI->getOperand(2))->isZero())
374 auto *
Load = dyn_cast<LoadInst>(IntrI->getOperand(1));
375 if (!Load || !
Load->isSimple())
379 if (
Load->getType() != FixedPredType)
385 auto *LoadPred =
Builder.CreateLoad(PredType,
Load->getPointerOperand());
387 BitCast->replaceAllUsesWith(LoadPred);
388 BitCast->eraseFromParent();
389 if (IntrI->getNumUses() == 0)
390 IntrI->eraseFromParent();
391 if (
Load->getNumUses() == 0)
392 Load->eraseFromParent();
397bool SVEIntrinsicOpts::optimizeInstructions(
399 bool Changed =
false;
401 for (
auto *
F : Functions) {
402 DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>(*F).getDomTree();
408 for (
auto *BB : RPOT) {
410 switch (
I.getOpcode()) {
411 case Instruction::Store:
412 Changed |= optimizePredicateStore(&
I);
414 case Instruction::BitCast:
415 Changed |= optimizePredicateLoad(&
I);
425bool SVEIntrinsicOpts::optimizeFunctions(
427 bool Changed =
false;
429 Changed |= optimizePTrueIntrinsicCalls(Functions);
430 Changed |= optimizeInstructions(Functions);
435bool SVEIntrinsicOpts::runOnModule(
Module &M) {
436 bool Changed =
false;
442 for (
auto &
F :
M.getFunctionList()) {
443 if (!
F.isDeclaration())
446 switch (
F.getIntrinsicID()) {
447 case Intrinsic::vector_extract:
448 case Intrinsic::vector_insert:
449 case Intrinsic::aarch64_sve_ptrue:
450 for (
User *U :
F.users())
458 if (!Functions.
empty())
459 Changed |= optimizeFunctions(Functions);
This file contains the declarations for the subclasses of Constant, which represent the different fla...
static Function * getFunction(Constant *C)
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
static bool isPTruePromoted(IntrinsicInst *PTrue)
Checks if a ptrue intrinsic call is promoted.
This file implements a set that has insertion order iteration characteristics.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
void setPreservesCFG()
This function should be called by the pass, iff they do not:
LLVM Basic Block Representation.
const_iterator getFirstInsertionPt() const
Returns an iterator to the first instruction in this block that is suitable for inserting a non-PHI i...
LLVMContext & getContext() const
Get the context in which this basic block lives.
Legacy analysis pass which computes a DominatorTree.
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
static FixedVectorType * get(Type *ElementType, unsigned NumElts)
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
A wrapper class for inspecting calls to intrinsic functions.
This is an important class for using LLVM in a threaded context.
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
virtual bool runOnModule(Module &M)=0
runOnModule - Virtual method overriden by subclasses to process the module being operated on.
A Module instance is used to store all the information related to an LLVM module.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
static ScalableVectorType * get(Type *ElementType, unsigned MinNumElts)
bool remove(const value_type &X)
Remove an item from the set vector.
bool remove_if(UnaryPredicate P)
Remove items from the set vector based on a predicate function.
size_type size() const
Determine the number of elements in the SetVector.
iterator end()
Get an iterator to the end of the SetVector.
bool empty() const
Determine if the SetVector is empty or not.
iterator begin()
Get an iterator to the beginning of the SetVector.
bool insert(const value_type &X)
Insert a new element into the SetVector.
A SetVector that performs no allocations if smaller than a certain size.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
static IntegerType * getInt1Ty(LLVMContext &C)
static IntegerType * getInt8Ty(LLVMContext &C)
Type * getType() const
All values are typed, get the type of this value.
iterator_range< user_iterator > users()
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
bool match(Val *V, const Pattern &P)
This is an optimization pass for GlobalISel generic memory operations.
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
ModulePass * createSVEIntrinsicOptsPass()
void initializeSVEIntrinsicOptsPass(PassRegistry &)