91using namespace PatternMatch;
94 std::numeric_limits<unsigned>::max();
98class StraightLineStrengthReduceLegacyPass :
public FunctionPass {
118 DL = &
M.getDataLayout();
125class StraightLineStrengthReduce {
141 Candidate() =
default;
155 Value *Stride =
nullptr;
175 Candidate *Basis =
nullptr;
183 bool isBasisFor(
const Candidate &Basis,
const Candidate &
C);
191 bool isSimplestForm(
const Candidate &
C);
198 void allocateCandidatesAndFindBasisForAdd(
Instruction *
I);
202 void allocateCandidatesAndFindBasisForAdd(
Value *LHS,
Value *RHS,
205 void allocateCandidatesAndFindBasisForMul(
Instruction *
I);
209 void allocateCandidatesAndFindBasisForMul(
Value *LHS,
Value *RHS,
223 void allocateCandidatesAndFindBasis(Candidate::Kind CT,
const SCEV *
B,
228 void rewriteCandidateWithBasis(
const Candidate &
C,
const Candidate &Basis);
240 static Value *emitBump(
const Candidate &Basis,
const Candidate &
C,
242 bool &BumpWithUglyGEP);
248 std::list<Candidate> Candidates;
253 std::vector<Instruction *> UnlinkedInstructions;
258char StraightLineStrengthReduceLegacyPass::ID = 0;
261 "Straight line strength reduction",
false,
false)
269 return new StraightLineStrengthReduceLegacyPass();
272bool StraightLineStrengthReduce::isBasisFor(
const Candidate &Basis,
273 const Candidate &
C) {
274 return (Basis.Ins !=
C.Ins &&
277 Basis.Ins->getType() ==
C.Ins->getType() &&
279 DT->dominates(Basis.Ins->getParent(),
C.Ins->getParent()) &&
281 Basis.Base ==
C.Base && Basis.Stride ==
C.Stride &&
282 Basis.CandidateKind ==
C.CandidateKind);
296 return Index->getBitWidth() <= 64 &&
301bool StraightLineStrengthReduce::isFoldable(
const Candidate &
C,
304 if (
C.CandidateKind == Candidate::Add)
306 if (
C.CandidateKind == Candidate::GEP)
313 unsigned NumNonZeroIndices = 0;
316 if (ConstIdx ==
nullptr || !ConstIdx->
isZero())
319 return NumNonZeroIndices <= 1;
322bool StraightLineStrengthReduce::isSimplestForm(
const Candidate &
C) {
323 if (
C.CandidateKind == Candidate::Add) {
325 return C.Index->isOne() ||
C.Index->isMinusOne();
327 if (
C.CandidateKind == Candidate::Mul) {
329 return C.Index->isZero();
331 if (
C.CandidateKind == Candidate::GEP) {
333 return ((
C.Index->isOne() ||
C.Index->isMinusOne()) &&
346void StraightLineStrengthReduce::allocateCandidatesAndFindBasis(
349 Candidate
C(CT,
B,
Idx, S,
I);
363 if (!isFoldable(
C,
TTI,
DL) && !isSimplestForm(
C)) {
365 unsigned NumIterations = 0;
367 static const unsigned MaxNumIterations = 50;
368 for (
auto Basis = Candidates.rbegin();
369 Basis != Candidates.rend() && NumIterations < MaxNumIterations;
370 ++Basis, ++NumIterations) {
371 if (isBasisFor(*Basis,
C)) {
379 Candidates.push_back(
C);
382void StraightLineStrengthReduce::allocateCandidatesAndFindBasis(
384 switch (
I->getOpcode()) {
385 case Instruction::Add:
386 allocateCandidatesAndFindBasisForAdd(
I);
388 case Instruction::Mul:
389 allocateCandidatesAndFindBasisForMul(
I);
391 case Instruction::GetElementPtr:
392 allocateCandidatesAndFindBasisForGEP(cast<GetElementPtrInst>(
I));
397void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForAdd(
400 if (!isa<IntegerType>(
I->getType()))
403 assert(
I->getNumOperands() == 2 &&
"isn't I an add?");
405 allocateCandidatesAndFindBasisForAdd(LHS, RHS,
I);
407 allocateCandidatesAndFindBasisForAdd(RHS, LHS,
I);
410void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForAdd(
416 allocateCandidatesAndFindBasis(Candidate::Add, SE->getSCEV(LHS),
Idx, S,
I);
421 allocateCandidatesAndFindBasis(Candidate::Add, SE->getSCEV(LHS),
Idx, S,
I);
425 allocateCandidatesAndFindBasis(Candidate::Add, SE->getSCEV(LHS), One, RHS,
442void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForMul(
449 allocateCandidatesAndFindBasis(Candidate::Mul, SE->getSCEV(
B),
Idx, RHS,
I);
455 allocateCandidatesAndFindBasis(Candidate::Mul, SE->getSCEV(
B),
Idx, RHS,
I);
459 allocateCandidatesAndFindBasis(Candidate::Mul, SE->getSCEV(LHS), Zero, RHS,
464void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForMul(
468 if (!isa<IntegerType>(
I->getType()))
471 assert(
I->getNumOperands() == 2 &&
"isn't I a mul?");
473 allocateCandidatesAndFindBasisForMul(LHS, RHS,
I);
476 allocateCandidatesAndFindBasisForMul(RHS, LHS,
I);
480void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForGEP(
487 IntegerType *PtrIdxTy = cast<IntegerType>(
DL->getIndexType(
I->getType()));
489 PtrIdxTy,
Idx->getSExtValue() * (int64_t)ElementSize,
true);
490 allocateCandidatesAndFindBasis(Candidate::GEP,
B, ScaledIdx, S,
I);
493void StraightLineStrengthReduce::factorArrayIndex(
Value *ArrayIdx,
498 allocateCandidatesAndFindBasisForGEP(
500 ArrayIdx, ElementSize,
GEP);
517 allocateCandidatesAndFindBasisForGEP(
Base, RHS, LHS, ElementSize,
GEP);
524 allocateCandidatesAndFindBasisForGEP(
Base, PowerOf2, LHS, ElementSize,
GEP);
528void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForGEP(
531 if (
GEP->getType()->isVectorTy())
539 for (
unsigned I = 1,
E =
GEP->getNumOperands();
I !=
E; ++
I, ++GTI) {
543 const SCEV *OrigIndexExpr = IndexExprs[
I - 1];
544 IndexExprs[
I - 1] = SE->getZero(OrigIndexExpr->
getType());
548 const SCEV *BaseExpr = SE->getGEPExpr(cast<GEPOperator>(
GEP), IndexExprs);
552 DL->getIndexSizeInBits(
GEP->getAddressSpace())) {
555 factorArrayIndex(ArrayIdx, BaseExpr, ElementSize,
GEP);
560 Value *TruncatedArrayIdx =
nullptr;
563 DL->getIndexSizeInBits(
GEP->getAddressSpace())) {
566 factorArrayIndex(TruncatedArrayIdx, BaseExpr, ElementSize,
GEP);
569 IndexExprs[
I - 1] = OrigIndexExpr;
575 if (
A.getBitWidth() <
B.getBitWidth())
576 A =
A.sext(
B.getBitWidth());
577 else if (
A.getBitWidth() >
B.getBitWidth())
578 B =
B.sext(
A.getBitWidth());
581Value *StraightLineStrengthReduce::emitBump(
const Candidate &Basis,
585 bool &BumpWithUglyGEP) {
586 APInt Idx =
C.Index->getValue(), BasisIdx = Basis.Index->getValue();
590 BumpWithUglyGEP =
false;
591 if (Basis.CandidateKind == Candidate::GEP) {
594 DL->getTypeAllocSize(
595 cast<GetElementPtrInst>(Basis.Ins)->getResultElementType()));
601 BumpWithUglyGEP =
true;
606 if (IndexOffset == 1)
629 return Builder.
CreateMul(ExtendedStride, Delta);
632void StraightLineStrengthReduce::rewriteCandidateWithBasis(
633 const Candidate &
C,
const Candidate &Basis) {
634 assert(
C.CandidateKind == Basis.CandidateKind &&
C.Base == Basis.Base &&
635 C.Stride == Basis.Stride);
638 assert(Basis.Ins->getParent() !=
nullptr &&
"the basis is unlinked");
644 if (!
C.Ins->getParent())
648 bool BumpWithUglyGEP;
649 Value *Bump = emitBump(Basis,
C, Builder,
DL, BumpWithUglyGEP);
650 Value *Reduced =
nullptr;
651 switch (
C.CandidateKind) {
653 case Candidate::Mul: {
658 Reduced = Builder.
CreateSub(Basis.Ins, NegBump);
672 Reduced = Builder.
CreateAdd(Basis.Ins, Bump);
678 Type *OffsetTy =
DL->getIndexType(
C.Ins->getType());
679 bool InBounds = cast<GetElementPtrInst>(
C.Ins)->isInBounds();
680 if (BumpWithUglyGEP) {
682 unsigned AS = Basis.Ins->getType()->getPointerAddressSpace();
683 Type *CharTy = PointerType::get(Basis.Ins->getContext(), AS);
693 cast<GetElementPtrInst>(Basis.Ins)->getResultElementType(), Basis.Ins,
702 C.Ins->replaceAllUsesWith(Reduced);
705 C.Ins->removeFromParent();
706 UnlinkedInstructions.push_back(
C.Ins);
709bool StraightLineStrengthReduceLegacyPass::runOnFunction(
Function &
F) {
713 auto *
TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
F);
714 auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
715 auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
716 return StraightLineStrengthReduce(
DL, DT, SE,
TTI).runOnFunction(
F);
719bool StraightLineStrengthReduce::runOnFunction(
Function &
F) {
723 for (
auto &
I : *(
Node->getBlock()))
724 allocateCandidatesAndFindBasis(&
I);
728 while (!Candidates.empty()) {
729 const Candidate &
C = Candidates.back();
730 if (
C.Basis !=
nullptr) {
731 rewriteCandidateWithBasis(
C, *
C.Basis);
733 Candidates.pop_back();
737 for (
auto *UnlinkedInst : UnlinkedInstructions) {
738 for (
unsigned I = 0,
E = UnlinkedInst->getNumOperands();
I !=
E; ++
I) {
739 Value *
Op = UnlinkedInst->getOperand(
I);
740 UnlinkedInst->setOperand(
I,
nullptr);
743 UnlinkedInst->deleteValue();
745 bool Ret = !UnlinkedInstructions.empty();
746 UnlinkedInstructions.clear();
759 if (!StraightLineStrengthReduce(
DL, DT, SE,
TTI).runOnFunction(
F))
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
This file implements a class to represent arbitrary precision integral constant values and operations...
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
This file contains the declarations for the subclasses of Constant, which represent the different fla...
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
static bool runOnFunction(Function &F, bool PostInlining)
Module.h This file contains the declarations for the Module class.
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file defines the SmallVector class.
static bool matchesOr(Value *A, Value *&B, ConstantInt *&C)
static bool isAddFoldable(const SCEV *Base, ConstantInt *Index, Value *Stride, TargetTransformInfo *TTI)
static bool hasOnlyOneNonZeroIndex(GetElementPtrInst *GEP)
static void unifyBitWidth(APInt &A, APInt &B)
static bool matchesAdd(Value *A, Value *&B, ConstantInt *&C)
static bool isGEPFoldable(GetElementPtrInst *GEP, const TargetTransformInfo *TTI)
static const unsigned UnknownAddressSpace
Straight line strength reduction
Class for arbitrary precision integers.
bool isNegatedPowerOf2() const
Check if this APInt's negated value is a power of two greater than zero.
static void sdivrem(const APInt &LHS, const APInt &RHS, APInt &Quotient, APInt &Remainder)
bool isAllOnes() const
Determine if all bits are set. This is true for zero-width values.
unsigned getBitWidth() const
Return the number of bits in the APInt.
unsigned logBase2() const
bool isPowerOf2() const
Check if this APInt's value is a power of two greater than zero.
A container for analyses that lazily runs them and caches their results.
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
void setPreservesCFG()
This function should be called by the pass, iff they do not:
Represents analyses that only rely on functions' control flow.
This is the shared class of boolean and integer constants.
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
static Constant * get(Type *Ty, uint64_t V, bool IsSigned=false)
If Ty is a vector type, return a Constant with a splat of the given value.
This is an important base class in LLVM.
This class represents an Operation in the Expression.
A parsed version of the target data layout string in and methods for querying it.
Analysis pass which computes a DominatorTree.
Legacy analysis pass which computes a DominatorTree.
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
FunctionPass class - This class is used to implement most global optimizations.
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
an instruction for type-safe pointer arithmetic to access elements of arrays and structs
Value * CreateNeg(Value *V, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Value * CreateSub(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Value * CreateBitCast(Value *V, Type *DestTy, const Twine &Name="")
Value * CreateShl(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Value * CreateGEP(Type *Ty, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &Name="", bool IsInBounds=false)
IntegerType * getInt8Ty()
Fetch the type representing an 8-bit integer.
Value * CreateSExtOrTrunc(Value *V, Type *DestTy, const Twine &Name="")
Create a SExt or Trunc from the integer value V to DestTy.
Value * CreateMul(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Class to represent integer types.
static IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
A Module instance is used to store all the information related to an LLVM module.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
virtual bool doInitialization(Module &)
doInitialization - Virtual method overridden by subclasses to do any necessary initialization before ...
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
void preserveSet()
Mark an analysis set as preserved.
void preserve()
Mark an analysis as preserved.
This class represents an analyzed expression in the program.
Type * getType() const
Return the LLVM type of this SCEV expression.
Analysis pass that exposes the ScalarEvolution for a function.
The main scalar evolution driver.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
Analysis pass providing the TargetTransformInfo.
The instances of the Type class are immutable: once they are created, they are never changed.
unsigned getIntegerBitWidth() const
A Use represents the edge between a Value definition and its users.
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
LLVMContext & getContext() const
All values hold a context through their type.
void takeName(Value *V)
Transfer the name from V to this value.
Type * getIndexedType() const
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
@ C
The default llvm calling convention, compatible with C.
@ Invalid
Invalid file type.
BinaryOp_match< LHS, RHS, Instruction::Add > m_Add(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
CastInst_match< OpTy, Instruction::SExt > m_SExt(const OpTy &Op)
Matches SExt.
BinaryOp_match< cst_pred_ty< is_zero_int >, ValTy, Instruction::Sub > m_Neg(const ValTy &V)
Matches a 'Neg' as 'sub 0, V'.
OverflowingBinaryOp_match< LHS, RHS, Instruction::Shl, OverflowingBinaryOperator::NoSignedWrap > m_NSWShl(const LHS &L, const RHS &R)
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
BinaryOp_match< LHS, RHS, Instruction::Shl > m_Shl(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Or > m_Or(const LHS &L, const RHS &R)
OverflowingBinaryOp_match< LHS, RHS, Instruction::Mul, OverflowingBinaryOperator::NoSignedWrap > m_NSWMul(const LHS &L, const RHS &R)
This is an optimization pass for GlobalISel generic memory operations.
bool haveNoCommonBitsSet(const WithCache< const Value * > &LHSCache, const WithCache< const Value * > &RHSCache, const SimplifyQuery &SQ)
Return true if LHS and RHS have no common bits set.
bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
If the specified value is a trivially dead instruction, delete it.
void initializeStraightLineStrengthReduceLegacyPassPass(PassRegistry &)
gep_type_iterator gep_type_begin(const User *GEP)
iterator_range< df_iterator< T > > depth_first(const T &G)
FunctionPass * createStraightLineStrengthReducePass()