LLVM 22.0.0git
LowerMatrixIntrinsics.cpp
Go to the documentation of this file.
1//===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Lower matrix intrinsics to vector operations.
10//
11// TODO:
12// * Improve fusion:
13// * Support more cases, e.g. multiply-add, multiply-sub, operands/results
14// transposed.
15// * Improve cost-modeling, e.g. choose different number of rows/columns
16// columns for tiles, consider cost of copies on alias.
17//
18//===----------------------------------------------------------------------===//
19
22#include "llvm/ADT/STLExtras.h"
23#include "llvm/ADT/ScopeExit.h"
25#include "llvm/ADT/Statistic.h"
33#include "llvm/IR/CFG.h"
34#include "llvm/IR/DataLayout.h"
37#include "llvm/IR/Function.h"
38#include "llvm/IR/IRBuilder.h"
39#include "llvm/IR/InstrTypes.h"
47#include "llvm/Support/Debug.h"
51
52#include <cmath>
53
54using namespace llvm;
55using namespace PatternMatch;
56
57#define DEBUG_TYPE "lower-matrix-intrinsics"
58
59STATISTIC(FlattenedMatrices, "Number of matrix flattenings");
60STATISTIC(ReshapedMatrices, "Number of matrix reshapes");
61STATISTIC(SplitMatrices, "Number of matrix splits");
62
63static cl::opt<bool>
64 FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden,
65 cl::desc("Enable/disable fusing matrix instructions."));
66// TODO: Allow and use non-square tiles.
68 "fuse-matrix-tile-size", cl::init(4), cl::Hidden,
70 "Tile size for matrix instruction fusion using square-shaped tiles."));
71static cl::opt<bool> TileUseLoops("fuse-matrix-use-loops", cl::init(false),
73 cl::desc("Generate loop nest for tiling."));
75 "force-fuse-matrix", cl::init(false), cl::Hidden,
76 cl::desc("Force matrix instruction fusion even if not profitable."));
78 "matrix-allow-contract", cl::init(false), cl::Hidden,
79 cl::desc("Allow the use of FMAs if available and profitable. This may "
80 "result in different results, due to less rounding error."));
81
82static cl::opt<bool>
83 VerifyShapeInfo("verify-matrix-shapes", cl::Hidden,
84 cl::desc("Enable/disable matrix shape verification."),
85 cl::init(false));
86
88
90 "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor),
91 cl::desc("Sets the default matrix layout"),
93 "Use column-major layout"),
95 "Use row-major layout")));
96
97static cl::opt<bool> PrintAfterTransposeOpt("matrix-print-after-transpose-opt",
98 cl::init(false));
99
100/// Helper function to either return Scope, if it is a subprogram or the
101/// attached subprogram for a local scope.
103 if (auto *Subprogram = dyn_cast<DISubprogram>(Scope))
104 return Subprogram;
105 return cast<DILocalScope>(Scope)->getSubprogram();
106}
107
108/// Return true if V is a splat of a value (which is used when multiplying a
109/// matrix with a scalar).
110static bool isSplat(Value *V) {
111 if (auto *SV = dyn_cast<ShuffleVectorInst>(V))
112 return SV->isZeroEltSplat();
113 return false;
114}
115
116/// Match any mul operation (fp or integer).
117template <typename LTy, typename RTy>
118auto m_AnyMul(const LTy &L, const RTy &R) {
119 return m_CombineOr(m_Mul(L, R), m_FMul(L, R));
120}
121
122/// Match any add operation (fp or integer).
123template <typename LTy, typename RTy>
124auto m_AnyAdd(const LTy &L, const RTy &R) {
125 return m_CombineOr(m_Add(L, R), m_FAdd(L, R));
126}
127
128namespace {
129
130// Given an element pointer \p BasePtr to the start of a (sub) matrix, compute
131// the start address of vector \p VecIdx with type (\p EltType x \p NumElements)
132// assuming \p Stride elements between start two consecutive vectors.
133// \p Stride must be >= \p NumElements.
134// For column-major matrixes, the function computes the address of a column
135// vectors and \p NumElements must be set to the number of elements in a column
136// (= number of rows of the matrix). For row-major matrixes, the function
137// computes the address of a row vector and \p NumElements must be set to the
138// number of elements in a column (= number of columns of the matrix).
139//
140// Consider a 4x4 matrix in column-mjaor layout like below
141//
142// 0 1 2 3
143// 0 v_0_0 v_0_1 v_0_2 v_0_3
144// 1 v_1_0 v_1_1 v_1_2 v_1_3
145// 2 v_2_0 v_2_1 v_2_2 v_2_3
146// 3 v_3_0 v_3_1 v_3_2 v_3_3
147
148// To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1,
149// we need a pointer to the first element of the submatrix as base pointer.
150// Then we can use computeVectorAddr to compute the addresses for the columns
151// of the sub-matrix.
152//
153// Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..)
154// -> just returns Base
155// Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..)
156// -> returns Base + (1 * 4)
157// Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..)
158// -> returns Base + (2 * 4)
159//
160// The graphic below illustrates the number of elements in a column (marked
161// with |) and the number of skipped elements (marked with }).
162//
163// v_0_0 v_0_1 {v_0_2 {v_0_3
164// Base Col 1 Col 2
165// | | |
166// v_1_0 |v_1_1 |v_1_2 |v_1_3
167// v_2_0 |v_2_1 |v_2_2 |v_2_3
168// v_3_0 {v_3_1 {v_3_2 v_3_3
169//
170Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride,
171 unsigned NumElements, Type *EltType,
172 IRBuilder<> &Builder) {
173
174 assert((!isa<ConstantInt>(Stride) ||
175 cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) &&
176 "Stride must be >= the number of elements in the result vector.");
177
178 // Compute the start of the vector with index VecIdx as VecIdx * Stride.
179 Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start");
180
181 // Get pointer to the start of the selected vector. Skip GEP creation,
182 // if we select vector 0.
183 if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->isZero())
184 VecStart = BasePtr;
185 else
186 VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep");
187
188 return VecStart;
189}
190
191namespace {
192struct ShapeInfo {
193 unsigned NumRows;
194 unsigned NumColumns;
195
196 bool IsColumnMajor;
197
198 ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0)
199 : NumRows(NumRows), NumColumns(NumColumns),
200 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
201
202 ShapeInfo(Value *NumRows, Value *NumColumns)
203 : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(),
204 cast<ConstantInt>(NumColumns)->getZExtValue()) {}
205
206 bool operator==(const ShapeInfo &other) {
207 return NumRows == other.NumRows && NumColumns == other.NumColumns;
208 }
209 bool operator!=(const ShapeInfo &other) { return !(*this == other); }
210
211 /// Returns true if shape-information is defined, meaning both dimensions
212 /// are != 0.
213 operator bool() const {
214 assert(NumRows == 0 || NumColumns != 0);
215 return NumRows != 0;
216 }
217
218 unsigned getStride() const {
219 if (IsColumnMajor)
220 return NumRows;
221 return NumColumns;
222 }
223
224 unsigned getNumVectors() const {
225 if (IsColumnMajor)
226 return NumColumns;
227 return NumRows;
228 }
229
230 /// Returns the transposed shape.
231 ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); }
232
233 friend raw_ostream &operator<<(raw_ostream &OS, ShapeInfo SI);
234
235 LLVM_DUMP_METHOD void dump() const { dbgs() << *this << '\n'; }
236};
237
238raw_ostream &operator<<(raw_ostream &OS, ShapeInfo SI) {
239 return OS << SI.NumRows << 'x' << SI.NumColumns;
240}
241
242} // namespace
243
244static bool isUniformShape(Value *V) {
246 if (!I)
247 return true;
248
249 if (I->isBinaryOp())
250 return true;
251
252 if (auto *Cast = dyn_cast<CastInst>(V)) {
253 switch (Cast->getOpcode()) {
254 case llvm::Instruction::Trunc:
255 case llvm::Instruction::ZExt:
256 case llvm::Instruction::SExt:
257 case llvm::Instruction::FPToUI:
258 case llvm::Instruction::FPToSI:
259 case llvm::Instruction::UIToFP:
260 case llvm::Instruction::SIToFP:
261 case llvm::Instruction::FPTrunc:
262 case llvm::Instruction::FPExt:
263 return true;
264 case llvm::Instruction::AddrSpaceCast:
265 case CastInst::PtrToAddr:
266 case CastInst::PtrToInt:
267 case CastInst::IntToPtr:
268 return false;
269 case CastInst::BitCast: {
270 if (auto *SrcVTy = dyn_cast<FixedVectorType>(Cast->getSrcTy()))
271 if (auto *DestVTy = dyn_cast<FixedVectorType>(Cast->getDestTy()))
272 return SrcVTy->getNumElements() == DestVTy->getNumElements();
273 return false;
274 }
275 case llvm::Instruction::CastOpsEnd:
276 llvm_unreachable("not an actual cast op");
277 }
278 llvm_unreachable("unhandled cast opcode");
279 }
280
281 if (auto *II = dyn_cast<IntrinsicInst>(V))
282 switch (II->getIntrinsicID()) {
283 case Intrinsic::abs:
284 case Intrinsic::fabs:
285 return true;
286 default:
287 return false;
288 }
289
290 switch (I->getOpcode()) {
291 case Instruction::PHI:
292 case Instruction::FNeg:
293 return true;
294 default:
295 return false;
296 }
297}
298
299/// Return the ShapeInfo for the result of \p I, it it can be determined.
300static std::optional<ShapeInfo>
301computeShapeInfoForInst(Instruction *I,
302 const DenseMap<Value *, ShapeInfo> &ShapeMap) {
303 Value *M;
304 Value *N;
305 Value *K;
307 m_Value(), m_Value(), m_Value(M), m_Value(N), m_Value(K))))
308 return ShapeInfo(M, K);
310 m_Value(N)))) {
311 // Flip dimensions.
312 return ShapeInfo(N, M);
313 }
315 m_Value(), m_Value(), m_Value(), m_Value(), m_Value(M),
316 m_Value(N))))
317 return ShapeInfo(N, M);
319 m_Value(), m_Value(), m_Value(), m_Value(M), m_Value(N))))
320 return ShapeInfo(M, N);
321 Value *MatrixA;
322 if (match(I, m_Store(m_Value(MatrixA), m_Value()))) {
323 auto OpShape = ShapeMap.find(MatrixA);
324 if (OpShape != ShapeMap.end())
325 return OpShape->second;
326 }
327
328 if (isUniformShape(I) || isa<SelectInst>(I)) {
329 auto Ops = I->operands();
330 auto ShapedOps = isa<SelectInst>(I) ? drop_begin(Ops) : Ops;
331 // Find the first operand that has a known shape and use that.
332 for (auto &Op : ShapedOps) {
333 auto OpShape = ShapeMap.find(Op.get());
334 if (OpShape != ShapeMap.end())
335 return OpShape->second;
336 }
337 }
338 return std::nullopt;
339}
340
341/// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
342///
343/// Currently, the lowering for each matrix intrinsic is done as follows:
344/// 1. Propagate the shape information from intrinsics to connected
345/// instructions.
346/// 2. Lower instructions with shape information (assuming column-major layout).
347/// The lowering works similarly using row-major layout.
348/// 2.1. Get column vectors for each argument. If we already lowered the
349/// definition of an argument, use the produced column vectors directly.
350/// If not, split the operand vector containing an embedded matrix into
351/// a set of column vectors,
352/// 2.2. Lower the instruction in terms of column major operations, which
353/// yields a set of column vectors containing result matrix. Note that we
354/// lower all instructions that have shape information. Besides the
355/// intrinsics, this includes stores for example.
356/// 2.3. Update uses of the lowered instruction. If we have shape information
357/// for a user, there is nothing to do, as we will look up the result
358/// column matrix when lowering the user. For other uses, we embed the
359/// result matrix in a flat vector and update the use.
360/// 2.4. Cache the result column matrix for the instruction we lowered
361/// 3. After we lowered all instructions in a function, remove the now
362/// obsolete instructions.
363///
364class LowerMatrixIntrinsics {
365 Function &Func;
366 const DataLayout &DL;
367 const TargetTransformInfo &TTI;
369 AliasAnalysis *AA = nullptr;
370 DominatorTree *DT = nullptr;
371 LoopInfo *LI = nullptr;
372 OptimizationRemarkEmitter *ORE = nullptr;
373
374 /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation.
375 struct OpInfoTy {
376 /// Number of stores emitted to generate this matrix.
377 unsigned NumStores = 0;
378 /// Number of loads emitted to generate this matrix.
379 unsigned NumLoads = 0;
380 /// Number of compute operations emitted to generate this matrix.
381 unsigned NumComputeOps = 0;
382 /// Most of the time transposes can be fused with matrix multiplies or can
383 /// be folded away via algebraic simplifications. This is the number of
384 /// transposes that we failed to make "free" via such optimizations.
385 unsigned NumExposedTransposes = 0;
386
387 OpInfoTy &operator+=(const OpInfoTy &RHS) {
388 NumStores += RHS.NumStores;
389 NumLoads += RHS.NumLoads;
390 NumComputeOps += RHS.NumComputeOps;
391 NumExposedTransposes += RHS.NumExposedTransposes;
392 return *this;
393 }
394 };
395
396 /// Wrapper class representing a matrix as a set of vectors, either in row or
397 /// column major layout. All vectors must have the same vector type.
398 class MatrixTy {
399 SmallVector<Value *, 16> Vectors;
400
401 OpInfoTy OpInfo;
402
403 bool IsColumnMajor = true;
404
405 public:
406 MatrixTy() : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
407 MatrixTy(ArrayRef<Value *> Vectors)
408 : Vectors(Vectors),
409 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
410 MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy)
411 : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {
412
413 unsigned D = isColumnMajor() ? NumColumns : NumRows;
414 for (unsigned J = 0; J < D; ++J)
416 EltTy, isColumnMajor() ? NumRows : NumColumns)));
417 }
418
419 Value *getVector(unsigned i) const { return Vectors[i]; }
420 Value *getColumn(unsigned i) const {
421 assert(isColumnMajor() && "only supported for column-major matrixes");
422 return Vectors[i];
423 }
424 Value *getRow(unsigned i) const {
425 assert(!isColumnMajor() && "only supported for row-major matrixes");
426 return Vectors[i];
427 }
428
429 void setVector(unsigned i, Value *V) { Vectors[i] = V; }
430
431 Type *getElementType() const { return getVectorTy()->getElementType(); }
432
433 unsigned getNumVectors() const {
434 if (isColumnMajor())
435 return getNumColumns();
436 return getNumRows();
437 }
438
439 unsigned getNumColumns() const {
440 if (isColumnMajor())
441 return Vectors.size();
442 else {
443 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
444 return getVectorTy()->getNumElements();
445 }
446 }
447 unsigned getNumRows() const {
448 if (isColumnMajor()) {
449 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
450 return getVectorTy()->getNumElements();
451 } else
452 return Vectors.size();
453 }
454
455 void addVector(Value *V) { Vectors.push_back(V); }
456 FixedVectorType *getColumnTy() {
457 assert(isColumnMajor() && "only supported for column-major matrixes");
458 return getVectorTy();
459 }
460
461 FixedVectorType *getVectorTy() const {
462 return cast<FixedVectorType>(Vectors[0]->getType());
463 }
464
465 iterator_range<SmallVector<Value *, 8>::iterator> columns() {
466 assert(isColumnMajor() &&
467 "columns() only supported for column-major matrixes");
468 return make_range(Vectors.begin(), Vectors.end());
469 }
470
471 iterator_range<SmallVector<Value *, 8>::iterator> vectors() {
472 return make_range(Vectors.begin(), Vectors.end());
473 }
474
475 /// Embed the vectors of the matrix into a flat vector by concatenating
476 /// them.
477 Value *embedInVector(IRBuilder<> &Builder) const {
478 return Vectors.size() == 1 ? Vectors[0]
479 : concatenateVectors(Builder, Vectors);
480 }
481
482 MatrixTy &addNumLoads(unsigned N) {
483 OpInfo.NumLoads += N;
484 return *this;
485 }
486
487 void setNumLoads(unsigned N) { OpInfo.NumLoads = N; }
488
489 MatrixTy &addNumStores(unsigned N) {
490 OpInfo.NumStores += N;
491 return *this;
492 }
493
494 MatrixTy &addNumExposedTransposes(unsigned N) {
495 OpInfo.NumExposedTransposes += N;
496 return *this;
497 }
498
499 MatrixTy &addNumComputeOps(unsigned N) {
500 OpInfo.NumComputeOps += N;
501 return *this;
502 }
503
504 unsigned getNumStores() const { return OpInfo.NumStores; }
505 unsigned getNumLoads() const { return OpInfo.NumLoads; }
506 unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; }
507
508 const OpInfoTy &getOpInfo() const { return OpInfo; }
509
510 bool isColumnMajor() const { return IsColumnMajor; }
511
512 unsigned getStride() const {
513 if (isColumnMajor())
514 return getNumRows();
515 return getNumColumns();
516 }
517
518 ShapeInfo shape() const { return {getNumRows(), getNumColumns()}; }
519
520 /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the
521 /// matrix is column-major, the result vector is extracted from a column
522 /// vector, otherwise from a row vector.
523 Value *extractVector(unsigned I, unsigned J, unsigned NumElts,
524 IRBuilder<> &Builder) const {
525 Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I);
526 assert(cast<FixedVectorType>(Vec->getType())->getNumElements() >=
527 NumElts &&
528 "Extracted vector will contain poison values");
529 return Builder.CreateShuffleVector(
530 Vec, createSequentialMask(isColumnMajor() ? I : J, NumElts, 0),
531 "block");
532 }
533 };
534
535 /// Maps instructions to their shape information. The shape information
536 /// describes the shape to be used while lowering. This matches the shape of
537 /// the result value of the instruction, with the only exceptions being store
538 /// instructions and the matrix_column_major_store intrinsics. For those, the
539 /// shape information indicates that those instructions should be lowered
540 /// using shape information as well. Note that extra care is needed when
541 /// erasing or RAUW'ing a value that is present in ShapeMap. If the
542 /// replacement is also a matrix operation, use
543 /// updateShapeAndReplaceAllUsesWith to make sure the replacement is added to
544 /// ShapeMap. We don't use ValueMap, as there are also cases where we do not
545 /// want to add shape information for a replacement instruction. When directly
546 /// erasing a value with an entry in ShapeMap, use
547 /// eraseFromParentAndRemoveFromShapeMap to make sure ShapeMap is also updated
548 /// accordingly.
549 DenseMap<Value *, ShapeInfo> ShapeMap;
550
551 /// List of instructions to remove. While lowering, we are not replacing all
552 /// users of a lowered instruction, if shape information is available and
553 /// those need to be removed after we finished lowering.
554 SmallVector<Instruction *, 16> ToRemove;
555
556 /// Map from instructions to their produced column matrix.
557 MapVector<Value *, MatrixTy> Inst2ColumnMatrix;
558
559private:
560 static FastMathFlags getFastMathFlags(Instruction *Inst) {
561 FastMathFlags FMF;
562
563 if (isa<FPMathOperator>(*Inst))
564 FMF = Inst->getFastMathFlags();
565
567
568 return FMF;
569 }
570
571public:
572 LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI,
574 : Func(F), DL(F.getDataLayout()), TTI(TTI), AM(AM) {}
575
576 unsigned getNumOps(Type *VT) {
577 assert(isa<FixedVectorType>(VT) && "Expected vector type");
578 return getNumOps(VT->getScalarType(),
579 cast<FixedVectorType>(VT)->getNumElements());
580 }
581
582 /// Is this the minimal version executed in the backend pipelines.
583 bool isMinimal() const {
584 return !DT;
585 }
586
587 /// Return the estimated number of vector ops required for an operation on
588 /// \p VT * N.
589 unsigned getNumOps(Type *ST, unsigned N) {
590 return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedValue() /
591 double(TTI.getRegisterBitWidth(
593 .getFixedValue()));
594 }
595
596 /// Return the set of vectors that a matrix value is lowered to.
597 ///
598 /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise
599 /// split the flat vector \p MatrixVal containing a matrix with shape \p SI
600 /// into vectors.
601 MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI,
602 IRBuilder<> &Builder) {
603 FixedVectorType *VType = cast<FixedVectorType>(MatrixVal->getType());
604 assert(VType->getNumElements() == SI.NumRows * SI.NumColumns &&
605 "The vector size must match the number of matrix elements");
606
607 // Check if we lowered MatrixVal using shape information. In that case,
608 // return the existing matrix, if it matches the requested shape
609 // information. If there is a mis-match, embed the result in a flat
610 // vector and split it later.
611 auto Found = Inst2ColumnMatrix.find(MatrixVal);
612 if (Found != Inst2ColumnMatrix.end()) {
613 MatrixTy &M = Found->second;
614 // Return the found matrix, if its shape matches the requested shape
615 // information
616 if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns())
617 return M;
618
619 MatrixVal = M.embedInVector(Builder);
620 }
621
622 // Otherwise split MatrixVal.
623 SmallVector<Value *, 16> SplitVecs;
624 for (unsigned MaskStart = 0; MaskStart < VType->getNumElements();
625 MaskStart += SI.getStride()) {
626 Value *V = Builder.CreateShuffleVector(
627 MatrixVal, createSequentialMask(MaskStart, SI.getStride(), 0),
628 "split");
629 SplitVecs.push_back(V);
630 }
631
632 if (Instruction *Inst = dyn_cast<Instruction>(MatrixVal)) {
633 if (Found != Inst2ColumnMatrix.end()) {
634 // FIXME: re: "at least": SplitVecs.size() doesn't count the shuffles
635 // that embedInVector created.
636 LLVM_DEBUG(dbgs() << "matrix reshape from " << Found->second.shape()
637 << " to " << SI << " using at least "
638 << SplitVecs.size() << " shuffles on behalf of:\n"
639 << *Inst << '\n');
640 ReshapedMatrices++;
641 } else if (!ShapeMap.contains(MatrixVal)) {
643 dbgs()
644 << "splitting a " << SI << " matrix with " << SplitVecs.size()
645 << " shuffles beacuse we do not have a shape-aware lowering for "
646 "its def:\n"
647 << *Inst << '\n');
648 (void)Inst;
649 SplitMatrices++;
650 } else {
651 // The ShapeMap has it, so it's a case where we're being lowered
652 // before the def, and we expect that InstCombine will clean things up
653 // afterward.
654 }
655 }
656
657 return {SplitVecs};
658 }
659
660 /// If \p V already has a known shape return false. Otherwise set the shape
661 /// for instructions that support it.
662 bool setShapeInfo(Value *V, ShapeInfo Shape) {
663 assert(Shape && "Shape not set");
664 if (isa<UndefValue>(V) || !supportsShapeInfo(V))
665 return false;
666
667 auto SIter = ShapeMap.find(V);
668 if (SIter != ShapeMap.end()) {
669 if (VerifyShapeInfo && (SIter->second.NumRows != Shape.NumRows ||
670 SIter->second.NumColumns != Shape.NumColumns)) {
671 errs() << "Conflicting shapes (" << SIter->second.NumRows << "x"
672 << SIter->second.NumColumns << " vs " << Shape.NumRows << "x"
673 << Shape.NumColumns << ") for " << *V << "\n";
675 "Matrix shape verification failed, compilation aborted!");
676 }
677
678 LLVM_DEBUG(dbgs() << " not overriding existing shape: "
679 << SIter->second.NumRows << " "
680 << SIter->second.NumColumns << " for " << *V << "\n");
681 return false;
682 }
683
684 ShapeMap.insert({V, Shape});
685 LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns
686 << " for " << *V << "\n");
687 return true;
688 }
689
690 /// Returns true if shape information can be used for \p V. The supported
691 /// instructions must match the instructions that can be lowered by this pass.
692 bool supportsShapeInfo(Value *V) {
694 if (!Inst)
695 return false;
696
697 IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst);
698 if (II)
699 switch (II->getIntrinsicID()) {
700 case Intrinsic::matrix_multiply:
701 case Intrinsic::matrix_transpose:
702 case Intrinsic::matrix_column_major_load:
703 case Intrinsic::matrix_column_major_store:
704 return true;
705 default:
706 return isUniformShape(II);
707 }
708 return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V) ||
710 }
711
712 /// Propagate the shape information of instructions to their users.
713 /// The work list contains instructions for which we can compute the shape,
714 /// either based on the information provided by matrix intrinsics or known
715 /// shapes of operands.
717 propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) {
719 // Pop an element for which we guaranteed to have at least one of the
720 // operand shapes. Add the shape for this and then add users to the work
721 // list.
722 LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n");
723 while (!WorkList.empty()) {
724 Instruction *Inst = WorkList.pop_back_val();
725
726 // New entry, set the value and insert operands
727 bool Propagate = false;
728 if (auto SI = computeShapeInfoForInst(Inst, ShapeMap))
729 Propagate = setShapeInfo(Inst, *SI);
730
731 if (Propagate) {
732 NewWorkList.push_back(Inst);
733 for (auto *User : Inst->users())
734 if (ShapeMap.count(User) == 0)
735 WorkList.push_back(cast<Instruction>(User));
736 }
737 }
738
739 return NewWorkList;
740 }
741
742 /// Propagate the shape to operands of instructions with shape information.
743 /// \p Worklist contains the instruction for which we already know the shape.
745 propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) {
747
748 auto pushInstruction = [](Value *V,
749 SmallVectorImpl<Instruction *> &WorkList) {
751 if (I)
752 WorkList.push_back(I);
753 };
754 // Pop an element with known shape. Traverse the operands, if their shape
755 // derives from the result shape and is unknown, add it and add them to the
756 // worklist.
757 LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n");
758 while (!WorkList.empty()) {
759 Value *V = WorkList.pop_back_val();
760
761 size_t BeforeProcessingV = WorkList.size();
762 if (!isa<Instruction>(V))
763 continue;
764
765 Value *MatrixA;
766 Value *MatrixB;
767 Value *M;
768 Value *N;
769 Value *K;
771 m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
772 m_Value(N), m_Value(K)))) {
773 if (setShapeInfo(MatrixA, {M, N}))
774 pushInstruction(MatrixA, WorkList);
775
776 if (setShapeInfo(MatrixB, {N, K}))
777 pushInstruction(MatrixB, WorkList);
778
780 m_Value(MatrixA), m_Value(M), m_Value(N)))) {
781 // Flip dimensions.
782 if (setShapeInfo(MatrixA, {M, N}))
783 pushInstruction(MatrixA, WorkList);
785 m_Value(MatrixA), m_Value(), m_Value(), m_Value(),
786 m_Value(M), m_Value(N)))) {
787 if (setShapeInfo(MatrixA, {M, N})) {
788 pushInstruction(MatrixA, WorkList);
789 }
790 } else if (isa<LoadInst>(V) ||
792 // Nothing to do, no matrix input.
793 } else if (isa<StoreInst>(V)) {
794 // Nothing to do. We forward-propagated to this so we would just
795 // backward propagate to an instruction with an already known shape.
796 } else if (isUniformShape(V) || isa<SelectInst>(V)) {
797 auto Ops = cast<Instruction>(V)->operands();
798 auto ShapedOps = isa<SelectInst>(V) ? drop_begin(Ops) : Ops;
799 // Propagate to all operands.
800 ShapeInfo Shape = ShapeMap[V];
801 for (Use &U : ShapedOps) {
802 if (setShapeInfo(U.get(), Shape))
803 pushInstruction(U.get(), WorkList);
804 }
805 }
806 // After we discovered new shape info for new instructions in the
807 // worklist, we use their users as seeds for the next round of forward
808 // propagation.
809 for (size_t I = BeforeProcessingV; I != WorkList.size(); I++)
810 for (User *U : WorkList[I]->users())
811 if (isa<Instruction>(U) && V != U)
812 NewWorkList.push_back(cast<Instruction>(U));
813 }
814 return NewWorkList;
815 }
816
817 /// (Op0 op Op1)^T -> Op0^T op Op1^T
818 /// Transpose \p Op0 and \p Op1 of shape \p Shape0 and \p Shape1, then use
819 /// them on both sides of \p Operation.
820 Instruction *distributeTransposes(
821 Value *Op0, ShapeInfo Shape0, Value *Op1, ShapeInfo Shape1,
822 MatrixBuilder &Builder,
823 function_ref<Instruction *(Value *, ShapeInfo, Value *, ShapeInfo)>
824 Operation) {
825 Value *T0 = Builder.CreateMatrixTranspose(
826 Op0, Shape0.NumRows, Shape0.NumColumns, Op0->getName() + "_t");
827 // We are being run after shape prop, add shape for newly created
828 // instructions so that we lower them later.
829 setShapeInfo(T0, Shape0.t());
830 Value *T1 = Builder.CreateMatrixTranspose(
831 Op1, Shape1.NumRows, Shape1.NumColumns, Op1->getName() + "_t");
832 setShapeInfo(T1, Shape1.t());
833 return Operation(T0, Shape0.t(), T1, Shape1.t());
834 }
835
836 /// Erase \p Inst from both ShapeMap (if an entry exists) and erase \p Inst
837 /// itself.
838 void eraseFromParentAndRemoveFromShapeMap(Instruction *Inst) {
839 ShapeMap.erase(Inst);
840 Inst->eraseFromParent();
841 }
842
843 /// Erase \p V from \p BB and move \II forward to avoid invalidating
844 /// iterators.
845 void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II,
846 BasicBlock &BB) {
847 auto *Inst = cast<Instruction>(V);
848 // Still used, don't erase.
849 if (!Inst->use_empty())
850 return;
851 if (II != BB.rend() && Inst == &*II)
852 ++II;
853 eraseFromParentAndRemoveFromShapeMap(Inst);
854 }
855
856 /// Add a new entry to ShapeMap for \p New with \p Old's shape info, erase the
857 /// entry for \p Old and replace all uses of \p Old with \p New.
858 void updateShapeAndReplaceAllUsesWith(Instruction &Old, Value *New) {
859 // We need to remove Old from the ShapeMap otherwise RAUW will replace it
860 // with New. We should only add New it it supportsShapeInfo so we insert
861 // it conditionally instead.
862 auto S = ShapeMap.find(&Old);
863 if (S != ShapeMap.end()) {
864 ShapeMap.erase(S);
865 if (supportsShapeInfo(New))
866 ShapeMap.insert({New, S->second});
867 }
868 Old.replaceAllUsesWith(New);
869 }
870
871 /// Sink a top-level transpose inside matmuls and adds.
872 /// This creates and erases instructions as needed, and returns the newly
873 /// created instruction while updating the iterator to avoid invalidation. If
874 /// this returns nullptr, no new instruction was created.
875 Instruction *sinkTranspose(Instruction &I, BasicBlock::reverse_iterator &II,
876 bool &Changed) {
877 BasicBlock &BB = *I.getParent();
878 IRBuilder<> IB(&I);
879 MatrixBuilder Builder(IB);
880
881 Value *TA, *TAMA, *TAMB;
882 ConstantInt *R, *K, *C;
885 return nullptr;
886
887 // Transpose of a transpose is a nop when the shapes match.
888 Value *TATA;
890 m_Value(TATA), m_Specific(C), m_Specific(R)))) {
891 updateShapeAndReplaceAllUsesWith(I, TATA);
892 eraseFromParentAndMove(&I, II, BB);
893 eraseFromParentAndMove(TA, II, BB);
894 Changed = true;
895 return nullptr;
896 }
897
898 // k^T -> k
899 if (isSplat(TA)) {
900 updateShapeAndReplaceAllUsesWith(I, TA);
901 eraseFromParentAndMove(&I, II, BB);
902 Changed = true;
903 return nullptr;
904 }
905
906 // (A * B)^t -> B^t * A^t
907 // RxK KxC CxK KxR
909 m_Value(TAMA), m_Value(TAMB), m_ConstantInt(R),
911 auto NewInst = distributeTransposes(
912 TAMB, {K, C}, TAMA, {R, K}, Builder,
913 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
914 return Builder.CreateMatrixMultiply(T0, T1, Shape0.NumRows,
915 Shape0.NumColumns,
916 Shape1.NumColumns, "mmul");
917 });
918 updateShapeAndReplaceAllUsesWith(I, NewInst);
919 eraseFromParentAndMove(&I, II, BB);
920 eraseFromParentAndMove(TA, II, BB);
921 Changed = true;
922 return NewInst;
923 }
924
925 // Same as above, but with a mul, which occurs when multiplied
926 // with a scalar.
927 // (A * k)^t -> A^t * k
928 // R x C RxC
929 if (match(TA, m_AnyMul(m_Value(TAMA), m_Value(TAMB))) &&
930 (isSplat(TAMA) || isSplat(TAMB))) {
931 IRBuilder<> LocalBuilder(&I);
932 // We know that the transposed operand is of shape RxC.
933 // An when multiplied with a scalar, the shape is preserved.
934 auto NewInst = distributeTransposes(
935 TAMA, {R, C}, TAMB, {R, C}, Builder,
936 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
937 bool IsFP = I.getType()->isFPOrFPVectorTy();
938 auto *Mul = IsFP ? LocalBuilder.CreateFMul(T0, T1, "mmul")
939 : LocalBuilder.CreateMul(T0, T1, "mmul");
941 setShapeInfo(Result, Shape0);
942 return Result;
943 });
944 updateShapeAndReplaceAllUsesWith(I, NewInst);
945 eraseFromParentAndMove(&I, II, BB);
946 eraseFromParentAndMove(TA, II, BB);
947 Changed = true;
948 return NewInst;
949 }
950
951 // (A + B)^t -> A^t + B^t
952 // RxC RxC CxR CxR
953 if (match(TA, m_AnyAdd(m_Value(TAMA), m_Value(TAMB)))) {
954 IRBuilder<> LocalBuilder(&I);
955 auto NewInst = distributeTransposes(
956 TAMA, {R, C}, TAMB, {R, C}, Builder,
957 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
958 bool IsFP = I.getType()->isFPOrFPVectorTy();
959 auto *Add = IsFP ? LocalBuilder.CreateFAdd(T0, T1, "madd")
960 : LocalBuilder.CreateAdd(T0, T1, "madd");
961
963 setShapeInfo(Result, Shape0);
964 return Result;
965 });
966 updateShapeAndReplaceAllUsesWith(I, NewInst);
967 eraseFromParentAndMove(&I, II, BB);
968 eraseFromParentAndMove(TA, II, BB);
969 Changed = true;
970 return NewInst;
971 }
972
973 return nullptr;
974 }
975
976 bool liftTranspose(Instruction &I) {
977 // Erase dead Instructions after lifting transposes from binops.
978 auto CleanupBinOp = [this](Instruction &T, Value *A, Value *B) {
979 if (T.use_empty())
980 eraseFromParentAndRemoveFromShapeMap(&T);
981 if (A->use_empty())
982 eraseFromParentAndRemoveFromShapeMap(cast<Instruction>(A));
983 if (A != B && B->use_empty())
984 eraseFromParentAndRemoveFromShapeMap(cast<Instruction>(B));
985 };
986
987 Value *A, *B, *AT, *BT;
988 ConstantInt *R, *K, *C;
989 // A^t * B ^t -> (B * A)^t
995 IRBuilder<> IB(&I);
996 MatrixBuilder Builder(IB);
997 Value *M = Builder.CreateMatrixMultiply(
998 BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue());
999 setShapeInfo(M, {C, R});
1000 Instruction *NewInst = Builder.CreateMatrixTranspose(M, C->getZExtValue(),
1001 R->getZExtValue());
1002 updateShapeAndReplaceAllUsesWith(I, NewInst);
1003 CleanupBinOp(I, A, B);
1004 return true;
1005 }
1006 // A^t + B ^t -> (A + B)^t. Pick rows and columns from first transpose. If
1007 // the shape of the second transpose is different, there's a shape conflict
1008 // which gets resolved by picking the shape of the first operand.
1009 else if (match(&I, m_FAdd(m_Value(A), m_Value(B))) &&
1011 m_Value(AT), m_ConstantInt(R), m_ConstantInt(C))) &&
1014 IRBuilder<> Builder(&I);
1015 auto *Add = Builder.CreateFAdd(AT, BT, "mfadd");
1016 MatrixBuilder MBuilder(Builder);
1017 Instruction *NewInst = MBuilder.CreateMatrixTranspose(
1018 Add, R->getZExtValue(), C->getZExtValue(), "mfadd_t");
1019 updateShapeAndReplaceAllUsesWith(I, NewInst);
1020 assert(computeShapeInfoForInst(NewInst, ShapeMap) ==
1021 computeShapeInfoForInst(&I, ShapeMap) &&
1022 "Shape of new instruction doesn't match original shape.");
1023 CleanupBinOp(I, A, B);
1024 if (auto *AddI = dyn_cast<Instruction>(Add)) {
1025 setShapeInfo(AddI, {R, C});
1026 assert(
1027 computeShapeInfoForInst(AddI, ShapeMap).value_or(ShapeMap[AddI]) ==
1028 ShapeMap[AddI] &&
1029 "Shape of updated addition doesn't match cached shape.");
1030 }
1031 return true;
1032 }
1033 return false;
1034 }
1035
1036 /// Try moving transposes in order to fold them away or into multiplies.
1037 bool optimizeTransposes() {
1038 bool Changed = false;
1039 // First sink all transposes inside matmuls and adds, hoping that we end up
1040 // with NN, NT or TN variants.
1041 for (BasicBlock &BB : reverse(Func)) {
1042 for (auto II = BB.rbegin(); II != BB.rend();) {
1043 Instruction &I = *II;
1044 // We may remove II. By default continue on the next/prev instruction.
1045 ++II;
1046 if (Instruction *NewInst = sinkTranspose(I, II, Changed))
1047 II = std::next(BasicBlock::reverse_iterator(NewInst));
1048 }
1049 }
1050
1051 // If we have a TT matmul or a TT add, lift the transpose. We may be able
1052 // to fold into consuming multiply or add.
1053 for (BasicBlock &BB : Func) {
1054 for (Instruction &I : llvm::make_early_inc_range(BB)) {
1055 Changed |= liftTranspose(I);
1056 }
1057 }
1058 return Changed;
1059 }
1060
1061 bool Visit() {
1063
1064 // Initially only the shape of matrix intrinsics is known.
1065 // Initialize the work list with ops carrying shape information.
1066 for (BasicBlock &BB : Func)
1067 for (Instruction &Inst : BB) {
1068 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst);
1069 if (!II)
1070 continue;
1071
1072 switch (II->getIntrinsicID()) {
1073 case Intrinsic::matrix_multiply:
1074 case Intrinsic::matrix_transpose:
1075 case Intrinsic::matrix_column_major_load:
1076 case Intrinsic::matrix_column_major_store:
1077 WorkList.push_back(&Inst);
1078 break;
1079 default:
1080 break;
1081 }
1082 }
1083
1084 // Avoid unnecessary work if there are no matrix intrinsics in the function.
1085 if (WorkList.empty())
1086 return false;
1087
1088 if (AM) {
1089 ORE = &AM->getResult<OptimizationRemarkEmitterAnalysis>(Func);
1090 AA = &AM->getResult<AAManager>(Func);
1091 DT = &AM->getResult<DominatorTreeAnalysis>(Func);
1092 LI = &AM->getResult<LoopAnalysis>(Func);
1093 }
1094
1095 // Propagate shapes until nothing changes any longer.
1096 while (!WorkList.empty()) {
1097 WorkList = propagateShapeForward(WorkList);
1098 WorkList = propagateShapeBackward(WorkList);
1099 }
1100
1101 bool Changed = false;
1102 if (!isMinimal()) {
1103 Changed |= optimizeTransposes();
1105 dbgs() << "Dump after matrix transpose optimization:\n";
1106 Func.print(dbgs());
1107 }
1108 }
1109
1110 SmallVector<CallInst *, 16> MaybeFusableInsts;
1111 SmallVector<Instruction *, 16> MatrixInsts;
1113
1114 // First, collect all instructions with shape information and candidates for
1115 // fusion (currently only matrix multiplies).
1116 ReversePostOrderTraversal<Function *> RPOT(&Func);
1117 for (auto *BB : RPOT)
1118 for (Instruction &I : *BB) {
1120 LifetimeEnds.push_back(cast<IntrinsicInst>(&I));
1121 if (!ShapeMap.contains(&I))
1122 continue;
1124 MaybeFusableInsts.push_back(cast<CallInst>(&I));
1125 MatrixInsts.push_back(&I);
1126 }
1127
1128 // Second, try to lower any dot products
1129 SmallPtrSet<Instruction *, 16> FusedInsts;
1130 for (CallInst *CI : MaybeFusableInsts)
1131 lowerDotProduct(CI, FusedInsts, getFastMathFlags(CI));
1132
1133 // Third, try to fuse candidates.
1134 for (CallInst *CI : MaybeFusableInsts)
1135 if (!FusedInsts.contains(CI))
1136 LowerMatrixMultiplyFused(CI, FusedInsts, LifetimeEnds);
1137
1138 Changed |= !FusedInsts.empty();
1139
1140 // Fourth, pre-process all the PHINode's. The incoming values will be
1141 // assigned later in VisitPHI.
1142 for (Instruction *Inst : MatrixInsts) {
1143 if (FusedInsts.count(Inst))
1144 continue;
1145
1146 auto *PHI = dyn_cast<PHINode>(Inst);
1147 if (!PHI)
1148 continue;
1149
1150 const ShapeInfo &SI = ShapeMap.at(Inst);
1151 auto *EltTy = cast<FixedVectorType>(PHI->getType())->getElementType();
1152 MatrixTy PhiM(SI.NumRows, SI.NumColumns, EltTy);
1153
1154 IRBuilder<> Builder(Inst);
1155 for (unsigned VI = 0, VE = PhiM.getNumVectors(); VI != VE; ++VI)
1156 PhiM.setVector(VI, Builder.CreatePHI(PhiM.getVectorTy(),
1157 PHI->getNumIncomingValues(),
1158 PHI->getName()));
1159 assert(!Inst2ColumnMatrix.contains(PHI) && "map already contains phi?");
1160 Inst2ColumnMatrix[PHI] = PhiM;
1161 }
1162
1163 // Fifth, lower remaining instructions with shape information.
1164 for (Instruction *Inst : MatrixInsts) {
1165 if (FusedInsts.count(Inst))
1166 continue;
1167
1168 const ShapeInfo &SI = ShapeMap.at(Inst);
1169
1170 Value *Op1;
1171 Value *Op2;
1172 MatrixTy Result;
1173 IRBuilder<> Builder(Inst);
1174 if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
1175 Result = VisitBinaryOperator(BinOp, SI, Builder);
1176 else if (auto *Cast = dyn_cast<CastInst>(Inst))
1177 Result = VisitCastInstruction(Cast, SI, Builder);
1178 else if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
1179 Result = VisitUnaryOperator(UnOp, SI, Builder);
1180 else if (auto *Intr = dyn_cast<IntrinsicInst>(Inst))
1181 Result = VisitIntrinsicInst(Intr, SI, Builder);
1182 else if (auto *Select = dyn_cast<SelectInst>(Inst))
1183 Result = VisitSelectInst(Select, SI, Builder);
1184 else if (match(Inst, m_Load(m_Value(Op1))))
1185 Result = VisitLoad(cast<LoadInst>(Inst), SI, Op1, Builder);
1186 else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
1187 Result = VisitStore(cast<StoreInst>(Inst), SI, Op1, Op2, Builder);
1188 else if (auto *PHI = dyn_cast<PHINode>(Inst))
1189 Result = VisitPHI(PHI, SI, Builder);
1190 else
1191 continue;
1192
1193 finalizeLowering(Inst, Result, Builder);
1194 Changed = true;
1195 }
1196
1197 if (ORE) {
1198 RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func);
1199 RemarkGen.emitRemarks();
1200 }
1201
1202 // Delete the instructions backwards, as it has a reduced likelihood of
1203 // having to update as many def-use and use-def chains.
1204 //
1205 // Because we add to ToRemove during fusion we can't guarantee that defs
1206 // are before uses. Change uses to poison temporarily as these should get
1207 // removed as well.
1208 //
1209 // For verification, we keep track of where we changed uses to poison in
1210 // PoisonedInsts and then check that we in fact remove them.
1211 SmallPtrSet<Instruction *, 16> PoisonedInsts;
1212 for (auto *Inst : reverse(ToRemove)) {
1213 for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
1214 if (auto *Poisoned = dyn_cast<Instruction>(U.getUser()))
1215 PoisonedInsts.insert(Poisoned);
1216 U.set(PoisonValue::get(Inst->getType()));
1217 }
1218 Inst->eraseFromParent();
1219 PoisonedInsts.erase(Inst);
1220 }
1221 if (!PoisonedInsts.empty()) {
1222 // If we didn't remove all poisoned instructions, it's a hard error.
1223 dbgs() << "Poisoned but present instructions:\n";
1224 for (auto *I : PoisonedInsts)
1225 dbgs() << *I << "\n";
1226 llvm_unreachable("Poisoned but instruction not removed");
1227 }
1228
1229 return Changed;
1230 }
1231
1232 /// Replace intrinsic calls.
1233 MatrixTy VisitIntrinsicInst(IntrinsicInst *Inst, const ShapeInfo &SI,
1234 IRBuilder<> &Builder) {
1235 assert(Inst->getCalledFunction() &&
1236 Inst->getCalledFunction()->isIntrinsic());
1237
1238 switch (Inst->getCalledFunction()->getIntrinsicID()) {
1239 case Intrinsic::matrix_multiply:
1240 return LowerMultiply(Inst, Builder);
1241 case Intrinsic::matrix_transpose:
1242 return LowerTranspose(Inst, Builder);
1243 case Intrinsic::matrix_column_major_load:
1244 return LowerColumnMajorLoad(Inst, Builder);
1245 case Intrinsic::matrix_column_major_store:
1246 return LowerColumnMajorStore(Inst, Builder);
1247 case Intrinsic::abs:
1248 case Intrinsic::fabs: {
1249 MatrixTy Result;
1250 MatrixTy M = getMatrix(Inst->getOperand(0), SI, Builder);
1251 Builder.setFastMathFlags(getFastMathFlags(Inst));
1252
1253 for (auto *Vector : M.vectors()) {
1254 switch (Inst->getIntrinsicID()) {
1255 case Intrinsic::abs:
1256 Result.addVector(Builder.CreateBinaryIntrinsic(Intrinsic::abs, Vector,
1257 Inst->getOperand(1)));
1258 continue;
1259 case Intrinsic::fabs:
1260 Result.addVector(
1261 Builder.CreateUnaryIntrinsic(Inst->getIntrinsicID(), Vector));
1262 continue;
1263 default:
1264 llvm_unreachable("unexpected intrinsic");
1265 }
1266 }
1267
1268 return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
1269 Result.getNumVectors());
1270 }
1271 default:
1272 break;
1273 }
1275 "only intrinsics supporting shape info should be seen here");
1276 }
1277
1278 /// Compute the alignment for a column/row \p Idx with \p Stride between them.
1279 /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a
1280 /// ConstantInt, reduce the initial alignment based on the byte offset. For
1281 /// non-ConstantInt strides, return the common alignment of the initial
1282 /// alignment and the element size in bytes.
1283 Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy,
1284 MaybeAlign A) const {
1285 Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy);
1286 if (Idx == 0)
1287 return InitialAlign;
1288
1289 TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy);
1290 if (auto *ConstStride = dyn_cast<ConstantInt>(Stride)) {
1291 uint64_t StrideInBytes =
1292 ConstStride->getZExtValue() * ElementSizeInBits / 8;
1293 return commonAlignment(InitialAlign, Idx * StrideInBytes);
1294 }
1295 return commonAlignment(InitialAlign, ElementSizeInBits / 8);
1296 }
1297
1298 /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between
1299 /// vectors.
1300 MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride,
1301 bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) {
1302 auto *VType = cast<FixedVectorType>(Ty);
1303 Type *EltTy = VType->getElementType();
1304 Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride());
1305 Value *EltPtr = Ptr;
1306 MatrixTy Result;
1307 for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) {
1308 Value *GEP = computeVectorAddr(
1309 EltPtr, Builder.getIntN(Stride->getType()->getScalarSizeInBits(), I),
1310 Stride, Shape.getStride(), EltTy, Builder);
1311 Value *Vector = Builder.CreateAlignedLoad(
1312 VecTy, GEP, getAlignForIndex(I, Stride, EltTy, MAlign),
1313 IsVolatile, "col.load");
1314
1315 Result.addVector(Vector);
1316 }
1317 return Result.addNumLoads(getNumOps(Result.getVectorTy()) *
1318 Result.getNumVectors());
1319 }
1320
1321 /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix,
1322 /// starting at \p MatrixPtr[I][J].
1323 MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile,
1324 ShapeInfo MatrixShape, Value *I, Value *J,
1325 ShapeInfo ResultShape, Type *EltTy,
1326 IRBuilder<> &Builder) {
1327 Value *Offset = Builder.CreateAdd(
1328 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
1329
1330 Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset);
1331 auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows *
1332 ResultShape.NumColumns);
1333
1334 return loadMatrix(TileTy, TileStart, Align,
1335 Builder.getInt64(MatrixShape.getStride()), IsVolatile,
1336 ResultShape, Builder);
1337 }
1338
1339 /// Lower a load instruction with shape information.
1340 MatrixTy LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align,
1341 Value *Stride, bool IsVolatile, ShapeInfo Shape,
1342 IRBuilder<> &Builder) {
1343 return loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile, Shape,
1344 Builder);
1345 }
1346
1347 /// Lowers llvm.matrix.column.major.load.
1348 ///
1349 /// The intrinsic loads a matrix from memory using a stride between columns.
1350 MatrixTy LowerColumnMajorLoad(CallInst *Inst, IRBuilder<> &Builder) {
1352 "Intrinsic only supports column-major layout!");
1353 Value *Ptr = Inst->getArgOperand(0);
1354 Value *Stride = Inst->getArgOperand(1);
1355 return LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride,
1356 cast<ConstantInt>(Inst->getArgOperand(2))->isOne(),
1357 {Inst->getArgOperand(3), Inst->getArgOperand(4)}, Builder);
1358 }
1359
1360 /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p
1361 /// MatrixPtr[I][J].
1362 void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr,
1363 MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape,
1364 Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) {
1365 Value *Offset = Builder.CreateAdd(
1366 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
1367
1368 Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset);
1369 auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() *
1370 StoreVal.getNumColumns());
1371
1372 storeMatrix(TileTy, StoreVal, TileStart, MAlign,
1373 Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder);
1374 }
1375
1376 /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between
1377 /// vectors.
1378 MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr,
1379 MaybeAlign MAlign, Value *Stride, bool IsVolatile,
1380 IRBuilder<> &Builder) {
1381 auto *VType = cast<FixedVectorType>(Ty);
1382 Value *EltPtr = Ptr;
1383 for (auto Vec : enumerate(StoreVal.vectors())) {
1384 Value *GEP = computeVectorAddr(
1385 EltPtr,
1386 Builder.getIntN(Stride->getType()->getScalarSizeInBits(),
1387 Vec.index()),
1388 Stride, StoreVal.getStride(), VType->getElementType(), Builder);
1389 Builder.CreateAlignedStore(Vec.value(), GEP,
1390 getAlignForIndex(Vec.index(), Stride,
1391 VType->getElementType(),
1392 MAlign),
1393 IsVolatile);
1394 }
1395 return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) *
1396 StoreVal.getNumVectors());
1397 }
1398
1399 /// Lower a store instruction with shape information.
1400 MatrixTy LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr,
1401 MaybeAlign A, Value *Stride, bool IsVolatile,
1402 ShapeInfo Shape, IRBuilder<> &Builder) {
1403 auto StoreVal = getMatrix(Matrix, Shape, Builder);
1404 return storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride, IsVolatile,
1405 Builder);
1406 }
1407
1408 /// Lowers llvm.matrix.column.major.store.
1409 ///
1410 /// The intrinsic store a matrix back memory using a stride between columns.
1411 MatrixTy LowerColumnMajorStore(CallInst *Inst, IRBuilder<> &Builder) {
1413 "Intrinsic only supports column-major layout!");
1414 Value *Matrix = Inst->getArgOperand(0);
1415 Value *Ptr = Inst->getArgOperand(1);
1416 Value *Stride = Inst->getArgOperand(2);
1417 return LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride,
1418 cast<ConstantInt>(Inst->getArgOperand(3))->isOne(),
1419 {Inst->getArgOperand(4), Inst->getArgOperand(5)},
1420 Builder);
1421 }
1422
1423 // Set elements I..I+NumElts-1 to Block
1424 Value *insertVector(Value *Col, unsigned I, Value *Block,
1425 IRBuilder<> &Builder) {
1426
1427 // First, bring Block to the same size as Col
1428 unsigned BlockNumElts =
1429 cast<FixedVectorType>(Block->getType())->getNumElements();
1430 unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements();
1431 assert(NumElts >= BlockNumElts && "Too few elements for current block");
1432
1433 Block = Builder.CreateShuffleVector(
1434 Block, createSequentialMask(0, BlockNumElts, NumElts - BlockNumElts));
1435
1436 // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7,
1437 // 8, 4, 5, 6
1438 SmallVector<int, 16> Mask;
1439 unsigned i;
1440 for (i = 0; i < I; i++)
1441 Mask.push_back(i);
1442
1443 unsigned VecNumElts =
1444 cast<FixedVectorType>(Col->getType())->getNumElements();
1445 for (; i < I + BlockNumElts; i++)
1446 Mask.push_back(i - I + VecNumElts);
1447
1448 for (; i < VecNumElts; i++)
1449 Mask.push_back(i);
1450
1451 return Builder.CreateShuffleVector(Col, Block, Mask);
1452 }
1453
1454 Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp,
1455 IRBuilder<> &Builder, bool AllowContraction,
1456 unsigned &NumComputeOps) {
1457 NumComputeOps += getNumOps(A->getType());
1458 if (!Sum)
1459 return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B);
1460
1461 if (UseFPOp) {
1462 if (AllowContraction) {
1463 // Use fmuladd for floating point operations and let the backend decide
1464 // if that's profitable.
1465 return Builder.CreateIntrinsic(Intrinsic::fmuladd, A->getType(),
1466 {A, B, Sum});
1467 }
1468 NumComputeOps += getNumOps(A->getType());
1469 Value *Mul = Builder.CreateFMul(A, B);
1470 return Builder.CreateFAdd(Sum, Mul);
1471 }
1472
1473 NumComputeOps += getNumOps(A->getType());
1474 Value *Mul = Builder.CreateMul(A, B);
1475 return Builder.CreateAdd(Sum, Mul);
1476 }
1477
1478 /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
1479 /// users with shape information, there's nothing to do: they will use the
1480 /// cached value when they are lowered. For other users, \p Matrix is
1481 /// flattened and the uses are updated to use it. Also marks \p Inst for
1482 /// deletion.
1483 void finalizeLowering(Instruction *Inst, MatrixTy Matrix,
1484 IRBuilder<> &Builder) {
1485 auto inserted = Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix));
1486 (void)inserted;
1487 assert((inserted.second || isa<PHINode>(Inst)) &&
1488 "multiple matrix lowering mapping");
1489
1490 ToRemove.push_back(Inst);
1491 Value *Flattened = nullptr;
1492 for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
1493 if (ShapeMap.contains(U.getUser()))
1494 continue;
1495
1496 if (!Flattened) {
1497 Flattened = Matrix.embedInVector(Builder);
1498 LLVM_DEBUG(
1499 if (Instruction *User = dyn_cast<Instruction>(U.getUser())) dbgs()
1500 << "flattening a " << Matrix.shape() << " matrix:\n"
1501 << *Inst
1502 << "\nbecause we do not have a shape-aware lowering for its "
1503 "user:\n"
1504 << *User << '\n';);
1505 FlattenedMatrices++;
1506 }
1507 U.set(Flattened);
1508 }
1509 }
1510
1511 /// Special case for MatMul lowering. Prevents scalar loads of row-major
1512 /// vectors Lowers to vector reduction add instead of sequential add if
1513 /// reassocation is enabled.
1514 void lowerDotProduct(CallInst *MatMul,
1515 SmallPtrSet<Instruction *, 16> &FusedInsts,
1516 FastMathFlags FMF) {
1517 if (FusedInsts.contains(MatMul) ||
1519 return;
1520 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1521 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1522
1523 if (LShape.NumRows != 1 || RShape.NumColumns != 1) // not a dot product
1524 return;
1525
1526 Value *LHS = MatMul->getArgOperand(0);
1527 Value *RHS = MatMul->getArgOperand(1);
1528
1529 Type *ElementType = cast<FixedVectorType>(LHS->getType())->getElementType();
1530 bool IsIntVec = ElementType->isIntegerTy();
1531
1532 // Floating point reductions require reassocation.
1533 if (!IsIntVec && !FMF.allowReassoc())
1534 return;
1535
1536 auto CanBeFlattened = [](Value *Op) {
1537 if (match(Op, m_BinOp()))
1538 return true;
1539 return match(
1541 m_Load(m_Value()),
1544 m_Value(), m_SpecificInt(1))))));
1545 };
1546 // Returns the cost benefit of using \p Op with the dot product lowering. If
1547 // the returned cost is < 0, the argument is cheaper to use in the
1548 // dot-product lowering.
1549 auto GetCostForArg = [this, &CanBeFlattened](Value *Op, unsigned N) {
1550 if (!ShapeMap.contains(Op))
1551 return InstructionCost::getInvalid();
1552
1553 if (!isa<Instruction>(Op))
1554 return InstructionCost(0);
1555
1556 FixedVectorType *VecTy = cast<FixedVectorType>(Op->getType());
1557 Type *EltTy = VecTy->getElementType();
1558
1559 if (!CanBeFlattened(Op)) {
1560 InstructionCost EmbedCost(0);
1561 // Roughly estimate the cost for embedding the columns into a vector.
1562 for (unsigned I = 1; I < N; ++I)
1563 EmbedCost += TTI.getShuffleCost(
1566 return EmbedCost;
1567 }
1568
1569 if (match(Op, m_BinOp()) && ShapeMap.contains(Op)) {
1570 InstructionCost OriginalCost =
1571 TTI.getArithmeticInstrCost(cast<Instruction>(Op)->getOpcode(),
1572 EltTy) *
1573 N;
1574 InstructionCost NewCost = TTI.getArithmeticInstrCost(
1575 cast<Instruction>(Op)->getOpcode(), VecTy);
1576 return NewCost - OriginalCost;
1577 }
1578
1580 // The transpose can be skipped for the dot product lowering, roughly
1581 // estimate the savings as the cost of embedding the columns in a
1582 // vector.
1583 InstructionCost EmbedCost(0);
1584 for (unsigned I = 1; I < N; ++I)
1585 EmbedCost -= TTI.getShuffleCost(
1588 return EmbedCost;
1589 }
1590
1591 // Costs for loads.
1592 if (N == 1)
1593 return InstructionCost(0);
1594
1595 return TTI.getMemoryOpCost(Instruction::Load, VecTy, Align(1), 0) -
1596 N * TTI.getMemoryOpCost(Instruction::Load, EltTy, Align(1), 0);
1597 };
1598
1599 // Iterate over LHS and operations feeding LHS and check if it is profitable
1600 // to flatten the visited ops. For each op, we compute the difference
1601 // between the flattened and matrix versions.
1602 SmallPtrSet<Value *, 4> Seen;
1603 SmallVector<Value *> WorkList;
1604 SmallVector<Value *> ToFlatten;
1605 WorkList.push_back(LHS);
1606 InstructionCost LHSCost(0);
1607 while (!WorkList.empty()) {
1608 Value *Op = WorkList.pop_back_val();
1609 if (!Seen.insert(Op).second)
1610 continue;
1611
1612 InstructionCost OpCost = GetCostForArg(Op, LShape.NumColumns);
1613 if (OpCost + LHSCost >= LHSCost)
1614 continue;
1615
1616 LHSCost += OpCost;
1617 ToFlatten.push_back(Op);
1618 if (auto *I = dyn_cast<Instruction>(Op))
1619 WorkList.append(I->op_begin(), I->op_end());
1620 }
1621
1622 // We compare the costs of a vector.reduce.add to sequential add.
1623 int AddOpCode = IsIntVec ? Instruction::Add : Instruction::FAdd;
1624 int MulOpCode = IsIntVec ? Instruction::Mul : Instruction::FMul;
1625 InstructionCost ReductionCost =
1626 TTI.getArithmeticReductionCost(
1627 AddOpCode, cast<FixedVectorType>(LHS->getType()),
1628 IsIntVec ? std::nullopt : std::optional(FMF)) +
1629 TTI.getArithmeticInstrCost(MulOpCode, LHS->getType());
1630 InstructionCost SequentialAddCost =
1631 TTI.getArithmeticInstrCost(AddOpCode, ElementType) *
1632 (LShape.NumColumns - 1) +
1633 TTI.getArithmeticInstrCost(MulOpCode, ElementType) *
1634 (LShape.NumColumns);
1635 if ((LHSCost + ReductionCost - SequentialAddCost) > InstructionCost(0))
1636 return;
1637
1638 FusedInsts.insert(MatMul);
1639 IRBuilder<> Builder(MatMul);
1640 auto FlattenArg = [&Builder, &FusedInsts, &CanBeFlattened,
1641 this](Value *Op) {
1642 // Matmul must be the only user of loads because we don't use LowerLoad
1643 // for row vectors (LowerLoad results in scalar loads and shufflevectors
1644 // instead of single vector load).
1645 if (!CanBeFlattened(Op))
1646 return;
1647
1648 if (match(Op, m_BinOp())) {
1649 auto It = ShapeMap.find(Op);
1650 if (It != ShapeMap.end()) {
1651 It->second = It->second.t();
1652 return;
1653 }
1654 }
1655
1656 FusedInsts.insert(cast<Instruction>(Op));
1657 // If vector uses the builtin load, lower to a LoadInst
1658 Value *Arg;
1660 m_Value(Arg)))) {
1661 auto *NewLoad = Builder.CreateLoad(Op->getType(), Arg);
1662 Op->replaceAllUsesWith(NewLoad);
1663 eraseFromParentAndRemoveFromShapeMap(cast<Instruction>(Op));
1664 return;
1666 m_Value(Arg)))) {
1667 ToRemove.push_back(cast<Instruction>(Op));
1668 Op->replaceAllUsesWith(Arg);
1669 return;
1670 }
1671 };
1672
1673 for (auto *V : ToFlatten)
1674 FlattenArg(V);
1675
1676 LHS = MatMul->getArgOperand(0);
1677
1678 // Insert mul/fmul and llvm.vector.reduce.fadd
1679 Value *Mul =
1680 IsIntVec ? Builder.CreateMul(LHS, RHS) : Builder.CreateFMul(LHS, RHS);
1681
1682 Value *Result;
1683 if (IsIntVec)
1684 Result = Builder.CreateAddReduce(Mul);
1685 else {
1686 Result = Builder.CreateFAddReduce(
1687 ConstantFP::get(
1688 cast<FixedVectorType>(LHS->getType())->getElementType(), 0.0),
1689 Mul);
1690 cast<Instruction>(Result)->setFastMathFlags(FMF);
1691 }
1692
1693 // pack scalar back into a matrix and then replace matmul inst
1695 Result, uint64_t(0));
1696 MatMul->replaceAllUsesWith(Result);
1697 FusedInsts.insert(MatMul);
1698 ToRemove.push_back(MatMul);
1699 }
1700
1701 /// Compute \p Result += \p A * \p B for input matrices with left-associating
1702 /// addition.
1703 ///
1704 /// We can fold a transpose into the operand that is used to extract scalars.
1705 /// This is the first operands with row-major and the second with
1706 /// column-major. If \p IsScalarMatrixTransposed we assume the appropriate
1707 /// operand is transposed.
1708 void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A,
1709 const MatrixTy &B, IRBuilder<> &Builder, bool IsTiled,
1710 bool IsScalarMatrixTransposed, FastMathFlags FMF) {
1711 const unsigned VF = std::max<unsigned>(
1712 TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
1713 .getFixedValue() /
1714 Result.getElementType()->getPrimitiveSizeInBits().getFixedValue(),
1715 1U);
1716 unsigned R = Result.getNumRows();
1717 unsigned C = Result.getNumColumns();
1718 unsigned M = A.getNumColumns();
1719
1720 bool IsFP = Result.getElementType()->isFloatingPointTy();
1721 assert(A.isColumnMajor() == B.isColumnMajor() &&
1722 Result.isColumnMajor() == A.isColumnMajor() &&
1723 "operands must agree on matrix layout");
1724 unsigned NumComputeOps = 0;
1725
1726 Builder.setFastMathFlags(FMF);
1727
1728 if (A.isColumnMajor()) {
1729 // Multiply columns from the first operand with scalars from the second
1730 // operand. Then move along the K axes and accumulate the columns. With
1731 // this the adds can be vectorized without reassociation.
1732 for (unsigned J = 0; J < C; ++J) {
1733 unsigned BlockSize = VF;
1734 // If Result is zero, we don't need to accumulate in the K==0 iteration.
1735 bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J));
1736
1737 for (unsigned I = 0; I < R; I += BlockSize) {
1738 // Gradually lower the vectorization factor to cover the remainder.
1739 while (I + BlockSize > R)
1740 BlockSize /= 2;
1741
1742 Value *Sum = IsTiled ? Result.extractVector(I, J, BlockSize, Builder)
1743 : nullptr;
1744 for (unsigned K = 0; K < M; ++K) {
1745 Value *L = A.extractVector(I, K, BlockSize, Builder);
1746 Value *RH = Builder.CreateExtractElement(
1747 B.getColumn(IsScalarMatrixTransposed ? K : J),
1748 IsScalarMatrixTransposed ? J : K);
1749 Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat");
1750 Sum =
1751 createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat,
1752 IsFP, Builder, FMF.allowContract(), NumComputeOps);
1753 }
1754 Result.setVector(J,
1755 insertVector(Result.getVector(J), I, Sum, Builder));
1756 }
1757 }
1758 } else {
1759 // Multiply rows from the second operand with scalars from the first
1760 // operand. Then move along the K axes and accumulate the rows. With this
1761 // the adds can be vectorized without reassociation.
1762 for (unsigned I = 0; I < R; ++I) {
1763 unsigned BlockSize = VF;
1764 bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I));
1765 for (unsigned J = 0; J < C; J += BlockSize) {
1766 // Gradually lower the vectorization factor to cover the remainder.
1767 while (J + BlockSize > C)
1768 BlockSize /= 2;
1769
1770 Value *Sum = nullptr;
1771 for (unsigned K = 0; K < M; ++K) {
1772 Value *R = B.extractVector(K, J, BlockSize, Builder);
1773 Value *LH = Builder.CreateExtractElement(
1774 A.getVector(IsScalarMatrixTransposed ? K : I),
1775 IsScalarMatrixTransposed ? I : K);
1776 Value *Splat = Builder.CreateVectorSplat(BlockSize, LH, "splat");
1777 Sum =
1778 createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R,
1779 IsFP, Builder, FMF.allowContract(), NumComputeOps);
1780 }
1781 Result.setVector(I,
1782 insertVector(Result.getVector(I), J, Sum, Builder));
1783 }
1784 }
1785 }
1786 Result.addNumComputeOps(NumComputeOps);
1787 }
1788
1789 /// Ensure that the memory in \p Load does not alias \p Store by potentially
1790 /// copying it to a new location. This new or otherwise the original location
1791 /// is returned.
1792 Value *getNonAliasingPointer(LoadInst *Load, StoreInst *Store,
1793 CallInst *MatMul) {
1794 MemoryLocation StoreLoc = MemoryLocation::get(Store);
1795 MemoryLocation LoadLoc = MemoryLocation::get(Load);
1796
1797 // If we can statically determine noalias we're good.
1798 if (AA->isNoAlias(LoadLoc, StoreLoc))
1799 return Load->getPointerOperand();
1800
1801 // Create code to check if the memory locations of the Load and Store
1802 // overlap and if they do, copy Load's operand to a new buffer.
1803
1804 // First, create new blocks for 2n part of the check and the copy.
1805 BasicBlock *Check0 = MatMul->getParent();
1806 // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a
1807 // DT. Manually collect dominator tree updates, to avoid unnecessary work,
1808 // as we adjust Check0 and Check1's branches.
1810 for (BasicBlock *Succ : successors(Check0))
1811 DTUpdates.push_back({DT->Delete, Check0, Succ});
1812
1813 BasicBlock *Check1 =
1814 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1815 nullptr, "alias_cont");
1816 BasicBlock *Copy =
1817 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1818 nullptr, "copy");
1819 BasicBlock *Fusion =
1820 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1821 nullptr, "no_alias");
1822
1823 // Check if the loaded memory location begins before the end of the store
1824 // location. If the condition holds, they might overlap, otherwise they are
1825 // guaranteed to not overlap.
1826 IRBuilder<> Builder(MatMul);
1827 Check0->getTerminator()->eraseFromParent();
1828 Builder.SetInsertPoint(Check0);
1829 Type *IntPtrTy = Builder.getIntPtrTy(Load->getDataLayout());
1830 Value *StoreBegin = Builder.CreatePtrToInt(
1831 const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin");
1832 Value *StoreEnd = Builder.CreateAdd(
1833 StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()),
1834 "store.end", true, true);
1835 Value *LoadBegin = Builder.CreatePtrToInt(const_cast<Value *>(LoadLoc.Ptr),
1836 IntPtrTy, "load.begin");
1837 Builder.CreateCondBr(Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1,
1838 Fusion);
1839
1840 // Check if the store begins before the end of the load location. If the
1841 // condition holds, they alias, otherwise they are guaranteed to not
1842 // overlap.
1843 Check1->getTerminator()->eraseFromParent();
1844 Builder.SetInsertPoint(Check1, Check1->begin());
1845 Value *LoadEnd = Builder.CreateAdd(
1846 LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()),
1847 "load.end", true, true);
1848 Builder.CreateCondBr(Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy,
1849 Fusion);
1850
1851 // Copy load operand to new alloca.
1852 Builder.SetInsertPoint(Copy, Copy->begin());
1853 auto *VT = cast<FixedVectorType>(Load->getType());
1854 // Use an array type for the alloca, to avoid potentially huge alignment
1855 // requirements for large vector types.
1856 auto *ArrayTy = ArrayType::get(VT->getElementType(), VT->getNumElements());
1857 AllocaInst *Alloca =
1858 Builder.CreateAlloca(ArrayTy, Load->getPointerAddressSpace());
1859
1860 Builder.CreateMemCpy(Alloca, Alloca->getAlign(), Load->getPointerOperand(),
1861 Load->getAlign(), LoadLoc.Size.getValue());
1862 Builder.SetInsertPoint(Fusion, Fusion->begin());
1863 PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3);
1864 PHI->addIncoming(Load->getPointerOperand(), Check0);
1865 PHI->addIncoming(Load->getPointerOperand(), Check1);
1866 PHI->addIncoming(Alloca, Copy);
1867
1868 // Adjust DT.
1869 DTUpdates.push_back({DT->Insert, Check0, Check1});
1870 DTUpdates.push_back({DT->Insert, Check0, Fusion});
1871 DTUpdates.push_back({DT->Insert, Check1, Copy});
1872 DTUpdates.push_back({DT->Insert, Check1, Fusion});
1873 DT->applyUpdates(DTUpdates);
1874 return PHI;
1875 }
1876
1877 bool isFusionProfitable(CallInst *MatMul) {
1878 if (ForceFusion)
1879 return true;
1880
1881 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1882 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1883
1884 const unsigned R = LShape.NumRows;
1885 const unsigned C = RShape.NumColumns;
1886 const unsigned M = LShape.NumColumns;
1887 auto *EltType = cast<FixedVectorType>(MatMul->getType())->getElementType();
1888
1889 const unsigned VF = std::max<unsigned>(
1890 TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
1891 .getFixedValue() /
1893 1U);
1894
1895 // Cost model for tiling
1896 //
1897 // For tiling to be beneficial, we need reuse either along the R or
1898 // the C axis. We vectorize along the R axis so that means at least
1899 // 3 elements.
1900 // TODO: Also consider cost of copying if operands alias.
1901 if (R <= VF && C == 1)
1902 return false;
1903 // Then we need enough elements to exceed the number of vector
1904 // registers we have. Note that this is an oversimplification since
1905 // fusing also takes some extra loads which may exceed the number of
1906 // reloads necessary.
1907 unsigned Op0Regs = (R + VF - 1) / VF * M;
1908 unsigned Op1Regs = (M + VF - 1) / VF * C;
1909 return Op0Regs + Op1Regs >
1910 TTI.getNumberOfRegisters(TTI.getRegisterClassForType(true));
1911 }
1912
1913 MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) {
1914 MatrixTy Res;
1915 auto *ColumType = FixedVectorType::get(EltType, R);
1916 for (unsigned I = 0; I < C; ++I)
1917 Res.addVector(ConstantAggregateZero::get(ColumType));
1918 return Res;
1919 }
1920
1921 void createTiledLoops(CallInst *MatMul, Value *LPtr, ShapeInfo LShape,
1922 Value *RPtr, ShapeInfo RShape, StoreInst *Store) {
1923 auto *EltType = cast<FixedVectorType>(MatMul->getType())->getElementType();
1924
1925 // Create the main tiling loop nest.
1926 TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns, TileSize);
1927 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
1928 Instruction *InsertI = cast<Instruction>(MatMul);
1929 BasicBlock *Start = InsertI->getParent();
1930 BasicBlock *End =
1931 SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue");
1932 IRBuilder<> Builder(MatMul);
1933 BasicBlock *InnerBody = TI.CreateTiledLoops(Start, End, Builder, DTU, *LI);
1934
1935 Type *TileVecTy =
1937 MatrixTy TileResult;
1938 // Insert in the inner loop header.
1939 Builder.SetInsertPoint(TI.KLoop.Header->getTerminator());
1940 // Create PHI nodes for the result columns to accumulate across iterations.
1941 SmallVector<PHINode *, 4> ColumnPhis;
1942 for (unsigned I = 0; I < TileSize; I++) {
1943 auto *Phi = Builder.CreatePHI(TileVecTy, 2, "result.vec." + Twine(I));
1944 Phi->addIncoming(ConstantAggregateZero::get(TileVecTy),
1945 TI.RowLoop.Header->getSingleSuccessor());
1946 TileResult.addVector(Phi);
1947 ColumnPhis.push_back(Phi);
1948 }
1949
1950 // Insert in the inner loop body, which computes
1951 // Res += Load(CurrentRow, K) * Load(K, CurrentColumn)
1952 Builder.SetInsertPoint(InnerBody->getTerminator());
1953 // Load tiles of the operands.
1954 MatrixTy A =
1955 loadMatrix(LPtr, {}, false, LShape, TI.RowLoop.Index, TI.KLoop.Index,
1956 {TileSize, TileSize}, EltType, Builder);
1957 MatrixTy B =
1958 loadMatrix(RPtr, {}, false, RShape, TI.KLoop.Index, TI.ColumnLoop.Index,
1959 {TileSize, TileSize}, EltType, Builder);
1960 emitMatrixMultiply(TileResult, A, B, Builder, true, false,
1961 getFastMathFlags(MatMul));
1962 // Store result after the inner loop is done.
1963 Builder.SetInsertPoint(TI.RowLoop.Latch->getTerminator());
1964 storeMatrix(TileResult, Store->getPointerOperand(), Store->getAlign(),
1965 Store->isVolatile(), {LShape.NumRows, RShape.NumColumns},
1966 TI.RowLoop.Index, TI.ColumnLoop.Index, EltType, Builder);
1967
1968 for (unsigned I = 0; I < TileResult.getNumVectors(); I++)
1969 ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.KLoop.Latch);
1970
1971 // Force unrolling of a few iterations of the inner loop, to make sure there
1972 // is enough work per iteration.
1973 // FIXME: The unroller should make this decision directly instead, but
1974 // currently the cost-model is not up to the task.
1975 unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns / TileSize);
1976 addStringMetadataToLoop(LI->getLoopFor(TI.KLoop.Header),
1977 "llvm.loop.unroll.count", InnerLoopUnrollCount);
1978 }
1979
1980 void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1,
1981 StoreInst *Store,
1982 SmallPtrSetImpl<Instruction *> &FusedInsts) {
1984 "Tiling only supported for column-major matrixes at the moment!");
1985 if (!isFusionProfitable(MatMul))
1986 return;
1987
1988 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1989 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1990
1991 const unsigned R = LShape.NumRows;
1992 const unsigned C = RShape.NumColumns;
1993 const unsigned M = LShape.NumColumns;
1994 auto *EltType = cast<FixedVectorType>(MatMul->getType())->getElementType();
1995
1996 Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul);
1997 Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul);
1998 Value *CPtr = Store->getPointerOperand();
1999
2000 if (TileUseLoops && (R % TileSize == 0 && C % TileSize == 0))
2001 createTiledLoops(MatMul, APtr, LShape, BPtr, RShape, Store);
2002 else {
2003 IRBuilder<> Builder(Store);
2004 for (unsigned J = 0; J < C; J += TileSize)
2005 for (unsigned I = 0; I < R; I += TileSize) {
2006 const unsigned TileR = std::min(R - I, unsigned(TileSize));
2007 const unsigned TileC = std::min(C - J, unsigned(TileSize));
2008 MatrixTy Res = getZeroMatrix(EltType, TileR, TileC);
2009
2010 for (unsigned K = 0; K < M; K += TileSize) {
2011 const unsigned TileM = std::min(M - K, unsigned(TileSize));
2012 MatrixTy A =
2013 loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(),
2014 LShape, Builder.getInt64(I), Builder.getInt64(K),
2015 {TileR, TileM}, EltType, Builder);
2016 MatrixTy B =
2017 loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(),
2018 RShape, Builder.getInt64(K), Builder.getInt64(J),
2019 {TileM, TileC}, EltType, Builder);
2020 emitMatrixMultiply(Res, A, B, Builder, true, false,
2021 getFastMathFlags(MatMul));
2022 }
2023 storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M},
2024 Builder.getInt64(I), Builder.getInt64(J), EltType,
2025 Builder);
2026 }
2027 }
2028
2029 // Mark eliminated instructions as fused and remove them.
2030 FusedInsts.insert(Store);
2031 FusedInsts.insert(MatMul);
2032 eraseFromParentAndRemoveFromShapeMap(Store);
2033 eraseFromParentAndRemoveFromShapeMap(MatMul);
2034 if (LoadOp0->use_empty()) {
2035 FusedInsts.insert(LoadOp0);
2036 eraseFromParentAndRemoveFromShapeMap(LoadOp0);
2037 }
2038 if (LoadOp1 != LoadOp0 && LoadOp1->use_empty()) {
2039 FusedInsts.insert(LoadOp1);
2040 eraseFromParentAndRemoveFromShapeMap(LoadOp1);
2041 }
2042 }
2043
2044 /// Try to lower matrix multiply chains by fusing operations.
2045 ///
2046 /// Call finalizeLowering on lowered instructions. Instructions that are
2047 /// completely eliminated by fusion are added to \p FusedInsts.
2048 void
2049 LowerMatrixMultiplyFused(CallInst *MatMul,
2050 SmallPtrSetImpl<Instruction *> &FusedInsts,
2051 SmallVector<IntrinsicInst *, 16> &LifetimeEnds) {
2052 if (!FuseMatrix || !DT)
2053 return;
2054
2055 assert(AA && LI && "Analyses should be available");
2056
2057 Value *A = MatMul->getArgOperand(0);
2058 Value *B = MatMul->getArgOperand(1);
2059
2060 // We can fold the transpose into the operand that is used to fetch scalars.
2061 Value *T;
2065 IRBuilder<> Builder(MatMul);
2066 auto *EltType =
2067 cast<FixedVectorType>(MatMul->getType())->getElementType();
2068 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
2069 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
2070 const unsigned R = LShape.NumRows;
2071 const unsigned M = LShape.NumColumns;
2072 const unsigned C = RShape.NumColumns;
2073
2074 MatrixTy MA;
2075 MatrixTy MB;
2076
2077 Value *Transpose;
2079 MA = getMatrix(A, ShapeInfo(R, M), Builder);
2080 MB = getMatrix(T, ShapeInfo(C, M), Builder);
2081 Transpose = B;
2082 } else {
2083 MA = getMatrix(T, ShapeInfo(R, M), Builder);
2084 MB = getMatrix(B, ShapeInfo(C, M), Builder);
2085 Transpose = A;
2086 }
2087
2088 // Initialize the output
2089 MatrixTy Result(R, C, EltType);
2090
2091 emitMatrixMultiply(Result, MA, MB, Builder, false, true,
2092 getFastMathFlags(MatMul));
2093
2094 FusedInsts.insert(MatMul);
2095 if (Transpose->hasOneUse()) {
2096 FusedInsts.insert(cast<Instruction>(Transpose));
2097 ToRemove.push_back(cast<Instruction>(Transpose));
2098 // TODO: add a fake entry for the folded instruction so that this is
2099 // included in the expression in the remark.
2100 Inst2ColumnMatrix[Transpose] = MatrixTy(M, C, EltType);
2101 }
2102 finalizeLowering(MatMul, Result, Builder);
2103 return;
2104 }
2105
2107 return;
2108
2109 // Lower {ld, ld} -> matmul -> st chains. No need to call finalizeLowering
2110 // since the single store user will be lowered as part of this.
2111 auto *LoadOp0 = dyn_cast<LoadInst>(A);
2112 auto *LoadOp1 = dyn_cast<LoadInst>(B);
2113 auto *Store = dyn_cast<StoreInst>(*MatMul->user_begin());
2114 if (LoadOp0 && LoadOp1 && Store) {
2115 // The store address must dominate the MatMul instruction, otherwise
2116 // we create invalid IR.
2117 SetVector<Value *> WorkList;
2118 WorkList.insert(Store->getOperand(1));
2120 for (unsigned I = 0; I != WorkList.size(); ++I) {
2121 Value *Current = WorkList[I];
2122 auto *CurrI = dyn_cast<Instruction>(Current);
2123 if (!CurrI)
2124 continue;
2125 if (isa<PHINode>(CurrI))
2126 return;
2127 if (DT->dominates(CurrI, MatMul))
2128 continue;
2129 if (CurrI->mayHaveSideEffects() || CurrI->mayReadFromMemory())
2130 return;
2131 ToHoist.push_back(CurrI);
2132 WorkList.insert_range(CurrI->operands());
2133 }
2134
2135 sort(ToHoist, [this](Instruction *A, Instruction *B) {
2136 return DT->dominates(A, B);
2137 });
2138 for (Instruction *I : ToHoist)
2139 I->moveBefore(MatMul->getIterator());
2140
2141 // Deal with lifetime.end calls that might be between Load0/Load1 and the
2142 // store. To avoid introducing loads to dead objects (i.e. after the
2143 // lifetime has been termined by @llvm.lifetime.end), either sink them
2144 // after the store if in the same block, or remove the lifetime.end marker
2145 // otherwise. This might pessimize further optimizations, by extending the
2146 // lifetime of the object until the function returns, but should be
2147 // conservatively correct.
2148 MemoryLocation Load0Loc = MemoryLocation::get(LoadOp0);
2149 MemoryLocation Load1Loc = MemoryLocation::get(LoadOp1);
2150 BasicBlock *StoreParent = Store->getParent();
2151 bool FusableOpsInSameBlock = LoadOp0->getParent() == StoreParent &&
2152 LoadOp1->getParent() == StoreParent;
2153 for (unsigned Idx = 0; Idx != LifetimeEnds.size();) {
2154 IntrinsicInst *End = LifetimeEnds[Idx];
2155 auto Inc = make_scope_exit([&Idx]() { Idx++; });
2156 // If the lifetime.end is guaranteed to be before the loads or after the
2157 // store, it won't interfere with fusion.
2158 if (DT->dominates(End, LoadOp0) && DT->dominates(End, LoadOp1))
2159 continue;
2160 if (DT->dominates(Store, End))
2161 continue;
2162 // If all fusable ops are in the same block and the lifetime.end is in a
2163 // different block, it won't interfere with fusion.
2164 if (FusableOpsInSameBlock && End->getParent() != StoreParent)
2165 continue;
2166
2167 // If the loads don't alias the lifetime.end, it won't interfere with
2168 // fusion.
2169 MemoryLocation EndLoc = MemoryLocation::getForArgument(End, 0, nullptr);
2170 if (!EndLoc.Ptr)
2171 continue;
2172 if (AA->isNoAlias(Load0Loc, EndLoc) && AA->isNoAlias(Load1Loc, EndLoc))
2173 continue;
2174
2175 // If both lifetime.end and the store are in the same block, extend the
2176 // lifetime until after the store, so the new lifetime covers the loads
2177 // we introduce later.
2178 if (End->getParent() == StoreParent) {
2179 End->moveAfter(Store);
2180 continue;
2181 }
2182
2183 // Otherwise remove the conflicting lifetime.end marker.
2184 ToRemove.push_back(End);
2185 std::swap(LifetimeEnds[Idx], LifetimeEnds.back());
2186 LifetimeEnds.pop_back();
2187 Inc.release();
2188 }
2189
2190 emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts);
2191 return;
2192 }
2193 }
2194
2195 /// Lowers llvm.matrix.multiply.
2196 MatrixTy LowerMultiply(CallInst *MatMul, IRBuilder<> &Builder) {
2197 auto *EltType = cast<FixedVectorType>(MatMul->getType())->getElementType();
2198 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
2199 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
2200
2201 const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder);
2202 const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder);
2203 assert(Lhs.getElementType() == Rhs.getElementType() &&
2204 "Matrix multiply argument element types do not match.");
2205
2206 const unsigned R = LShape.NumRows;
2207 const unsigned C = RShape.NumColumns;
2208 assert(LShape.NumColumns == RShape.NumRows);
2209
2210 // Initialize the output
2211 MatrixTy Result(R, C, EltType);
2212 assert(Lhs.getElementType() == Result.getElementType() &&
2213 "Matrix multiply result element type does not match arguments.");
2214
2215 emitMatrixMultiply(Result, Lhs, Rhs, Builder, false, false,
2216 getFastMathFlags(MatMul));
2217 return Result;
2218 }
2219
2220 /// Lowers llvm.matrix.transpose.
2221 MatrixTy LowerTranspose(CallInst *Inst, IRBuilder<> &Builder) {
2222 MatrixTy Result;
2223 Value *InputVal = Inst->getArgOperand(0);
2224 FixedVectorType *VectorTy = cast<FixedVectorType>(InputVal->getType());
2225 ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2));
2226 MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);
2227
2228 const unsigned NewNumVecs =
2229 InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns;
2230 const unsigned NewNumElts =
2231 InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows;
2232
2233 for (unsigned I = 0; I < NewNumVecs; ++I) {
2234 // Build a single result vector. First initialize it.
2235 Value *ResultVector = PoisonValue::get(
2236 FixedVectorType::get(VectorTy->getElementType(), NewNumElts));
2237 // Go through the old elements and insert it into the resulting vector.
2238 for (auto J : enumerate(InputMatrix.vectors())) {
2239 Value *Elt = Builder.CreateExtractElement(J.value(), I);
2240 // Row and column indices are transposed.
2241 ResultVector =
2242 Builder.CreateInsertElement(ResultVector, Elt, J.index());
2243 }
2244 Result.addVector(ResultVector);
2245 }
2246
2247 // TODO: Improve estimate of operations needed for transposes. Currently we
2248 // just count the insertelement/extractelement instructions, but do not
2249 // account for later simplifications/combines.
2250 return Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns)
2251 .addNumExposedTransposes(1);
2252 }
2253
2254 /// Lower load instructions.
2255 MatrixTy VisitLoad(LoadInst *Inst, const ShapeInfo &SI, Value *Ptr,
2256 IRBuilder<> &Builder) {
2257 return LowerLoad(Inst, Ptr, Inst->getAlign(),
2258 Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI,
2259 Builder);
2260 }
2261
2262 MatrixTy VisitStore(StoreInst *Inst, const ShapeInfo &SI, Value *StoredVal,
2263 Value *Ptr, IRBuilder<> &Builder) {
2264 return LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
2265 Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI,
2266 Builder);
2267 }
2268
2269 MatrixTy VisitPHI(PHINode *Inst, const ShapeInfo &SI, IRBuilder<> &Builder) {
2270 auto BlockIP = Inst->getParent()->getFirstInsertionPt();
2271 Builder.SetInsertPoint(BlockIP);
2272 MatrixTy PhiM = getMatrix(Inst, SI, Builder);
2273
2274 for (auto [IncomingV, IncomingB] :
2275 llvm::zip_equal(Inst->incoming_values(), Inst->blocks())) {
2276 // getMatrix() may insert some instructions to help with reshaping. The
2277 // safest place for those is at the top of the block after the rest of the
2278 // PHI's. Even better, if we can put it in the incoming block.
2279 Builder.SetInsertPoint(BlockIP);
2280 if (auto *IncomingInst = dyn_cast<Instruction>(IncomingV))
2281 if (auto MaybeIP = IncomingInst->getInsertionPointAfterDef())
2282 Builder.SetInsertPoint(*MaybeIP);
2283
2284 MatrixTy OpM = getMatrix(IncomingV, SI, Builder);
2285
2286 for (unsigned VI = 0, VE = PhiM.getNumVectors(); VI != VE; ++VI) {
2287 PHINode *NewPHI = cast<PHINode>(PhiM.getVector(VI));
2288 NewPHI->addIncoming(OpM.getVector(VI), IncomingB);
2289 }
2290 }
2291
2292 // finalizeLowering() may also insert instructions in some cases. The safe
2293 // place for those is at the end of the initial block of PHIs.
2294 Builder.SetInsertPoint(BlockIP);
2295 return PhiM;
2296 }
2297
2298 /// Lower binary operators.
2299 MatrixTy VisitBinaryOperator(BinaryOperator *Inst, const ShapeInfo &SI,
2300 IRBuilder<> &Builder) {
2301 Value *Lhs = Inst->getOperand(0);
2302 Value *Rhs = Inst->getOperand(1);
2303
2304 MatrixTy Result;
2305 MatrixTy A = getMatrix(Lhs, SI, Builder);
2306 MatrixTy B = getMatrix(Rhs, SI, Builder);
2307 assert(A.isColumnMajor() == B.isColumnMajor() &&
2308 Result.isColumnMajor() == A.isColumnMajor() &&
2309 "operands must agree on matrix layout");
2310
2311 Builder.setFastMathFlags(getFastMathFlags(Inst));
2312
2313 for (auto [AV, BV] : llvm::zip_equal(A.vectors(), B.vectors()))
2314 Result.addVector(Builder.CreateBinOp(Inst->getOpcode(), AV, BV));
2315
2316 return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
2317 Result.getNumVectors());
2318 }
2319
2320 /// Lower unary operators.
2321 MatrixTy VisitUnaryOperator(UnaryOperator *Inst, const ShapeInfo &SI,
2322 IRBuilder<> &Builder) {
2323 Value *Op = Inst->getOperand(0);
2324
2325 MatrixTy Result;
2326 MatrixTy M = getMatrix(Op, SI, Builder);
2327
2328 Builder.setFastMathFlags(getFastMathFlags(Inst));
2329
2330 // Helper to perform unary op on vectors.
2331 auto BuildVectorOp = [&Builder, Inst](Value *Op) {
2332 switch (Inst->getOpcode()) {
2333 case Instruction::FNeg:
2334 return Builder.CreateFNeg(Op);
2335 default:
2336 llvm_unreachable("Unsupported unary operator for matrix");
2337 }
2338 };
2339
2340 for (auto *Vector : M.vectors())
2341 Result.addVector(BuildVectorOp(Vector));
2342
2343 return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
2344 Result.getNumVectors());
2345 }
2346
2347 /// Lower cast instructions.
2348 MatrixTy VisitCastInstruction(CastInst *Inst, const ShapeInfo &Shape,
2349 IRBuilder<> &Builder) {
2350 Value *Op = Inst->getOperand(0);
2351
2352 MatrixTy Result;
2353 MatrixTy M = getMatrix(Op, Shape, Builder);
2354
2355 Builder.setFastMathFlags(getFastMathFlags(Inst));
2356
2357 auto *OrigVTy = cast<VectorType>(Inst->getType());
2358 auto *NewVTy = VectorType::get(OrigVTy->getElementType(),
2359 ElementCount::getFixed(M.getStride()));
2360
2361 for (auto *Vector : M.vectors())
2362 Result.addVector(Builder.CreateCast(Inst->getOpcode(), Vector, NewVTy));
2363
2364 return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
2365 Result.getNumVectors());
2366 }
2367
2368 /// Lower selects.
2369 MatrixTy VisitSelectInst(SelectInst *Inst, const ShapeInfo &Shape,
2370 IRBuilder<> &Builder) {
2371 Value *Cond = Inst->getOperand(0);
2372 Value *OpA = Inst->getOperand(1);
2373 Value *OpB = Inst->getOperand(2);
2374
2375 MatrixTy Result;
2376 MatrixTy A = getMatrix(OpA, Shape, Builder);
2377 MatrixTy B = getMatrix(OpB, Shape, Builder);
2378
2379 SmallVector<Value*> CondV;
2380 if (isa<FixedVectorType>(Cond->getType())) {
2381 MatrixTy C = getMatrix(Cond, Shape, Builder);
2382 llvm::copy(C.vectors(), std::back_inserter(CondV));
2383 } else {
2384 CondV.resize(A.getNumVectors());
2385 llvm::fill(CondV, Cond);
2386 }
2387
2388 for (auto [CV, AV, BV] : llvm::zip_equal(CondV, A.vectors(), B.vectors()))
2389 Result.addVector(Builder.CreateSelect(CV, AV, BV));
2390
2391 return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
2392 Result.getNumVectors());
2393 }
2394
2395 /// Helper to linearize a matrix expression tree into a string. Currently
2396 /// matrix expressions are linarized by starting at an expression leaf and
2397 /// linearizing bottom up.
2398 struct ExprLinearizer {
2399 unsigned LengthToBreak = 100;
2400 std::string Str;
2401 raw_string_ostream Stream;
2402 unsigned LineLength = 0;
2403 const DataLayout &DL;
2404
2405 /// Mapping from instructions to matrixes. It is used to identify
2406 /// matrix instructions.
2407 const MapVector<Value *, MatrixTy> &Inst2Matrix;
2408
2409 /// Mapping from values to the leaves of all expressions that the value is
2410 /// part of.
2411 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared;
2412
2413 /// Set of matrix expressions in the scope of a given DISubprogram.
2414 const SmallSetVector<Value *, 32> &ExprsInSubprogram;
2415
2416 /// Leaf node of the expression to linearize.
2417 Value *Leaf;
2418
2419 /// Used to keep track of sub-expressions that get reused while linearizing
2420 /// the expression. Re-used sub-expressions are marked as (reused).
2421 SmallPtrSet<Value *, 8> ReusedExprs;
2422
2423 ExprLinearizer(const DataLayout &DL,
2424 const MapVector<Value *, MatrixTy> &Inst2Matrix,
2425 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
2426 const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2427 Value *Leaf)
2428 : Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared),
2429 ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {}
2430
2431 void indent(unsigned N) {
2432 LineLength += N;
2433 for (unsigned i = 0; i < N; i++)
2434 Stream << " ";
2435 }
2436
2437 void lineBreak() {
2438 Stream << "\n";
2439 LineLength = 0;
2440 }
2441
2442 void maybeIndent(unsigned Indent) {
2443 if (LineLength >= LengthToBreak)
2444 lineBreak();
2445
2446 if (LineLength == 0)
2447 indent(Indent);
2448 }
2449
2450 void write(StringRef S) {
2451 LineLength += S.size();
2452 Stream << S;
2453 }
2454
2455 Value *getUnderlyingObjectThroughLoads(Value *V) {
2456 if (Value *Ptr = getPointerOperand(V))
2457 return getUnderlyingObjectThroughLoads(Ptr);
2458 else if (V->getType()->isPointerTy())
2459 return getUnderlyingObject(V);
2460 return V;
2461 }
2462
2463 /// Returns true if \p V is a matrix value in the given subprogram.
2464 bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); }
2465
2466 /// If \p V is a matrix value, print its shape as NumRows x NumColumns to
2467 /// \p SS.
2468 void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) {
2469 auto M = Inst2Matrix.find(V);
2470 if (M == Inst2Matrix.end())
2471 SS << "unknown";
2472 else {
2473 SS << M->second.getNumRows();
2474 SS << "x";
2475 SS << M->second.getNumColumns();
2476 }
2477 }
2478
2479 /// Write the called function name. Handles calls to llvm.matrix.*
2480 /// specially: we write the name, followed by the dimensions of the input
2481 /// matrixes, followed by the scalar type name.
2482 void writeFnName(CallInst *CI) {
2483 if (!CI->getCalledFunction())
2484 write("<no called fn>");
2485 else {
2486 StringRef Name = CI->getCalledFunction()->getName();
2487 if (!Name.starts_with("llvm.matrix")) {
2488 write(Name);
2489 return;
2490 }
2491 auto *II = cast<IntrinsicInst>(CI);
2492 write(Intrinsic::getBaseName(II->getIntrinsicID())
2493 .drop_front(StringRef("llvm.matrix.").size()));
2494 write(".");
2495 std::string Tmp;
2496 raw_string_ostream SS(Tmp);
2497
2498 switch (II->getIntrinsicID()) {
2499 case Intrinsic::matrix_multiply:
2500 prettyPrintMatrixType(II->getOperand(0), SS);
2501 SS << ".";
2502 prettyPrintMatrixType(II->getOperand(1), SS);
2503 SS << "." << *II->getType()->getScalarType();
2504 break;
2505 case Intrinsic::matrix_transpose:
2506 prettyPrintMatrixType(II->getOperand(0), SS);
2507 SS << "." << *II->getType()->getScalarType();
2508 break;
2509 case Intrinsic::matrix_column_major_load:
2510 prettyPrintMatrixType(II, SS);
2511 SS << "." << *II->getType()->getScalarType();
2512 break;
2513 case Intrinsic::matrix_column_major_store:
2514 prettyPrintMatrixType(II->getOperand(0), SS);
2515 SS << "." << *II->getOperand(0)->getType()->getScalarType();
2516 break;
2517 default:
2518 llvm_unreachable("Unhandled case");
2519 }
2520 write(Tmp);
2521 }
2522 }
2523
2524 unsigned getNumShapeArgs(CallInst *CI) const {
2525 if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) {
2526 switch (II->getIntrinsicID()) {
2527 case Intrinsic::matrix_multiply:
2528 return 3;
2529 case Intrinsic::matrix_transpose:
2530 return 2;
2531 case Intrinsic::matrix_column_major_load:
2532 case Intrinsic::matrix_column_major_store:
2533 return 3;
2534 default:
2535 return 0;
2536 }
2537 }
2538 return 0;
2539 }
2540
2541 /// Special printing for values: for pointers, we print if they refer to an
2542 /// (function) external address or a stack address, for other values we
2543 /// either print the constant or "scalar"/"matrix" for other values.
2544 void write(Value *V) {
2545 V = getUnderlyingObjectThroughLoads(V);
2546 if (V->getType()->isPointerTy()) {
2547 if (isa<AllocaInst>(V)) {
2548 Stream << "stack addr";
2549 LineLength += StringRef("stack addr").size();
2550 } else {
2551 Stream << "addr";
2552 LineLength += StringRef("addr").size();
2553 }
2554 if (!V->getName().empty()) {
2555 Stream << " %" << V->getName() << "";
2556 LineLength += V->getName().size() + 2;
2557 }
2558 return;
2559 }
2560
2561 std::string Tmp;
2562 raw_string_ostream TmpStream(Tmp);
2563
2564 if (auto *CI = dyn_cast<ConstantInt>(V))
2565 TmpStream << CI->getValue();
2566 else if (isa<Constant>(V))
2567 TmpStream << "constant";
2568 else {
2569 if (isMatrix(V))
2570 TmpStream << "matrix";
2571 else
2572 TmpStream << "scalar";
2573 }
2574 Tmp = std::string(StringRef(Tmp).trim());
2575 LineLength += Tmp.size();
2576 Stream << Tmp;
2577 }
2578
2579 /// Linearize expression \p Expr starting at an indentation of \p Indent.
2580 /// Expressions that are re-used multiple times are prefixed with (reused)
2581 /// at the re-used root instruction.
2582 void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused,
2583 bool ParentShared) {
2584 auto *I = cast<Instruction>(Expr);
2585 maybeIndent(Indent);
2586 SmallVector<Value *, 8> Ops;
2587
2588 // Is Expr shared with other expression leaves?
2589 bool ExprShared = false;
2590
2591 // Deal with shared subtrees. Mark them as shared, if required.
2592 if (!ParentShared) {
2593 auto SI = Shared.find(Expr);
2594 assert(SI != Shared.end() && SI->second.count(Leaf));
2595
2596 for (Value *S : SI->second) {
2597 if (S == Leaf)
2598 continue;
2599 DebugLoc DL = cast<Instruction>(S)->getDebugLoc();
2600 write("shared with remark at line " + std::to_string(DL.getLine()) +
2601 " column " + std::to_string(DL.getCol()) + " (");
2602 }
2603 ExprShared = SI->second.size() > 1;
2604 }
2605
2606 bool Reused = !ReusedExprs.insert(Expr).second;
2607 if (Reused && !ParentReused)
2608 write("(reused) ");
2609
2610 if (auto *CI = dyn_cast<CallInst>(I)) {
2611 writeFnName(CI);
2612
2613 Ops.append(CI->arg_begin(), CI->arg_end() - getNumShapeArgs(CI));
2614 } else if (isa<BitCastInst>(Expr)) {
2615 // Special case bitcasts, which are used to materialize matrixes from
2616 // non-matrix ops.
2617 write("matrix");
2618 return;
2619 } else {
2620 Ops.append(I->value_op_begin(), I->value_op_end());
2621 write(I->getOpcodeName());
2622 }
2623
2624 write("(");
2625
2626 unsigned NumOpsToBreak = 1;
2628 NumOpsToBreak = 2;
2629
2630 for (Value *Op : Ops) {
2631 if (Ops.size() > NumOpsToBreak)
2632 lineBreak();
2633
2634 maybeIndent(Indent + 1);
2635 if (isMatrix(Op))
2636 linearizeExpr(Op, Indent + 1, Reused, ExprShared);
2637 else
2638 write(Op);
2639 if (Op != Ops.back())
2640 write(", ");
2641 }
2642
2643 write(")");
2644 }
2645
2646 const std::string &getResult() {
2647 return Str;
2648 }
2649 };
2650
2651 /// Generate remarks for matrix operations in a function. To generate remarks
2652 /// for matrix expressions, the following approach is used:
2653 /// 1. Use the inlined-at debug information to group matrix operations to the
2654 /// DISubprograms they are contained in.
2655 /// 2. Collect leaves of matrix expressions (done in
2656 /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression
2657 // mapping. Leaves are lowered matrix instructions without other matrix
2658 // users (like stores) in the current subprogram.
2659 /// 3. For each leaf, create a remark containing a linearizied version of the
2660 /// matrix expression. The expression is linearized by a recursive
2661 /// bottom-up traversal of the matrix operands, starting at a leaf. Note
2662 /// that multiple leaves can share sub-expressions. Shared subexpressions
2663 /// are explicitly marked as shared().
2664 struct RemarkGenerator {
2665 const MapVector<Value *, MatrixTy> &Inst2Matrix;
2666 OptimizationRemarkEmitter &ORE;
2667 Function &Func;
2668 const DataLayout &DL;
2669
2670 RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix,
2671 OptimizationRemarkEmitter &ORE, Function &Func)
2672 : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func),
2673 DL(Func.getDataLayout()) {}
2674
2675 /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are
2676 /// instructions in Inst2Matrix returning void or without any users in
2677 /// \p ExprsInSubprogram. Currently that should only include stores.
2679 getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) {
2681 for (auto *Expr : ExprsInSubprogram)
2682 if (Expr->getType()->isVoidTy() ||
2683 !any_of(Expr->users(), [&ExprsInSubprogram](User *U) {
2684 return ExprsInSubprogram.count(U);
2685 }))
2686 Leaves.push_back(Expr);
2687 return Leaves;
2688 }
2689
2690 /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf
2691 /// to all visited expressions in \p Shared. Limit the matrix operations to
2692 /// the ones in \p ExprsInSubprogram.
2693 void collectSharedInfo(Value *Leaf, Value *V,
2694 const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2695 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) {
2696
2697 if (!ExprsInSubprogram.count(V))
2698 return;
2699
2700 Shared[V].insert(Leaf);
2701
2702 for (Value *Op : cast<Instruction>(V)->operand_values())
2703 collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared);
2704 }
2705
2706 /// Calculate the number of exclusive and shared op counts for expression
2707 /// starting at \p V. Expressions used multiple times are counted once.
2708 /// Limit the matrix operations to the ones in \p ExprsInSubprogram.
2709 std::pair<OpInfoTy, OpInfoTy>
2710 sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs,
2711 const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2712 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const {
2713 if (!ExprsInSubprogram.count(Root))
2714 return {};
2715
2716 // Already counted this expression. Stop.
2717 if (!ReusedExprs.insert(Root).second)
2718 return {};
2719
2720 OpInfoTy SharedCount;
2721 OpInfoTy Count;
2722
2723 auto I = Shared.find(Root);
2724 auto CM = Inst2Matrix.find(Root);
2725 if (I->second.size() == 1)
2726 Count = CM->second.getOpInfo();
2727 else
2728 SharedCount = CM->second.getOpInfo();
2729
2730 for (Value *Op : cast<Instruction>(Root)->operand_values()) {
2731 auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared);
2732 Count += C.first;
2733 SharedCount += C.second;
2734 }
2735 return {Count, SharedCount};
2736 }
2737
2738 void emitRemarks() {
2739 if (!ORE.allowExtraAnalysis(DEBUG_TYPE))
2740 return;
2741
2742 // Map matrix operations to their containting subprograms, by traversing
2743 // the inlinedAt chain. If the function does not have a DISubprogram, we
2744 // only map them to the containing function.
2745 MapVector<DISubprogram *, SmallVector<Value *, 8>> Subprog2Exprs;
2746 for (const auto &KV : Inst2Matrix) {
2747 if (Func.getSubprogram()) {
2748 auto *I = cast<Instruction>(KV.first);
2749 DILocation *Context = I->getDebugLoc();
2750 while (Context) {
2751 Subprog2Exprs[getSubprogram(Context->getScope())].push_back(
2752 KV.first);
2753 Context = DebugLoc(Context).getInlinedAt();
2754 }
2755 } else {
2756 Subprog2Exprs[nullptr].push_back(KV.first);
2757 }
2758 }
2759 for (auto &KV : Subprog2Exprs) {
2760 SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(),
2761 KV.second.end());
2762 auto Leaves = getExpressionLeaves(ExprsInSubprogram);
2763
2764 DenseMap<Value *, SmallPtrSet<Value *, 2>> Shared;
2765 for (Value *Leaf : Leaves)
2766 collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared);
2767
2768 // Generate remarks for each leaf.
2769 for (auto *L : Leaves) {
2770
2771 DebugLoc Loc = cast<Instruction>(L)->getDebugLoc();
2772 DILocation *Context = cast<Instruction>(L)->getDebugLoc();
2773 while (Context) {
2774 if (getSubprogram(Context->getScope()) == KV.first) {
2775 Loc = Context;
2776 break;
2777 }
2778 Context = DebugLoc(Context).getInlinedAt();
2779 }
2780
2781 SmallPtrSet<Value *, 8> ReusedExprs;
2782 OpInfoTy Counts, SharedCounts;
2783 std::tie(Counts, SharedCounts) =
2784 sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared);
2785
2786 OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc,
2788
2789 Rem << "Lowered with ";
2790 Rem << ore::NV("NumStores", Counts.NumStores) << " stores, "
2791 << ore::NV("NumLoads", Counts.NumLoads) << " loads, "
2792 << ore::NV("NumComputeOps", Counts.NumComputeOps)
2793 << " compute ops, "
2794 << ore::NV("NumExposedTransposes", Counts.NumExposedTransposes)
2795 << " exposed transposes";
2796
2797 if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 ||
2798 SharedCounts.NumComputeOps > 0) {
2799 Rem << ",\nadditionally "
2800 << ore::NV("NumStores", SharedCounts.NumStores) << " stores, "
2801 << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, "
2802 << ore::NV("NumFPOps", SharedCounts.NumComputeOps)
2803 << " compute ops"
2804 << " are shared with other expressions";
2805 }
2806
2807 Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL));
2808 ORE.emit(Rem);
2809 }
2810 }
2811 }
2812
2813 std::string
2814 linearize(Value *L,
2815 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
2816 const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2817 const DataLayout &DL) {
2818 ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L);
2819 Lin.linearizeExpr(L, 0, false, false);
2820 return Lin.getResult();
2821 }
2822 };
2823};
2824} // namespace
2825
2828 auto &TTI = AM.getResult<TargetIRAnalysis>(F);
2829
2830 LowerMatrixIntrinsics LMT(F, TTI, Minimal ? nullptr : &AM);
2831 if (LMT.Visit()) {
2833 if (!Minimal) {
2834 PA.preserve<LoopAnalysis>();
2836 }
2837 return PA;
2838 }
2839 return PreservedAnalyses::all();
2840}
2841
2843 raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
2845 OS, MapClassName2PassName);
2846 OS << '<';
2847 if (Minimal)
2848 OS << "minimal";
2849 OS << '>';
2850}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
AMDGPU Register Bank Select
Rewrite undef for PHI
static const Function * getParent(const Value *V)
BitTracker BT
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
#define clEnumValN(ENUMVAL, FLAGNAME, DESC)
#define LLVM_DUMP_METHOD
Mark debug helper function definitions like dump() that should not be stripped from debug builds.
Definition Compiler.h:638
Hexagon Common GEP
hexagon Hexagon specific predictive commoning for HVX vectors
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
iv users
Definition IVUsers.cpp:48
const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Definition Lint.cpp:539
Live Register Matrix
static DISubprogram * getSubprogram(DIScope *Scope)
Helper function to either return Scope, if it is a subprogram or the attached subprogram for a local ...
static cl::opt< bool > ForceFusion("force-fuse-matrix", cl::init(false), cl::Hidden, cl::desc("Force matrix instruction fusion even if not profitable."))
static cl::opt< bool > VerifyShapeInfo("verify-matrix-shapes", cl::Hidden, cl::desc("Enable/disable matrix shape verification."), cl::init(false))
static bool isSplat(Value *V)
Return true if V is a splat of a value (which is used when multiplying a matrix with a scalar).
static cl::opt< bool > TileUseLoops("fuse-matrix-use-loops", cl::init(false), cl::Hidden, cl::desc("Generate loop nest for tiling."))
static cl::opt< bool > FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden, cl::desc("Enable/disable fusing matrix instructions."))
auto m_AnyAdd(const LTy &L, const RTy &R)
Match any add operation (fp or integer).
static cl::opt< bool > AllowContractEnabled("matrix-allow-contract", cl::init(false), cl::Hidden, cl::desc("Allow the use of FMAs if available and profitable. This may " "result in different results, due to less rounding error."))
auto m_AnyMul(const LTy &L, const RTy &R)
Match any mul operation (fp or integer).
static cl::opt< bool > PrintAfterTransposeOpt("matrix-print-after-transpose-opt", cl::init(false))
#define DEBUG_TYPE
static cl::opt< unsigned > TileSize("fuse-matrix-tile-size", cl::init(4), cl::Hidden, cl::desc("Tile size for matrix instruction fusion using square-shaped tiles."))
static cl::opt< MatrixLayoutTy > MatrixLayout("matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor), cl::desc("Sets the default matrix layout"), cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major", "Use column-major layout"), clEnumValN(MatrixLayoutTy::RowMajor, "row-major", "Use row-major layout")))
#define F(x, y, z)
Definition MD5.cpp:55
#define I(x, y, z)
Definition MD5.cpp:58
#define T
#define T1
uint64_t IntrinsicInst * II
PowerPC Reduce CR logical Operation
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
const SmallVectorImpl< MachineOperand > & Cond
static Value * extractVector(IRBuilderTy &IRB, Value *V, unsigned BeginIndex, unsigned EndIndex, const Twine &Name)
Definition SROA.cpp:2600
static Value * insertVector(IRBuilderTy &IRB, Value *Old, Value *V, unsigned BeginIndex, const Twine &Name)
Definition SROA.cpp:2622
This file contains some templates that are useful if you are working with the STL at all.
This file defines the make_scope_exit function, which executes user-defined cleanup logic at scope ex...
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition Statistic.h:171
#define LLVM_DEBUG(...)
Definition Debug.h:114
static SymbolRef::Type getType(const Symbol *Sym)
Definition TapiFile.cpp:39
static const int BlockSize
Definition TarWriter.cpp:33
This pass exposes codegen information to IR-level passes.
static std::optional< unsigned > getOpcode(ArrayRef< VPValue * > Values)
Returns the opcode of Values or ~0 if they do not all agree.
Definition VPlanSLP.cpp:247
static SDValue LowerStore(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)
static SDValue LowerLoad(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)
Value * RHS
Value * LHS
Align getAlign() const
Return the alignment of the memory that is being allocated by the instruction.
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
iterator begin()
Instruction iterator methods.
Definition BasicBlock.h:459
const Function * getParent() const
Return the enclosing method, or null if none.
Definition BasicBlock.h:213
reverse_iterator rbegin()
Definition BasicBlock.h:475
InstListType::reverse_iterator reverse_iterator
Definition BasicBlock.h:172
reverse_iterator rend()
Definition BasicBlock.h:477
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition BasicBlock.h:233
BinaryOps getOpcode() const
Definition InstrTypes.h:374
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
User::op_iterator arg_begin()
Return the iterator pointing to the beginning of the argument list.
MaybeAlign getParamAlign(unsigned ArgNo) const
Extract the alignment for a call or parameter (0=unknown).
Value * getArgOperand(unsigned i) const
User::op_iterator arg_end()
Return the iterator pointing to the end of the argument list.
Instruction::CastOps getOpcode() const
Return the opcode of this CastInst.
Definition InstrTypes.h:612
static LLVM_ABI ConstantAggregateZero * get(Type *Ty)
LLVM_ABI DISubprogram * getSubprogram() const
Get the subprogram for this scope.
Base class for scope-like contexts.
Subprogram description. Uses SubclassData1.
iterator find(const_arg_type_t< KeyT > Val)
Definition DenseMap.h:165
iterator end()
Definition DenseMap.h:81
Analysis pass which computes a DominatorTree.
Definition Dominators.h:284
static constexpr ElementCount getFixed(ScalarTy MinVal)
Definition TypeSize.h:309
void setAllowContract(bool B=true)
Definition FMF.h:90
bool allowReassoc() const
Flag queries.
Definition FMF.h:64
bool allowContract() const
Definition FMF.h:69
unsigned getNumElements() const
static LLVM_ABI FixedVectorType * get(Type *ElementType, unsigned NumElts)
Definition Type.cpp:803
Intrinsic::ID getIntrinsicID() const LLVM_READONLY
getIntrinsicID - This method returns the ID number of the specified function, or Intrinsic::not_intri...
Definition Function.h:244
bool isIntrinsic() const
isIntrinsic - Returns true if the function's name starts with "llvm.".
Definition Function.h:249
LLVM_ABI CallInst * CreateFAddReduce(Value *Acc, Value *Src)
Create a sequential vector fadd reduction intrinsic of the source vector.
Value * CreateICmpULT(Value *LHS, Value *RHS, const Twine &Name="")
Definition IRBuilder.h:2345
Value * CreateInsertElement(Type *VecTy, Value *NewElt, Value *Idx, const Twine &Name="")
Definition IRBuilder.h:2571
AllocaInst * CreateAlloca(Type *Ty, unsigned AddrSpace, Value *ArraySize=nullptr, const Twine &Name="")
Definition IRBuilder.h:1830
Value * CreateExtractElement(Value *Vec, Value *Idx, const Twine &Name="")
Definition IRBuilder.h:2559
LoadInst * CreateAlignedLoad(Type *Ty, Value *Ptr, MaybeAlign Align, const char *Name)
Definition IRBuilder.h:1864
CallInst * CreateMemCpy(Value *Dst, MaybeAlign DstAlign, Value *Src, MaybeAlign SrcAlign, uint64_t Size, bool isVolatile=false, const AAMDNodes &AAInfo=AAMDNodes())
Create and insert a memcpy between the specified pointers.
Definition IRBuilder.h:687
Value * CreateFAdd(Value *L, Value *R, const Twine &Name="", MDNode *FPMD=nullptr)
Definition IRBuilder.h:1613
LLVM_ABI Value * CreateVectorSplat(unsigned NumElts, Value *V, const Twine &Name="")
Return a vector value that contains.
LLVM_ABI Value * CreateSelect(Value *C, Value *True, Value *False, const Twine &Name="", Instruction *MDFrom=nullptr)
LLVM_ABI CallInst * CreateAddReduce(Value *Src)
Create a vector int add reduction intrinsic of the source vector.
IntegerType * getIntPtrTy(const DataLayout &DL, unsigned AddrSpace=0)
Fetch the type of an integer with size at least as big as that of a pointer in the given address spac...
Definition IRBuilder.h:611
Value * CreateCast(Instruction::CastOps Op, Value *V, Type *DestTy, const Twine &Name="", MDNode *FPMathTag=nullptr, FMFSource FMFSource={})
Definition IRBuilder.h:2238
void setFastMathFlags(FastMathFlags NewFMF)
Set the fast-math flags to be used with generated fp-math operators.
Definition IRBuilder.h:345
Value * CreateGEP(Type *Ty, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &Name="", GEPNoWrapFlags NW=GEPNoWrapFlags::none())
Definition IRBuilder.h:1923
ConstantInt * getInt64(uint64_t C)
Get a constant 64-bit value.
Definition IRBuilder.h:527
LLVM_ABI Value * CreateBinaryIntrinsic(Intrinsic::ID ID, Value *LHS, Value *RHS, FMFSource FMFSource={}, const Twine &Name="")
Create a call to intrinsic ID with 2 operands which is mangled on the first type.
LLVM_ABI CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="")
Create a call to intrinsic ID with Args, mangled using Types.
PHINode * CreatePHI(Type *Ty, unsigned NumReservedValues, const Twine &Name="")
Definition IRBuilder.h:2494
ConstantInt * getIntN(unsigned N, uint64_t C)
Get a constant N-bit value, zero extended or truncated from a 64-bit value.
Definition IRBuilder.h:533
BranchInst * CreateCondBr(Value *Cond, BasicBlock *True, BasicBlock *False, MDNode *BranchWeights=nullptr, MDNode *Unpredictable=nullptr)
Create a conditional 'br Cond, TrueDest, FalseDest' instruction.
Definition IRBuilder.h:1197
LLVM_ABI CallInst * CreateUnaryIntrinsic(Intrinsic::ID ID, Value *V, FMFSource FMFSource={}, const Twine &Name="")
Create a call to intrinsic ID with 1 operand which is mangled on its type.
LoadInst * CreateLoad(Type *Ty, Value *Ptr, const char *Name)
Provided to resolve 'CreateLoad(Ty, Ptr, "...")' correctly, instead of converting the string to 'bool...
Definition IRBuilder.h:1847
Value * CreateShuffleVector(Value *V1, Value *V2, Value *Mask, const Twine &Name="")
Definition IRBuilder.h:2593
Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Definition IRBuilder.h:1403
Value * CreatePtrToInt(Value *V, Type *DestTy, const Twine &Name="")
Definition IRBuilder.h:2194
Value * CreateBinOp(Instruction::BinaryOps Opc, Value *LHS, Value *RHS, const Twine &Name="", MDNode *FPMathTag=nullptr)
Definition IRBuilder.h:1708
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
Definition IRBuilder.h:207
StoreInst * CreateAlignedStore(Value *Val, Value *Ptr, MaybeAlign Align, bool isVolatile=false)
Definition IRBuilder.h:1883
Value * CreateFMul(Value *L, Value *R, const Twine &Name="", MDNode *FPMD=nullptr)
Definition IRBuilder.h:1651
Value * CreateFNeg(Value *V, const Twine &Name="", MDNode *FPMathTag=nullptr)
Definition IRBuilder.h:1790
Value * CreateMul(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Definition IRBuilder.h:1437
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition IRBuilder.h:2780
LLVM_ABI void moveAfter(Instruction *MovePos)
Unlink this instruction from its current basic block and insert it into the basic block that MovePos ...
LLVM_ABI void setFastMathFlags(FastMathFlags FMF)
Convenience function for setting multiple fast-math flags on this instruction, which must be an opera...
LLVM_ABI InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
LLVM_ABI FastMathFlags getFastMathFlags() const LLVM_READONLY
Convenience function for getting all the fast-math flags, which must be an operator which supports th...
Intrinsic::ID getIntrinsicID() const
Return the intrinsic ID of this intrinsic.
bool isVolatile() const
Return true if this is a load from a volatile memory location.
Align getAlign() const
Return the alignment of the access that is being performed.
TypeSize getValue() const
Analysis pass that exposes the LoopInfo for a function.
Definition LoopInfo.h:569
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void printPipeline(raw_ostream &OS, function_ref< StringRef(StringRef)> MapClassName2PassName)
CallInst * CreateMatrixTranspose(Value *Matrix, unsigned Rows, unsigned Columns, const Twine &Name="")
Create a llvm.matrix.transpose call, transposing Matrix with Rows rows and Columns columns.
CallInst * CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows, unsigned LHSColumns, unsigned RHSColumns, const Twine &Name="")
Create a llvm.matrix.multiply call, multiplying matrixes LHS and RHS.
static LLVM_ABI MemoryLocation get(const LoadInst *LI)
Return a location with information about the memory reference by the given instruction.
LocationSize Size
The maximum size of the location, in address-units, or UnknownSize if the size is not known.
const Value * Ptr
The address of the start of the location.
static LLVM_ABI MemoryLocation getForArgument(const CallBase *Call, unsigned ArgIdx, const TargetLibraryInfo *TLI)
Return a location representing a particular argument of a call.
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
iterator_range< const_block_iterator > blocks() const
op_range incoming_values()
static LLVM_ABI PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalyses & preserve()
Mark an analysis as preserved.
Definition Analysis.h:132
size_type size() const
Determine the number of elements in the SetVector.
Definition SetVector.h:104
void insert_range(Range &&R)
Definition SetVector.h:193
size_type count(const key_type &key) const
Count the number of elements of a given key in the SetVector.
Definition SetVector.h:279
bool insert(const value_type &X)
Insert a new element into the SetVector.
Definition SetVector.h:168
bool erase(PtrType Ptr)
Remove pointer from the set.
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
bool contains(ConstPtrType Ptr) const
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
void resize(size_type N)
void push_back(const T &Elt)
Align getAlign() const
bool isVolatile() const
Return true if this is a store to a volatile memory location.
StringRef - Represent a constant reference to a string, i.e.
Definition StringRef.h:55
StringRef drop_front(size_t N=1) const
Return a StringRef equal to 'this' but with the first N elements dropped.
Definition StringRef.h:619
constexpr size_t size() const
size - Get the string size.
Definition StringRef.h:154
Analysis pass providing the TargetTransformInfo.
@ TCK_RecipThroughput
Reciprocal throughput.
@ SK_Splice
Concatenates elements from the first input vector with elements of the second input vector.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:45
Type * getScalarType() const
If this is a vector type, return the element type, otherwise return 'this'.
Definition Type.h:352
LLVM_ABI TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
Definition Type.cpp:198
LLVM_ABI unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
Definition Type.cpp:231
bool isVoidTy() const
Return true if this is 'void'.
Definition Type.h:139
UnaryOps getOpcode() const
Definition InstrTypes.h:154
Value * getOperand(unsigned i) const
Definition User.h:232
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:256
user_iterator user_begin()
Definition Value.h:402
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition Value.h:439
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition Value.cpp:546
iterator_range< user_iterator > users()
Definition Value.h:426
bool use_empty() const
Definition Value.h:346
iterator_range< use_iterator > uses()
Definition Value.h:380
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
Definition Value.cpp:322
Type * getElementType() const
constexpr ScalarTy getFixedValue() const
Definition TypeSize.h:200
An efficient, type-erasing, non-owning reference to a callable.
const ParentTy * getParent() const
Definition ilist_node.h:34
self_iterator getIterator()
Definition ilist_node.h:130
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition raw_ostream.h:53
Changed
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr char Align[]
Key for Kernel::Arg::Metadata::mAlign.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
@ BasicBlock
Various leaf nodes.
Definition ISDOpcodes.h:81
LLVM_ABI StringRef getBaseName(ID id)
Return the LLVM name for an intrinsic, without encoded types for overloading, such as "llvm....
OneUse_match< SubPat > m_OneUse(const SubPat &SP)
TwoOps_match< ValueOpTy, PointerOpTy, Instruction::Store > m_Store(const ValueOpTy &ValueOp, const PointerOpTy &PointerOp)
Matches StoreInst.
BinaryOp_match< LHS, RHS, Instruction::Add > m_Add(const LHS &L, const RHS &R)
class_match< BinaryOperator > m_BinOp()
Match an arbitrary binary operation and ignore it.
specific_intval< false > m_SpecificInt(const APInt &V)
Match a specific integer value or vector with all elements equal to the value.
BinaryOp_match< LHS, RHS, Instruction::FMul > m_FMul(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
specificval_ty m_Specific(const Value *V)
Match if we have a specific specified value.
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
BinaryOp_match< LHS, RHS, Instruction::FAdd > m_FAdd(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
OneOps_match< OpTy, Instruction::Load > m_Load(const OpTy &Op)
Matches LoadInst.
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
match_combine_or< LTy, RTy > m_CombineOr(const LTy &L, const RTy &R)
Combine two pattern matchers matching L || R.
ValuesClass values(OptsTy... Options)
Helper to build a ValuesClass by forwarding a variable number of arguments as an initializer list to ...
initializer< Ty > init(const Ty &Val)
ElementType
The element type of an SRV or UAV resource.
Definition DXILABI.h:60
DiagnosticInfoOptimizationBase::Argument NV
NodeAddr< PhiNode * > Phi
Definition RDFGraph.h:390
friend class Instruction
Iterator for Instructions in a `BasicBlock.
Definition BasicBlock.h:73
This is an optimization pass for GlobalISel generic memory operations.
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition STLExtras.h:318
void dump(const SparseBitVector< ElementSize > &LHS, raw_ostream &out)
@ Offset
Definition DWP.cpp:477
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
void fill(R &&Range, T &&Value)
Provide wrappers to std::fill which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1725
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
Definition STLExtras.h:1657
detail::zippy< detail::zip_first, T, U, Args... > zip_equal(T &&t, U &&u, Args &&...args)
zip iterator that assumes that all iteratees have the same length.
Definition STLExtras.h:841
detail::scope_exit< std::decay_t< Callable > > make_scope_exit(Callable &&F)
Definition ScopeExit.h:59
auto enumerate(FirstRange &&First, RestRanges &&...Rest)
Given two or more input ranges, returns a new range whose values are tuples (A, B,...
Definition STLExtras.h:2452
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:649
auto successors(const MachineBasicBlock *BB)
bool operator!=(uint64_t V1, const APInt &V2)
Definition APInt.h:2113
iterator_range< T > make_range(T x, T y)
Convenience function for iterating over sub-ranges.
LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt & operator+=(DynamicAPInt &A, int64_t B)
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition STLExtras.h:634
LLVM_ABI Value * concatenateVectors(IRBuilderBase &Builder, ArrayRef< Value * > Vecs)
Concatenate a list of vectors.
bool operator==(const AddressRangeValuePair &LHS, const AddressRangeValuePair &RHS)
const Value * getPointerOperand(const Value *V)
A helper function that returns the pointer operand of a load, store or GEP instruction.
LLVM_ABI void addStringMetadataToLoop(Loop *TheLoop, const char *MDString, unsigned V=0)
Set input string into loop metadata by keeping other values intact.
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1712
auto reverse(ContainerTy &&C)
Definition STLExtras.h:408
LLVM_ABI Error write(MCStreamer &Out, ArrayRef< std::string > Inputs, OnCuIndexOverflow OverflowOptValue)
Definition DWP.cpp:622
void sort(IteratorTy Start, IteratorTy End)
Definition STLExtras.h:1624
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:207
LLVM_ABI void report_fatal_error(Error Err, bool gen_crash_diag=true)
Definition Error.cpp:167
FunctionAddr VTableAddr Count
Definition InstrProf.h:139
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:548
LLVM_ABI raw_fd_ostream & errs()
This returns a reference to a raw_ostream for standard error.
TargetTransformInfo TTI
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
@ Mul
Product of integers.
@ Add
Sum of integers.
DWARFExpression::Operation Op
raw_ostream & operator<<(raw_ostream &OS, const APFixedPoint &FX)
ArrayRef(const T &OneElt) -> ArrayRef< T >
OutputIt copy(R &&Range, OutputIt Out)
Definition STLExtras.h:1815
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:565
LLVM_ABI BasicBlock * SplitBlock(BasicBlock *Old, BasicBlock::iterator SplitPt, DominatorTree *DT, LoopInfo *LI=nullptr, MemorySSAUpdater *MSSAU=nullptr, const Twine &BBName="", bool Before=false)
Split the specified block at the specified instruction.
Align commonAlignment(Align A, uint64_t Offset)
Returns the alignment that satisfies both alignments.
Definition Alignment.h:212
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
LLVM_ABI const Value * getUnderlyingObject(const Value *V, unsigned MaxLookup=MaxLookupSearchDepth)
This method strips off any GEP address adjustments, pointer casts or llvm.threadlocal....
AAResults AliasAnalysis
Temporary typedef for legacy code that uses a generic AliasAnalysis pointer or reference.
LLVM_ABI llvm::SmallVector< int, 16 > createSequentialMask(unsigned Start, unsigned NumInts, unsigned NumUndefs)
Create a sequential shuffle mask.
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition BitVector.h:853
#define N
A CRTP mix-in to automatically provide informational APIs needed for passes.
Definition PassManager.h:70