LLVM 22.0.0git
StraightLineStrengthReduce.cpp
Go to the documentation of this file.
1//===- StraightLineStrengthReduce.cpp - -----------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements straight-line strength reduction (SLSR). Unlike loop
10// strength reduction, this algorithm is designed to reduce arithmetic
11// redundancy in straight-line code instead of loops. It has proven to be
12// effective in simplifying arithmetic statements derived from an unrolled loop.
13// It can also simplify the logic of SeparateConstOffsetFromGEP.
14//
15// There are many optimizations we can perform in the domain of SLSR.
16// We look for strength reduction candidates in the following forms:
17//
18// Form Add: B + i * S
19// Form Mul: (B + i) * S
20// Form GEP: &B[i * S]
21//
22// where S is an integer variable, and i is a constant integer. If we found two
23// candidates S1 and S2 in the same form and S1 dominates S2, we may rewrite S2
24// in a simpler way with respect to S1 (index delta). For example,
25//
26// S1: X = B + i * S
27// S2: Y = B + i' * S => X + (i' - i) * S
28//
29// S1: X = (B + i) * S
30// S2: Y = (B + i') * S => X + (i' - i) * S
31//
32// S1: X = &B[i * S]
33// S2: Y = &B[i' * S] => &X[(i' - i) * S]
34//
35// Note: (i' - i) * S is folded to the extent possible.
36//
37// For Add and GEP forms, we can also rewrite a candidate in a simpler way
38// with respect to other dominating candidates if their B or S are different
39// but other parts are the same. For example,
40//
41// Base Delta:
42// S1: X = B + i * S
43// S2: Y = B' + i * S => X + (B' - B)
44//
45// S1: X = &B [i * S]
46// S2: Y = &B'[i * S] => X + (B' - B)
47//
48// Stride Delta:
49// S1: X = B + i * S
50// S2: Y = B + i * S' => X + i * (S' - S)
51//
52// S1: X = &B[i * S]
53// S2: Y = &B[i * S'] => X + i * (S' - S)
54//
55// PS: Stride delta rewrite on Mul form is usually non-profitable, and Base
56// delta rewrite sometimes is profitable, so we do not support them on Mul.
57//
58// This rewriting is in general a good idea. The code patterns we focus on
59// usually come from loop unrolling, so the delta is likely the same
60// across iterations and can be reused. When that happens, the optimized form
61// takes only one add starting from the second iteration.
62//
63// When such rewriting is possible, we call S1 a "basis" of S2. When S2 has
64// multiple bases, we choose to rewrite S2 with respect to its "immediate"
65// basis, the basis that is the closest ancestor in the dominator tree.
66//
67// TODO:
68//
69// - Floating point arithmetics when fast math is enabled.
70
72#include "llvm/ADT/APInt.h"
74#include "llvm/ADT/SetVector.h"
80#include "llvm/IR/Constants.h"
81#include "llvm/IR/DataLayout.h"
83#include "llvm/IR/Dominators.h"
85#include "llvm/IR/IRBuilder.h"
86#include "llvm/IR/Instruction.h"
88#include "llvm/IR/Module.h"
89#include "llvm/IR/Operator.h"
91#include "llvm/IR/Type.h"
92#include "llvm/IR/Value.h"
94#include "llvm/Pass.h"
100#include <cassert>
101#include <cstdint>
102#include <limits>
103#include <list>
104#include <queue>
105#include <vector>
106
107using namespace llvm;
108using namespace PatternMatch;
109
110#define DEBUG_TYPE "slsr"
111
112static const unsigned UnknownAddressSpace =
113 std::numeric_limits<unsigned>::max();
114
115DEBUG_COUNTER(StraightLineStrengthReduceCounter, "slsr-counter",
116 "Controls whether rewriteCandidate is executed.");
117
118namespace {
119
120class StraightLineStrengthReduceLegacyPass : public FunctionPass {
121 const DataLayout *DL = nullptr;
122
123public:
124 static char ID;
125
126 StraightLineStrengthReduceLegacyPass() : FunctionPass(ID) {
129 }
130
131 void getAnalysisUsage(AnalysisUsage &AU) const override {
132 AU.addRequired<DominatorTreeWrapperPass>();
133 AU.addRequired<ScalarEvolutionWrapperPass>();
134 AU.addRequired<TargetTransformInfoWrapperPass>();
135 // We do not modify the shape of the CFG.
136 AU.setPreservesCFG();
137 }
138
139 bool doInitialization(Module &M) override {
140 DL = &M.getDataLayout();
141 return false;
142 }
143
144 bool runOnFunction(Function &F) override;
145};
146
147class StraightLineStrengthReduce {
148public:
149 StraightLineStrengthReduce(const DataLayout *DL, DominatorTree *DT,
150 ScalarEvolution *SE, TargetTransformInfo *TTI)
151 : DL(DL), DT(DT), SE(SE), TTI(TTI) {}
152
153 // SLSR candidate. Such a candidate must be in one of the forms described in
154 // the header comments.
155 struct Candidate {
156 enum Kind {
157 Invalid, // reserved for the default constructor
158 Add, // B + i * S
159 Mul, // (B + i) * S
160 GEP, // &B[..][i * S][..]
161 };
162
163 enum DKind {
164 InvalidDelta, // reserved for the default constructor
165 IndexDelta, // Delta is a constant from Index
166 BaseDelta, // Delta is a constant or variable from Base
167 StrideDelta, // Delta is a constant or variable from Stride
168 };
169
170 Candidate() = default;
171 Candidate(Kind CT, const SCEV *B, ConstantInt *Idx, Value *S,
172 Instruction *I, const SCEV *StrideSCEV)
173 : CandidateKind(CT), Base(B), Index(Idx), Stride(S), Ins(I),
174 StrideSCEV(StrideSCEV) {}
175
176 Kind CandidateKind = Invalid;
177
178 const SCEV *Base = nullptr;
179 // TODO: Swap Index and Stride's name.
180 // Note that Index and Stride of a GEP candidate do not necessarily have the
181 // same integer type. In that case, during rewriting, Stride will be
182 // sign-extended or truncated to Index's type.
183 ConstantInt *Index = nullptr;
184
185 Value *Stride = nullptr;
186
187 // The instruction this candidate corresponds to. It helps us to rewrite a
188 // candidate with respect to its immediate basis. Note that one instruction
189 // can correspond to multiple candidates depending on how you associate the
190 // expression. For instance,
191 //
192 // (a + 1) * (b + 2)
193 //
194 // can be treated as
195 //
196 // <Base: a, Index: 1, Stride: b + 2>
197 //
198 // or
199 //
200 // <Base: b, Index: 2, Stride: a + 1>
201 Instruction *Ins = nullptr;
202
203 // Points to the immediate basis of this candidate, or nullptr if we cannot
204 // find any basis for this candidate.
205 Candidate *Basis = nullptr;
206
207 DKind DeltaKind = InvalidDelta;
208
209 // Store SCEV of Stride to compute delta from different strides
210 const SCEV *StrideSCEV = nullptr;
211
212 // Points to (Y - X) that will be used to rewrite this candidate.
213 Value *Delta = nullptr;
214
215 /// Cost model: Evaluate the computational efficiency of the candidate.
216 ///
217 /// Efficiency levels (higher is better):
218 /// ZeroInst (5) - [Variable] or [Const]
219 /// OneInstOneVar (4) - [Variable + Const] or [Variable * Const]
220 /// OneInstTwoVar (3) - [Variable + Variable] or [Variable * Variable]
221 /// TwoInstOneVar (2) - [Const + Const * Variable]
222 /// TwoInstTwoVar (1) - [Variable + Const * Variable]
223 enum EfficiencyLevel : unsigned {
224 Unknown = 0,
225 TwoInstTwoVar = 1,
226 TwoInstOneVar = 2,
227 OneInstTwoVar = 3,
228 OneInstOneVar = 4,
229 ZeroInst = 5
230 };
231
232 static EfficiencyLevel
233 getComputationEfficiency(Kind CandidateKind, const ConstantInt *Index,
234 const Value *Stride, const SCEV *Base = nullptr) {
235 bool IsConstantBase = false;
236 bool IsZeroBase = false;
237 // When evaluating the efficiency of a rewrite, if the Base's SCEV is
238 // not available, conservatively assume the base is not constant.
239 if (auto *ConstBase = dyn_cast_or_null<SCEVConstant>(Base)) {
240 IsConstantBase = true;
241 IsZeroBase = ConstBase->getValue()->isZero();
242 }
243
244 bool IsConstantStride = isa<ConstantInt>(Stride);
245 bool IsZeroStride =
246 IsConstantStride && cast<ConstantInt>(Stride)->isZero();
247 // All constants
248 if (IsConstantBase && IsConstantStride)
249 return ZeroInst;
250
251 // (Base + Index) * Stride
252 if (CandidateKind == Mul) {
253 if (IsZeroStride)
254 return ZeroInst;
255 if (Index->isZero())
256 return (IsConstantStride || IsConstantBase) ? OneInstOneVar
257 : OneInstTwoVar;
258
259 if (IsConstantBase)
260 return IsZeroBase && (Index->isOne() || Index->isMinusOne())
261 ? ZeroInst
262 : OneInstOneVar;
263
264 if (IsConstantStride) {
265 auto *CI = cast<ConstantInt>(Stride);
266 return (CI->isOne() || CI->isMinusOne()) ? OneInstOneVar
267 : TwoInstOneVar;
268 }
269 return TwoInstTwoVar;
270 }
271
272 // Base + Index * Stride
273 assert(CandidateKind == Add || CandidateKind == GEP);
274 if (Index->isZero() || IsZeroStride)
275 return ZeroInst;
276
277 bool IsSimpleIndex = Index->isOne() || Index->isMinusOne();
278
279 if (IsConstantBase)
280 return IsZeroBase ? (IsSimpleIndex ? ZeroInst : OneInstOneVar)
281 : (IsSimpleIndex ? OneInstOneVar : TwoInstOneVar);
282
283 if (IsConstantStride)
284 return IsZeroStride ? ZeroInst : OneInstOneVar;
285
286 if (IsSimpleIndex)
287 return OneInstTwoVar;
288
289 return TwoInstTwoVar;
290 }
291
292 // Evaluate if the given delta is profitable to rewrite this candidate.
293 bool isProfitableRewrite(const Value *Delta, const DKind DeltaKind) const {
294 // This function cannot accurately evaluate the profit of whole expression
295 // with context. A candidate (B + I * S) cannot express whether this
296 // instruction needs to compute on its own (I * S), which may be shared
297 // with other candidates or may need instructions to compute.
298 // If the rewritten form has the same strength, still rewrite to
299 // (X + Delta) since it may expose more CSE opportunities on Delta, as
300 // unrolled loops usually have identical Delta for each unrolled body.
301 //
302 // Note, this function should only be used on Index Delta rewrite.
303 // Base and Stride delta need context info to evaluate the register
304 // pressure impact from variable delta.
305 return getComputationEfficiency(CandidateKind, Index, Stride, Base) <=
306 getRewriteEfficiency(Delta, DeltaKind);
307 }
308
309 // Evaluate the rewrite efficiency of this candidate with its Basis
310 EfficiencyLevel getRewriteEfficiency() const {
311 return Basis ? getRewriteEfficiency(Delta, DeltaKind) : Unknown;
312 }
313
314 // Evaluate the rewrite efficiency of this candidate with a given delta
315 EfficiencyLevel getRewriteEfficiency(const Value *Delta,
316 const DKind DeltaKind) const {
317 switch (DeltaKind) {
318 case BaseDelta: // [X + Delta]
319 return getComputationEfficiency(
320 CandidateKind,
321 ConstantInt::get(cast<IntegerType>(Delta->getType()), 1), Delta);
322 case StrideDelta: // [X + Index * Delta]
323 return getComputationEfficiency(CandidateKind, Index, Delta);
324 case IndexDelta: // [X + Delta * Stride]
325 return getComputationEfficiency(CandidateKind, cast<ConstantInt>(Delta),
326 Stride);
327 default:
328 return Unknown;
329 }
330 }
331
332 bool isHighEfficiency() const {
333 return getComputationEfficiency(CandidateKind, Index, Stride, Base) >=
334 OneInstOneVar;
335 }
336
337 // Verify that this candidate has valid delta components relative to the
338 // basis
339 bool hasValidDelta(const Candidate &Basis) const {
340 switch (DeltaKind) {
341 case IndexDelta:
342 // Index differs, Base and Stride must match
343 return Base == Basis.Base && StrideSCEV == Basis.StrideSCEV;
344 case StrideDelta:
345 // Stride differs, Base and Index must match
346 return Base == Basis.Base && Index == Basis.Index;
347 case BaseDelta:
348 // Base differs, Stride and Index must match
349 return StrideSCEV == Basis.StrideSCEV && Index == Basis.Index;
350 default:
351 return false;
352 }
353 }
354 };
355
356 bool runOnFunction(Function &F);
357
358private:
359 // Fetch straight-line basis for rewriting C, update C.Basis to point to it,
360 // and store the delta between C and its Basis in C.Delta.
361 void setBasisAndDeltaFor(Candidate &C);
362 // Returns whether the candidate can be folded into an addressing mode.
363 bool isFoldable(const Candidate &C, TargetTransformInfo *TTI);
364
365 // Checks whether I is in a candidate form. If so, adds all the matching forms
366 // to Candidates, and tries to find the immediate basis for each of them.
367 void allocateCandidatesAndFindBasis(Instruction *I);
368
369 // Allocate candidates and find bases for Add instructions.
370 void allocateCandidatesAndFindBasisForAdd(Instruction *I);
371
372 // Given I = LHS + RHS, factors RHS into i * S and makes (LHS + i * S) a
373 // candidate.
374 void allocateCandidatesAndFindBasisForAdd(Value *LHS, Value *RHS,
375 Instruction *I);
376 // Allocate candidates and find bases for Mul instructions.
377 void allocateCandidatesAndFindBasisForMul(Instruction *I);
378
379 // Splits LHS into Base + Index and, if succeeds, calls
380 // allocateCandidatesAndFindBasis.
381 void allocateCandidatesAndFindBasisForMul(Value *LHS, Value *RHS,
382 Instruction *I);
383
384 // Allocate candidates and find bases for GetElementPtr instructions.
385 void allocateCandidatesAndFindBasisForGEP(GetElementPtrInst *GEP);
386
387 // Adds the given form <CT, B, Idx, S> to Candidates, and finds its immediate
388 // basis.
389 void allocateCandidatesAndFindBasis(Candidate::Kind CT, const SCEV *B,
390 ConstantInt *Idx, Value *S,
391 Instruction *I);
392
393 // Rewrites candidate C with respect to Basis.
394 void rewriteCandidate(const Candidate &C);
395
396 // Emit code that computes the "bump" from Basis to C.
397 static Value *emitBump(const Candidate &Basis, const Candidate &C,
398 IRBuilder<> &Builder, const DataLayout *DL);
399
400 const DataLayout *DL = nullptr;
401 DominatorTree *DT = nullptr;
402 ScalarEvolution *SE;
403 TargetTransformInfo *TTI = nullptr;
404 std::list<Candidate> Candidates;
405
406 // Map from SCEV to instructions that represent the value,
407 // instructions are sorted in depth-first order.
408 DenseMap<const SCEV *, SmallSetVector<Instruction *, 2>> SCEVToInsts;
409
410 // Record the dependency between instructions. If C.Basis == B, we would have
411 // {B.Ins -> {C.Ins, ...}}.
412 MapVector<Instruction *, std::vector<Instruction *>> DependencyGraph;
413
414 // Map between each instruction and its possible candidates.
415 DenseMap<Instruction *, SmallVector<Candidate *, 3>> RewriteCandidates;
416
417 // All instructions that have candidates sort in topological order based on
418 // dependency graph, from roots to leaves.
419 std::vector<Instruction *> SortedCandidateInsts;
420
421 // Record all instructions that are already rewritten and will be removed
422 // later.
423 std::vector<Instruction *> DeadInstructions;
424
425 // Classify candidates against Delta kind
426 class CandidateDictTy {
427 public:
428 using CandsTy = SmallVector<Candidate *, 8>;
429 using BBToCandsTy = DenseMap<const BasicBlock *, CandsTy>;
430
431 private:
432 // Index delta Basis must have the same (Base, StrideSCEV, Inst.Type)
433 using IndexDeltaKeyTy = std::tuple<const SCEV *, const SCEV *, Type *>;
434 DenseMap<IndexDeltaKeyTy, BBToCandsTy> IndexDeltaCandidates;
435
436 // Base delta Basis must have the same (StrideSCEV, Index, Inst.Type)
437 using BaseDeltaKeyTy = std::tuple<const SCEV *, ConstantInt *, Type *>;
438 DenseMap<BaseDeltaKeyTy, BBToCandsTy> BaseDeltaCandidates;
439
440 // Stride delta Basis must have the same (Base, Index, Inst.Type)
441 using StrideDeltaKeyTy = std::tuple<const SCEV *, ConstantInt *, Type *>;
442 DenseMap<StrideDeltaKeyTy, BBToCandsTy> StrideDeltaCandidates;
443
444 public:
445 // TODO: Disable index delta on GEP after we completely move
446 // from typed GEP to PtrAdd.
447 const BBToCandsTy *getCandidatesWithDeltaKind(const Candidate &C,
448 Candidate::DKind K) const {
449 assert(K != Candidate::InvalidDelta);
450 if (K == Candidate::IndexDelta) {
451 IndexDeltaKeyTy IndexDeltaKey(C.Base, C.StrideSCEV, C.Ins->getType());
452 auto It = IndexDeltaCandidates.find(IndexDeltaKey);
453 if (It != IndexDeltaCandidates.end())
454 return &It->second;
455 } else if (K == Candidate::BaseDelta) {
456 BaseDeltaKeyTy BaseDeltaKey(C.StrideSCEV, C.Index, C.Ins->getType());
457 auto It = BaseDeltaCandidates.find(BaseDeltaKey);
458 if (It != BaseDeltaCandidates.end())
459 return &It->second;
460 } else {
461 assert(K == Candidate::StrideDelta);
462 StrideDeltaKeyTy StrideDeltaKey(C.Base, C.Index, C.Ins->getType());
463 auto It = StrideDeltaCandidates.find(StrideDeltaKey);
464 if (It != StrideDeltaCandidates.end())
465 return &It->second;
466 }
467 return nullptr;
468 }
469
470 // Pointers to C must remain valid until CandidateDict is cleared.
471 void add(Candidate &C) {
472 Type *ValueType = C.Ins->getType();
473 BasicBlock *BB = C.Ins->getParent();
474 IndexDeltaKeyTy IndexDeltaKey(C.Base, C.StrideSCEV, ValueType);
475 BaseDeltaKeyTy BaseDeltaKey(C.StrideSCEV, C.Index, ValueType);
476 StrideDeltaKeyTy StrideDeltaKey(C.Base, C.Index, ValueType);
477 IndexDeltaCandidates[IndexDeltaKey][BB].push_back(&C);
478 BaseDeltaCandidates[BaseDeltaKey][BB].push_back(&C);
479 StrideDeltaCandidates[StrideDeltaKey][BB].push_back(&C);
480 }
481 // Remove all mappings from set
482 void clear() {
483 IndexDeltaCandidates.clear();
484 BaseDeltaCandidates.clear();
485 StrideDeltaCandidates.clear();
486 }
487 } CandidateDict;
488
489 const SCEV *getAndRecordSCEV(Value *V) {
490 auto *S = SE->getSCEV(V);
493 SCEVToInsts[S].insert(cast<Instruction>(V));
494
495 return S;
496 }
497
498 // Get the nearest instruction before CI that represents the value of S,
499 // return nullptr if no instruction is associated with S or S is not a
500 // reusable expression.
501 Value *getNearestValueOfSCEV(const SCEV *S, const Instruction *CI) const {
503 return nullptr;
504
505 if (auto *SU = dyn_cast<SCEVUnknown>(S))
506 return SU->getValue();
507 if (auto *SC = dyn_cast<SCEVConstant>(S))
508 return SC->getValue();
509
510 auto It = SCEVToInsts.find(S);
511 if (It == SCEVToInsts.end())
512 return nullptr;
513
514 // Instructions are sorted in depth-first order, so search for the nearest
515 // instruction by walking the list in reverse order.
516 for (Instruction *I : reverse(It->second))
517 if (DT->dominates(I, CI))
518 return I;
519
520 return nullptr;
521 }
522
523 struct DeltaInfo {
524 Candidate *Cand;
525 Candidate::DKind DeltaKind;
526 Value *Delta;
527
528 DeltaInfo()
529 : Cand(nullptr), DeltaKind(Candidate::InvalidDelta), Delta(nullptr) {}
530 DeltaInfo(Candidate *Cand, Candidate::DKind DeltaKind, Value *Delta)
531 : Cand(Cand), DeltaKind(DeltaKind), Delta(Delta) {}
532 operator bool() const { return Cand != nullptr; }
533 };
534
535 friend raw_ostream &operator<<(raw_ostream &OS, const DeltaInfo &DI);
536
537 DeltaInfo compressPath(Candidate &C, Candidate *Basis) const;
538
539 Candidate *pickRewriteCandidate(Instruction *I) const;
540 void sortCandidateInstructions();
541 static Constant *getIndexDelta(Candidate &C, Candidate &Basis);
542 static bool isSimilar(Candidate &C, Candidate &Basis, Candidate::DKind K);
543
544 // Add Basis -> C in DependencyGraph and propagate
545 // C.Stride and C.Delta's dependency to C
546 void addDependency(Candidate &C, Candidate *Basis) {
547 if (Basis)
548 DependencyGraph[Basis->Ins].emplace_back(C.Ins);
549
550 // If any candidate of Inst has a basis, then Inst will be rewritten,
551 // C must be rewritten after rewriting Inst, so we need to propagate
552 // the dependency to C
553 auto PropagateDependency = [&](Instruction *Inst) {
554 if (auto CandsIt = RewriteCandidates.find(Inst);
555 CandsIt != RewriteCandidates.end() &&
556 llvm::any_of(CandsIt->second,
557 [](Candidate *Cand) { return Cand->Basis; }))
558 DependencyGraph[Inst].emplace_back(C.Ins);
559 };
560
561 // If C has a variable delta and the delta is a candidate,
562 // propagate its dependency to C
563 if (auto *DeltaInst = dyn_cast_or_null<Instruction>(C.Delta))
564 PropagateDependency(DeltaInst);
565
566 // If the stride is a candidate, propagate its dependency to C
567 if (auto *StrideInst = dyn_cast<Instruction>(C.Stride))
568 PropagateDependency(StrideInst);
569 };
570};
571
573 const StraightLineStrengthReduce::Candidate &C) {
574 OS << "Ins: " << *C.Ins << "\n Base: " << *C.Base
575 << "\n Index: " << *C.Index << "\n Stride: " << *C.Stride
576 << "\n StrideSCEV: " << *C.StrideSCEV;
577 if (C.Basis)
578 OS << "\n Delta: " << *C.Delta << "\n Basis: \n [ " << *C.Basis << " ]";
579 return OS;
580}
581
582[[maybe_unused]] LLVM_DUMP_METHOD inline raw_ostream &
583operator<<(raw_ostream &OS, const StraightLineStrengthReduce::DeltaInfo &DI) {
584 OS << "Cand: " << *DI.Cand << "\n";
585 OS << "Delta Kind: ";
586 switch (DI.DeltaKind) {
587 case StraightLineStrengthReduce::Candidate::IndexDelta:
588 OS << "Index";
589 break;
590 case StraightLineStrengthReduce::Candidate::BaseDelta:
591 OS << "Base";
592 break;
593 case StraightLineStrengthReduce::Candidate::StrideDelta:
594 OS << "Stride";
595 break;
596 default:
597 break;
598 }
599 OS << "\nDelta: " << *DI.Delta;
600 return OS;
601}
602
603} // end anonymous namespace
604
605char StraightLineStrengthReduceLegacyPass::ID = 0;
606
607INITIALIZE_PASS_BEGIN(StraightLineStrengthReduceLegacyPass, "slsr",
608 "Straight line strength reduction", false, false)
612INITIALIZE_PASS_END(StraightLineStrengthReduceLegacyPass, "slsr",
613 "Straight line strength reduction", false, false)
614
616 return new StraightLineStrengthReduceLegacyPass();
617}
618
619// A helper function that unifies the bitwidth of A and B.
620static void unifyBitWidth(APInt &A, APInt &B) {
621 if (A.getBitWidth() < B.getBitWidth())
622 A = A.sext(B.getBitWidth());
623 else if (A.getBitWidth() > B.getBitWidth())
624 B = B.sext(A.getBitWidth());
625}
626
627Constant *StraightLineStrengthReduce::getIndexDelta(Candidate &C,
628 Candidate &Basis) {
629 APInt Idx = C.Index->getValue(), BasisIdx = Basis.Index->getValue();
630 unifyBitWidth(Idx, BasisIdx);
631 APInt IndexDelta = Idx - BasisIdx;
632 IntegerType *DeltaType =
633 IntegerType::get(C.Ins->getContext(), IndexDelta.getBitWidth());
634 return ConstantInt::get(DeltaType, IndexDelta);
635}
636
637bool StraightLineStrengthReduce::isSimilar(Candidate &C, Candidate &Basis,
638 Candidate::DKind K) {
639 bool SameType = false;
640 switch (K) {
641 case Candidate::StrideDelta:
642 SameType = C.StrideSCEV->getType() == Basis.StrideSCEV->getType();
643 break;
644 case Candidate::BaseDelta:
645 SameType = C.Base->getType() == Basis.Base->getType();
646 break;
647 case Candidate::IndexDelta:
648 SameType = true;
649 break;
650 default:;
651 }
652 return SameType && Basis.Ins != C.Ins &&
653 Basis.CandidateKind == C.CandidateKind;
654}
655
656void StraightLineStrengthReduce::setBasisAndDeltaFor(Candidate &C) {
657 auto SearchFrom = [this, &C](const CandidateDictTy::BBToCandsTy &BBToCands,
658 auto IsTarget) -> bool {
659 // Search dominating candidates by walking the immediate-dominator chain
660 // from the candidate's defining block upward. Visiting blocks in this
661 // order ensures we prefer the closest dominating basis.
662 const BasicBlock *BB = C.Ins->getParent();
663 while (BB) {
664 auto It = BBToCands.find(BB);
665 if (It != BBToCands.end())
666 for (Candidate *Basis : reverse(It->second))
667 if (IsTarget(Basis))
668 return true;
669
670 const DomTreeNode *Node = DT->getNode(BB);
671 if (!Node)
672 break;
673 Node = Node->getIDom();
674 BB = Node ? Node->getBlock() : nullptr;
675 }
676 return false;
677 };
678
679 // Priority:
680 // Constant Delta from Index > Constant Delta from Base >
681 // Constant Delta from Stride > Variable Delta from Base or Stride
682 // TODO: Change the priority to align with the cost model.
683
684 // First, look for a constant index-diff basis
685 if (const auto *IndexDeltaCandidates =
686 CandidateDict.getCandidatesWithDeltaKind(C, Candidate::IndexDelta)) {
687 bool FoundConstDelta =
688 SearchFrom(*IndexDeltaCandidates, [&](Candidate *Basis) {
689 if (isSimilar(C, *Basis, Candidate::IndexDelta)) {
690 assert(DT->dominates(Basis->Ins, C.Ins));
691 auto *Delta = getIndexDelta(C, *Basis);
692 if (!C.isProfitableRewrite(Delta, Candidate::IndexDelta))
693 return false;
694 C.Basis = Basis;
695 C.DeltaKind = Candidate::IndexDelta;
696 C.Delta = Delta;
697 LLVM_DEBUG(dbgs() << "Found delta from Index " << *C.Delta << "\n");
698 return true;
699 }
700 return false;
701 });
702 if (FoundConstDelta)
703 return;
704 }
705
706 // No constant-index-diff basis found. look for the best possible base-diff
707 // or stride-diff basis
708 // Base/Stride diffs not supported for form (B + i) * S
709 if (C.CandidateKind == Candidate::Mul)
710 return;
711
712 auto For = [this, &C](Candidate::DKind K) {
713 // return true if find a Basis with constant delta and stop searching,
714 // return false if did not find a Basis or the delta is not a constant
715 // and continue searching for a Basis with constant delta
716 return [K, this, &C](Candidate *Basis) -> bool {
717 if (!isSimilar(C, *Basis, K))
718 return false;
719
720 assert(DT->dominates(Basis->Ins, C.Ins));
721 const SCEV *BasisPart =
722 (K == Candidate::BaseDelta) ? Basis->Base : Basis->StrideSCEV;
723 const SCEV *CandPart =
724 (K == Candidate::BaseDelta) ? C.Base : C.StrideSCEV;
725 const SCEV *Diff = SE->getMinusSCEV(CandPart, BasisPart);
726 Value *AvailableVal = getNearestValueOfSCEV(Diff, C.Ins);
727 if (!AvailableVal)
728 return false;
729
730 // Record delta if none has been found yet, or the new delta is
731 // a constant that is better than the existing delta.
732 if (!C.Delta || isa<ConstantInt>(AvailableVal)) {
733 C.Delta = AvailableVal;
734 C.Basis = Basis;
735 C.DeltaKind = K;
736 }
737 return isa<ConstantInt>(C.Delta);
738 };
739 };
740
741 if (const auto *BaseDeltaCandidates =
742 CandidateDict.getCandidatesWithDeltaKind(C, Candidate::BaseDelta)) {
743 if (SearchFrom(*BaseDeltaCandidates, For(Candidate::BaseDelta))) {
744 LLVM_DEBUG(dbgs() << "Found delta from Base: " << *C.Delta << "\n");
745 return;
746 }
747 }
748
749 if (const auto *StrideDeltaCandidates =
750 CandidateDict.getCandidatesWithDeltaKind(C, Candidate::StrideDelta)) {
751 if (SearchFrom(*StrideDeltaCandidates, For(Candidate::StrideDelta))) {
752 LLVM_DEBUG(dbgs() << "Found delta from Stride: " << *C.Delta << "\n");
753 return;
754 }
755 }
756
757 // If we did not find a constant delta, we might have found a variable delta
758 if (C.Delta) {
759 LLVM_DEBUG({
760 dbgs() << "Found delta from ";
761 if (C.DeltaKind == Candidate::BaseDelta)
762 dbgs() << "Base: ";
763 else
764 dbgs() << "Stride: ";
765 dbgs() << *C.Delta << "\n";
766 });
767 assert(C.DeltaKind != Candidate::InvalidDelta && C.Basis);
768 }
769}
770
771// Compress the path from `Basis` to the deepest Basis in the Basis chain
772// to avoid non-profitable data dependency and improve ILP.
773// X = A + 1
774// Y = X + 1
775// Z = Y + 1
776// ->
777// X = A + 1
778// Y = A + 2
779// Z = A + 3
780// Return the delta info for C aginst the new Basis
781auto StraightLineStrengthReduce::compressPath(Candidate &C,
782 Candidate *Basis) const
783 -> DeltaInfo {
784 if (!Basis || !Basis->Basis || C.CandidateKind == Candidate::Mul)
785 return {};
786 Candidate *Root = Basis;
787 Value *NewDelta = nullptr;
788 auto NewKind = Candidate::InvalidDelta;
789
790 while (Root->Basis) {
791 Candidate *NextRoot = Root->Basis;
792 if (C.Base == NextRoot->Base && C.StrideSCEV == NextRoot->StrideSCEV &&
793 isSimilar(C, *NextRoot, Candidate::IndexDelta)) {
794 ConstantInt *CI = cast<ConstantInt>(getIndexDelta(C, *NextRoot));
795 if (CI->isZero() || CI->isOne() || isa<SCEVConstant>(C.StrideSCEV)) {
796 Root = NextRoot;
797 NewKind = Candidate::IndexDelta;
798 NewDelta = CI;
799 continue;
800 }
801 }
802
803 const SCEV *CandPart = nullptr;
804 const SCEV *BasisPart = nullptr;
805 auto CurrKind = Candidate::InvalidDelta;
806 if (C.Base == NextRoot->Base && C.Index == NextRoot->Index) {
807 CandPart = C.StrideSCEV;
808 BasisPart = NextRoot->StrideSCEV;
809 CurrKind = Candidate::StrideDelta;
810 } else if (C.StrideSCEV == NextRoot->StrideSCEV &&
811 C.Index == NextRoot->Index) {
812 CandPart = C.Base;
813 BasisPart = NextRoot->Base;
814 CurrKind = Candidate::BaseDelta;
815 } else
816 break;
817
818 assert(CandPart && BasisPart);
819 if (!isSimilar(C, *NextRoot, CurrKind))
820 break;
821
822 if (auto DeltaVal =
823 dyn_cast<SCEVConstant>(SE->getMinusSCEV(CandPart, BasisPart))) {
824 Root = NextRoot;
825 NewDelta = DeltaVal->getValue();
826 NewKind = CurrKind;
827 } else
828 break;
829 }
830
831 if (Root != Basis) {
832 assert(NewKind != Candidate::InvalidDelta && NewDelta);
833 LLVM_DEBUG(dbgs() << "Found new Basis with " << *NewDelta
834 << " from path compression.\n");
835 return {Root, NewKind, NewDelta};
836 }
837
838 return {};
839}
840
841// Topologically sort candidate instructions based on their relationship in
842// dependency graph.
843void StraightLineStrengthReduce::sortCandidateInstructions() {
844 SortedCandidateInsts.clear();
845 // An instruction may have multiple candidates that get different Basis
846 // instructions, and each candidate can get dependencies from Basis and
847 // Stride when Stride will also be rewritten by SLSR. Hence, an instruction
848 // may have multiple dependencies. Use InDegree to ensure all dependencies
849 // processed before processing itself.
850 DenseMap<Instruction *, int> InDegree;
851 for (auto &KV : DependencyGraph) {
852 InDegree.try_emplace(KV.first, 0);
853
854 for (auto *Child : KV.second) {
855 InDegree[Child]++;
856 }
857 }
858 std::queue<Instruction *> WorkList;
859 DenseSet<Instruction *> Visited;
860
861 for (auto &KV : DependencyGraph)
862 if (InDegree[KV.first] == 0)
863 WorkList.push(KV.first);
864
865 while (!WorkList.empty()) {
866 Instruction *I = WorkList.front();
867 WorkList.pop();
868 if (!Visited.insert(I).second)
869 continue;
870
871 SortedCandidateInsts.push_back(I);
872
873 for (auto *Next : DependencyGraph[I]) {
874 auto &Degree = InDegree[Next];
875 if (--Degree == 0)
876 WorkList.push(Next);
877 }
878 }
879
880 assert(SortedCandidateInsts.size() == DependencyGraph.size() &&
881 "Dependency graph should not have cycles");
882}
883
884auto StraightLineStrengthReduce::pickRewriteCandidate(Instruction *I) const
885 -> Candidate * {
886 // Return the candidate of instruction I that has the highest profit.
887 auto It = RewriteCandidates.find(I);
888 if (It == RewriteCandidates.end())
889 return nullptr;
890
891 Candidate *BestC = nullptr;
892 auto BestEfficiency = Candidate::Unknown;
893 for (Candidate *C : reverse(It->second))
894 if (C->Basis) {
895 auto Efficiency = C->getRewriteEfficiency();
896 if (Efficiency > BestEfficiency) {
897 BestEfficiency = Efficiency;
898 BestC = C;
899 }
900 }
901
902 return BestC;
903}
904
906 const TargetTransformInfo *TTI) {
907 SmallVector<const Value *, 4> Indices(GEP->indices());
908 return TTI->getGEPCost(GEP->getSourceElementType(), GEP->getPointerOperand(),
910}
911
912// Returns whether (Base + Index * Stride) can be folded to an addressing mode.
913static bool isAddFoldable(const SCEV *Base, ConstantInt *Index, Value *Stride,
915 // Index->getSExtValue() may crash if Index is wider than 64-bit.
916 return Index->getBitWidth() <= 64 &&
917 TTI->isLegalAddressingMode(Base->getType(), nullptr, 0, true,
918 Index->getSExtValue(), UnknownAddressSpace);
919}
920
921bool StraightLineStrengthReduce::isFoldable(const Candidate &C,
922 TargetTransformInfo *TTI) {
923 if (C.CandidateKind == Candidate::Add)
924 return isAddFoldable(C.Base, C.Index, C.Stride, TTI);
925 if (C.CandidateKind == Candidate::GEP)
927 return false;
928}
929
930void StraightLineStrengthReduce::allocateCandidatesAndFindBasis(
931 Candidate::Kind CT, const SCEV *B, ConstantInt *Idx, Value *S,
932 Instruction *I) {
933 // Record the SCEV of S that we may use it as a variable delta.
934 // Ensure that we rewrite C with a existing IR that reproduces delta value.
935
936 Candidate C(CT, B, Idx, S, I, getAndRecordSCEV(S));
937 // If we can fold I into an addressing mode, computing I is likely free or
938 // takes only one instruction. So, we don't need to analyze or rewrite it.
939 //
940 // Currently, this algorithm can at best optimize complex computations into
941 // a `variable +/* constant` form. However, some targets have stricter
942 // constraints on the their addressing mode.
943 // For example, a `variable + constant` can only be folded to an addressing
944 // mode if the constant falls within a certain range.
945 // So, we also check if the instruction is already high efficient enough
946 // for the strength reduction algorithm.
947 if (!isFoldable(C, TTI) && !C.isHighEfficiency()) {
948 setBasisAndDeltaFor(C);
949
950 // Compress unnecessary rewrite to improve ILP
951 if (auto Res = compressPath(C, C.Basis)) {
952 C.Basis = Res.Cand;
953 C.DeltaKind = Res.DeltaKind;
954 C.Delta = Res.Delta;
955 }
956 }
957 // Regardless of whether we find a basis for C, we need to push C to the
958 // candidate list so that it can be the basis of other candidates.
959 LLVM_DEBUG(dbgs() << "Allocated Candidate: " << C << "\n");
960 Candidates.push_back(C);
961 RewriteCandidates[C.Ins].push_back(&Candidates.back());
962 CandidateDict.add(Candidates.back());
963}
964
965void StraightLineStrengthReduce::allocateCandidatesAndFindBasis(
966 Instruction *I) {
967 switch (I->getOpcode()) {
968 case Instruction::Add:
969 allocateCandidatesAndFindBasisForAdd(I);
970 break;
971 case Instruction::Mul:
972 allocateCandidatesAndFindBasisForMul(I);
973 break;
974 case Instruction::GetElementPtr:
975 allocateCandidatesAndFindBasisForGEP(cast<GetElementPtrInst>(I));
976 break;
977 }
978}
979
980void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForAdd(
981 Instruction *I) {
982 // Try matching B + i * S.
983 if (!isa<IntegerType>(I->getType()))
984 return;
985
986 assert(I->getNumOperands() == 2 && "isn't I an add?");
987 Value *LHS = I->getOperand(0), *RHS = I->getOperand(1);
988 allocateCandidatesAndFindBasisForAdd(LHS, RHS, I);
989 if (LHS != RHS)
990 allocateCandidatesAndFindBasisForAdd(RHS, LHS, I);
991}
992
993void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForAdd(
994 Value *LHS, Value *RHS, Instruction *I) {
995 Value *S = nullptr;
996 ConstantInt *Idx = nullptr;
997 if (match(RHS, m_Mul(m_Value(S), m_ConstantInt(Idx)))) {
998 // I = LHS + RHS = LHS + Idx * S
999 allocateCandidatesAndFindBasis(Candidate::Add, SE->getSCEV(LHS), Idx, S, I);
1000 } else if (match(RHS, m_Shl(m_Value(S), m_ConstantInt(Idx)))) {
1001 // I = LHS + RHS = LHS + (S << Idx) = LHS + S * (1 << Idx)
1002 APInt One(Idx->getBitWidth(), 1);
1003 Idx = ConstantInt::get(Idx->getContext(), One << Idx->getValue());
1004 allocateCandidatesAndFindBasis(Candidate::Add, SE->getSCEV(LHS), Idx, S, I);
1005 } else {
1006 // At least, I = LHS + 1 * RHS
1007 ConstantInt *One = ConstantInt::get(cast<IntegerType>(I->getType()), 1);
1008 allocateCandidatesAndFindBasis(Candidate::Add, SE->getSCEV(LHS), One, RHS,
1009 I);
1010 }
1011}
1012
1013// Returns true if A matches B + C where C is constant.
1014static bool matchesAdd(Value *A, Value *&B, ConstantInt *&C) {
1015 return match(A, m_c_Add(m_Value(B), m_ConstantInt(C)));
1016}
1017
1018// Returns true if A matches B | C where C is constant.
1019static bool matchesOr(Value *A, Value *&B, ConstantInt *&C) {
1020 return match(A, m_c_Or(m_Value(B), m_ConstantInt(C)));
1021}
1022
1023void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForMul(
1024 Value *LHS, Value *RHS, Instruction *I) {
1025 Value *B = nullptr;
1026 ConstantInt *Idx = nullptr;
1027 if (matchesAdd(LHS, B, Idx)) {
1028 // If LHS is in the form of "Base + Index", then I is in the form of
1029 // "(Base + Index) * RHS".
1030 allocateCandidatesAndFindBasis(Candidate::Mul, SE->getSCEV(B), Idx, RHS, I);
1031 } else if (matchesOr(LHS, B, Idx) && haveNoCommonBitsSet(B, Idx, *DL)) {
1032 // If LHS is in the form of "Base | Index" and Base and Index have no common
1033 // bits set, then
1034 // Base | Index = Base + Index
1035 // and I is thus in the form of "(Base + Index) * RHS".
1036 allocateCandidatesAndFindBasis(Candidate::Mul, SE->getSCEV(B), Idx, RHS, I);
1037 } else {
1038 // Otherwise, at least try the form (LHS + 0) * RHS.
1039 ConstantInt *Zero = ConstantInt::get(cast<IntegerType>(I->getType()), 0);
1040 allocateCandidatesAndFindBasis(Candidate::Mul, SE->getSCEV(LHS), Zero, RHS,
1041 I);
1042 }
1043}
1044
1045void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForMul(
1046 Instruction *I) {
1047 // Try matching (B + i) * S.
1048 // TODO: we could extend SLSR to float and vector types.
1049 if (!isa<IntegerType>(I->getType()))
1050 return;
1051
1052 assert(I->getNumOperands() == 2 && "isn't I a mul?");
1053 Value *LHS = I->getOperand(0), *RHS = I->getOperand(1);
1054 allocateCandidatesAndFindBasisForMul(LHS, RHS, I);
1055 if (LHS != RHS) {
1056 // Symmetrically, try to split RHS to Base + Index.
1057 allocateCandidatesAndFindBasisForMul(RHS, LHS, I);
1058 }
1059}
1060
1061void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForGEP(
1062 GetElementPtrInst *GEP) {
1063 // TODO: handle vector GEPs
1064 if (GEP->getType()->isVectorTy())
1065 return;
1066
1068 for (Use &Idx : GEP->indices())
1069 IndexExprs.push_back(SE->getSCEV(Idx));
1070
1072 for (unsigned I = 1, E = GEP->getNumOperands(); I != E; ++I, ++GTI) {
1073 if (GTI.isStruct())
1074 continue;
1075
1076 const SCEV *OrigIndexExpr = IndexExprs[I - 1];
1077 IndexExprs[I - 1] = SE->getZero(OrigIndexExpr->getType());
1078
1079 // The base of this candidate is GEP's base plus the offsets of all
1080 // indices except this current one.
1081 const SCEV *BaseExpr = SE->getGEPExpr(cast<GEPOperator>(GEP), IndexExprs);
1082 Value *ArrayIdx = GEP->getOperand(I);
1083 uint64_t ElementSize = GTI.getSequentialElementStride(*DL);
1084 IntegerType *PtrIdxTy = cast<IntegerType>(DL->getIndexType(GEP->getType()));
1085 ConstantInt *ElementSizeIdx = ConstantInt::get(PtrIdxTy, ElementSize, true);
1086 if (ArrayIdx->getType()->getIntegerBitWidth() <=
1087 DL->getIndexSizeInBits(GEP->getAddressSpace())) {
1088 // Skip factoring if ArrayIdx is wider than the index size, because
1089 // ArrayIdx is implicitly truncated to the index size.
1090 allocateCandidatesAndFindBasis(Candidate::GEP, BaseExpr, ElementSizeIdx,
1091 ArrayIdx, GEP);
1092 }
1093 // When ArrayIdx is the sext of a value, we try to factor that value as
1094 // well. Handling this case is important because array indices are
1095 // typically sign-extended to the pointer index size.
1096 Value *TruncatedArrayIdx = nullptr;
1097 if (match(ArrayIdx, m_SExt(m_Value(TruncatedArrayIdx))) &&
1098 TruncatedArrayIdx->getType()->getIntegerBitWidth() <=
1099 DL->getIndexSizeInBits(GEP->getAddressSpace())) {
1100 // Skip factoring if TruncatedArrayIdx is wider than the pointer size,
1101 // because TruncatedArrayIdx is implicitly truncated to the pointer size.
1102 allocateCandidatesAndFindBasis(Candidate::GEP, BaseExpr, ElementSizeIdx,
1103 TruncatedArrayIdx, GEP);
1104 }
1105
1106 IndexExprs[I - 1] = OrigIndexExpr;
1107 }
1108}
1109
1110Value *StraightLineStrengthReduce::emitBump(const Candidate &Basis,
1111 const Candidate &C,
1112 IRBuilder<> &Builder,
1113 const DataLayout *DL) {
1114 auto CreateMul = [&](Value *LHS, Value *RHS) {
1115 if (ConstantInt *CR = dyn_cast<ConstantInt>(RHS)) {
1116 const APInt &ConstRHS = CR->getValue();
1117 IntegerType *DeltaType =
1118 IntegerType::get(C.Ins->getContext(), ConstRHS.getBitWidth());
1119 if (ConstRHS.isPowerOf2()) {
1120 ConstantInt *Exponent =
1121 ConstantInt::get(DeltaType, ConstRHS.logBase2());
1122 return Builder.CreateShl(LHS, Exponent);
1123 }
1124 if (ConstRHS.isNegatedPowerOf2()) {
1125 ConstantInt *Exponent =
1126 ConstantInt::get(DeltaType, (-ConstRHS).logBase2());
1127 return Builder.CreateNeg(Builder.CreateShl(LHS, Exponent));
1128 }
1129 }
1130
1131 return Builder.CreateMul(LHS, RHS);
1132 };
1133
1134 Value *Delta = C.Delta;
1135 // If Delta is 0, C is a fully redundant of C.Basis,
1136 // just replace C.Ins with Basis.Ins
1137 if (ConstantInt *CI = dyn_cast<ConstantInt>(Delta);
1138 CI && CI->getValue().isZero())
1139 return nullptr;
1140
1141 if (C.DeltaKind == Candidate::IndexDelta) {
1142 APInt IndexDelta = cast<ConstantInt>(C.Delta)->getValue();
1143 // IndexDelta
1144 // X = B + i * S
1145 // Y = B + i` * S
1146 // = B + (i + IndexDelta) * S
1147 // = B + i * S + IndexDelta * S
1148 // = X + IndexDelta * S
1149 // Bump = (i' - i) * S
1150
1151 // Common case 1: if (i' - i) is 1, Bump = S.
1152 if (IndexDelta == 1)
1153 return C.Stride;
1154 // Common case 2: if (i' - i) is -1, Bump = -S.
1155 if (IndexDelta.isAllOnes())
1156 return Builder.CreateNeg(C.Stride);
1157
1158 IntegerType *DeltaType =
1159 IntegerType::get(Basis.Ins->getContext(), IndexDelta.getBitWidth());
1160 Value *ExtendedStride = Builder.CreateSExtOrTrunc(C.Stride, DeltaType);
1161
1162 return CreateMul(ExtendedStride, C.Delta);
1163 }
1164
1165 assert(C.DeltaKind == Candidate::StrideDelta ||
1166 C.DeltaKind == Candidate::BaseDelta);
1167 assert(C.CandidateKind != Candidate::Mul);
1168 // StrideDelta
1169 // X = B + i * S
1170 // Y = B + i * S'
1171 // = B + i * (S + StrideDelta)
1172 // = B + i * S + i * StrideDelta
1173 // = X + i * StrideDelta
1174 // Bump = i * (S' - S)
1175 //
1176 // BaseDelta
1177 // X = B + i * S
1178 // Y = B' + i * S
1179 // = (B + BaseDelta) + i * S
1180 // = X + BaseDelta
1181 // Bump = (B' - B).
1182 Value *Bump = C.Delta;
1183 if (C.DeltaKind == Candidate::StrideDelta) {
1184 // If this value is consumed by a GEP, promote StrideDelta before doing
1185 // StrideDelta * Index to ensure the same semantics as the original GEP.
1186 if (C.CandidateKind == Candidate::GEP) {
1187 auto *GEP = cast<GetElementPtrInst>(C.Ins);
1188 Type *NewScalarIndexTy =
1189 DL->getIndexType(GEP->getPointerOperandType()->getScalarType());
1190 Bump = Builder.CreateSExtOrTrunc(Bump, NewScalarIndexTy);
1191 }
1192 if (!C.Index->isOne()) {
1193 Value *ExtendedIndex =
1194 Builder.CreateSExtOrTrunc(C.Index, Bump->getType());
1195 Bump = CreateMul(Bump, ExtendedIndex);
1196 }
1197 }
1198 return Bump;
1199}
1200
1201void StraightLineStrengthReduce::rewriteCandidate(const Candidate &C) {
1202 if (!DebugCounter::shouldExecute(StraightLineStrengthReduceCounter))
1203 return;
1204
1205 const Candidate &Basis = *C.Basis;
1206 assert(C.Delta && C.CandidateKind == Basis.CandidateKind &&
1207 C.hasValidDelta(Basis));
1208
1209 IRBuilder<> Builder(C.Ins);
1210 Value *Bump = emitBump(Basis, C, Builder, DL);
1211 Value *Reduced = nullptr; // equivalent to but weaker than C.Ins
1212 // If delta is 0, C is a fully redundant of Basis, and Bump is nullptr,
1213 // just replace C.Ins with Basis.Ins
1214 if (!Bump)
1215 Reduced = Basis.Ins;
1216 else {
1217 switch (C.CandidateKind) {
1218 case Candidate::Add:
1219 case Candidate::Mul: {
1220 // C = Basis + Bump
1221 Value *NegBump;
1222 if (match(Bump, m_Neg(m_Value(NegBump)))) {
1223 // If Bump is a neg instruction, emit C = Basis - (-Bump).
1224 Reduced = Builder.CreateSub(Basis.Ins, NegBump);
1225 // We only use the negative argument of Bump, and Bump itself may be
1226 // trivially dead.
1228 } else {
1229 // It's tempting to preserve nsw on Bump and/or Reduced. However, it's
1230 // usually unsound, e.g.,
1231 //
1232 // X = (-2 +nsw 1) *nsw INT_MAX
1233 // Y = (-2 +nsw 3) *nsw INT_MAX
1234 // =>
1235 // Y = X + 2 * INT_MAX
1236 //
1237 // Neither + and * in the resultant expression are nsw.
1238 Reduced = Builder.CreateAdd(Basis.Ins, Bump);
1239 }
1240 break;
1241 }
1242 case Candidate::GEP: {
1243 bool InBounds = cast<GetElementPtrInst>(C.Ins)->isInBounds();
1244 // C = (char *)Basis + Bump
1245 Reduced = Builder.CreatePtrAdd(Basis.Ins, Bump, "", InBounds);
1246 break;
1247 }
1248 default:
1249 llvm_unreachable("C.CandidateKind is invalid");
1250 };
1251 Reduced->takeName(C.Ins);
1252 }
1253 C.Ins->replaceAllUsesWith(Reduced);
1254 DeadInstructions.push_back(C.Ins);
1255}
1256
1257bool StraightLineStrengthReduceLegacyPass::runOnFunction(Function &F) {
1258 if (skipFunction(F))
1259 return false;
1260
1261 auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
1262 auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
1263 auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
1264 return StraightLineStrengthReduce(DL, DT, SE, TTI).runOnFunction(F);
1265}
1266
1267bool StraightLineStrengthReduce::runOnFunction(Function &F) {
1268 LLVM_DEBUG(dbgs() << "SLSR on Function: " << F.getName() << "\n");
1269 // Traverse the dominator tree in the depth-first order. This order makes sure
1270 // all bases of a candidate are in Candidates when we process it.
1271 for (const auto Node : depth_first(DT))
1272 for (auto &I : *(Node->getBlock()))
1273 allocateCandidatesAndFindBasis(&I);
1274
1275 // Build the dependency graph and sort candidate instructions from dependency
1276 // roots to leaves
1277 for (auto &C : Candidates) {
1278 DependencyGraph.try_emplace(C.Ins);
1279 addDependency(C, C.Basis);
1280 }
1281 sortCandidateInstructions();
1282
1283 // Rewrite candidates in the topological order that rewrites a Candidate
1284 // always before rewriting its Basis
1285 for (Instruction *I : reverse(SortedCandidateInsts))
1286 if (Candidate *C = pickRewriteCandidate(I))
1287 rewriteCandidate(*C);
1288
1289 for (auto *DeadIns : DeadInstructions)
1290 // A dead instruction may be another dead instruction's op,
1291 // don't delete an instruction twice
1292 if (DeadIns->getParent())
1294
1295 bool Ret = !DeadInstructions.empty();
1296 DeadInstructions.clear();
1297 DependencyGraph.clear();
1298 RewriteCandidates.clear();
1299 SortedCandidateInsts.clear();
1300 // First clear all references to candidates in the list
1301 CandidateDict.clear();
1302 // Then destroy the list
1303 Candidates.clear();
1304 return Ret;
1305}
1306
1307PreservedAnalyses
1309 const DataLayout *DL = &F.getDataLayout();
1310 auto *DT = &AM.getResult<DominatorTreeAnalysis>(F);
1311 auto *SE = &AM.getResult<ScalarEvolutionAnalysis>(F);
1312 auto *TTI = &AM.getResult<TargetIRAnalysis>(F);
1313
1314 if (!StraightLineStrengthReduce(DL, DT, SE, TTI).runOnFunction(F))
1315 return PreservedAnalyses::all();
1316
1322 return PA;
1323}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
This file implements a class to represent arbitrary precision integral constant values and operations...
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
#define LLVM_DUMP_METHOD
Mark debug helper function definitions like dump() that should not be stripped from debug builds.
Definition Compiler.h:638
This file contains the declarations for the subclasses of Constant, which represent the different fla...
This file provides an implementation of debug counters.
#define DEBUG_COUNTER(VARNAME, COUNTERNAME, DESC)
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
static bool runOnFunction(Function &F, bool PostInlining)
Hexagon Common GEP
Module.h This file contains the declarations for the Module class.
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Definition Lint.cpp:539
#define F(x, y, z)
Definition MD5.cpp:54
#define I(x, y, z)
Definition MD5.cpp:57
Machine Check Debug Module
static bool isGEPFoldable(GetElementPtrInst *GEP, const TargetTransformInfo *TTI)
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition PassSupport.h:42
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition PassSupport.h:39
static BinaryOperator * CreateMul(Value *S1, Value *S2, const Twine &Name, BasicBlock::iterator InsertBefore, Value *FlagsOp)
This file implements a set that has insertion order iteration characteristics.
This file defines the SmallVector class.
static bool matchesOr(Value *A, Value *&B, ConstantInt *&C)
static bool isAddFoldable(const SCEV *Base, ConstantInt *Index, Value *Stride, TargetTransformInfo *TTI)
static void unifyBitWidth(APInt &A, APInt &B)
static bool matchesAdd(Value *A, Value *&B, ConstantInt *&C)
static const unsigned UnknownAddressSpace
#define LLVM_DEBUG(...)
Definition Debug.h:114
This pass exposes codegen information to IR-level passes.
Value * RHS
Value * LHS
Class for arbitrary precision integers.
Definition APInt.h:78
bool isNegatedPowerOf2() const
Check if this APInt's negated value is a power of two greater than zero.
Definition APInt.h:450
bool isAllOnes() const
Determine if all bits are set. This is true for zero-width values.
Definition APInt.h:372
unsigned getBitWidth() const
Return the number of bits in the APInt.
Definition APInt.h:1489
unsigned logBase2() const
Definition APInt.h:1762
bool isPowerOf2() const
Check if this APInt's value is a power of two greater than zero.
Definition APInt.h:441
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
AnalysisUsage & addRequired()
LLVM_ABI void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition Pass.cpp:270
const Function * getParent() const
Return the enclosing method, or null if none.
Definition BasicBlock.h:213
Represents analyses that only rely on functions' control flow.
Definition Analysis.h:73
This is the shared class of boolean and integer constants.
Definition Constants.h:87
bool isOne() const
This is just a convenience method to make client code smaller for a common case.
Definition Constants.h:220
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
Definition Constants.h:214
unsigned getBitWidth() const
getBitWidth - Return the scalar bitwidth of this constant.
Definition Constants.h:157
const APInt & getValue() const
Return the constant as an APInt value reference.
Definition Constants.h:154
This is an important base class in LLVM.
Definition Constant.h:43
A parsed version of the target data layout string in and methods for querying it.
Definition DataLayout.h:63
static bool shouldExecute(unsigned CounterName)
iterator find(const_arg_type_t< KeyT > Val)
Definition DenseMap.h:178
std::pair< iterator, bool > try_emplace(KeyT &&Key, Ts &&...Args)
Definition DenseMap.h:256
iterator end()
Definition DenseMap.h:81
Analysis pass which computes a DominatorTree.
Definition Dominators.h:283
DomTreeNodeBase< NodeT > * getNode(const NodeT *BB) const
getNode - return the (Post)DominatorTree node for the specified basic block.
Legacy analysis pass which computes a DominatorTree.
Definition Dominators.h:321
LLVM_ABI bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
FunctionPass class - This class is used to implement most global optimizations.
Definition Pass.h:314
an instruction for type-safe pointer arithmetic to access elements of arrays and structs
Value * CreatePtrAdd(Value *Ptr, Value *Offset, const Twine &Name="", GEPNoWrapFlags NW=GEPNoWrapFlags::none())
Definition IRBuilder.h:2039
Value * CreateNeg(Value *V, const Twine &Name="", bool HasNSW=false)
Definition IRBuilder.h:1784
Value * CreateSub(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Definition IRBuilder.h:1420
Value * CreateShl(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Definition IRBuilder.h:1492
Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Definition IRBuilder.h:1403
Value * CreateSExtOrTrunc(Value *V, Type *DestTy, const Twine &Name="")
Create a SExt or Trunc from the integer value V to DestTy.
Definition IRBuilder.h:2118
Value * CreateMul(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Definition IRBuilder.h:1437
static LLVM_ABI IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition Type.cpp:318
static LLVM_ABI PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalyses & preserveSet()
Mark an analysis set as preserved.
Definition Analysis.h:151
PreservedAnalyses & preserve()
Mark an analysis as preserved.
Definition Analysis.h:132
This class represents an analyzed expression in the program.
LLVM_ABI Type * getType() const
Return the LLVM type of this SCEV expression.
Analysis pass that exposes the ScalarEvolution for a function.
const SCEV * getZero(Type *Ty)
Return a SCEV for the constant 0 of a specific type.
LLVM_ABI const SCEV * getSCEV(Value *V)
Return a SCEV expression for the full generality of the specified expression.
LLVM_ABI const SCEV * getMinusSCEV(const SCEV *LHS, const SCEV *RHS, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Return LHS-RHS.
LLVM_ABI const SCEV * getGEPExpr(GEPOperator *GEP, ArrayRef< const SCEV * > IndexExprs)
Returns an expression for a GEP.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
Analysis pass providing the TargetTransformInfo.
Wrapper pass for TargetTransformInfo.
This pass provides access to the codegen interfaces that are needed for IR-level transformations.
@ TCC_Free
Expected to fold away in lowering.
LLVM_ABI unsigned getIntegerBitWidth() const
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:256
LLVM_ABI LLVMContext & getContext() const
All values hold a context through their type.
Definition Value.cpp:1099
LLVM_ABI void takeName(Value *V)
Transfer the name from V to this value.
Definition Value.cpp:396
std::pair< iterator, bool > insert(const ValueT &V)
Definition DenseSet.h:202
TypeSize getSequentialElementStride(const DataLayout &DL) const
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition raw_ostream.h:53
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
@ BasicBlock
Various leaf nodes.
Definition ISDOpcodes.h:81
BinaryOp_match< SpecificConstantMatch, SrcTy, TargetOpcode::G_SUB > m_Neg(const SrcTy &&Src)
Matches a register negated by a G_SUB.
bool match(Val *V, const Pattern &P)
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Add, true > m_c_Add(const LHS &L, const RHS &R)
Matches a Add with LHS and RHS in either order.
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
BinaryOp_match< LHS, RHS, Instruction::Shl > m_Shl(const LHS &L, const RHS &R)
CastInst_match< OpTy, SExtInst > m_SExt(const OpTy &Op)
Matches SExt.
BinaryOp_match< LHS, RHS, Instruction::Or, true > m_c_Or(const LHS &L, const RHS &R)
Matches an Or with LHS and RHS in either order.
NodeAddr< NodeBase * > Node
Definition RDFGraph.h:381
friend class Instruction
Iterator for Instructions in a `BasicBlock.
Definition BasicBlock.h:73
This is an optimization pass for GlobalISel generic memory operations.
LLVM_ABI bool haveNoCommonBitsSet(const WithCache< const Value * > &LHSCache, const WithCache< const Value * > &RHSCache, const SimplifyQuery &SQ)
Return true if LHS and RHS have no common bits set.
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
LLVM_ABI bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
If the specified value is a trivially dead instruction, delete it.
Definition Local.cpp:533
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
LLVM_ABI void initializeStraightLineStrengthReduceLegacyPassPass(PassRegistry &)
DomTreeNodeBase< BasicBlock > DomTreeNode
Definition Dominators.h:94
auto dyn_cast_or_null(const Y &Val)
Definition Casting.h:753
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1732
auto reverse(ContainerTy &&C)
Definition STLExtras.h:406
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:207
generic_gep_type_iterator<> gep_type_iterator
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:547
TargetTransformInfo TTI
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
FunctionAddr VTableAddr Next
Definition InstrProf.h:141
raw_ostream & operator<<(raw_ostream &OS, const APFixedPoint &FX)
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
gep_type_iterator gep_type_begin(const User *GEP)
PointerUnion< const Value *, const PseudoSourceValue * > ValueType
iterator_range< df_iterator< T > > depth_first(const T &G)
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
LLVM_ABI FunctionPass * createStraightLineStrengthReducePass()