32 #include "llvm/IR/IntrinsicsAArch64.h"
41 #define DEBUG_TYPE "aarch64-sve-intrinsic-opts"
50 bool runOnModule(
Module &
M)
override;
68 void SVEIntrinsicOpts::getAnalysisUsage(
AnalysisUsage &AU)
const {
74 static const char *
name =
"SVE intrinsics optimizations";
80 return new SVEIntrinsicOpts();
100 if (
match(
User, m_Intrinsic<Intrinsic::aarch64_sve_convert_to_svbool>())) {
101 ConvertToUses.push_back(cast<IntrinsicInst>(
User));
106 if (ConvertToUses.empty())
112 const auto *PTrueVTy = cast<ScalableVectorType>(PTrue->
getType());
115 auto *IntrUser = dyn_cast<IntrinsicInst>(
User);
116 if (IntrUser && IntrUser->getIntrinsicID() ==
117 Intrinsic::aarch64_sve_convert_from_svbool) {
118 const auto *IntrUserVTy = cast<ScalableVectorType>(IntrUser->getType());
121 if (IntrUserVTy->getElementCount().getKnownMinValue() >
122 PTrueVTy->getElementCount().getKnownMinValue())
134 bool SVEIntrinsicOpts::coalescePTrueIntrinsicCalls(
136 if (PTrues.
size() <= 1)
140 auto *MostEncompassingPTrue = *std::max_element(
141 PTrues.
begin(), PTrues.
end(), [](
auto *PTrue1,
auto *PTrue2) {
142 auto *PTrue1VTy = cast<ScalableVectorType>(PTrue1->getType());
143 auto *PTrue2VTy = cast<ScalableVectorType>(PTrue2->getType());
144 return PTrue1VTy->getElementCount().getKnownMinValue() <
145 PTrue2VTy->getElementCount().getKnownMinValue();
150 PTrues.
remove(MostEncompassingPTrue);
156 MostEncompassingPTrue->moveBefore(
BB,
BB.getFirstInsertionPt());
160 Builder.SetInsertPoint(&
BB, ++MostEncompassingPTrue->getIterator());
162 auto *MostEncompassingPTrueVTy =
163 cast<VectorType>(MostEncompassingPTrue->getType());
164 auto *ConvertToSVBool =
Builder.CreateIntrinsic(
165 Intrinsic::aarch64_sve_convert_to_svbool, {MostEncompassingPTrueVTy},
166 {MostEncompassingPTrue});
168 bool ConvertFromCreated =
false;
169 for (
auto *PTrue : PTrues) {
170 auto *PTrueVTy = cast<VectorType>(PTrue->getType());
174 if (MostEncompassingPTrueVTy != PTrueVTy) {
175 ConvertFromCreated =
true;
177 Builder.SetInsertPoint(&
BB, ++ConvertToSVBool->getIterator());
178 auto *ConvertFromSVBool =
179 Builder.CreateIntrinsic(Intrinsic::aarch64_sve_convert_from_svbool,
180 {PTrueVTy}, {ConvertToSVBool});
181 PTrue->replaceAllUsesWith(ConvertFromSVBool);
183 PTrue->replaceAllUsesWith(MostEncompassingPTrue);
185 PTrue->eraseFromParent();
189 if (!ConvertFromCreated)
190 ConvertToSVBool->eraseFromParent();
243 bool SVEIntrinsicOpts::optimizePTrueIntrinsicCalls(
245 bool Changed =
false;
247 for (
auto *
F : Functions) {
248 for (
auto &
BB : *
F) {
257 auto *IntrI = dyn_cast<IntrinsicInst>(&
I);
258 if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
261 const auto PTruePattern =
262 cast<ConstantInt>(IntrI->getOperand(0))->getZExtValue();
265 SVAllPTrues.
insert(IntrI);
266 if (PTruePattern == AArch64SVEPredPattern::pow2)
267 SVPow2PTrues.
insert(IntrI);
270 Changed |= coalescePTrueIntrinsicCalls(
BB, SVAllPTrues);
271 Changed |= coalescePTrueIntrinsicCalls(
BB, SVPow2PTrues);
280 bool SVEIntrinsicOpts::optimizePredicateStore(
Instruction *
I) {
281 auto *
F =
I->getFunction();
282 auto Attr =
F->getFnAttribute(Attribute::VScaleRange);
286 unsigned MinVScale = Attr.getVScaleRangeMin();
289 if (!MaxVScale || MinVScale != MaxVScale)
294 auto *FixedPredType =
298 auto *
Store = dyn_cast<StoreInst>(
I);
303 if (
Store->getOperand(0)->getType() != FixedPredType)
307 auto *IntrI = dyn_cast<IntrinsicInst>(
Store->getOperand(0));
309 IntrI->getIntrinsicID() != Intrinsic::experimental_vector_extract)
313 if (!cast<ConstantInt>(IntrI->getOperand(1))->isZero())
317 auto *BitCast = dyn_cast<BitCastInst>(IntrI->getOperand(0));
322 if (BitCast->getOperand(0)->getType() != PredType)
328 auto *PtrBitCast =
Builder.CreateBitCast(
329 Store->getPointerOperand(),
330 PredType->getPointerTo(
Store->getPointerAddressSpace()));
331 Builder.CreateStore(BitCast->getOperand(0), PtrBitCast);
333 Store->eraseFromParent();
334 if (IntrI->getNumUses() == 0)
335 IntrI->eraseFromParent();
336 if (BitCast->getNumUses() == 0)
337 BitCast->eraseFromParent();
344 bool SVEIntrinsicOpts::optimizePredicateLoad(
Instruction *
I) {
345 auto *
F =
I->getFunction();
346 auto Attr =
F->getFnAttribute(Attribute::VScaleRange);
350 unsigned MinVScale = Attr.getVScaleRangeMin();
353 if (!MaxVScale || MinVScale != MaxVScale)
358 auto *FixedPredType =
362 auto *BitCast = dyn_cast<BitCastInst>(
I);
363 if (!BitCast || BitCast->getType() != PredType)
367 auto *IntrI = dyn_cast<IntrinsicInst>(BitCast->getOperand(0));
369 IntrI->getIntrinsicID() != Intrinsic::experimental_vector_insert)
373 if (!isa<UndefValue>(IntrI->getOperand(0)) ||
374 !cast<ConstantInt>(IntrI->getOperand(2))->isZero())
378 auto *
Load = dyn_cast<LoadInst>(IntrI->getOperand(1));
383 if (
Load->getType() != FixedPredType)
389 auto *PtrBitCast =
Builder.CreateBitCast(
390 Load->getPointerOperand(),
391 PredType->getPointerTo(
Load->getPointerAddressSpace()));
392 auto *LoadPred =
Builder.CreateLoad(PredType, PtrBitCast);
394 BitCast->replaceAllUsesWith(LoadPred);
395 BitCast->eraseFromParent();
396 if (IntrI->getNumUses() == 0)
397 IntrI->eraseFromParent();
398 if (
Load->getNumUses() == 0)
399 Load->eraseFromParent();
404 bool SVEIntrinsicOpts::optimizeInstructions(
406 bool Changed =
false;
408 for (
auto *
F : Functions) {
409 DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>(*F).getDomTree();
415 for (
auto *
BB : RPOT) {
417 switch (
I.getOpcode()) {
419 Changed |= optimizePredicateStore(&
I);
421 case Instruction::BitCast:
422 Changed |= optimizePredicateLoad(&
I);
432 bool SVEIntrinsicOpts::optimizeFunctions(
434 bool Changed =
false;
436 Changed |= optimizePTrueIntrinsicCalls(Functions);
437 Changed |= optimizeInstructions(Functions);
442 bool SVEIntrinsicOpts::runOnModule(
Module &M) {
443 bool Changed =
false;
449 for (
auto &
F :
M.getFunctionList()) {
450 if (!
F.isDeclaration())
453 switch (
F.getIntrinsicID()) {
454 case Intrinsic::experimental_vector_extract:
455 case Intrinsic::experimental_vector_insert:
456 case Intrinsic::aarch64_sve_ptrue:
457 for (
User *U :
F.users())
465 if (!Functions.
empty())
466 Changed |= optimizeFunctions(Functions);