32#include "llvm/IR/IntrinsicsAArch64.h"
42#define DEBUG_TYPE "aarch64-sve-intrinsic-opts"
55 bool coalescePTrueIntrinsicCalls(
BasicBlock &BB,
69void SVEIntrinsicOpts::getAnalysisUsage(
AnalysisUsage &AU)
const {
74char SVEIntrinsicOpts::ID = 0;
75static const char *
name =
"SVE intrinsics optimizations";
81 return new SVEIntrinsicOpts();
101 if (
match(
User, m_Intrinsic<Intrinsic::aarch64_sve_convert_to_svbool>())) {
107 if (ConvertToUses.
empty())
113 const auto *PTrueVTy = cast<ScalableVectorType>(PTrue->
getType());
116 auto *IntrUser = dyn_cast<IntrinsicInst>(
User);
117 if (IntrUser && IntrUser->getIntrinsicID() ==
118 Intrinsic::aarch64_sve_convert_from_svbool) {
119 const auto *IntrUserVTy = cast<ScalableVectorType>(IntrUser->getType());
122 if (IntrUserVTy->getElementCount().getKnownMinValue() >
123 PTrueVTy->getElementCount().getKnownMinValue())
135bool SVEIntrinsicOpts::coalescePTrueIntrinsicCalls(
137 if (PTrues.
size() <= 1)
141 auto *MostEncompassingPTrue = *std::max_element(
142 PTrues.
begin(), PTrues.
end(), [](
auto *PTrue1,
auto *PTrue2) {
143 auto *PTrue1VTy = cast<ScalableVectorType>(PTrue1->getType());
144 auto *PTrue2VTy = cast<ScalableVectorType>(PTrue2->getType());
145 return PTrue1VTy->getElementCount().getKnownMinValue() <
146 PTrue2VTy->getElementCount().getKnownMinValue();
151 PTrues.
remove(MostEncompassingPTrue);
161 Builder.SetInsertPoint(&BB, ++MostEncompassingPTrue->getIterator());
163 auto *MostEncompassingPTrueVTy =
164 cast<VectorType>(MostEncompassingPTrue->getType());
165 auto *ConvertToSVBool =
Builder.CreateIntrinsic(
166 Intrinsic::aarch64_sve_convert_to_svbool, {MostEncompassingPTrueVTy},
167 {MostEncompassingPTrue});
169 bool ConvertFromCreated =
false;
170 for (
auto *PTrue : PTrues) {
171 auto *PTrueVTy = cast<VectorType>(PTrue->getType());
175 if (MostEncompassingPTrueVTy != PTrueVTy) {
176 ConvertFromCreated =
true;
178 Builder.SetInsertPoint(&BB, ++ConvertToSVBool->getIterator());
179 auto *ConvertFromSVBool =
180 Builder.CreateIntrinsic(Intrinsic::aarch64_sve_convert_from_svbool,
181 {PTrueVTy}, {ConvertToSVBool});
182 PTrue->replaceAllUsesWith(ConvertFromSVBool);
184 PTrue->replaceAllUsesWith(MostEncompassingPTrue);
186 PTrue->eraseFromParent();
190 if (!ConvertFromCreated)
191 ConvertToSVBool->eraseFromParent();
244bool SVEIntrinsicOpts::optimizePTrueIntrinsicCalls(
246 bool Changed =
false;
248 for (
auto *
F : Functions) {
249 for (
auto &BB : *
F) {
258 auto *IntrI = dyn_cast<IntrinsicInst>(&
I);
259 if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
262 const auto PTruePattern =
263 cast<ConstantInt>(IntrI->getOperand(0))->getZExtValue();
265 if (PTruePattern == AArch64SVEPredPattern::all)
266 SVAllPTrues.
insert(IntrI);
267 if (PTruePattern == AArch64SVEPredPattern::pow2)
268 SVPow2PTrues.
insert(IntrI);
271 Changed |= coalescePTrueIntrinsicCalls(BB, SVAllPTrues);
272 Changed |= coalescePTrueIntrinsicCalls(BB, SVPow2PTrues);
281bool SVEIntrinsicOpts::optimizePredicateStore(
Instruction *
I) {
282 auto *
F =
I->getFunction();
283 auto Attr =
F->getFnAttribute(Attribute::VScaleRange);
287 unsigned MinVScale = Attr.getVScaleRangeMin();
288 std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax();
290 if (!MaxVScale || MinVScale != MaxVScale)
295 auto *FixedPredType =
299 auto *
Store = dyn_cast<StoreInst>(
I);
300 if (!Store || !
Store->isSimple())
304 if (
Store->getOperand(0)->getType() != FixedPredType)
308 auto *IntrI = dyn_cast<IntrinsicInst>(
Store->getOperand(0));
309 if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::vector_extract)
313 if (!cast<ConstantInt>(IntrI->getOperand(1))->isZero())
317 auto *BitCast = dyn_cast<BitCastInst>(IntrI->getOperand(0));
322 if (BitCast->getOperand(0)->getType() != PredType)
328 auto *PtrBitCast =
Builder.CreateBitCast(
329 Store->getPointerOperand(),
330 PredType->getPointerTo(
Store->getPointerAddressSpace()));
331 Builder.CreateStore(BitCast->getOperand(0), PtrBitCast);
333 Store->eraseFromParent();
334 if (IntrI->getNumUses() == 0)
335 IntrI->eraseFromParent();
336 if (BitCast->getNumUses() == 0)
337 BitCast->eraseFromParent();
344bool SVEIntrinsicOpts::optimizePredicateLoad(
Instruction *
I) {
345 auto *
F =
I->getFunction();
346 auto Attr =
F->getFnAttribute(Attribute::VScaleRange);
350 unsigned MinVScale = Attr.getVScaleRangeMin();
351 std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax();
353 if (!MaxVScale || MinVScale != MaxVScale)
358 auto *FixedPredType =
362 auto *BitCast = dyn_cast<BitCastInst>(
I);
363 if (!BitCast || BitCast->getType() != PredType)
367 auto *IntrI = dyn_cast<IntrinsicInst>(BitCast->getOperand(0));
368 if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::vector_insert)
372 if (!isa<UndefValue>(IntrI->getOperand(0)) ||
373 !cast<ConstantInt>(IntrI->getOperand(2))->isZero())
377 auto *
Load = dyn_cast<LoadInst>(IntrI->getOperand(1));
378 if (!Load || !
Load->isSimple())
382 if (
Load->getType() != FixedPredType)
388 auto *PtrBitCast =
Builder.CreateBitCast(
389 Load->getPointerOperand(),
390 PredType->getPointerTo(
Load->getPointerAddressSpace()));
391 auto *LoadPred =
Builder.CreateLoad(PredType, PtrBitCast);
393 BitCast->replaceAllUsesWith(LoadPred);
394 BitCast->eraseFromParent();
395 if (IntrI->getNumUses() == 0)
396 IntrI->eraseFromParent();
397 if (
Load->getNumUses() == 0)
398 Load->eraseFromParent();
403bool SVEIntrinsicOpts::optimizeInstructions(
405 bool Changed =
false;
407 for (
auto *
F : Functions) {
408 DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>(*F).getDomTree();
414 for (
auto *BB : RPOT) {
416 switch (
I.getOpcode()) {
417 case Instruction::Store:
418 Changed |= optimizePredicateStore(&
I);
420 case Instruction::BitCast:
421 Changed |= optimizePredicateLoad(&
I);
431bool SVEIntrinsicOpts::optimizeFunctions(
433 bool Changed =
false;
435 Changed |= optimizePTrueIntrinsicCalls(Functions);
436 Changed |= optimizeInstructions(Functions);
441bool SVEIntrinsicOpts::runOnModule(
Module &M) {
442 bool Changed =
false;
448 for (
auto &
F :
M.getFunctionList()) {
449 if (!
F.isDeclaration())
452 switch (
F.getIntrinsicID()) {
453 case Intrinsic::vector_extract:
454 case Intrinsic::vector_insert:
455 case Intrinsic::aarch64_sve_ptrue:
456 for (
User *U :
F.users())
464 if (!Functions.
empty())
465 Changed |= optimizeFunctions(Functions);
This file contains the declarations for the subclasses of Constant, which represent the different fla...
static Function * getFunction(Constant *C)
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
static bool isPTruePromoted(IntrinsicInst *PTrue)
Checks if a ptrue intrinsic call is promoted.
This file implements a set that has insertion order iteration characteristics.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
void setPreservesCFG()
This function should be called by the pass, iff they do not:
LLVM Basic Block Representation.
const_iterator getFirstInsertionPt() const
Returns an iterator to the first instruction in this block that is suitable for inserting a non-PHI i...
LLVMContext & getContext() const
Get the context in which this basic block lives.
Legacy analysis pass which computes a DominatorTree.
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
static FixedVectorType * get(Type *ElementType, unsigned NumElts)
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
A wrapper class for inspecting calls to intrinsic functions.
This is an important class for using LLVM in a threaded context.
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
virtual bool runOnModule(Module &M)=0
runOnModule - Virtual method overriden by subclasses to process the module being operated on.
A Module instance is used to store all the information related to an LLVM module.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
static ScalableVectorType * get(Type *ElementType, unsigned MinNumElts)
size_type size() const
Determine the number of elements in the SetVector.
bool remove_if(UnaryPredicate P)
Remove items from the set vector based on a predicate function.
bool remove(const value_type &X)
Remove an item from the set vector.
bool insert(const value_type &X)
Insert a new element into the SetVector.
bool empty() const
Determine if the SetVector is empty or not.
iterator begin()
Get an iterator to the beginning of the SetVector.
iterator end()
Get an iterator to the end of the SetVector.
A SetVector that performs no allocations if smaller than a certain size.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
static IntegerType * getInt1Ty(LLVMContext &C)
static IntegerType * getInt8Ty(LLVMContext &C)
Type * getType() const
All values are typed, get the type of this value.
iterator_range< user_iterator > users()
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
bool match(Val *V, const Pattern &P)
This is an optimization pass for GlobalISel generic memory operations.
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
ModulePass * createSVEIntrinsicOptsPass()
void initializeSVEIntrinsicOptsPass(PassRegistry &)