46#include "llvm/IR/IntrinsicsARM.h"
57#define DEBUG_TYPE "mve-tail-predication"
58#define DESC "Transform predicated vector loops to use MVE tail predication"
61 "tail-predication",
cl::desc(
"MVE tail-predication pass options"),
64 "Don't tail-predicate loops"),
66 "enabled-no-reductions",
67 "Enable tail-predication, but not for reduction loops"),
70 "Enable tail-predication, including reduction loops"),
72 "force-enabled-no-reductions",
73 "Enable tail-predication, but not for reduction loops, "
74 "and force this which might be unsafe"),
77 "Enable tail-predication, including reduction loops, "
78 "and force this which might be unsafe")));
83class MVETailPredication :
public LoopPass {
109 bool TryConvertActiveLaneMask(
Value *TripCount);
129 auto &TPC = getAnalysis<TargetPassConfig>();
132 TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
F);
133 SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
138 if (!
ST->hasMVEIntegerOps() || !
ST->hasV8_1MMainlineOps()) {
148 for (
auto &
I : *BB) {
149 auto *
Call = dyn_cast<IntrinsicInst>(&
I);
154 if (
ID == Intrinsic::start_loop_iterations ||
155 ID == Intrinsic::test_start_loop_iterations)
156 return cast<IntrinsicInst>(&
I);
173 LLVM_DEBUG(
dbgs() <<
"ARM TP: Running on Loop: " << *L << *Setup <<
"\n");
175 bool Changed = TryConvertActiveLaneMask(
Setup->getArgOperand(0));
199 bool ForceTailPredication =
204 bool Changed =
false;
205 if (!
L->makeLoopInvariant(ElemCount, Changed))
208 auto *
EC= SE->getSCEV(ElemCount);
209 auto *TC = SE->getSCEV(TripCount);
211 cast<FixedVectorType>(ActiveLaneMask->
getType())->getNumElements();
212 if (VectorWidth != 2 && VectorWidth != 4 && VectorWidth != 8 &&
220 if (!SE->isLoopInvariant(EC, L)) {
221 LLVM_DEBUG(
dbgs() <<
"ARM TP: element count must be loop invariant.\n");
231 auto *IVExpr = SE->getSCEV(
IV);
232 auto *AddExpr = dyn_cast<SCEVAddRecExpr>(IVExpr);
235 LLVM_DEBUG(
dbgs() <<
"ARM TP: induction not an add expr: "; IVExpr->dump());
239 if (AddExpr->getLoop() != L) {
243 auto *Step = dyn_cast<SCEVConstant>(AddExpr->getOperand(1));
245 LLVM_DEBUG(
dbgs() <<
"ARM TP: induction step is not a constant: ";
246 AddExpr->getOperand(1)->
dump());
249 auto StepValue = Step->getValue()->getSExtValue();
250 if (VectorWidth != StepValue) {
252 <<
" doesn't match vector width " << VectorWidth <<
"\n");
256 if ((ConstElemCount = dyn_cast<ConstantInt>(ElemCount))) {
257 ConstantInt *TC = dyn_cast<ConstantInt>(TripCount);
260 "set.loop.iterations\n");
270 (ConstElemCount->
getZExtValue() + VectorWidth - 1) / VectorWidth;
276 LLVM_DEBUG(
dbgs() <<
"ARM TP: inconsistent constant tripcount values: "
277 << TC1 <<
" from set.loop.iterations, and "
278 << TC2 <<
" from get.active.lane.mask\n");
281 }
else if (!ForceTailPredication) {
296 auto *Start = AddExpr->getStart();
297 auto *ECPlusVWMinus1 = SE->getAddExpr(EC,
301 auto *Ceil = SE->getUDivExpr(ECPlusVWMinus1, VW);
306 dbgs() <<
"ARM TP: Analysing overflow behaviour for:\n";
307 dbgs() <<
"ARM TP: - TripCount = " << *TC <<
"\n";
308 dbgs() <<
"ARM TP: - ElemCount = " << *
EC <<
"\n";
309 dbgs() <<
"ARM TP: - Start = " << *Start <<
"\n";
310 dbgs() <<
"ARM TP: - BETC = " << *SE->getBackedgeTakenCount(L) <<
"\n";
311 dbgs() <<
"ARM TP: - VecWidth = " << VectorWidth <<
"\n";
312 dbgs() <<
"ARM TP: - (ElemCount+VW-1) / VW = " << *Ceil <<
"\n";
327 const SCEV *Div = SE->getUDivExpr(
328 SE->getAddExpr(SE->getMulExpr(Ceil, VW), SE->getNegativeSCEV(VW),
329 SE->getNegativeSCEV(Start)),
331 const SCEV *Sub = SE->getMinusSCEV(SE->getBackedgeTakenCount(L), Div);
337 Sub = SE->applyLoopGuards(Sub, L);
341 LLVM_DEBUG(
dbgs() <<
"ARM TP: possible overflow in sub expression.\n");
350 if (
auto *BaseC = dyn_cast<SCEVConstant>(AddExpr->getStart())) {
351 if (BaseC->getAPInt().urem(VectorWidth) == 0)
352 return SE->getMinusSCEV(EC, BaseC);
353 }
else if (
auto *BaseV = dyn_cast<SCEVUnknown>(AddExpr->getStart())) {
354 Type *Ty = BaseV->getType();
358 L->getHeader()->getModule()->getDataLayout()))
359 return SE->getMinusSCEV(EC, BaseV);
360 }
else if (
auto *BaseMul = dyn_cast<SCEVMulExpr>(AddExpr->getStart())) {
361 if (
auto *BaseC = dyn_cast<SCEVConstant>(BaseMul->getOperand(0)))
362 if (BaseC->getAPInt().urem(VectorWidth) == 0)
363 return SE->getMinusSCEV(EC, BaseC);
364 if (
auto *BaseC = dyn_cast<SCEVConstant>(BaseMul->getOperand(1)))
365 if (BaseC->getAPInt().urem(VectorWidth) == 0)
366 return SE->getMinusSCEV(EC, BaseC);
370 dbgs() <<
"ARM TP: induction base is not know to be a multiple of VF: "
371 << *AddExpr->getOperand(0) <<
"\n");
375void MVETailPredication::InsertVCTPIntrinsic(
IntrinsicInst *ActiveLaneMask,
378 Module *
M =
L->getHeader()->getModule();
380 unsigned VectorWidth =
381 cast<FixedVectorType>(ActiveLaneMask->
getType())->getNumElements();
384 Builder.SetInsertPoint(
L->getHeader(),
L->getHeader()->getFirstNonPHIIt());
390 Builder.SetInsertPoint(ActiveLaneMask);
394 switch (VectorWidth) {
397 case 2: VCTPID = Intrinsic::arm_mve_vctp64;
break;
398 case 4: VCTPID = Intrinsic::arm_mve_vctp32;
break;
399 case 8: VCTPID = Intrinsic::arm_mve_vctp16;
break;
400 case 16: VCTPID = Intrinsic::arm_mve_vctp8;
break;
408 Value *Remaining =
Builder.CreateSub(Processed, Factor);
411 << *Processed <<
"\n"
412 <<
"ARM TP: Inserted VCTP: " << *VCTPCall <<
"\n");
415bool MVETailPredication::TryConvertActiveLaneMask(
Value *TripCount) {
417 for (
auto *BB :
L->getBlocks())
419 if (
auto *
Int = dyn_cast<IntrinsicInst>(&
I))
420 if (
Int->getIntrinsicID() == Intrinsic::get_active_lane_mask)
423 if (ActiveLaneMasks.
empty())
428 for (
auto *ActiveLaneMask : ActiveLaneMasks) {
430 << *ActiveLaneMask <<
"\n");
432 const SCEV *StartSCEV = IsSafeActiveMask(ActiveLaneMask, TripCount);
437 LLVM_DEBUG(
dbgs() <<
"ARM TP: Safe to insert VCTP. Start is " << *StartSCEV
439 SCEVExpander Expander(*SE,
L->getHeader()->getModule()->getDataLayout(),
442 Value *Start = Expander.expandCodeFor(StartSCEV, StartSCEV->
getType(), Ins);
443 LLVM_DEBUG(
dbgs() <<
"ARM TP: Created start value " << *Start <<
"\n");
444 InsertVCTPIntrinsic(ActiveLaneMask, Start);
448 for (
auto *II : ActiveLaneMasks)
450 for (
auto *
I :
L->blocks())
456 return new MVETailPredication();
459char MVETailPredication::ID = 0;
#define clEnumValN(ENUMVAL, FLAGNAME, DESC)
cl::opt< TailPredication::Mode > EnableTailPredication("tail-predication", cl::desc("MVE tail-predication pass options"), cl::init(TailPredication::Enabled), cl::values(clEnumValN(TailPredication::Disabled, "disabled", "Don't tail-predicate loops"), clEnumValN(TailPredication::EnabledNoReductions, "enabled-no-reductions", "Enable tail-predication, but not for reduction loops"), clEnumValN(TailPredication::Enabled, "enabled", "Enable tail-predication, including reduction loops"), clEnumValN(TailPredication::ForceEnabledNoReductions, "force-enabled-no-reductions", "Enable tail-predication, but not for reduction loops, " "and force this which might be unsafe"), clEnumValN(TailPredication::ForceEnabled, "force-enabled", "Enable tail-predication, including reduction loops, " "and force this which might be unsafe")))
const char LLVMTargetMachineRef TM
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Target-Independent Code Generator Pass Configuration Options pass.
static const uint32_t IV[8]
Class for arbitrary precision integers.
static APInt getLowBitsSet(unsigned numBits, unsigned loBitsSet)
Constructs an APInt value that has the bottom loBitsSet bits set.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
void setPreservesCFG()
This function should be called by the pass, iff they do not:
LLVM Basic Block Representation.
const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
This is the shared class of boolean and integer constants.
static Constant * get(Type *Ty, uint64_t V, bool IsSigned=false)
If Ty is a vector type, return a Constant with a splat of the given value.
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
static IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
A wrapper class for inspecting calls to intrinsic functions.
The legacy pass manager's analysis pass to compute loop information.
virtual bool runOnLoop(Loop *L, LPPassManager &LPM)=0
Represents a single loop in the control flow graph.
A Module instance is used to store all the information related to an LLVM module.
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
Pass interface - Implemented by all 'passes'.
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
This class uses information about analyze scalars to rewrite expressions in canonical form.
This class represents an analyzed expression in the program.
bool isZero() const
Return true if the expression is a constant zero.
void dump() const
This method is used for debugging.
Type * getType() const
Return the LLVM type of this SCEV expression.
The main scalar evolution driver.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Primary interface to the complete machine description for the target machine.
Target-Independent Code Generator Pass Configuration Options.
The instances of the Type class are immutable: once they are created, they are never changed.
TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
Value * getOperand(unsigned i) const
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Function * getDeclaration(Module *M, ID id, ArrayRef< Type * > Tys=std::nullopt)
Create or insert an LLVM Function declaration for an intrinsic, and return it.
@ ForceEnabledNoReductions
ValuesClass values(OptsTy... Options)
Helper to build a ValuesClass by forwarding a variable number of arguments as an initializer list to ...
initializer< Ty > init(const Ty &Val)
This is an optimization pass for GlobalISel generic memory operations.
void dump(const SparseBitVector< ElementSize > &LHS, raw_ostream &out)
bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
If the specified value is a trivially dead instruction, delete it.
unsigned Log2_64(uint64_t Value)
Return the floor log base 2 of the specified value, -1 if the value is zero.
bool DeleteDeadPHIs(BasicBlock *BB, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr)
Examine each PHI in the given block and delete it if it is dead.
Pass * createMVETailPredicationPass()
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
bool MaskedValueIsZero(const Value *V, const APInt &Mask, const DataLayout &DL, unsigned Depth=0, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true)
Return true if 'V & Mask' is known to be zero.