File: | lib/Analysis/VectorUtils.cpp |
Warning: | line 996, column 11 Called C++ object pointer is null |
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
1 | //===----------- VectorUtils.cpp - Vectorizer utility functions -----------===// | |||
2 | // | |||
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | |||
4 | // See https://llvm.org/LICENSE.txt for license information. | |||
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |||
6 | // | |||
7 | //===----------------------------------------------------------------------===// | |||
8 | // | |||
9 | // This file defines vectorizer utilities. | |||
10 | // | |||
11 | //===----------------------------------------------------------------------===// | |||
12 | ||||
13 | #include "llvm/Analysis/VectorUtils.h" | |||
14 | #include "llvm/ADT/EquivalenceClasses.h" | |||
15 | #include "llvm/Analysis/DemandedBits.h" | |||
16 | #include "llvm/Analysis/LoopInfo.h" | |||
17 | #include "llvm/Analysis/LoopIterator.h" | |||
18 | #include "llvm/Analysis/ScalarEvolution.h" | |||
19 | #include "llvm/Analysis/ScalarEvolutionExpressions.h" | |||
20 | #include "llvm/Analysis/TargetTransformInfo.h" | |||
21 | #include "llvm/Analysis/ValueTracking.h" | |||
22 | #include "llvm/IR/Constants.h" | |||
23 | #include "llvm/IR/GetElementPtrTypeIterator.h" | |||
24 | #include "llvm/IR/IRBuilder.h" | |||
25 | #include "llvm/IR/PatternMatch.h" | |||
26 | #include "llvm/IR/Value.h" | |||
27 | ||||
28 | #define DEBUG_TYPE"vectorutils" "vectorutils" | |||
29 | ||||
30 | using namespace llvm; | |||
31 | using namespace llvm::PatternMatch; | |||
32 | ||||
33 | /// Maximum factor for an interleaved memory access. | |||
34 | static cl::opt<unsigned> MaxInterleaveGroupFactor( | |||
35 | "max-interleave-group-factor", cl::Hidden, | |||
36 | cl::desc("Maximum factor for an interleaved access group (default = 8)"), | |||
37 | cl::init(8)); | |||
38 | ||||
39 | /// Return true if all of the intrinsic's arguments and return type are scalars | |||
40 | /// for the scalar form of the intrinsic and vectors for the vector form of the | |||
41 | /// intrinsic. | |||
42 | bool llvm::isTriviallyVectorizable(Intrinsic::ID ID) { | |||
43 | switch (ID) { | |||
44 | case Intrinsic::bswap: // Begin integer bit-manipulation. | |||
45 | case Intrinsic::bitreverse: | |||
46 | case Intrinsic::ctpop: | |||
47 | case Intrinsic::ctlz: | |||
48 | case Intrinsic::cttz: | |||
49 | case Intrinsic::fshl: | |||
50 | case Intrinsic::fshr: | |||
51 | case Intrinsic::sadd_sat: | |||
52 | case Intrinsic::ssub_sat: | |||
53 | case Intrinsic::uadd_sat: | |||
54 | case Intrinsic::usub_sat: | |||
55 | case Intrinsic::smul_fix: | |||
56 | case Intrinsic::umul_fix: | |||
57 | case Intrinsic::sqrt: // Begin floating-point. | |||
58 | case Intrinsic::sin: | |||
59 | case Intrinsic::cos: | |||
60 | case Intrinsic::exp: | |||
61 | case Intrinsic::exp2: | |||
62 | case Intrinsic::log: | |||
63 | case Intrinsic::log10: | |||
64 | case Intrinsic::log2: | |||
65 | case Intrinsic::fabs: | |||
66 | case Intrinsic::minnum: | |||
67 | case Intrinsic::maxnum: | |||
68 | case Intrinsic::minimum: | |||
69 | case Intrinsic::maximum: | |||
70 | case Intrinsic::copysign: | |||
71 | case Intrinsic::floor: | |||
72 | case Intrinsic::ceil: | |||
73 | case Intrinsic::trunc: | |||
74 | case Intrinsic::rint: | |||
75 | case Intrinsic::nearbyint: | |||
76 | case Intrinsic::round: | |||
77 | case Intrinsic::pow: | |||
78 | case Intrinsic::fma: | |||
79 | case Intrinsic::fmuladd: | |||
80 | case Intrinsic::powi: | |||
81 | case Intrinsic::canonicalize: | |||
82 | return true; | |||
83 | default: | |||
84 | return false; | |||
85 | } | |||
86 | } | |||
87 | ||||
88 | /// Identifies if the intrinsic has a scalar operand. It check for | |||
89 | /// ctlz,cttz and powi special intrinsics whose argument is scalar. | |||
90 | bool llvm::hasVectorInstrinsicScalarOpd(Intrinsic::ID ID, | |||
91 | unsigned ScalarOpdIdx) { | |||
92 | switch (ID) { | |||
93 | case Intrinsic::ctlz: | |||
94 | case Intrinsic::cttz: | |||
95 | case Intrinsic::powi: | |||
96 | return (ScalarOpdIdx == 1); | |||
97 | case Intrinsic::smul_fix: | |||
98 | case Intrinsic::umul_fix: | |||
99 | return (ScalarOpdIdx == 2); | |||
100 | default: | |||
101 | return false; | |||
102 | } | |||
103 | } | |||
104 | ||||
105 | /// Returns intrinsic ID for call. | |||
106 | /// For the input call instruction it finds mapping intrinsic and returns | |||
107 | /// its ID, in case it does not found it return not_intrinsic. | |||
108 | Intrinsic::ID llvm::getVectorIntrinsicIDForCall(const CallInst *CI, | |||
109 | const TargetLibraryInfo *TLI) { | |||
110 | Intrinsic::ID ID = getIntrinsicForCallSite(CI, TLI); | |||
111 | if (ID == Intrinsic::not_intrinsic) | |||
112 | return Intrinsic::not_intrinsic; | |||
113 | ||||
114 | if (isTriviallyVectorizable(ID) || ID == Intrinsic::lifetime_start || | |||
115 | ID == Intrinsic::lifetime_end || ID == Intrinsic::assume || | |||
116 | ID == Intrinsic::sideeffect) | |||
117 | return ID; | |||
118 | return Intrinsic::not_intrinsic; | |||
119 | } | |||
120 | ||||
121 | /// Find the operand of the GEP that should be checked for consecutive | |||
122 | /// stores. This ignores trailing indices that have no effect on the final | |||
123 | /// pointer. | |||
124 | unsigned llvm::getGEPInductionOperand(const GetElementPtrInst *Gep) { | |||
125 | const DataLayout &DL = Gep->getModule()->getDataLayout(); | |||
126 | unsigned LastOperand = Gep->getNumOperands() - 1; | |||
127 | unsigned GEPAllocSize = DL.getTypeAllocSize(Gep->getResultElementType()); | |||
128 | ||||
129 | // Walk backwards and try to peel off zeros. | |||
130 | while (LastOperand > 1 && match(Gep->getOperand(LastOperand), m_Zero())) { | |||
131 | // Find the type we're currently indexing into. | |||
132 | gep_type_iterator GEPTI = gep_type_begin(Gep); | |||
133 | std::advance(GEPTI, LastOperand - 2); | |||
134 | ||||
135 | // If it's a type with the same allocation size as the result of the GEP we | |||
136 | // can peel off the zero index. | |||
137 | if (DL.getTypeAllocSize(GEPTI.getIndexedType()) != GEPAllocSize) | |||
138 | break; | |||
139 | --LastOperand; | |||
140 | } | |||
141 | ||||
142 | return LastOperand; | |||
143 | } | |||
144 | ||||
145 | /// If the argument is a GEP, then returns the operand identified by | |||
146 | /// getGEPInductionOperand. However, if there is some other non-loop-invariant | |||
147 | /// operand, it returns that instead. | |||
148 | Value *llvm::stripGetElementPtr(Value *Ptr, ScalarEvolution *SE, Loop *Lp) { | |||
149 | GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr); | |||
150 | if (!GEP) | |||
151 | return Ptr; | |||
152 | ||||
153 | unsigned InductionOperand = getGEPInductionOperand(GEP); | |||
154 | ||||
155 | // Check that all of the gep indices are uniform except for our induction | |||
156 | // operand. | |||
157 | for (unsigned i = 0, e = GEP->getNumOperands(); i != e; ++i) | |||
158 | if (i != InductionOperand && | |||
159 | !SE->isLoopInvariant(SE->getSCEV(GEP->getOperand(i)), Lp)) | |||
160 | return Ptr; | |||
161 | return GEP->getOperand(InductionOperand); | |||
162 | } | |||
163 | ||||
164 | /// If a value has only one user that is a CastInst, return it. | |||
165 | Value *llvm::getUniqueCastUse(Value *Ptr, Loop *Lp, Type *Ty) { | |||
166 | Value *UniqueCast = nullptr; | |||
167 | for (User *U : Ptr->users()) { | |||
168 | CastInst *CI = dyn_cast<CastInst>(U); | |||
169 | if (CI && CI->getType() == Ty) { | |||
170 | if (!UniqueCast) | |||
171 | UniqueCast = CI; | |||
172 | else | |||
173 | return nullptr; | |||
174 | } | |||
175 | } | |||
176 | return UniqueCast; | |||
177 | } | |||
178 | ||||
179 | /// Get the stride of a pointer access in a loop. Looks for symbolic | |||
180 | /// strides "a[i*stride]". Returns the symbolic stride, or null otherwise. | |||
181 | Value *llvm::getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *Lp) { | |||
182 | auto *PtrTy = dyn_cast<PointerType>(Ptr->getType()); | |||
183 | if (!PtrTy || PtrTy->isAggregateType()) | |||
184 | return nullptr; | |||
185 | ||||
186 | // Try to remove a gep instruction to make the pointer (actually index at this | |||
187 | // point) easier analyzable. If OrigPtr is equal to Ptr we are analyzing the | |||
188 | // pointer, otherwise, we are analyzing the index. | |||
189 | Value *OrigPtr = Ptr; | |||
190 | ||||
191 | // The size of the pointer access. | |||
192 | int64_t PtrAccessSize = 1; | |||
193 | ||||
194 | Ptr = stripGetElementPtr(Ptr, SE, Lp); | |||
195 | const SCEV *V = SE->getSCEV(Ptr); | |||
196 | ||||
197 | if (Ptr != OrigPtr) | |||
198 | // Strip off casts. | |||
199 | while (const SCEVCastExpr *C = dyn_cast<SCEVCastExpr>(V)) | |||
200 | V = C->getOperand(); | |||
201 | ||||
202 | const SCEVAddRecExpr *S = dyn_cast<SCEVAddRecExpr>(V); | |||
203 | if (!S) | |||
204 | return nullptr; | |||
205 | ||||
206 | V = S->getStepRecurrence(*SE); | |||
207 | if (!V) | |||
208 | return nullptr; | |||
209 | ||||
210 | // Strip off the size of access multiplication if we are still analyzing the | |||
211 | // pointer. | |||
212 | if (OrigPtr == Ptr) { | |||
213 | if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(V)) { | |||
214 | if (M->getOperand(0)->getSCEVType() != scConstant) | |||
215 | return nullptr; | |||
216 | ||||
217 | const APInt &APStepVal = cast<SCEVConstant>(M->getOperand(0))->getAPInt(); | |||
218 | ||||
219 | // Huge step value - give up. | |||
220 | if (APStepVal.getBitWidth() > 64) | |||
221 | return nullptr; | |||
222 | ||||
223 | int64_t StepVal = APStepVal.getSExtValue(); | |||
224 | if (PtrAccessSize != StepVal) | |||
225 | return nullptr; | |||
226 | V = M->getOperand(1); | |||
227 | } | |||
228 | } | |||
229 | ||||
230 | // Strip off casts. | |||
231 | Type *StripedOffRecurrenceCast = nullptr; | |||
232 | if (const SCEVCastExpr *C = dyn_cast<SCEVCastExpr>(V)) { | |||
233 | StripedOffRecurrenceCast = C->getType(); | |||
234 | V = C->getOperand(); | |||
235 | } | |||
236 | ||||
237 | // Look for the loop invariant symbolic value. | |||
238 | const SCEVUnknown *U = dyn_cast<SCEVUnknown>(V); | |||
239 | if (!U) | |||
240 | return nullptr; | |||
241 | ||||
242 | Value *Stride = U->getValue(); | |||
243 | if (!Lp->isLoopInvariant(Stride)) | |||
244 | return nullptr; | |||
245 | ||||
246 | // If we have stripped off the recurrence cast we have to make sure that we | |||
247 | // return the value that is used in this loop so that we can replace it later. | |||
248 | if (StripedOffRecurrenceCast) | |||
249 | Stride = getUniqueCastUse(Stride, Lp, StripedOffRecurrenceCast); | |||
250 | ||||
251 | return Stride; | |||
252 | } | |||
253 | ||||
254 | /// Given a vector and an element number, see if the scalar value is | |||
255 | /// already around as a register, for example if it were inserted then extracted | |||
256 | /// from the vector. | |||
257 | Value *llvm::findScalarElement(Value *V, unsigned EltNo) { | |||
258 | assert(V->getType()->isVectorTy() && "Not looking at a vector?")((V->getType()->isVectorTy() && "Not looking at a vector?" ) ? static_cast<void> (0) : __assert_fail ("V->getType()->isVectorTy() && \"Not looking at a vector?\"" , "/build/llvm-toolchain-snapshot-9~svn362543/lib/Analysis/VectorUtils.cpp" , 258, __PRETTY_FUNCTION__)); | |||
259 | VectorType *VTy = cast<VectorType>(V->getType()); | |||
260 | unsigned Width = VTy->getNumElements(); | |||
261 | if (EltNo >= Width) // Out of range access. | |||
262 | return UndefValue::get(VTy->getElementType()); | |||
263 | ||||
264 | if (Constant *C = dyn_cast<Constant>(V)) | |||
265 | return C->getAggregateElement(EltNo); | |||
266 | ||||
267 | if (InsertElementInst *III = dyn_cast<InsertElementInst>(V)) { | |||
268 | // If this is an insert to a variable element, we don't know what it is. | |||
269 | if (!isa<ConstantInt>(III->getOperand(2))) | |||
270 | return nullptr; | |||
271 | unsigned IIElt = cast<ConstantInt>(III->getOperand(2))->getZExtValue(); | |||
272 | ||||
273 | // If this is an insert to the element we are looking for, return the | |||
274 | // inserted value. | |||
275 | if (EltNo == IIElt) | |||
276 | return III->getOperand(1); | |||
277 | ||||
278 | // Otherwise, the insertelement doesn't modify the value, recurse on its | |||
279 | // vector input. | |||
280 | return findScalarElement(III->getOperand(0), EltNo); | |||
281 | } | |||
282 | ||||
283 | if (ShuffleVectorInst *SVI = dyn_cast<ShuffleVectorInst>(V)) { | |||
284 | unsigned LHSWidth = SVI->getOperand(0)->getType()->getVectorNumElements(); | |||
285 | int InEl = SVI->getMaskValue(EltNo); | |||
286 | if (InEl < 0) | |||
287 | return UndefValue::get(VTy->getElementType()); | |||
288 | if (InEl < (int)LHSWidth) | |||
289 | return findScalarElement(SVI->getOperand(0), InEl); | |||
290 | return findScalarElement(SVI->getOperand(1), InEl - LHSWidth); | |||
291 | } | |||
292 | ||||
293 | // Extract a value from a vector add operation with a constant zero. | |||
294 | // TODO: Use getBinOpIdentity() to generalize this. | |||
295 | Value *Val; Constant *C; | |||
296 | if (match(V, m_Add(m_Value(Val), m_Constant(C)))) | |||
297 | if (Constant *Elt = C->getAggregateElement(EltNo)) | |||
298 | if (Elt->isNullValue()) | |||
299 | return findScalarElement(Val, EltNo); | |||
300 | ||||
301 | // Otherwise, we don't know. | |||
302 | return nullptr; | |||
303 | } | |||
304 | ||||
305 | /// Get splat value if the input is a splat vector or return nullptr. | |||
306 | /// This function is not fully general. It checks only 2 cases: | |||
307 | /// the input value is (1) a splat constants vector or (2) a sequence | |||
308 | /// of instructions that broadcast a single value into a vector. | |||
309 | /// | |||
310 | const llvm::Value *llvm::getSplatValue(const Value *V) { | |||
311 | ||||
312 | if (auto *C = dyn_cast<Constant>(V)) | |||
313 | if (isa<VectorType>(V->getType())) | |||
314 | return C->getSplatValue(); | |||
315 | ||||
316 | auto *ShuffleInst = dyn_cast<ShuffleVectorInst>(V); | |||
317 | if (!ShuffleInst) | |||
318 | return nullptr; | |||
319 | // All-zero (or undef) shuffle mask elements. | |||
320 | for (int MaskElt : ShuffleInst->getShuffleMask()) | |||
321 | if (MaskElt != 0 && MaskElt != -1) | |||
322 | return nullptr; | |||
323 | // The first shuffle source is 'insertelement' with index 0. | |||
324 | auto *InsertEltInst = | |||
325 | dyn_cast<InsertElementInst>(ShuffleInst->getOperand(0)); | |||
326 | if (!InsertEltInst || !isa<ConstantInt>(InsertEltInst->getOperand(2)) || | |||
327 | !cast<ConstantInt>(InsertEltInst->getOperand(2))->isZero()) | |||
328 | return nullptr; | |||
329 | ||||
330 | return InsertEltInst->getOperand(1); | |||
331 | } | |||
332 | ||||
333 | MapVector<Instruction *, uint64_t> | |||
334 | llvm::computeMinimumValueSizes(ArrayRef<BasicBlock *> Blocks, DemandedBits &DB, | |||
335 | const TargetTransformInfo *TTI) { | |||
336 | ||||
337 | // DemandedBits will give us every value's live-out bits. But we want | |||
338 | // to ensure no extra casts would need to be inserted, so every DAG | |||
339 | // of connected values must have the same minimum bitwidth. | |||
340 | EquivalenceClasses<Value *> ECs; | |||
341 | SmallVector<Value *, 16> Worklist; | |||
342 | SmallPtrSet<Value *, 4> Roots; | |||
343 | SmallPtrSet<Value *, 16> Visited; | |||
344 | DenseMap<Value *, uint64_t> DBits; | |||
345 | SmallPtrSet<Instruction *, 4> InstructionSet; | |||
346 | MapVector<Instruction *, uint64_t> MinBWs; | |||
347 | ||||
348 | // Determine the roots. We work bottom-up, from truncs or icmps. | |||
349 | bool SeenExtFromIllegalType = false; | |||
350 | for (auto *BB : Blocks) | |||
351 | for (auto &I : *BB) { | |||
352 | InstructionSet.insert(&I); | |||
353 | ||||
354 | if (TTI && (isa<ZExtInst>(&I) || isa<SExtInst>(&I)) && | |||
355 | !TTI->isTypeLegal(I.getOperand(0)->getType())) | |||
356 | SeenExtFromIllegalType = true; | |||
357 | ||||
358 | // Only deal with non-vector integers up to 64-bits wide. | |||
359 | if ((isa<TruncInst>(&I) || isa<ICmpInst>(&I)) && | |||
360 | !I.getType()->isVectorTy() && | |||
361 | I.getOperand(0)->getType()->getScalarSizeInBits() <= 64) { | |||
362 | // Don't make work for ourselves. If we know the loaded type is legal, | |||
363 | // don't add it to the worklist. | |||
364 | if (TTI && isa<TruncInst>(&I) && TTI->isTypeLegal(I.getType())) | |||
365 | continue; | |||
366 | ||||
367 | Worklist.push_back(&I); | |||
368 | Roots.insert(&I); | |||
369 | } | |||
370 | } | |||
371 | // Early exit. | |||
372 | if (Worklist.empty() || (TTI && !SeenExtFromIllegalType)) | |||
373 | return MinBWs; | |||
374 | ||||
375 | // Now proceed breadth-first, unioning values together. | |||
376 | while (!Worklist.empty()) { | |||
377 | Value *Val = Worklist.pop_back_val(); | |||
378 | Value *Leader = ECs.getOrInsertLeaderValue(Val); | |||
379 | ||||
380 | if (Visited.count(Val)) | |||
381 | continue; | |||
382 | Visited.insert(Val); | |||
383 | ||||
384 | // Non-instructions terminate a chain successfully. | |||
385 | if (!isa<Instruction>(Val)) | |||
386 | continue; | |||
387 | Instruction *I = cast<Instruction>(Val); | |||
388 | ||||
389 | // If we encounter a type that is larger than 64 bits, we can't represent | |||
390 | // it so bail out. | |||
391 | if (DB.getDemandedBits(I).getBitWidth() > 64) | |||
392 | return MapVector<Instruction *, uint64_t>(); | |||
393 | ||||
394 | uint64_t V = DB.getDemandedBits(I).getZExtValue(); | |||
395 | DBits[Leader] |= V; | |||
396 | DBits[I] = V; | |||
397 | ||||
398 | // Casts, loads and instructions outside of our range terminate a chain | |||
399 | // successfully. | |||
400 | if (isa<SExtInst>(I) || isa<ZExtInst>(I) || isa<LoadInst>(I) || | |||
401 | !InstructionSet.count(I)) | |||
402 | continue; | |||
403 | ||||
404 | // Unsafe casts terminate a chain unsuccessfully. We can't do anything | |||
405 | // useful with bitcasts, ptrtoints or inttoptrs and it'd be unsafe to | |||
406 | // transform anything that relies on them. | |||
407 | if (isa<BitCastInst>(I) || isa<PtrToIntInst>(I) || isa<IntToPtrInst>(I) || | |||
408 | !I->getType()->isIntegerTy()) { | |||
409 | DBits[Leader] |= ~0ULL; | |||
410 | continue; | |||
411 | } | |||
412 | ||||
413 | // We don't modify the types of PHIs. Reductions will already have been | |||
414 | // truncated if possible, and inductions' sizes will have been chosen by | |||
415 | // indvars. | |||
416 | if (isa<PHINode>(I)) | |||
417 | continue; | |||
418 | ||||
419 | if (DBits[Leader] == ~0ULL) | |||
420 | // All bits demanded, no point continuing. | |||
421 | continue; | |||
422 | ||||
423 | for (Value *O : cast<User>(I)->operands()) { | |||
424 | ECs.unionSets(Leader, O); | |||
425 | Worklist.push_back(O); | |||
426 | } | |||
427 | } | |||
428 | ||||
429 | // Now we've discovered all values, walk them to see if there are | |||
430 | // any users we didn't see. If there are, we can't optimize that | |||
431 | // chain. | |||
432 | for (auto &I : DBits) | |||
433 | for (auto *U : I.first->users()) | |||
434 | if (U->getType()->isIntegerTy() && DBits.count(U) == 0) | |||
435 | DBits[ECs.getOrInsertLeaderValue(I.first)] |= ~0ULL; | |||
436 | ||||
437 | for (auto I = ECs.begin(), E = ECs.end(); I != E; ++I) { | |||
438 | uint64_t LeaderDemandedBits = 0; | |||
439 | for (auto MI = ECs.member_begin(I), ME = ECs.member_end(); MI != ME; ++MI) | |||
440 | LeaderDemandedBits |= DBits[*MI]; | |||
441 | ||||
442 | uint64_t MinBW = (sizeof(LeaderDemandedBits) * 8) - | |||
443 | llvm::countLeadingZeros(LeaderDemandedBits); | |||
444 | // Round up to a power of 2 | |||
445 | if (!isPowerOf2_64((uint64_t)MinBW)) | |||
446 | MinBW = NextPowerOf2(MinBW); | |||
447 | ||||
448 | // We don't modify the types of PHIs. Reductions will already have been | |||
449 | // truncated if possible, and inductions' sizes will have been chosen by | |||
450 | // indvars. | |||
451 | // If we are required to shrink a PHI, abandon this entire equivalence class. | |||
452 | bool Abort = false; | |||
453 | for (auto MI = ECs.member_begin(I), ME = ECs.member_end(); MI != ME; ++MI) | |||
454 | if (isa<PHINode>(*MI) && MinBW < (*MI)->getType()->getScalarSizeInBits()) { | |||
455 | Abort = true; | |||
456 | break; | |||
457 | } | |||
458 | if (Abort) | |||
459 | continue; | |||
460 | ||||
461 | for (auto MI = ECs.member_begin(I), ME = ECs.member_end(); MI != ME; ++MI) { | |||
462 | if (!isa<Instruction>(*MI)) | |||
463 | continue; | |||
464 | Type *Ty = (*MI)->getType(); | |||
465 | if (Roots.count(*MI)) | |||
466 | Ty = cast<Instruction>(*MI)->getOperand(0)->getType(); | |||
467 | if (MinBW < Ty->getScalarSizeInBits()) | |||
468 | MinBWs[cast<Instruction>(*MI)] = MinBW; | |||
469 | } | |||
470 | } | |||
471 | ||||
472 | return MinBWs; | |||
473 | } | |||
474 | ||||
475 | /// Add all access groups in @p AccGroups to @p List. | |||
476 | template <typename ListT> | |||
477 | static void addToAccessGroupList(ListT &List, MDNode *AccGroups) { | |||
478 | // Interpret an access group as a list containing itself. | |||
479 | if (AccGroups->getNumOperands() == 0) { | |||
480 | assert(isValidAsAccessGroup(AccGroups) && "Node must be an access group")((isValidAsAccessGroup(AccGroups) && "Node must be an access group" ) ? static_cast<void> (0) : __assert_fail ("isValidAsAccessGroup(AccGroups) && \"Node must be an access group\"" , "/build/llvm-toolchain-snapshot-9~svn362543/lib/Analysis/VectorUtils.cpp" , 480, __PRETTY_FUNCTION__)); | |||
481 | List.insert(AccGroups); | |||
482 | return; | |||
483 | } | |||
484 | ||||
485 | for (auto &AccGroupListOp : AccGroups->operands()) { | |||
486 | auto *Item = cast<MDNode>(AccGroupListOp.get()); | |||
487 | assert(isValidAsAccessGroup(Item) && "List item must be an access group")((isValidAsAccessGroup(Item) && "List item must be an access group" ) ? static_cast<void> (0) : __assert_fail ("isValidAsAccessGroup(Item) && \"List item must be an access group\"" , "/build/llvm-toolchain-snapshot-9~svn362543/lib/Analysis/VectorUtils.cpp" , 487, __PRETTY_FUNCTION__)); | |||
488 | List.insert(Item); | |||
489 | } | |||
490 | } | |||
491 | ||||
492 | MDNode *llvm::uniteAccessGroups(MDNode *AccGroups1, MDNode *AccGroups2) { | |||
493 | if (!AccGroups1) | |||
494 | return AccGroups2; | |||
495 | if (!AccGroups2) | |||
496 | return AccGroups1; | |||
497 | if (AccGroups1 == AccGroups2) | |||
498 | return AccGroups1; | |||
499 | ||||
500 | SmallSetVector<Metadata *, 4> Union; | |||
501 | addToAccessGroupList(Union, AccGroups1); | |||
502 | addToAccessGroupList(Union, AccGroups2); | |||
503 | ||||
504 | if (Union.size() == 0) | |||
505 | return nullptr; | |||
506 | if (Union.size() == 1) | |||
507 | return cast<MDNode>(Union.front()); | |||
508 | ||||
509 | LLVMContext &Ctx = AccGroups1->getContext(); | |||
510 | return MDNode::get(Ctx, Union.getArrayRef()); | |||
511 | } | |||
512 | ||||
513 | MDNode *llvm::intersectAccessGroups(const Instruction *Inst1, | |||
514 | const Instruction *Inst2) { | |||
515 | bool MayAccessMem1 = Inst1->mayReadOrWriteMemory(); | |||
516 | bool MayAccessMem2 = Inst2->mayReadOrWriteMemory(); | |||
517 | ||||
518 | if (!MayAccessMem1 && !MayAccessMem2) | |||
519 | return nullptr; | |||
520 | if (!MayAccessMem1) | |||
521 | return Inst2->getMetadata(LLVMContext::MD_access_group); | |||
522 | if (!MayAccessMem2) | |||
523 | return Inst1->getMetadata(LLVMContext::MD_access_group); | |||
524 | ||||
525 | MDNode *MD1 = Inst1->getMetadata(LLVMContext::MD_access_group); | |||
526 | MDNode *MD2 = Inst2->getMetadata(LLVMContext::MD_access_group); | |||
527 | if (!MD1 || !MD2) | |||
528 | return nullptr; | |||
529 | if (MD1 == MD2) | |||
530 | return MD1; | |||
531 | ||||
532 | // Use set for scalable 'contains' check. | |||
533 | SmallPtrSet<Metadata *, 4> AccGroupSet2; | |||
534 | addToAccessGroupList(AccGroupSet2, MD2); | |||
535 | ||||
536 | SmallVector<Metadata *, 4> Intersection; | |||
537 | if (MD1->getNumOperands() == 0) { | |||
538 | assert(isValidAsAccessGroup(MD1) && "Node must be an access group")((isValidAsAccessGroup(MD1) && "Node must be an access group" ) ? static_cast<void> (0) : __assert_fail ("isValidAsAccessGroup(MD1) && \"Node must be an access group\"" , "/build/llvm-toolchain-snapshot-9~svn362543/lib/Analysis/VectorUtils.cpp" , 538, __PRETTY_FUNCTION__)); | |||
539 | if (AccGroupSet2.count(MD1)) | |||
540 | Intersection.push_back(MD1); | |||
541 | } else { | |||
542 | for (const MDOperand &Node : MD1->operands()) { | |||
543 | auto *Item = cast<MDNode>(Node.get()); | |||
544 | assert(isValidAsAccessGroup(Item) && "List item must be an access group")((isValidAsAccessGroup(Item) && "List item must be an access group" ) ? static_cast<void> (0) : __assert_fail ("isValidAsAccessGroup(Item) && \"List item must be an access group\"" , "/build/llvm-toolchain-snapshot-9~svn362543/lib/Analysis/VectorUtils.cpp" , 544, __PRETTY_FUNCTION__)); | |||
545 | if (AccGroupSet2.count(Item)) | |||
546 | Intersection.push_back(Item); | |||
547 | } | |||
548 | } | |||
549 | ||||
550 | if (Intersection.size() == 0) | |||
551 | return nullptr; | |||
552 | if (Intersection.size() == 1) | |||
553 | return cast<MDNode>(Intersection.front()); | |||
554 | ||||
555 | LLVMContext &Ctx = Inst1->getContext(); | |||
556 | return MDNode::get(Ctx, Intersection); | |||
557 | } | |||
558 | ||||
559 | /// \returns \p I after propagating metadata from \p VL. | |||
560 | Instruction *llvm::propagateMetadata(Instruction *Inst, ArrayRef<Value *> VL) { | |||
561 | Instruction *I0 = cast<Instruction>(VL[0]); | |||
562 | SmallVector<std::pair<unsigned, MDNode *>, 4> Metadata; | |||
563 | I0->getAllMetadataOtherThanDebugLoc(Metadata); | |||
564 | ||||
565 | for (auto Kind : {LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope, | |||
566 | LLVMContext::MD_noalias, LLVMContext::MD_fpmath, | |||
567 | LLVMContext::MD_nontemporal, LLVMContext::MD_invariant_load, | |||
568 | LLVMContext::MD_access_group}) { | |||
569 | MDNode *MD = I0->getMetadata(Kind); | |||
570 | ||||
571 | for (int J = 1, E = VL.size(); MD && J != E; ++J) { | |||
572 | const Instruction *IJ = cast<Instruction>(VL[J]); | |||
573 | MDNode *IMD = IJ->getMetadata(Kind); | |||
574 | switch (Kind) { | |||
575 | case LLVMContext::MD_tbaa: | |||
576 | MD = MDNode::getMostGenericTBAA(MD, IMD); | |||
577 | break; | |||
578 | case LLVMContext::MD_alias_scope: | |||
579 | MD = MDNode::getMostGenericAliasScope(MD, IMD); | |||
580 | break; | |||
581 | case LLVMContext::MD_fpmath: | |||
582 | MD = MDNode::getMostGenericFPMath(MD, IMD); | |||
583 | break; | |||
584 | case LLVMContext::MD_noalias: | |||
585 | case LLVMContext::MD_nontemporal: | |||
586 | case LLVMContext::MD_invariant_load: | |||
587 | MD = MDNode::intersect(MD, IMD); | |||
588 | break; | |||
589 | case LLVMContext::MD_access_group: | |||
590 | MD = intersectAccessGroups(Inst, IJ); | |||
591 | break; | |||
592 | default: | |||
593 | llvm_unreachable("unhandled metadata")::llvm::llvm_unreachable_internal("unhandled metadata", "/build/llvm-toolchain-snapshot-9~svn362543/lib/Analysis/VectorUtils.cpp" , 593); | |||
594 | } | |||
595 | } | |||
596 | ||||
597 | Inst->setMetadata(Kind, MD); | |||
598 | } | |||
599 | ||||
600 | return Inst; | |||
601 | } | |||
602 | ||||
603 | Constant * | |||
604 | llvm::createBitMaskForGaps(IRBuilder<> &Builder, unsigned VF, | |||
605 | const InterleaveGroup<Instruction> &Group) { | |||
606 | // All 1's means mask is not needed. | |||
607 | if (Group.getNumMembers() == Group.getFactor()) | |||
608 | return nullptr; | |||
609 | ||||
610 | // TODO: support reversed access. | |||
611 | assert(!Group.isReverse() && "Reversed group not supported.")((!Group.isReverse() && "Reversed group not supported." ) ? static_cast<void> (0) : __assert_fail ("!Group.isReverse() && \"Reversed group not supported.\"" , "/build/llvm-toolchain-snapshot-9~svn362543/lib/Analysis/VectorUtils.cpp" , 611, __PRETTY_FUNCTION__)); | |||
612 | ||||
613 | SmallVector<Constant *, 16> Mask; | |||
614 | for (unsigned i = 0; i < VF; i++) | |||
615 | for (unsigned j = 0; j < Group.getFactor(); ++j) { | |||
616 | unsigned HasMember = Group.getMember(j) ? 1 : 0; | |||
617 | Mask.push_back(Builder.getInt1(HasMember)); | |||
618 | } | |||
619 | ||||
620 | return ConstantVector::get(Mask); | |||
621 | } | |||
622 | ||||
623 | Constant *llvm::createReplicatedMask(IRBuilder<> &Builder, | |||
624 | unsigned ReplicationFactor, unsigned VF) { | |||
625 | SmallVector<Constant *, 16> MaskVec; | |||
626 | for (unsigned i = 0; i < VF; i++) | |||
627 | for (unsigned j = 0; j < ReplicationFactor; j++) | |||
628 | MaskVec.push_back(Builder.getInt32(i)); | |||
629 | ||||
630 | return ConstantVector::get(MaskVec); | |||
631 | } | |||
632 | ||||
633 | Constant *llvm::createInterleaveMask(IRBuilder<> &Builder, unsigned VF, | |||
634 | unsigned NumVecs) { | |||
635 | SmallVector<Constant *, 16> Mask; | |||
636 | for (unsigned i = 0; i < VF; i++) | |||
637 | for (unsigned j = 0; j < NumVecs; j++) | |||
638 | Mask.push_back(Builder.getInt32(j * VF + i)); | |||
639 | ||||
640 | return ConstantVector::get(Mask); | |||
641 | } | |||
642 | ||||
643 | Constant *llvm::createStrideMask(IRBuilder<> &Builder, unsigned Start, | |||
644 | unsigned Stride, unsigned VF) { | |||
645 | SmallVector<Constant *, 16> Mask; | |||
646 | for (unsigned i = 0; i < VF; i++) | |||
647 | Mask.push_back(Builder.getInt32(Start + i * Stride)); | |||
648 | ||||
649 | return ConstantVector::get(Mask); | |||
650 | } | |||
651 | ||||
652 | Constant *llvm::createSequentialMask(IRBuilder<> &Builder, unsigned Start, | |||
653 | unsigned NumInts, unsigned NumUndefs) { | |||
654 | SmallVector<Constant *, 16> Mask; | |||
655 | for (unsigned i = 0; i < NumInts; i++) | |||
656 | Mask.push_back(Builder.getInt32(Start + i)); | |||
657 | ||||
658 | Constant *Undef = UndefValue::get(Builder.getInt32Ty()); | |||
659 | for (unsigned i = 0; i < NumUndefs; i++) | |||
660 | Mask.push_back(Undef); | |||
661 | ||||
662 | return ConstantVector::get(Mask); | |||
663 | } | |||
664 | ||||
665 | /// A helper function for concatenating vectors. This function concatenates two | |||
666 | /// vectors having the same element type. If the second vector has fewer | |||
667 | /// elements than the first, it is padded with undefs. | |||
668 | static Value *concatenateTwoVectors(IRBuilder<> &Builder, Value *V1, | |||
669 | Value *V2) { | |||
670 | VectorType *VecTy1 = dyn_cast<VectorType>(V1->getType()); | |||
671 | VectorType *VecTy2 = dyn_cast<VectorType>(V2->getType()); | |||
672 | assert(VecTy1 && VecTy2 &&((VecTy1 && VecTy2 && VecTy1->getScalarType () == VecTy2->getScalarType() && "Expect two vectors with the same element type" ) ? static_cast<void> (0) : __assert_fail ("VecTy1 && VecTy2 && VecTy1->getScalarType() == VecTy2->getScalarType() && \"Expect two vectors with the same element type\"" , "/build/llvm-toolchain-snapshot-9~svn362543/lib/Analysis/VectorUtils.cpp" , 674, __PRETTY_FUNCTION__)) | |||
673 | VecTy1->getScalarType() == VecTy2->getScalarType() &&((VecTy1 && VecTy2 && VecTy1->getScalarType () == VecTy2->getScalarType() && "Expect two vectors with the same element type" ) ? static_cast<void> (0) : __assert_fail ("VecTy1 && VecTy2 && VecTy1->getScalarType() == VecTy2->getScalarType() && \"Expect two vectors with the same element type\"" , "/build/llvm-toolchain-snapshot-9~svn362543/lib/Analysis/VectorUtils.cpp" , 674, __PRETTY_FUNCTION__)) | |||
674 | "Expect two vectors with the same element type")((VecTy1 && VecTy2 && VecTy1->getScalarType () == VecTy2->getScalarType() && "Expect two vectors with the same element type" ) ? static_cast<void> (0) : __assert_fail ("VecTy1 && VecTy2 && VecTy1->getScalarType() == VecTy2->getScalarType() && \"Expect two vectors with the same element type\"" , "/build/llvm-toolchain-snapshot-9~svn362543/lib/Analysis/VectorUtils.cpp" , 674, __PRETTY_FUNCTION__)); | |||
675 | ||||
676 | unsigned NumElts1 = VecTy1->getNumElements(); | |||
677 | unsigned NumElts2 = VecTy2->getNumElements(); | |||
678 | assert(NumElts1 >= NumElts2 && "Unexpect the first vector has less elements")((NumElts1 >= NumElts2 && "Unexpect the first vector has less elements" ) ? static_cast<void> (0) : __assert_fail ("NumElts1 >= NumElts2 && \"Unexpect the first vector has less elements\"" , "/build/llvm-toolchain-snapshot-9~svn362543/lib/Analysis/VectorUtils.cpp" , 678, __PRETTY_FUNCTION__)); | |||
679 | ||||
680 | if (NumElts1 > NumElts2) { | |||
681 | // Extend with UNDEFs. | |||
682 | Constant *ExtMask = | |||
683 | createSequentialMask(Builder, 0, NumElts2, NumElts1 - NumElts2); | |||
684 | V2 = Builder.CreateShuffleVector(V2, UndefValue::get(VecTy2), ExtMask); | |||
685 | } | |||
686 | ||||
687 | Constant *Mask = createSequentialMask(Builder, 0, NumElts1 + NumElts2, 0); | |||
688 | return Builder.CreateShuffleVector(V1, V2, Mask); | |||
689 | } | |||
690 | ||||
691 | Value *llvm::concatenateVectors(IRBuilder<> &Builder, ArrayRef<Value *> Vecs) { | |||
692 | unsigned NumVecs = Vecs.size(); | |||
693 | assert(NumVecs > 1 && "Should be at least two vectors")((NumVecs > 1 && "Should be at least two vectors") ? static_cast<void> (0) : __assert_fail ("NumVecs > 1 && \"Should be at least two vectors\"" , "/build/llvm-toolchain-snapshot-9~svn362543/lib/Analysis/VectorUtils.cpp" , 693, __PRETTY_FUNCTION__)); | |||
694 | ||||
695 | SmallVector<Value *, 8> ResList; | |||
696 | ResList.append(Vecs.begin(), Vecs.end()); | |||
697 | do { | |||
698 | SmallVector<Value *, 8> TmpList; | |||
699 | for (unsigned i = 0; i < NumVecs - 1; i += 2) { | |||
700 | Value *V0 = ResList[i], *V1 = ResList[i + 1]; | |||
701 | assert((V0->getType() == V1->getType() || i == NumVecs - 2) &&(((V0->getType() == V1->getType() || i == NumVecs - 2) && "Only the last vector may have a different type") ? static_cast <void> (0) : __assert_fail ("(V0->getType() == V1->getType() || i == NumVecs - 2) && \"Only the last vector may have a different type\"" , "/build/llvm-toolchain-snapshot-9~svn362543/lib/Analysis/VectorUtils.cpp" , 702, __PRETTY_FUNCTION__)) | |||
702 | "Only the last vector may have a different type")(((V0->getType() == V1->getType() || i == NumVecs - 2) && "Only the last vector may have a different type") ? static_cast <void> (0) : __assert_fail ("(V0->getType() == V1->getType() || i == NumVecs - 2) && \"Only the last vector may have a different type\"" , "/build/llvm-toolchain-snapshot-9~svn362543/lib/Analysis/VectorUtils.cpp" , 702, __PRETTY_FUNCTION__)); | |||
703 | ||||
704 | TmpList.push_back(concatenateTwoVectors(Builder, V0, V1)); | |||
705 | } | |||
706 | ||||
707 | // Push the last vector if the total number of vectors is odd. | |||
708 | if (NumVecs % 2 != 0) | |||
709 | TmpList.push_back(ResList[NumVecs - 1]); | |||
710 | ||||
711 | ResList = TmpList; | |||
712 | NumVecs = ResList.size(); | |||
713 | } while (NumVecs > 1); | |||
714 | ||||
715 | return ResList[0]; | |||
716 | } | |||
717 | ||||
718 | bool llvm::maskIsAllZeroOrUndef(Value *Mask) { | |||
719 | auto *ConstMask = dyn_cast<Constant>(Mask); | |||
720 | if (!ConstMask) | |||
721 | return false; | |||
722 | if (ConstMask->isNullValue() || isa<UndefValue>(ConstMask)) | |||
723 | return true; | |||
724 | for (unsigned I = 0, E = ConstMask->getType()->getVectorNumElements(); I != E; | |||
725 | ++I) { | |||
726 | if (auto *MaskElt = ConstMask->getAggregateElement(I)) | |||
727 | if (MaskElt->isNullValue() || isa<UndefValue>(MaskElt)) | |||
728 | continue; | |||
729 | return false; | |||
730 | } | |||
731 | return true; | |||
732 | } | |||
733 | ||||
734 | ||||
735 | bool llvm::maskIsAllOneOrUndef(Value *Mask) { | |||
736 | auto *ConstMask = dyn_cast<Constant>(Mask); | |||
737 | if (!ConstMask) | |||
738 | return false; | |||
739 | if (ConstMask->isAllOnesValue() || isa<UndefValue>(ConstMask)) | |||
740 | return true; | |||
741 | for (unsigned I = 0, E = ConstMask->getType()->getVectorNumElements(); I != E; | |||
742 | ++I) { | |||
743 | if (auto *MaskElt = ConstMask->getAggregateElement(I)) | |||
744 | if (MaskElt->isAllOnesValue() || isa<UndefValue>(MaskElt)) | |||
745 | continue; | |||
746 | return false; | |||
747 | } | |||
748 | return true; | |||
749 | } | |||
750 | ||||
751 | /// TODO: This is a lot like known bits, but for | |||
752 | /// vectors. Is there something we can common this with? | |||
753 | APInt llvm::possiblyDemandedEltsInMask(Value *Mask) { | |||
754 | ||||
755 | const unsigned VWidth = cast<VectorType>(Mask->getType())->getNumElements(); | |||
756 | APInt DemandedElts = APInt::getAllOnesValue(VWidth); | |||
757 | if (auto *CV = dyn_cast<ConstantVector>(Mask)) | |||
758 | for (unsigned i = 0; i < VWidth; i++) | |||
759 | if (CV->getAggregateElement(i)->isNullValue()) | |||
760 | DemandedElts.clearBit(i); | |||
761 | return DemandedElts; | |||
762 | } | |||
763 | ||||
764 | bool InterleavedAccessInfo::isStrided(int Stride) { | |||
765 | unsigned Factor = std::abs(Stride); | |||
766 | return Factor >= 2 && Factor <= MaxInterleaveGroupFactor; | |||
767 | } | |||
768 | ||||
769 | void InterleavedAccessInfo::collectConstStrideAccesses( | |||
770 | MapVector<Instruction *, StrideDescriptor> &AccessStrideInfo, | |||
771 | const ValueToValueMap &Strides) { | |||
772 | auto &DL = TheLoop->getHeader()->getModule()->getDataLayout(); | |||
773 | ||||
774 | // Since it's desired that the load/store instructions be maintained in | |||
775 | // "program order" for the interleaved access analysis, we have to visit the | |||
776 | // blocks in the loop in reverse postorder (i.e., in a topological order). | |||
777 | // Such an ordering will ensure that any load/store that may be executed | |||
778 | // before a second load/store will precede the second load/store in | |||
779 | // AccessStrideInfo. | |||
780 | LoopBlocksDFS DFS(TheLoop); | |||
781 | DFS.perform(LI); | |||
782 | for (BasicBlock *BB : make_range(DFS.beginRPO(), DFS.endRPO())) | |||
783 | for (auto &I : *BB) { | |||
784 | auto *LI = dyn_cast<LoadInst>(&I); | |||
785 | auto *SI = dyn_cast<StoreInst>(&I); | |||
786 | if (!LI && !SI) | |||
787 | continue; | |||
788 | ||||
789 | Value *Ptr = getLoadStorePointerOperand(&I); | |||
790 | // We don't check wrapping here because we don't know yet if Ptr will be | |||
791 | // part of a full group or a group with gaps. Checking wrapping for all | |||
792 | // pointers (even those that end up in groups with no gaps) will be overly | |||
793 | // conservative. For full groups, wrapping should be ok since if we would | |||
794 | // wrap around the address space we would do a memory access at nullptr | |||
795 | // even without the transformation. The wrapping checks are therefore | |||
796 | // deferred until after we've formed the interleaved groups. | |||
797 | int64_t Stride = getPtrStride(PSE, Ptr, TheLoop, Strides, | |||
798 | /*Assume=*/true, /*ShouldCheckWrap=*/false); | |||
799 | ||||
800 | const SCEV *Scev = replaceSymbolicStrideSCEV(PSE, Strides, Ptr); | |||
801 | PointerType *PtrTy = dyn_cast<PointerType>(Ptr->getType()); | |||
802 | uint64_t Size = DL.getTypeAllocSize(PtrTy->getElementType()); | |||
803 | ||||
804 | // An alignment of 0 means target ABI alignment. | |||
805 | unsigned Align = getLoadStoreAlignment(&I); | |||
806 | if (!Align) | |||
807 | Align = DL.getABITypeAlignment(PtrTy->getElementType()); | |||
808 | ||||
809 | AccessStrideInfo[&I] = StrideDescriptor(Stride, Scev, Size, Align); | |||
810 | } | |||
811 | } | |||
812 | ||||
813 | // Analyze interleaved accesses and collect them into interleaved load and | |||
814 | // store groups. | |||
815 | // | |||
816 | // When generating code for an interleaved load group, we effectively hoist all | |||
817 | // loads in the group to the location of the first load in program order. When | |||
818 | // generating code for an interleaved store group, we sink all stores to the | |||
819 | // location of the last store. This code motion can change the order of load | |||
820 | // and store instructions and may break dependences. | |||
821 | // | |||
822 | // The code generation strategy mentioned above ensures that we won't violate | |||
823 | // any write-after-read (WAR) dependences. | |||
824 | // | |||
825 | // E.g., for the WAR dependence: a = A[i]; // (1) | |||
826 | // A[i] = b; // (2) | |||
827 | // | |||
828 | // The store group of (2) is always inserted at or below (2), and the load | |||
829 | // group of (1) is always inserted at or above (1). Thus, the instructions will | |||
830 | // never be reordered. All other dependences are checked to ensure the | |||
831 | // correctness of the instruction reordering. | |||
832 | // | |||
833 | // The algorithm visits all memory accesses in the loop in bottom-up program | |||
834 | // order. Program order is established by traversing the blocks in the loop in | |||
835 | // reverse postorder when collecting the accesses. | |||
836 | // | |||
837 | // We visit the memory accesses in bottom-up order because it can simplify the | |||
838 | // construction of store groups in the presence of write-after-write (WAW) | |||
839 | // dependences. | |||
840 | // | |||
841 | // E.g., for the WAW dependence: A[i] = a; // (1) | |||
842 | // A[i] = b; // (2) | |||
843 | // A[i + 1] = c; // (3) | |||
844 | // | |||
845 | // We will first create a store group with (3) and (2). (1) can't be added to | |||
846 | // this group because it and (2) are dependent. However, (1) can be grouped | |||
847 | // with other accesses that may precede it in program order. Note that a | |||
848 | // bottom-up order does not imply that WAW dependences should not be checked. | |||
849 | void InterleavedAccessInfo::analyzeInterleaving( | |||
850 | bool EnablePredicatedInterleavedMemAccesses) { | |||
851 | LLVM_DEBUG(dbgs() << "LV: Analyzing interleaved accesses...\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("vectorutils")) { dbgs() << "LV: Analyzing interleaved accesses...\n" ; } } while (false); | |||
| ||||
852 | const ValueToValueMap &Strides = LAI->getSymbolicStrides(); | |||
853 | ||||
854 | // Holds all accesses with a constant stride. | |||
855 | MapVector<Instruction *, StrideDescriptor> AccessStrideInfo; | |||
856 | collectConstStrideAccesses(AccessStrideInfo, Strides); | |||
857 | ||||
858 | if (AccessStrideInfo.empty()) | |||
859 | return; | |||
860 | ||||
861 | // Collect the dependences in the loop. | |||
862 | collectDependences(); | |||
863 | ||||
864 | // Holds all interleaved store groups temporarily. | |||
865 | SmallSetVector<InterleaveGroup<Instruction> *, 4> StoreGroups; | |||
866 | // Holds all interleaved load groups temporarily. | |||
867 | SmallSetVector<InterleaveGroup<Instruction> *, 4> LoadGroups; | |||
868 | ||||
869 | // Search in bottom-up program order for pairs of accesses (A and B) that can | |||
870 | // form interleaved load or store groups. In the algorithm below, access A | |||
871 | // precedes access B in program order. We initialize a group for B in the | |||
872 | // outer loop of the algorithm, and then in the inner loop, we attempt to | |||
873 | // insert each A into B's group if: | |||
874 | // | |||
875 | // 1. A and B have the same stride, | |||
876 | // 2. A and B have the same memory object size, and | |||
877 | // 3. A belongs in B's group according to its distance from B. | |||
878 | // | |||
879 | // Special care is taken to ensure group formation will not break any | |||
880 | // dependences. | |||
881 | for (auto BI = AccessStrideInfo.rbegin(), E = AccessStrideInfo.rend(); | |||
882 | BI != E; ++BI) { | |||
883 | Instruction *B = BI->first; | |||
884 | StrideDescriptor DesB = BI->second; | |||
885 | ||||
886 | // Initialize a group for B if it has an allowable stride. Even if we don't | |||
887 | // create a group for B, we continue with the bottom-up algorithm to ensure | |||
888 | // we don't break any of B's dependences. | |||
889 | InterleaveGroup<Instruction> *Group = nullptr; | |||
890 | if (isStrided(DesB.Stride) && | |||
891 | (!isPredicated(B->getParent()) || EnablePredicatedInterleavedMemAccesses)) { | |||
892 | Group = getInterleaveGroup(B); | |||
893 | if (!Group) { | |||
894 | LLVM_DEBUG(dbgs() << "LV: Creating an interleave group with:" << *Bdo { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("vectorutils")) { dbgs() << "LV: Creating an interleave group with:" << *B << '\n'; } } while (false) | |||
895 | << '\n')do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("vectorutils")) { dbgs() << "LV: Creating an interleave group with:" << *B << '\n'; } } while (false); | |||
896 | Group = createInterleaveGroup(B, DesB.Stride, DesB.Align); | |||
897 | } | |||
898 | if (B->mayWriteToMemory()) | |||
899 | StoreGroups.insert(Group); | |||
900 | else | |||
901 | LoadGroups.insert(Group); | |||
902 | } | |||
903 | ||||
904 | for (auto AI = std::next(BI); AI != E; ++AI) { | |||
905 | Instruction *A = AI->first; | |||
906 | StrideDescriptor DesA = AI->second; | |||
907 | ||||
908 | // Our code motion strategy implies that we can't have dependences | |||
909 | // between accesses in an interleaved group and other accesses located | |||
910 | // between the first and last member of the group. Note that this also | |||
911 | // means that a group can't have more than one member at a given offset. | |||
912 | // The accesses in a group can have dependences with other accesses, but | |||
913 | // we must ensure we don't extend the boundaries of the group such that | |||
914 | // we encompass those dependent accesses. | |||
915 | // | |||
916 | // For example, assume we have the sequence of accesses shown below in a | |||
917 | // stride-2 loop: | |||
918 | // | |||
919 | // (1, 2) is a group | A[i] = a; // (1) | |||
920 | // | A[i-1] = b; // (2) | | |||
921 | // A[i-3] = c; // (3) | |||
922 | // A[i] = d; // (4) | (2, 4) is not a group | |||
923 | // | |||
924 | // Because accesses (2) and (3) are dependent, we can group (2) with (1) | |||
925 | // but not with (4). If we did, the dependent access (3) would be within | |||
926 | // the boundaries of the (2, 4) group. | |||
927 | if (!canReorderMemAccessesForInterleavedGroups(&*AI, &*BI)) { | |||
928 | // If a dependence exists and A is already in a group, we know that A | |||
929 | // must be a store since A precedes B and WAR dependences are allowed. | |||
930 | // Thus, A would be sunk below B. We release A's group to prevent this | |||
931 | // illegal code motion. A will then be free to form another group with | |||
932 | // instructions that precede it. | |||
933 | if (isInterleaved(A)) { | |||
934 | InterleaveGroup<Instruction> *StoreGroup = getInterleaveGroup(A); | |||
935 | StoreGroups.remove(StoreGroup); | |||
936 | releaseGroup(StoreGroup); | |||
937 | } | |||
938 | ||||
939 | // If a dependence exists and A is not already in a group (or it was | |||
940 | // and we just released it), B might be hoisted above A (if B is a | |||
941 | // load) or another store might be sunk below A (if B is a store). In | |||
942 | // either case, we can't add additional instructions to B's group. B | |||
943 | // will only form a group with instructions that it precedes. | |||
944 | break; | |||
945 | } | |||
946 | ||||
947 | // At this point, we've checked for illegal code motion. If either A or B | |||
948 | // isn't strided, there's nothing left to do. | |||
949 | if (!isStrided(DesA.Stride) || !isStrided(DesB.Stride)) | |||
950 | continue; | |||
951 | ||||
952 | // Ignore A if it's already in a group or isn't the same kind of memory | |||
953 | // operation as B. | |||
954 | // Note that mayReadFromMemory() isn't mutually exclusive to | |||
955 | // mayWriteToMemory in the case of atomic loads. We shouldn't see those | |||
956 | // here, canVectorizeMemory() should have returned false - except for the | |||
957 | // case we asked for optimization remarks. | |||
958 | if (isInterleaved(A) || | |||
959 | (A->mayReadFromMemory() != B->mayReadFromMemory()) || | |||
960 | (A->mayWriteToMemory() != B->mayWriteToMemory())) | |||
961 | continue; | |||
962 | ||||
963 | // Check rules 1 and 2. Ignore A if its stride or size is different from | |||
964 | // that of B. | |||
965 | if (DesA.Stride != DesB.Stride || DesA.Size != DesB.Size) | |||
966 | continue; | |||
967 | ||||
968 | // Ignore A if the memory object of A and B don't belong to the same | |||
969 | // address space | |||
970 | if (getLoadStoreAddressSpace(A) != getLoadStoreAddressSpace(B)) | |||
971 | continue; | |||
972 | ||||
973 | // Calculate the distance from A to B. | |||
974 | const SCEVConstant *DistToB = dyn_cast<SCEVConstant>( | |||
975 | PSE.getSE()->getMinusSCEV(DesA.Scev, DesB.Scev)); | |||
976 | if (!DistToB) | |||
977 | continue; | |||
978 | int64_t DistanceToB = DistToB->getAPInt().getSExtValue(); | |||
979 | ||||
980 | // Check rule 3. Ignore A if its distance to B is not a multiple of the | |||
981 | // size. | |||
982 | if (DistanceToB % static_cast<int64_t>(DesB.Size)) | |||
983 | continue; | |||
984 | ||||
985 | // All members of a predicated interleave-group must have the same predicate, | |||
986 | // and currently must reside in the same BB. | |||
987 | BasicBlock *BlockA = A->getParent(); | |||
988 | BasicBlock *BlockB = B->getParent(); | |||
989 | if ((isPredicated(BlockA) || isPredicated(BlockB)) && | |||
990 | (!EnablePredicatedInterleavedMemAccesses || BlockA != BlockB)) | |||
991 | continue; | |||
992 | ||||
993 | // The index of A is the index of B plus A's distance to B in multiples | |||
994 | // of the size. | |||
995 | int IndexA = | |||
996 | Group->getIndex(B) + DistanceToB / static_cast<int64_t>(DesB.Size); | |||
| ||||
997 | ||||
998 | // Try to insert A into B's group. | |||
999 | if (Group->insertMember(A, IndexA, DesA.Align)) { | |||
1000 | LLVM_DEBUG(dbgs() << "LV: Inserted:" << *A << '\n'do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("vectorutils")) { dbgs() << "LV: Inserted:" << * A << '\n' << " into the interleave group with" << *B << '\n'; } } while (false) | |||
1001 | << " into the interleave group with" << *Bdo { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("vectorutils")) { dbgs() << "LV: Inserted:" << * A << '\n' << " into the interleave group with" << *B << '\n'; } } while (false) | |||
1002 | << '\n')do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("vectorutils")) { dbgs() << "LV: Inserted:" << * A << '\n' << " into the interleave group with" << *B << '\n'; } } while (false); | |||
1003 | InterleaveGroupMap[A] = Group; | |||
1004 | ||||
1005 | // Set the first load in program order as the insert position. | |||
1006 | if (A->mayReadFromMemory()) | |||
1007 | Group->setInsertPos(A); | |||
1008 | } | |||
1009 | } // Iteration over A accesses. | |||
1010 | } // Iteration over B accesses. | |||
1011 | ||||
1012 | // Remove interleaved store groups with gaps. | |||
1013 | for (auto *Group : StoreGroups) | |||
1014 | if (Group->getNumMembers() != Group->getFactor()) { | |||
1015 | LLVM_DEBUG(do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("vectorutils")) { dbgs() << "LV: Invalidate candidate interleaved store group due " "to gaps.\n"; } } while (false) | |||
1016 | dbgs() << "LV: Invalidate candidate interleaved store group due "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("vectorutils")) { dbgs() << "LV: Invalidate candidate interleaved store group due " "to gaps.\n"; } } while (false) | |||
1017 | "to gaps.\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("vectorutils")) { dbgs() << "LV: Invalidate candidate interleaved store group due " "to gaps.\n"; } } while (false); | |||
1018 | releaseGroup(Group); | |||
1019 | } | |||
1020 | // Remove interleaved groups with gaps (currently only loads) whose memory | |||
1021 | // accesses may wrap around. We have to revisit the getPtrStride analysis, | |||
1022 | // this time with ShouldCheckWrap=true, since collectConstStrideAccesses does | |||
1023 | // not check wrapping (see documentation there). | |||
1024 | // FORNOW we use Assume=false; | |||
1025 | // TODO: Change to Assume=true but making sure we don't exceed the threshold | |||
1026 | // of runtime SCEV assumptions checks (thereby potentially failing to | |||
1027 | // vectorize altogether). | |||
1028 | // Additional optional optimizations: | |||
1029 | // TODO: If we are peeling the loop and we know that the first pointer doesn't | |||
1030 | // wrap then we can deduce that all pointers in the group don't wrap. | |||
1031 | // This means that we can forcefully peel the loop in order to only have to | |||
1032 | // check the first pointer for no-wrap. When we'll change to use Assume=true | |||
1033 | // we'll only need at most one runtime check per interleaved group. | |||
1034 | for (auto *Group : LoadGroups) { | |||
1035 | // Case 1: A full group. Can Skip the checks; For full groups, if the wide | |||
1036 | // load would wrap around the address space we would do a memory access at | |||
1037 | // nullptr even without the transformation. | |||
1038 | if (Group->getNumMembers() == Group->getFactor()) | |||
1039 | continue; | |||
1040 | ||||
1041 | // Case 2: If first and last members of the group don't wrap this implies | |||
1042 | // that all the pointers in the group don't wrap. | |||
1043 | // So we check only group member 0 (which is always guaranteed to exist), | |||
1044 | // and group member Factor - 1; If the latter doesn't exist we rely on | |||
1045 | // peeling (if it is a non-reversed accsess -- see Case 3). | |||
1046 | Value *FirstMemberPtr = getLoadStorePointerOperand(Group->getMember(0)); | |||
1047 | if (!getPtrStride(PSE, FirstMemberPtr, TheLoop, Strides, /*Assume=*/false, | |||
1048 | /*ShouldCheckWrap=*/true)) { | |||
1049 | LLVM_DEBUG(do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("vectorutils")) { dbgs() << "LV: Invalidate candidate interleaved group due to " "first group member potentially pointer-wrapping.\n"; } } while (false) | |||
1050 | dbgs() << "LV: Invalidate candidate interleaved group due to "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("vectorutils")) { dbgs() << "LV: Invalidate candidate interleaved group due to " "first group member potentially pointer-wrapping.\n"; } } while (false) | |||
1051 | "first group member potentially pointer-wrapping.\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("vectorutils")) { dbgs() << "LV: Invalidate candidate interleaved group due to " "first group member potentially pointer-wrapping.\n"; } } while (false); | |||
1052 | releaseGroup(Group); | |||
1053 | continue; | |||
1054 | } | |||
1055 | Instruction *LastMember = Group->getMember(Group->getFactor() - 1); | |||
1056 | if (LastMember) { | |||
1057 | Value *LastMemberPtr = getLoadStorePointerOperand(LastMember); | |||
1058 | if (!getPtrStride(PSE, LastMemberPtr, TheLoop, Strides, /*Assume=*/false, | |||
1059 | /*ShouldCheckWrap=*/true)) { | |||
1060 | LLVM_DEBUG(do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("vectorutils")) { dbgs() << "LV: Invalidate candidate interleaved group due to " "last group member potentially pointer-wrapping.\n"; } } while (false) | |||
1061 | dbgs() << "LV: Invalidate candidate interleaved group due to "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("vectorutils")) { dbgs() << "LV: Invalidate candidate interleaved group due to " "last group member potentially pointer-wrapping.\n"; } } while (false) | |||
1062 | "last group member potentially pointer-wrapping.\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("vectorutils")) { dbgs() << "LV: Invalidate candidate interleaved group due to " "last group member potentially pointer-wrapping.\n"; } } while (false); | |||
1063 | releaseGroup(Group); | |||
1064 | } | |||
1065 | } else { | |||
1066 | // Case 3: A non-reversed interleaved load group with gaps: We need | |||
1067 | // to execute at least one scalar epilogue iteration. This will ensure | |||
1068 | // we don't speculatively access memory out-of-bounds. We only need | |||
1069 | // to look for a member at index factor - 1, since every group must have | |||
1070 | // a member at index zero. | |||
1071 | if (Group->isReverse()) { | |||
1072 | LLVM_DEBUG(do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("vectorutils")) { dbgs() << "LV: Invalidate candidate interleaved group due to " "a reverse access with gaps.\n"; } } while (false) | |||
1073 | dbgs() << "LV: Invalidate candidate interleaved group due to "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("vectorutils")) { dbgs() << "LV: Invalidate candidate interleaved group due to " "a reverse access with gaps.\n"; } } while (false) | |||
1074 | "a reverse access with gaps.\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("vectorutils")) { dbgs() << "LV: Invalidate candidate interleaved group due to " "a reverse access with gaps.\n"; } } while (false); | |||
1075 | releaseGroup(Group); | |||
1076 | continue; | |||
1077 | } | |||
1078 | LLVM_DEBUG(do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("vectorutils")) { dbgs() << "LV: Interleaved group requires epilogue iteration.\n" ; } } while (false) | |||
1079 | dbgs() << "LV: Interleaved group requires epilogue iteration.\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("vectorutils")) { dbgs() << "LV: Interleaved group requires epilogue iteration.\n" ; } } while (false); | |||
1080 | RequiresScalarEpilogue = true; | |||
1081 | } | |||
1082 | } | |||
1083 | } | |||
1084 | ||||
1085 | void InterleavedAccessInfo::invalidateGroupsRequiringScalarEpilogue() { | |||
1086 | // If no group had triggered the requirement to create an epilogue loop, | |||
1087 | // there is nothing to do. | |||
1088 | if (!requiresScalarEpilogue()) | |||
1089 | return; | |||
1090 | ||||
1091 | // Avoid releasing a Group twice. | |||
1092 | SmallPtrSet<InterleaveGroup<Instruction> *, 4> DelSet; | |||
1093 | for (auto &I : InterleaveGroupMap) { | |||
1094 | InterleaveGroup<Instruction> *Group = I.second; | |||
1095 | if (Group->requiresScalarEpilogue()) | |||
1096 | DelSet.insert(Group); | |||
1097 | } | |||
1098 | for (auto *Ptr : DelSet) { | |||
1099 | LLVM_DEBUG(do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("vectorutils")) { dbgs() << "LV: Invalidate candidate interleaved group due to gaps that " "require a scalar epilogue (not allowed under optsize) and cannot " "be masked (not enabled). \n"; } } while (false) | |||
1100 | dbgs()do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("vectorutils")) { dbgs() << "LV: Invalidate candidate interleaved group due to gaps that " "require a scalar epilogue (not allowed under optsize) and cannot " "be masked (not enabled). \n"; } } while (false) | |||
1101 | << "LV: Invalidate candidate interleaved group due to gaps that "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("vectorutils")) { dbgs() << "LV: Invalidate candidate interleaved group due to gaps that " "require a scalar epilogue (not allowed under optsize) and cannot " "be masked (not enabled). \n"; } } while (false) | |||
1102 | "require a scalar epilogue (not allowed under optsize) and cannot "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("vectorutils")) { dbgs() << "LV: Invalidate candidate interleaved group due to gaps that " "require a scalar epilogue (not allowed under optsize) and cannot " "be masked (not enabled). \n"; } } while (false) | |||
1103 | "be masked (not enabled). \n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("vectorutils")) { dbgs() << "LV: Invalidate candidate interleaved group due to gaps that " "require a scalar epilogue (not allowed under optsize) and cannot " "be masked (not enabled). \n"; } } while (false); | |||
1104 | releaseGroup(Ptr); | |||
1105 | } | |||
1106 | ||||
1107 | RequiresScalarEpilogue = false; | |||
1108 | } | |||
1109 | ||||
1110 | template <typename InstT> | |||
1111 | void InterleaveGroup<InstT>::addMetadata(InstT *NewInst) const { | |||
1112 | llvm_unreachable("addMetadata can only be used for Instruction")::llvm::llvm_unreachable_internal("addMetadata can only be used for Instruction" , "/build/llvm-toolchain-snapshot-9~svn362543/lib/Analysis/VectorUtils.cpp" , 1112); | |||
1113 | } | |||
1114 | ||||
1115 | namespace llvm { | |||
1116 | template <> | |||
1117 | void InterleaveGroup<Instruction>::addMetadata(Instruction *NewInst) const { | |||
1118 | SmallVector<Value *, 4> VL; | |||
1119 | std::transform(Members.begin(), Members.end(), std::back_inserter(VL), | |||
1120 | [](std::pair<int, Instruction *> p) { return p.second; }); | |||
1121 | propagateMetadata(NewInst, VL); | |||
1122 | } | |||
1123 | } |