LLVM 19.0.0git
LowerMatrixIntrinsics.cpp
Go to the documentation of this file.
1//===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Lower matrix intrinsics to vector operations.
10//
11// TODO:
12// * Improve fusion:
13// * Support more cases, e.g. multiply-add, multiply-sub, operands/results
14// transposed.
15// * Improve cost-modeling, e.g. choose different number of rows/columns
16// columns for tiles, consider cost of copies on alias.
17//
18//===----------------------------------------------------------------------===//
19
22#include "llvm/ADT/SmallSet.h"
31#include "llvm/IR/CFG.h"
32#include "llvm/IR/DataLayout.h"
34#include "llvm/IR/Function.h"
35#include "llvm/IR/IRBuilder.h"
42#include "llvm/Support/Debug.h"
46
47#include <cmath>
48
49using namespace llvm;
50using namespace PatternMatch;
51
52#define DEBUG_TYPE "lower-matrix-intrinsics"
53
54static cl::opt<bool>
55 FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden,
56 cl::desc("Enable/disable fusing matrix instructions."));
57// TODO: Allow and use non-square tiles.
59 "fuse-matrix-tile-size", cl::init(4), cl::Hidden,
61 "Tile size for matrix instruction fusion using square-shaped tiles."));
62static cl::opt<bool> TileUseLoops("fuse-matrix-use-loops", cl::init(false),
64 cl::desc("Generate loop nest for tiling."));
66 "force-fuse-matrix", cl::init(false), cl::Hidden,
67 cl::desc("Force matrix instruction fusion even if not profitable."));
69 "matrix-allow-contract", cl::init(false), cl::Hidden,
70 cl::desc("Allow the use of FMAs if available and profitable. This may "
71 "result in different results, due to less rounding error."));
72
73static cl::opt<bool>
74 VerifyShapeInfo("verify-matrix-shapes", cl::Hidden,
75 cl::desc("Enable/disable matrix shape verification."),
76 cl::init(false));
77
79
81 "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor),
82 cl::desc("Sets the default matrix layout"),
84 "Use column-major layout"),
86 "Use row-major layout")));
87
88static cl::opt<bool> PrintAfterTransposeOpt("matrix-print-after-transpose-opt",
89 cl::init(false));
90
91/// Helper function to either return Scope, if it is a subprogram or the
92/// attached subprogram for a local scope.
94 if (auto *Subprogram = dyn_cast<DISubprogram>(Scope))
95 return Subprogram;
96 return cast<DILocalScope>(Scope)->getSubprogram();
97}
98
99/// Erase \p V from \p BB and move \II forward to avoid invalidating
100/// iterators.
102 BasicBlock &BB) {
103 auto *Inst = cast<Instruction>(V);
104 // Still used, don't erase.
105 if (!Inst->use_empty())
106 return;
107 if (II != BB.rend() && Inst == &*II)
108 ++II;
109 Inst->eraseFromParent();
110}
111
112/// Return true if V is a splat of a value (which is used when multiplying a
113/// matrix with a scalar).
114static bool isSplat(Value *V) {
115 if (auto *SV = dyn_cast<ShuffleVectorInst>(V))
116 return SV->isZeroEltSplat();
117 return false;
118}
119
120/// Match any mul operation (fp or integer).
121template <typename LTy, typename RTy>
122auto m_AnyMul(const LTy &L, const RTy &R) {
123 return m_CombineOr(m_Mul(L, R), m_FMul(L, R));
124}
125
126/// Match any add operation (fp or integer).
127template <typename LTy, typename RTy>
128auto m_AnyAdd(const LTy &L, const RTy &R) {
129 return m_CombineOr(m_Add(L, R), m_FAdd(L, R));
130}
131
132namespace {
133
134// Given an element pointer \p BasePtr to the start of a (sub) matrix, compute
135// the start address of vector \p VecIdx with type (\p EltType x \p NumElements)
136// assuming \p Stride elements between start two consecutive vectors.
137// \p Stride must be >= \p NumElements.
138// For column-major matrixes, the function computes the address of a column
139// vectors and \p NumElements must be set to the number of elements in a column
140// (= number of rows of the matrix). For row-major matrixes, the function
141// computes the address of a row vector and \p NumElements must be set to the
142// number of elements in a column (= number of columns of the matrix).
143//
144// Consider a 4x4 matrix in column-mjaor layout like below
145//
146// 0 1 2 3
147// 0 v_0_0 v_0_1 v_0_2 v_0_3
148// 1 v_1_0 v_1_1 v_1_2 v_1_3
149// 2 v_2_0 v_2_1 v_2_2 v_2_3
150// 3 v_3_0 v_3_1 v_3_2 v_3_3
151
152// To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1,
153// we need a pointer to the first element of the submatrix as base pointer.
154// Then we can use computeVectorAddr to compute the addresses for the columns
155// of the sub-matrix.
156//
157// Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..)
158// -> just returns Base
159// Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..)
160// -> returns Base + (1 * 4)
161// Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..)
162// -> returns Base + (2 * 4)
163//
164// The graphic below illustrates the number of elements in a column (marked
165// with |) and the number of skipped elements (marked with }).
166//
167// v_0_0 v_0_1 {v_0_2 {v_0_3
168// Base Col 1 Col 2
169// | | |
170// v_1_0 |v_1_1 |v_1_2 |v_1_3
171// v_2_0 |v_2_1 |v_2_2 |v_2_3
172// v_3_0 {v_3_1 {v_3_2 v_3_3
173//
174Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride,
175 unsigned NumElements, Type *EltType,
176 IRBuilder<> &Builder) {
177
178 assert((!isa<ConstantInt>(Stride) ||
179 cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) &&
180 "Stride must be >= the number of elements in the result vector.");
181
182 // Compute the start of the vector with index VecIdx as VecIdx * Stride.
183 Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start");
184
185 // Get pointer to the start of the selected vector. Skip GEP creation,
186 // if we select vector 0.
187 if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->isZero())
188 VecStart = BasePtr;
189 else
190 VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep");
191
192 return VecStart;
193}
194
195namespace {
196struct ShapeInfo {
197 unsigned NumRows;
198 unsigned NumColumns;
199
200 bool IsColumnMajor;
201
202 ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0)
203 : NumRows(NumRows), NumColumns(NumColumns),
204 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
205
206 ShapeInfo(Value *NumRows, Value *NumColumns)
207 : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(),
208 cast<ConstantInt>(NumColumns)->getZExtValue()) {}
209
210 bool operator==(const ShapeInfo &other) {
211 return NumRows == other.NumRows && NumColumns == other.NumColumns;
212 }
213 bool operator!=(const ShapeInfo &other) { return !(*this == other); }
214
215 /// Returns true if shape-information is defined, meaning both dimensions
216 /// are != 0.
217 operator bool() const {
218 assert(NumRows == 0 || NumColumns != 0);
219 return NumRows != 0;
220 }
221
222 unsigned getStride() const {
223 if (IsColumnMajor)
224 return NumRows;
225 return NumColumns;
226 }
227
228 unsigned getNumVectors() const {
229 if (IsColumnMajor)
230 return NumColumns;
231 return NumRows;
232 }
233
234 /// Returns the transposed shape.
235 ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); }
236};
237} // namespace
238
239static bool isUniformShape(Value *V) {
240 Instruction *I = dyn_cast<Instruction>(V);
241 if (!I)
242 return true;
243
244 switch (I->getOpcode()) {
245 case Instruction::FAdd:
246 case Instruction::FSub:
247 case Instruction::FMul: // Scalar multiply.
248 case Instruction::FNeg:
249 case Instruction::Add:
250 case Instruction::Mul:
251 case Instruction::Sub:
252 return true;
253 default:
254 return false;
255 }
256}
257
258/// Return the ShapeInfo for the result of \p I, it it can be determined.
259static std::optional<ShapeInfo>
260computeShapeInfoForInst(Instruction *I,
261 const ValueMap<Value *, ShapeInfo> &ShapeMap) {
262 Value *M;
263 Value *N;
264 Value *K;
265 if (match(I, m_Intrinsic<Intrinsic::matrix_multiply>(
266 m_Value(), m_Value(), m_Value(M), m_Value(N), m_Value(K))))
267 return ShapeInfo(M, K);
268 if (match(I, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(), m_Value(M),
269 m_Value(N)))) {
270 // Flip dimensions.
271 return ShapeInfo(N, M);
272 }
273 if (match(I, m_Intrinsic<Intrinsic::matrix_column_major_store>(
274 m_Value(), m_Value(), m_Value(), m_Value(), m_Value(M),
275 m_Value(N))))
276 return ShapeInfo(N, M);
277 if (match(I, m_Intrinsic<Intrinsic::matrix_column_major_load>(
278 m_Value(), m_Value(), m_Value(), m_Value(M), m_Value(N))))
279 return ShapeInfo(M, N);
280 Value *MatrixA;
281 if (match(I, m_Store(m_Value(MatrixA), m_Value()))) {
282 auto OpShape = ShapeMap.find(MatrixA);
283 if (OpShape != ShapeMap.end())
284 return OpShape->second;
285 }
286
287 if (isUniformShape(I)) {
288 // Find the first operand that has a known shape and use that.
289 for (auto &Op : I->operands()) {
290 auto OpShape = ShapeMap.find(Op.get());
291 if (OpShape != ShapeMap.end())
292 return OpShape->second;
293 }
294 }
295 return std::nullopt;
296}
297
298/// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
299///
300/// Currently, the lowering for each matrix intrinsic is done as follows:
301/// 1. Propagate the shape information from intrinsics to connected
302/// instructions.
303/// 2. Lower instructions with shape information (assuming column-major layout).
304/// The lowering works similarly using row-major layout.
305/// 2.1. Get column vectors for each argument. If we already lowered the
306/// definition of an argument, use the produced column vectors directly.
307/// If not, split the operand vector containing an embedded matrix into
308/// a set of column vectors,
309/// 2.2. Lower the instruction in terms of column major operations, which
310/// yields a set of column vectors containing result matrix. Note that we
311/// lower all instructions that have shape information. Besides the
312/// intrinsics, this includes stores for example.
313/// 2.3. Update uses of the lowered instruction. If we have shape information
314/// for a user, there is nothing to do, as we will look up the result
315/// column matrix when lowering the user. For other uses, we embed the
316/// result matrix in a flat vector and update the use.
317/// 2.4. Cache the result column matrix for the instruction we lowered
318/// 3. After we lowered all instructions in a function, remove the now
319/// obsolete instructions.
320///
321class LowerMatrixIntrinsics {
322 Function &Func;
323 const DataLayout &DL;
325 AliasAnalysis *AA;
326 DominatorTree *DT;
327 LoopInfo *LI;
329
330 /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation.
331 struct OpInfoTy {
332 /// Number of stores emitted to generate this matrix.
333 unsigned NumStores = 0;
334 /// Number of loads emitted to generate this matrix.
335 unsigned NumLoads = 0;
336 /// Number of compute operations emitted to generate this matrix.
337 unsigned NumComputeOps = 0;
338 /// Most of the time transposes can be fused with matrix multiplies or can
339 /// be folded away via algebraic simplifications. This is the number of
340 /// transposes that we failed to make "free" via such optimizations.
341 unsigned NumExposedTransposes = 0;
342
343 OpInfoTy &operator+=(const OpInfoTy &RHS) {
344 NumStores += RHS.NumStores;
345 NumLoads += RHS.NumLoads;
346 NumComputeOps += RHS.NumComputeOps;
347 NumExposedTransposes += RHS.NumExposedTransposes;
348 return *this;
349 }
350 };
351
352 /// Wrapper class representing a matrix as a set of vectors, either in row or
353 /// column major layout. All vectors must have the same vector type.
354 class MatrixTy {
356
357 OpInfoTy OpInfo;
358
359 bool IsColumnMajor = true;
360
361 public:
362 MatrixTy() : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
363 MatrixTy(ArrayRef<Value *> Vectors)
364 : Vectors(Vectors.begin(), Vectors.end()),
365 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
366 MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy)
367 : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {
368
369 unsigned D = isColumnMajor() ? NumColumns : NumRows;
370 for (unsigned J = 0; J < D; ++J)
372 EltTy, isColumnMajor() ? NumRows : NumColumns)));
373 }
374
375 Value *getVector(unsigned i) const { return Vectors[i]; }
376 Value *getColumn(unsigned i) const {
377 assert(isColumnMajor() && "only supported for column-major matrixes");
378 return Vectors[i];
379 }
380 Value *getRow(unsigned i) const {
381 assert(!isColumnMajor() && "only supported for row-major matrixes");
382 return Vectors[i];
383 }
384
385 void setVector(unsigned i, Value *V) { Vectors[i] = V; }
386
387 Type *getElementType() const { return getVectorTy()->getElementType(); }
388
389 unsigned getNumVectors() const {
390 if (isColumnMajor())
391 return getNumColumns();
392 return getNumRows();
393 }
394
395 unsigned getNumColumns() const {
396 if (isColumnMajor())
397 return Vectors.size();
398 else {
399 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
400 return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
401 }
402 }
403 unsigned getNumRows() const {
404 if (isColumnMajor()) {
405 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
406 return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
407 } else
408 return Vectors.size();
409 }
410
411 void addVector(Value *V) { Vectors.push_back(V); }
412 VectorType *getColumnTy() {
413 assert(isColumnMajor() && "only supported for column-major matrixes");
414 return getVectorTy();
415 }
416
417 VectorType *getVectorTy() const {
418 return cast<VectorType>(Vectors[0]->getType());
419 }
420
422 assert(isColumnMajor() &&
423 "columns() only supported for column-major matrixes");
424 return make_range(Vectors.begin(), Vectors.end());
425 }
426
428 return make_range(Vectors.begin(), Vectors.end());
429 }
430
431 /// Embed the vectors of the matrix into a flat vector by concatenating
432 /// them.
433 Value *embedInVector(IRBuilder<> &Builder) const {
434 return Vectors.size() == 1 ? Vectors[0]
435 : concatenateVectors(Builder, Vectors);
436 }
437
438 MatrixTy &addNumLoads(unsigned N) {
439 OpInfo.NumLoads += N;
440 return *this;
441 }
442
443 void setNumLoads(unsigned N) { OpInfo.NumLoads = N; }
444
445 MatrixTy &addNumStores(unsigned N) {
446 OpInfo.NumStores += N;
447 return *this;
448 }
449
450 MatrixTy &addNumExposedTransposes(unsigned N) {
451 OpInfo.NumExposedTransposes += N;
452 return *this;
453 }
454
455 MatrixTy &addNumComputeOps(unsigned N) {
456 OpInfo.NumComputeOps += N;
457 return *this;
458 }
459
460 unsigned getNumStores() const { return OpInfo.NumStores; }
461 unsigned getNumLoads() const { return OpInfo.NumLoads; }
462 unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; }
463
464 const OpInfoTy &getOpInfo() const { return OpInfo; }
465
466 bool isColumnMajor() const { return IsColumnMajor; }
467
468 unsigned getStride() const {
469 if (isColumnMajor())
470 return getNumRows();
471 return getNumColumns();
472 }
473
474 /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the
475 /// matrix is column-major, the result vector is extracted from a column
476 /// vector, otherwise from a row vector.
477 Value *extractVector(unsigned I, unsigned J, unsigned NumElts,
478 IRBuilder<> &Builder) const {
479 Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I);
480 assert(cast<FixedVectorType>(Vec->getType())->getNumElements() >=
481 NumElts &&
482 "Extracted vector will contain poison values");
483 return Builder.CreateShuffleVector(
484 Vec, createSequentialMask(isColumnMajor() ? I : J, NumElts, 0),
485 "block");
486 }
487 };
488
489 /// Maps instructions to their shape information. The shape information
490 /// describes the shape to be used while lowering. This matches the shape of
491 /// the result value of the instruction, with the only exceptions being store
492 /// instructions and the matrix_column_major_store intrinsics. For those, the
493 /// shape information indicates that those instructions should be lowered
494 /// using shape information as well. A ValueMap is used so that when
495 /// sub-passes like optimizeTransposes performs RAUW the map stays
496 /// up-to-date.
498
499 /// List of instructions to remove. While lowering, we are not replacing all
500 /// users of a lowered instruction, if shape information is available and
501 /// those need to be removed after we finished lowering.
503
504 /// Map from instructions to their produced column matrix.
505 MapVector<Value *, MatrixTy> Inst2ColumnMatrix;
506
507private:
509 FastMathFlags FMF;
510
511 if (isa<FPMathOperator>(*Inst))
512 FMF = Inst->getFastMathFlags();
513
515
516 return FMF;
517 }
518
519public:
520 LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI,
523 : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), AA(AA), DT(DT),
524 LI(LI), ORE(ORE) {}
525
526 unsigned getNumOps(Type *VT) {
527 assert(isa<VectorType>(VT) && "Expected vector type");
528 return getNumOps(VT->getScalarType(),
529 cast<FixedVectorType>(VT)->getNumElements());
530 }
531
532 /// Is this the minimal version executed in the backend pipelines.
533 bool isMinimal() const {
534 return !DT;
535 }
536
537 /// Return the estimated number of vector ops required for an operation on
538 /// \p VT * N.
539 unsigned getNumOps(Type *ST, unsigned N) {
540 return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedValue() /
543 .getFixedValue()));
544 }
545
546 /// Return the set of vectors that a matrix value is lowered to.
547 ///
548 /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise
549 /// split the flat vector \p MatrixVal containing a matrix with shape \p SI
550 /// into vectors.
551 MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI,
552 IRBuilder<> &Builder) {
553 VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType());
554 assert(VType && "MatrixVal must be a vector type");
555 assert(cast<FixedVectorType>(VType)->getNumElements() ==
556 SI.NumRows * SI.NumColumns &&
557 "The vector size must match the number of matrix elements");
558
559 // Check if we lowered MatrixVal using shape information. In that case,
560 // return the existing matrix, if it matches the requested shape
561 // information. If there is a mis-match, embed the result in a flat
562 // vector and split it later.
563 auto Found = Inst2ColumnMatrix.find(MatrixVal);
564 if (Found != Inst2ColumnMatrix.end()) {
565 MatrixTy &M = Found->second;
566 // Return the found matrix, if its shape matches the requested shape
567 // information
568 if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns())
569 return M;
570
571 MatrixVal = M.embedInVector(Builder);
572 }
573
574 // Otherwise split MatrixVal.
575 SmallVector<Value *, 16> SplitVecs;
576 for (unsigned MaskStart = 0;
577 MaskStart < cast<FixedVectorType>(VType)->getNumElements();
578 MaskStart += SI.getStride()) {
579 Value *V = Builder.CreateShuffleVector(
580 MatrixVal, createSequentialMask(MaskStart, SI.getStride(), 0),
581 "split");
582 SplitVecs.push_back(V);
583 }
584
585 return {SplitVecs};
586 }
587
588 /// If \p V already has a known shape return false. Otherwise set the shape
589 /// for instructions that support it.
590 bool setShapeInfo(Value *V, ShapeInfo Shape) {
591 assert(Shape && "Shape not set");
592 if (isa<UndefValue>(V) || !supportsShapeInfo(V))
593 return false;
594
595 auto SIter = ShapeMap.find(V);
596 if (SIter != ShapeMap.end()) {
597 if (VerifyShapeInfo && (SIter->second.NumRows != Shape.NumRows ||
598 SIter->second.NumColumns != Shape.NumColumns)) {
599 errs() << "Conflicting shapes (" << SIter->second.NumRows << "x"
600 << SIter->second.NumColumns << " vs " << Shape.NumRows << "x"
601 << Shape.NumColumns << ") for " << *V << "\n";
603 "Matrix shape verification failed, compilation aborted!");
604 }
605
606 LLVM_DEBUG(dbgs() << " not overriding existing shape: "
607 << SIter->second.NumRows << " "
608 << SIter->second.NumColumns << " for " << *V << "\n");
609 return false;
610 }
611
612 ShapeMap.insert({V, Shape});
613 LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns
614 << " for " << *V << "\n");
615 return true;
616 }
617
618 /// Returns true if shape information can be used for \p V. The supported
619 /// instructions must match the instructions that can be lowered by this pass.
620 bool supportsShapeInfo(Value *V) {
621 Instruction *Inst = dyn_cast<Instruction>(V);
622 if (!Inst)
623 return false;
624
625 IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst);
626 if (II)
627 switch (II->getIntrinsicID()) {
628 case Intrinsic::matrix_multiply:
629 case Intrinsic::matrix_transpose:
630 case Intrinsic::matrix_column_major_load:
631 case Intrinsic::matrix_column_major_store:
632 return true;
633 default:
634 return false;
635 }
636 return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
637 }
638
639 /// Propagate the shape information of instructions to their users.
640 /// The work list contains instructions for which we can compute the shape,
641 /// either based on the information provided by matrix intrinsics or known
642 /// shapes of operands.
644 propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) {
646 // Pop an element for which we guaranteed to have at least one of the
647 // operand shapes. Add the shape for this and then add users to the work
648 // list.
649 LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n");
650 while (!WorkList.empty()) {
651 Instruction *Inst = WorkList.pop_back_val();
652
653 // New entry, set the value and insert operands
654 bool Propagate = false;
655 if (auto SI = computeShapeInfoForInst(Inst, ShapeMap))
656 Propagate = setShapeInfo(Inst, *SI);
657
658 if (Propagate) {
659 NewWorkList.push_back(Inst);
660 for (auto *User : Inst->users())
661 if (ShapeMap.count(User) == 0)
662 WorkList.push_back(cast<Instruction>(User));
663 }
664 }
665
666 return NewWorkList;
667 }
668
669 /// Propagate the shape to operands of instructions with shape information.
670 /// \p Worklist contains the instruction for which we already know the shape.
672 propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) {
674
675 auto pushInstruction = [](Value *V,
677 Instruction *I = dyn_cast<Instruction>(V);
678 if (I)
679 WorkList.push_back(I);
680 };
681 // Pop an element with known shape. Traverse the operands, if their shape
682 // derives from the result shape and is unknown, add it and add them to the
683 // worklist.
684 LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n");
685 while (!WorkList.empty()) {
686 Value *V = WorkList.pop_back_val();
687
688 size_t BeforeProcessingV = WorkList.size();
689 if (!isa<Instruction>(V))
690 continue;
691
692 Value *MatrixA;
693 Value *MatrixB;
694 Value *M;
695 Value *N;
696 Value *K;
697 if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>(
698 m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
699 m_Value(N), m_Value(K)))) {
700 if (setShapeInfo(MatrixA, {M, N}))
701 pushInstruction(MatrixA, WorkList);
702
703 if (setShapeInfo(MatrixB, {N, K}))
704 pushInstruction(MatrixB, WorkList);
705
706 } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>(
707 m_Value(MatrixA), m_Value(M), m_Value(N)))) {
708 // Flip dimensions.
709 if (setShapeInfo(MatrixA, {M, N}))
710 pushInstruction(MatrixA, WorkList);
711 } else if (match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>(
712 m_Value(MatrixA), m_Value(), m_Value(), m_Value(),
713 m_Value(M), m_Value(N)))) {
714 if (setShapeInfo(MatrixA, {M, N})) {
715 pushInstruction(MatrixA, WorkList);
716 }
717 } else if (isa<LoadInst>(V) ||
718 match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) {
719 // Nothing to do, no matrix input.
720 } else if (isa<StoreInst>(V)) {
721 // Nothing to do. We forward-propagated to this so we would just
722 // backward propagate to an instruction with an already known shape.
723 } else if (isUniformShape(V)) {
724 // Propagate to all operands.
725 ShapeInfo Shape = ShapeMap[V];
726 for (Use &U : cast<Instruction>(V)->operands()) {
727 if (setShapeInfo(U.get(), Shape))
728 pushInstruction(U.get(), WorkList);
729 }
730 }
731 // After we discovered new shape info for new instructions in the
732 // worklist, we use their users as seeds for the next round of forward
733 // propagation.
734 for (size_t I = BeforeProcessingV; I != WorkList.size(); I++)
735 for (User *U : WorkList[I]->users())
736 if (isa<Instruction>(U) && V != U)
737 NewWorkList.push_back(cast<Instruction>(U));
738 }
739 return NewWorkList;
740 }
741
742 /// (Op0 op Op1)^T -> Op0^T op Op1^T
743 /// Transpose \p Op0 and \p Op1 of shape \p Shape0 and \p Shape1, then use
744 /// them on both sides of \p Operation.
745 Instruction *distributeTransposes(
746 Value *Op0, ShapeInfo Shape0, Value *Op1, ShapeInfo Shape1,
747 MatrixBuilder &Builder,
748 function_ref<Instruction *(Value *, ShapeInfo, Value *, ShapeInfo)>
749 Operation) {
750 Value *T0 = Builder.CreateMatrixTranspose(
751 Op0, Shape0.NumRows, Shape0.NumColumns, Op0->getName() + "_t");
752 // We are being run after shape prop, add shape for newly created
753 // instructions so that we lower them later.
754 setShapeInfo(T0, Shape0.t());
755 Value *T1 = Builder.CreateMatrixTranspose(
756 Op1, Shape1.NumRows, Shape1.NumColumns, Op1->getName() + "_t");
757 setShapeInfo(T1, Shape1.t());
758 return Operation(T0, Shape0.t(), T1, Shape1.t());
759 }
760
761 void updateShapeAndReplaceAllUsesWith(Instruction &Old, Value *New) {
762 // We need to remove Old from the ShapeMap otherwise RAUW will replace it
763 // with New. We should only add New it it supportsShapeInfo so we insert
764 // it conditionally instead.
765 auto S = ShapeMap.find(&Old);
766 if (S != ShapeMap.end()) {
767 ShapeMap.erase(S);
768 if (supportsShapeInfo(New))
769 ShapeMap.insert({New, S->second});
770 }
771 Old.replaceAllUsesWith(New);
772 }
773
774 /// Sink a top-level transpose inside matmuls and adds.
775 /// This creates and erases instructions as needed, and returns the newly
776 /// created instruction while updating the iterator to avoid invalidation. If
777 /// this returns nullptr, no new instruction was created.
779 BasicBlock &BB = *I.getParent();
780 IRBuilder<> IB(&I);
781 MatrixBuilder Builder(IB);
782
783 Value *TA, *TAMA, *TAMB;
784 ConstantInt *R, *K, *C;
785 if (!match(&I, m_Intrinsic<Intrinsic::matrix_transpose>(
787 return nullptr;
788
789 // Transpose of a transpose is a nop
790 Value *TATA;
791 if (match(TA, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TATA)))) {
792 updateShapeAndReplaceAllUsesWith(I, TATA);
793 eraseFromParentAndMove(&I, II, BB);
794 eraseFromParentAndMove(TA, II, BB);
795 return nullptr;
796 }
797
798 // k^T -> k
799 if (isSplat(TA)) {
800 updateShapeAndReplaceAllUsesWith(I, TA);
801 eraseFromParentAndMove(&I, II, BB);
802 return nullptr;
803 }
804
805 // (A * B)^t -> B^t * A^t
806 // RxK KxC CxK KxR
807 if (match(TA, m_Intrinsic<Intrinsic::matrix_multiply>(
808 m_Value(TAMA), m_Value(TAMB), m_ConstantInt(R),
810 auto NewInst = distributeTransposes(
811 TAMB, {K, C}, TAMA, {R, K}, Builder,
812 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
813 return Builder.CreateMatrixMultiply(T0, T1, Shape0.NumRows,
814 Shape0.NumColumns,
815 Shape1.NumColumns, "mmul");
816 });
817 updateShapeAndReplaceAllUsesWith(I, NewInst);
818 eraseFromParentAndMove(&I, II, BB);
819 eraseFromParentAndMove(TA, II, BB);
820 return NewInst;
821 }
822
823 // Same as above, but with a mul, which occurs when multiplied
824 // with a scalar.
825 // (A * k)^t -> A^t * k
826 // R x C RxC
827 if (match(TA, m_AnyMul(m_Value(TAMA), m_Value(TAMB))) &&
828 (isSplat(TAMA) || isSplat(TAMB))) {
829 IRBuilder<> LocalBuilder(&I);
830 // We know that the transposed operand is of shape RxC.
831 // An when multiplied with a scalar, the shape is preserved.
832 auto NewInst = distributeTransposes(
833 TAMA, {R, C}, TAMB, {R, C}, Builder,
834 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
835 bool IsFP = I.getType()->isFPOrFPVectorTy();
836 auto *Mul = IsFP ? LocalBuilder.CreateFMul(T0, T1, "mmul")
837 : LocalBuilder.CreateMul(T0, T1, "mmul");
838 auto *Result = cast<Instruction>(Mul);
839 setShapeInfo(Result, Shape0);
840 return Result;
841 });
842 updateShapeAndReplaceAllUsesWith(I, NewInst);
843 eraseFromParentAndMove(&I, II, BB);
844 eraseFromParentAndMove(TA, II, BB);
845 return NewInst;
846 }
847
848 // (A + B)^t -> A^t + B^t
849 // RxC RxC CxR CxR
850 if (match(TA, m_AnyAdd(m_Value(TAMA), m_Value(TAMB)))) {
851 IRBuilder<> LocalBuilder(&I);
852 auto NewInst = distributeTransposes(
853 TAMA, {R, C}, TAMB, {R, C}, Builder,
854 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
855 bool IsFP = I.getType()->isFPOrFPVectorTy();
856 auto *Add = IsFP ? LocalBuilder.CreateFAdd(T0, T1, "madd")
857 : LocalBuilder.CreateAdd(T0, T1, "madd");
858
859 auto *Result = cast<Instruction>(Add);
860 setShapeInfo(Result, Shape0);
861 return Result;
862 });
863 updateShapeAndReplaceAllUsesWith(I, NewInst);
864 eraseFromParentAndMove(&I, II, BB);
865 eraseFromParentAndMove(TA, II, BB);
866 return NewInst;
867 }
868
869 return nullptr;
870 }
871
872 void liftTranspose(Instruction &I) {
873 // Erase dead Instructions after lifting transposes from binops.
874 auto CleanupBinOp = [](Instruction &T, Value *A, Value *B) {
875 if (T.use_empty())
876 T.eraseFromParent();
877 if (A->use_empty())
878 cast<Instruction>(A)->eraseFromParent();
879 if (A != B && B->use_empty())
880 cast<Instruction>(B)->eraseFromParent();
881 };
882
883 Value *A, *B, *AT, *BT;
884 ConstantInt *R, *K, *C;
885 // A^t * B ^t -> (B * A)^t
886 if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>(
889 match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(AT))) &&
890 match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value((BT))))) {
891 IRBuilder<> IB(&I);
892 MatrixBuilder Builder(IB);
893 Value *M = Builder.CreateMatrixMultiply(
894 BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue());
895 setShapeInfo(M, {C, R});
896 Instruction *NewInst = Builder.CreateMatrixTranspose(M, C->getZExtValue(),
897 R->getZExtValue());
898 updateShapeAndReplaceAllUsesWith(I, NewInst);
899 CleanupBinOp(I, A, B);
900 }
901 // A^t + B ^t -> (A + B)^t. Pick rows and columns from first transpose. If
902 // the shape of the second transpose is different, there's a shape conflict
903 // which gets resolved by picking the shape of the first operand.
904 else if (match(&I, m_FAdd(m_Value(A), m_Value(B))) &&
905 match(A, m_Intrinsic<Intrinsic::matrix_transpose>(
906 m_Value(AT), m_ConstantInt(R), m_ConstantInt(C))) &&
907 match(B, m_Intrinsic<Intrinsic::matrix_transpose>(
909 IRBuilder<> Builder(&I);
910 auto *Add = cast<Instruction>(Builder.CreateFAdd(AT, BT, "mfadd"));
911 setShapeInfo(Add, {R, C});
912 MatrixBuilder MBuilder(Builder);
913 Instruction *NewInst = MBuilder.CreateMatrixTranspose(
914 Add, R->getZExtValue(), C->getZExtValue(), "mfadd_t");
915 updateShapeAndReplaceAllUsesWith(I, NewInst);
916 assert(computeShapeInfoForInst(NewInst, ShapeMap) ==
917 computeShapeInfoForInst(&I, ShapeMap) &&
918 "Shape of new instruction doesn't match original shape.");
919 CleanupBinOp(I, A, B);
920 assert(computeShapeInfoForInst(Add, ShapeMap).value_or(ShapeMap[Add]) ==
921 ShapeMap[Add] &&
922 "Shape of updated addition doesn't match cached shape.");
923 }
924 }
925
926 /// Try moving transposes in order to fold them away or into multiplies.
927 void optimizeTransposes() {
928 // First sink all transposes inside matmuls and adds, hoping that we end up
929 // with NN, NT or TN variants.
930 for (BasicBlock &BB : reverse(Func)) {
931 for (auto II = BB.rbegin(); II != BB.rend();) {
932 Instruction &I = *II;
933 // We may remove II. By default continue on the next/prev instruction.
934 ++II;
935 if (Instruction *NewInst = sinkTranspose(I, II))
936 II = std::next(BasicBlock::reverse_iterator(NewInst));
937 }
938 }
939
940 // If we have a TT matmul or a TT add, lift the transpose. We may be able
941 // to fold into consuming multiply or add.
942 for (BasicBlock &BB : Func) {
944 liftTranspose(I);
945 }
946 }
947 }
948
949 bool Visit() {
951
952 // Initially only the shape of matrix intrinsics is known.
953 // Initialize the work list with ops carrying shape information.
954 for (BasicBlock &BB : Func)
955 for (Instruction &Inst : BB) {
956 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst);
957 if (!II)
958 continue;
959
960 switch (II->getIntrinsicID()) {
961 case Intrinsic::matrix_multiply:
962 case Intrinsic::matrix_transpose:
963 case Intrinsic::matrix_column_major_load:
964 case Intrinsic::matrix_column_major_store:
965 WorkList.push_back(&Inst);
966 break;
967 default:
968 break;
969 }
970 }
971
972 // Avoid unnecessary work if there are no matrix intrinsics in the function.
973 if (WorkList.empty())
974 return false;
975
976 // Propagate shapes until nothing changes any longer.
977 while (!WorkList.empty()) {
978 WorkList = propagateShapeForward(WorkList);
979 WorkList = propagateShapeBackward(WorkList);
980 }
981
982 if (!isMinimal()) {
983 optimizeTransposes();
985 dbgs() << "Dump after matrix transpose optimization:\n";
986 Func.print(dbgs());
987 }
988 }
989
990 bool Changed = false;
991 SmallVector<CallInst *, 16> MaybeFusableInsts;
993
994 // First, collect all instructions with shape information and candidates for
995 // fusion (currently only matrix multiplies).
997 for (auto *BB : RPOT)
998 for (Instruction &I : *BB) {
999 if (ShapeMap.find(&I) == ShapeMap.end())
1000 continue;
1001 if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>()))
1002 MaybeFusableInsts.push_back(cast<CallInst>(&I));
1003 MatrixInsts.push_back(&I);
1004 }
1005
1006 // Second, try to lower any dot products
1008 for (CallInst *CI : MaybeFusableInsts)
1009 lowerDotProduct(CI, FusedInsts, getFastMathFlags(CI));
1010
1011 // Third, try to fuse candidates.
1012 for (CallInst *CI : MaybeFusableInsts)
1013 LowerMatrixMultiplyFused(CI, FusedInsts);
1014
1015 Changed = !FusedInsts.empty();
1016
1017 // Fourth, lower remaining instructions with shape information.
1018 for (Instruction *Inst : MatrixInsts) {
1019 if (FusedInsts.count(Inst))
1020 continue;
1021
1022 IRBuilder<> Builder(Inst);
1023
1024 if (CallInst *CInst = dyn_cast<CallInst>(Inst))
1025 Changed |= VisitCallInst(CInst);
1026
1027 Value *Op1;
1028 Value *Op2;
1029 if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
1030 Changed |= VisitBinaryOperator(BinOp);
1031 if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
1032 Changed |= VisitUnaryOperator(UnOp);
1033 if (match(Inst, m_Load(m_Value(Op1))))
1034 Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder);
1035 else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
1036 Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);
1037 }
1038
1039 if (ORE) {
1040 RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func);
1041 RemarkGen.emitRemarks();
1042 }
1043
1044 // Delete the instructions backwards, as it has a reduced likelihood of
1045 // having to update as many def-use and use-def chains.
1046 //
1047 // Because we add to ToRemove during fusion we can't guarantee that defs
1048 // are before uses. Change uses to poison temporarily as these should get
1049 // removed as well.
1050 //
1051 // For verification, we keep track of where we changed uses to poison in
1052 // PoisonedInsts and then check that we in fact remove them.
1053 SmallSet<Instruction *, 16> PoisonedInsts;
1054 for (auto *Inst : reverse(ToRemove)) {
1055 for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
1056 if (auto *Poisoned = dyn_cast<Instruction>(U.getUser()))
1057 PoisonedInsts.insert(Poisoned);
1058 U.set(PoisonValue::get(Inst->getType()));
1059 }
1060 Inst->eraseFromParent();
1061 PoisonedInsts.erase(Inst);
1062 }
1063 if (!PoisonedInsts.empty()) {
1064 // If we didn't remove all poisoned instructions, it's a hard error.
1065 dbgs() << "Poisoned but present instructions:\n";
1066 for (auto *I : PoisonedInsts)
1067 dbgs() << *I << "\n";
1068 llvm_unreachable("Poisoned but instruction not removed");
1069 }
1070
1071 return Changed;
1072 }
1073
1074 /// Replace intrinsic calls
1075 bool VisitCallInst(CallInst *Inst) {
1076 if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic())
1077 return false;
1078
1079 switch (Inst->getCalledFunction()->getIntrinsicID()) {
1080 case Intrinsic::matrix_multiply:
1081 LowerMultiply(Inst);
1082 break;
1083 case Intrinsic::matrix_transpose:
1084 LowerTranspose(Inst);
1085 break;
1086 case Intrinsic::matrix_column_major_load:
1087 LowerColumnMajorLoad(Inst);
1088 break;
1089 case Intrinsic::matrix_column_major_store:
1090 LowerColumnMajorStore(Inst);
1091 break;
1092 default:
1093 return false;
1094 }
1095 return true;
1096 }
1097
1098 /// Compute the alignment for a column/row \p Idx with \p Stride between them.
1099 /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a
1100 /// ConstantInt, reduce the initial alignment based on the byte offset. For
1101 /// non-ConstantInt strides, return the common alignment of the initial
1102 /// alignment and the element size in bytes.
1103 Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy,
1104 MaybeAlign A) const {
1105 Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy);
1106 if (Idx == 0)
1107 return InitialAlign;
1108
1109 TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy);
1110 if (auto *ConstStride = dyn_cast<ConstantInt>(Stride)) {
1111 uint64_t StrideInBytes =
1112 ConstStride->getZExtValue() * ElementSizeInBits / 8;
1113 return commonAlignment(InitialAlign, Idx * StrideInBytes);
1114 }
1115 return commonAlignment(InitialAlign, ElementSizeInBits / 8);
1116 }
1117
1118 /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between
1119 /// vectors.
1120 MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride,
1121 bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) {
1122 auto *VType = cast<VectorType>(Ty);
1123 Type *EltTy = VType->getElementType();
1124 Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride());
1125 Value *EltPtr = Ptr;
1126 MatrixTy Result;
1127 for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) {
1128 Value *GEP = computeVectorAddr(
1129 EltPtr, Builder.getIntN(Stride->getType()->getScalarSizeInBits(), I),
1130 Stride, Shape.getStride(), EltTy, Builder);
1131 Value *Vector = Builder.CreateAlignedLoad(
1132 VecTy, GEP, getAlignForIndex(I, Stride, EltTy, MAlign),
1133 IsVolatile, "col.load");
1134
1135 Result.addVector(Vector);
1136 }
1137 return Result.addNumLoads(getNumOps(Result.getVectorTy()) *
1138 Result.getNumVectors());
1139 }
1140
1141 /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix,
1142 /// starting at \p MatrixPtr[I][J].
1143 MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile,
1144 ShapeInfo MatrixShape, Value *I, Value *J,
1145 ShapeInfo ResultShape, Type *EltTy,
1146 IRBuilder<> &Builder) {
1147
1148 Value *Offset = Builder.CreateAdd(
1149 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
1150
1151 Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset);
1152 auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows *
1153 ResultShape.NumColumns);
1154
1155 return loadMatrix(TileTy, TileStart, Align,
1156 Builder.getInt64(MatrixShape.getStride()), IsVolatile,
1157 ResultShape, Builder);
1158 }
1159
1160 /// Lower a load instruction with shape information.
1161 void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride,
1162 bool IsVolatile, ShapeInfo Shape) {
1163 IRBuilder<> Builder(Inst);
1164 finalizeLowering(Inst,
1165 loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile,
1166 Shape, Builder),
1167 Builder);
1168 }
1169
1170 /// Lowers llvm.matrix.column.major.load.
1171 ///
1172 /// The intrinsic loads a matrix from memory using a stride between columns.
1173 void LowerColumnMajorLoad(CallInst *Inst) {
1174 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
1175 "Intrinsic only supports column-major layout!");
1176 Value *Ptr = Inst->getArgOperand(0);
1177 Value *Stride = Inst->getArgOperand(1);
1178 LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride,
1179 cast<ConstantInt>(Inst->getArgOperand(2))->isOne(),
1180 {Inst->getArgOperand(3), Inst->getArgOperand(4)});
1181 }
1182
1183 /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p
1184 /// MatrixPtr[I][J].
1185 void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr,
1186 MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape,
1187 Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) {
1188 Value *Offset = Builder.CreateAdd(
1189 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
1190
1191 Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset);
1192 auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() *
1193 StoreVal.getNumColumns());
1194
1195 storeMatrix(TileTy, StoreVal, TileStart, MAlign,
1196 Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder);
1197 }
1198
1199 /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between
1200 /// vectors.
1201 MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr,
1202 MaybeAlign MAlign, Value *Stride, bool IsVolatile,
1203 IRBuilder<> &Builder) {
1204 auto VType = cast<VectorType>(Ty);
1205 Value *EltPtr = Ptr;
1206 for (auto Vec : enumerate(StoreVal.vectors())) {
1207 Value *GEP = computeVectorAddr(
1208 EltPtr,
1209 Builder.getIntN(Stride->getType()->getScalarSizeInBits(),
1210 Vec.index()),
1211 Stride, StoreVal.getStride(), VType->getElementType(), Builder);
1212 Builder.CreateAlignedStore(Vec.value(), GEP,
1213 getAlignForIndex(Vec.index(), Stride,
1214 VType->getElementType(),
1215 MAlign),
1216 IsVolatile);
1217 }
1218 return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) *
1219 StoreVal.getNumVectors());
1220 }
1221
1222 /// Lower a store instruction with shape information.
1224 Value *Stride, bool IsVolatile, ShapeInfo Shape) {
1225 IRBuilder<> Builder(Inst);
1226 auto StoreVal = getMatrix(Matrix, Shape, Builder);
1227 finalizeLowering(Inst,
1228 storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride,
1229 IsVolatile, Builder),
1230 Builder);
1231 }
1232
1233 /// Lowers llvm.matrix.column.major.store.
1234 ///
1235 /// The intrinsic store a matrix back memory using a stride between columns.
1236 void LowerColumnMajorStore(CallInst *Inst) {
1237 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
1238 "Intrinsic only supports column-major layout!");
1239 Value *Matrix = Inst->getArgOperand(0);
1240 Value *Ptr = Inst->getArgOperand(1);
1241 Value *Stride = Inst->getArgOperand(2);
1242 LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride,
1243 cast<ConstantInt>(Inst->getArgOperand(3))->isOne(),
1244 {Inst->getArgOperand(4), Inst->getArgOperand(5)});
1245 }
1246
1247 // Set elements I..I+NumElts-1 to Block
1248 Value *insertVector(Value *Col, unsigned I, Value *Block,
1249 IRBuilder<> &Builder) {
1250
1251 // First, bring Block to the same size as Col
1252 unsigned BlockNumElts =
1253 cast<FixedVectorType>(Block->getType())->getNumElements();
1254 unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements();
1255 assert(NumElts >= BlockNumElts && "Too few elements for current block");
1256
1257 Block = Builder.CreateShuffleVector(
1258 Block, createSequentialMask(0, BlockNumElts, NumElts - BlockNumElts));
1259
1260 // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7,
1261 // 8, 4, 5, 6
1263 unsigned i;
1264 for (i = 0; i < I; i++)
1265 Mask.push_back(i);
1266
1267 unsigned VecNumElts =
1268 cast<FixedVectorType>(Col->getType())->getNumElements();
1269 for (; i < I + BlockNumElts; i++)
1270 Mask.push_back(i - I + VecNumElts);
1271
1272 for (; i < VecNumElts; i++)
1273 Mask.push_back(i);
1274
1275 return Builder.CreateShuffleVector(Col, Block, Mask);
1276 }
1277
1278 Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp,
1279 IRBuilder<> &Builder, bool AllowContraction,
1280 unsigned &NumComputeOps) {
1281 NumComputeOps += getNumOps(A->getType());
1282 if (!Sum)
1283 return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B);
1284
1285 if (UseFPOp) {
1286 if (AllowContraction) {
1287 // Use fmuladd for floating point operations and let the backend decide
1288 // if that's profitable.
1290 Func.getParent(), Intrinsic::fmuladd, A->getType());
1291 return Builder.CreateCall(FMulAdd, {A, B, Sum});
1292 }
1293 NumComputeOps += getNumOps(A->getType());
1294 Value *Mul = Builder.CreateFMul(A, B);
1295 return Builder.CreateFAdd(Sum, Mul);
1296 }
1297
1298 NumComputeOps += getNumOps(A->getType());
1299 Value *Mul = Builder.CreateMul(A, B);
1300 return Builder.CreateAdd(Sum, Mul);
1301 }
1302
1303 /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
1304 /// users with shape information, there's nothing to do: they will use the
1305 /// cached value when they are lowered. For other users, \p Matrix is
1306 /// flattened and the uses are updated to use it. Also marks \p Inst for
1307 /// deletion.
1308 void finalizeLowering(Instruction *Inst, MatrixTy Matrix,
1309 IRBuilder<> &Builder) {
1310 auto inserted = Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix));
1311 (void)inserted;
1312 assert(inserted.second && "multiple matrix lowering mapping");
1313
1314 ToRemove.push_back(Inst);
1315 Value *Flattened = nullptr;
1316 for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
1317 if (ShapeMap.find(U.getUser()) == ShapeMap.end()) {
1318 if (!Flattened)
1319 Flattened = Matrix.embedInVector(Builder);
1320 U.set(Flattened);
1321 }
1322 }
1323 }
1324
1325 /// Special case for MatMul lowering. Prevents scalar loads of row-major
1326 /// vectors Lowers to vector reduction add instead of sequential add if
1327 /// reassocation is enabled.
1328 void lowerDotProduct(CallInst *MatMul,
1330 FastMathFlags FMF) {
1331 if (FusedInsts.contains(MatMul) ||
1332 MatrixLayout != MatrixLayoutTy::ColumnMajor)
1333 return;
1334 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1335 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1336
1337 if (LShape.NumRows != 1 || RShape.NumColumns != 1) // not a dot product
1338 return;
1339
1340 Value *LHS = MatMul->getArgOperand(0);
1341 Value *RHS = MatMul->getArgOperand(1);
1342
1343 Type *ElementType = cast<VectorType>(LHS->getType())->getElementType();
1344 bool IsIntVec = ElementType->isIntegerTy();
1345
1346 // Floating point reductions require reassocation.
1347 if (!IsIntVec && !FMF.allowReassoc())
1348 return;
1349
1350 auto CanBeFlattened = [](Value *Op) {
1351 if (match(Op, m_BinOp()))
1352 return true;
1353 return match(
1355 m_Load(m_Value()),
1356 m_CombineOr(m_Intrinsic<Intrinsic::matrix_transpose>(),
1357 m_Intrinsic<Intrinsic::matrix_column_major_load>(
1358 m_Value(), m_SpecificInt(1))))));
1359 };
1360 // Returns the cost benefit of using \p Op with the dot product lowering. If
1361 // the returned cost is < 0, the argument is cheaper to use in the
1362 // dot-product lowering.
1363 auto GetCostForArg = [this, &CanBeFlattened](Value *Op, unsigned N) {
1364 if (ShapeMap.find(Op) == ShapeMap.end())
1366
1367 if (!isa<Instruction>(Op))
1368 return InstructionCost(0);
1369
1370 FixedVectorType *VecTy = cast<FixedVectorType>(Op->getType());
1371 Type *EltTy = VecTy->getElementType();
1372
1373 if (!CanBeFlattened(Op)) {
1374 InstructionCost EmbedCost(0);
1375 // Roughly estimate the cost for embedding the columns into a vector.
1376 for (unsigned I = 1; I < N; ++I)
1377 EmbedCost +=
1379 std::nullopt, TTI::TCK_RecipThroughput);
1380 return EmbedCost;
1381 }
1382
1383 if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end()) {
1384 InstructionCost OriginalCost =
1385 TTI.getArithmeticInstrCost(cast<Instruction>(Op)->getOpcode(),
1386 EltTy) *
1387 N;
1389 cast<Instruction>(Op)->getOpcode(), VecTy);
1390 return NewCost - OriginalCost;
1391 }
1392
1393 if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>())) {
1394 // The transpose can be skipped for the dot product lowering, roughly
1395 // estimate the savings as the cost of embedding the columns in a
1396 // vector.
1397 InstructionCost EmbedCost(0);
1398 for (unsigned I = 1; I < N; ++I)
1399 EmbedCost -=
1401 std::nullopt, TTI::TCK_RecipThroughput);
1402 return EmbedCost;
1403 }
1404
1405 // Costs for loads.
1406 if (N == 1)
1407 return InstructionCost(0);
1408
1409 return TTI.getMemoryOpCost(Instruction::Load, VecTy, Align(1), 0) -
1410 N * TTI.getMemoryOpCost(Instruction::Load, EltTy, Align(1), 0);
1411 };
1412
1413 // Iterate over LHS and operations feeding LHS and check if it is profitable
1414 // to flatten the visited ops. For each op, we compute the difference
1415 // between the flattened and matrix versions.
1417 SmallVector<Value *> WorkList;
1418 SmallVector<Value *> ToFlatten;
1419 WorkList.push_back(LHS);
1420 InstructionCost LHSCost(0);
1421 while (!WorkList.empty()) {
1422 Value *Op = WorkList.pop_back_val();
1423 if (!Seen.insert(Op).second)
1424 continue;
1425
1426 InstructionCost OpCost = GetCostForArg(Op, LShape.NumColumns);
1427 if (OpCost + LHSCost >= LHSCost)
1428 continue;
1429
1430 LHSCost += OpCost;
1431 ToFlatten.push_back(Op);
1432 if (auto *I = dyn_cast<Instruction>(Op))
1433 WorkList.append(I->op_begin(), I->op_end());
1434 }
1435
1436 // We compare the costs of a vector.reduce.add to sequential add.
1437 int AddOpCode = IsIntVec ? Instruction::Add : Instruction::FAdd;
1438 int MulOpCode = IsIntVec ? Instruction::Mul : Instruction::FMul;
1439 InstructionCost ReductionCost =
1441 AddOpCode, cast<VectorType>(LHS->getType()),
1442 IsIntVec ? std::nullopt : std::optional(FMF)) +
1443 TTI.getArithmeticInstrCost(MulOpCode, LHS->getType());
1444 InstructionCost SequentialAddCost =
1445 TTI.getArithmeticInstrCost(AddOpCode, ElementType) *
1446 (LShape.NumColumns - 1) +
1447 TTI.getArithmeticInstrCost(MulOpCode, ElementType) *
1448 (LShape.NumColumns);
1449 if ((LHSCost + ReductionCost - SequentialAddCost) > InstructionCost(0))
1450 return;
1451
1452 FusedInsts.insert(MatMul);
1453 IRBuilder<> Builder(MatMul);
1454 auto FlattenArg = [&Builder, &FusedInsts, &CanBeFlattened,
1455 this](Value *Op) {
1456 // Matmul must be the only user of loads because we don't use LowerLoad
1457 // for row vectors (LowerLoad results in scalar loads and shufflevectors
1458 // instead of single vector load).
1459 if (!CanBeFlattened(Op))
1460 return;
1461
1462 if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end()) {
1463 ShapeMap[Op] = ShapeMap[Op].t();
1464 return;
1465 }
1466
1467 FusedInsts.insert(cast<Instruction>(Op));
1468 // If vector uses the builtin load, lower to a LoadInst
1469 Value *Arg;
1470 if (match(Op, m_Intrinsic<Intrinsic::matrix_column_major_load>(
1471 m_Value(Arg)))) {
1472 auto *NewLoad = Builder.CreateLoad(Op->getType(), Arg);
1473 Op->replaceAllUsesWith(NewLoad);
1474 cast<Instruction>(Op)->eraseFromParent();
1475 return;
1476 } else if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>(
1477 m_Value(Arg)))) {
1478 ToRemove.push_back(cast<Instruction>(Op));
1479 Op->replaceAllUsesWith(Arg);
1480 return;
1481 }
1482 };
1483
1484 for (auto *V : ToFlatten)
1485 FlattenArg(V);
1486
1487 LHS = MatMul->getArgOperand(0);
1488
1489 // Insert mul/fmul and llvm.vector.reduce.fadd
1490 Value *Mul =
1491 IsIntVec ? Builder.CreateMul(LHS, RHS) : Builder.CreateFMul(LHS, RHS);
1492
1493 Value *Result;
1494 if (IsIntVec)
1495 Result = Builder.CreateAddReduce(Mul);
1496 else {
1497 Result = Builder.CreateFAddReduce(
1498 ConstantFP::get(cast<VectorType>(LHS->getType())->getElementType(),
1499 0.0),
1500 Mul);
1501 cast<Instruction>(Result)->setFastMathFlags(FMF);
1502 }
1503
1504 // pack scalar back into a matrix and then replace matmul inst
1506 Result, uint64_t(0));
1507 MatMul->replaceAllUsesWith(Result);
1508 FusedInsts.insert(MatMul);
1509 ToRemove.push_back(MatMul);
1510 }
1511
1512 /// Compute \p Result += \p A * \p B for input matrices with left-associating
1513 /// addition.
1514 ///
1515 /// We can fold a transpose into the operand that is used to extract scalars.
1516 /// This is the first operands with row-major and the second with
1517 /// column-major. If \p IsScalarMatrixTransposed we assume the appropriate
1518 /// operand is transposed.
1519 void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A,
1520 const MatrixTy &B, IRBuilder<> &Builder, bool IsTiled,
1521 bool IsScalarMatrixTransposed, FastMathFlags FMF) {
1522 const unsigned VF = std::max<unsigned>(
1524 .getFixedValue() /
1525 Result.getElementType()->getPrimitiveSizeInBits().getFixedValue(),
1526 1U);
1527 unsigned R = Result.getNumRows();
1528 unsigned C = Result.getNumColumns();
1529 unsigned M = A.getNumColumns();
1530
1531 bool IsFP = Result.getElementType()->isFloatingPointTy();
1532 assert(A.isColumnMajor() == B.isColumnMajor() &&
1533 Result.isColumnMajor() == A.isColumnMajor() &&
1534 "operands must agree on matrix layout");
1535 unsigned NumComputeOps = 0;
1536
1537 Builder.setFastMathFlags(FMF);
1538
1539 if (A.isColumnMajor()) {
1540 // Multiply columns from the first operand with scalars from the second
1541 // operand. Then move along the K axes and accumulate the columns. With
1542 // this the adds can be vectorized without reassociation.
1543 for (unsigned J = 0; J < C; ++J) {
1544 unsigned BlockSize = VF;
1545 // If Result is zero, we don't need to accumulate in the K==0 iteration.
1546 bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J));
1547
1548 for (unsigned I = 0; I < R; I += BlockSize) {
1549 // Gradually lower the vectorization factor to cover the remainder.
1550 while (I + BlockSize > R)
1551 BlockSize /= 2;
1552
1553 Value *Sum = IsTiled ? Result.extractVector(I, J, BlockSize, Builder)
1554 : nullptr;
1555 for (unsigned K = 0; K < M; ++K) {
1556 Value *L = A.extractVector(I, K, BlockSize, Builder);
1557 Value *RH = Builder.CreateExtractElement(
1558 B.getColumn(IsScalarMatrixTransposed ? K : J),
1559 IsScalarMatrixTransposed ? J : K);
1560 Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat");
1561 Sum =
1562 createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat,
1563 IsFP, Builder, FMF.allowContract(), NumComputeOps);
1564 }
1565 Result.setVector(J,
1566 insertVector(Result.getVector(J), I, Sum, Builder));
1567 }
1568 }
1569 } else {
1570 // Multiply rows from the second operand with scalars from the first
1571 // operand. Then move along the K axes and accumulate the rows. With this
1572 // the adds can be vectorized without reassociation.
1573 for (unsigned I = 0; I < R; ++I) {
1574 unsigned BlockSize = VF;
1575 bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I));
1576 for (unsigned J = 0; J < C; J += BlockSize) {
1577 // Gradually lower the vectorization factor to cover the remainder.
1578 while (J + BlockSize > C)
1579 BlockSize /= 2;
1580
1581 Value *Sum = nullptr;
1582 for (unsigned K = 0; K < M; ++K) {
1583 Value *R = B.extractVector(K, J, BlockSize, Builder);
1584 Value *LH = Builder.CreateExtractElement(
1585 A.getVector(IsScalarMatrixTransposed ? K : I),
1586 IsScalarMatrixTransposed ? I : K);
1587 Value *Splat = Builder.CreateVectorSplat(BlockSize, LH, "splat");
1588 Sum =
1589 createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R,
1590 IsFP, Builder, FMF.allowContract(), NumComputeOps);
1591 }
1592 Result.setVector(I,
1593 insertVector(Result.getVector(I), J, Sum, Builder));
1594 }
1595 }
1596 }
1597 Result.addNumComputeOps(NumComputeOps);
1598 }
1599
1600 /// Ensure that the memory in \p Load does not alias \p Store by potentially
1601 /// copying it to a new location. This new or otherwise the original location
1602 /// is returned.
1603 Value *getNonAliasingPointer(LoadInst *Load, StoreInst *Store,
1604 CallInst *MatMul) {
1605 MemoryLocation StoreLoc = MemoryLocation::get(Store);
1606 MemoryLocation LoadLoc = MemoryLocation::get(Load);
1607
1608 // If we can statically determine noalias we're good.
1609 if (AA->isNoAlias(LoadLoc, StoreLoc))
1610 return Load->getPointerOperand();
1611
1612 // Create code to check if the memory locations of the Load and Store
1613 // overlap and if they do, copy Load's operand to a new buffer.
1614
1615 // First, create new blocks for 2n part of the check and the copy.
1616 BasicBlock *Check0 = MatMul->getParent();
1617 // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a
1618 // DT. Manually collect dominator tree updates, to avoid unnecessary work,
1619 // as we adjust Check0 and Check1's branches.
1621 for (BasicBlock *Succ : successors(Check0))
1622 DTUpdates.push_back({DT->Delete, Check0, Succ});
1623
1624 BasicBlock *Check1 =
1625 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1626 nullptr, "alias_cont");
1627 BasicBlock *Copy =
1628 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1629 nullptr, "copy");
1630 BasicBlock *Fusion =
1631 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1632 nullptr, "no_alias");
1633
1634 // Check if the loaded memory location begins before the end of the store
1635 // location. If the condition holds, they might overlap, otherwise they are
1636 // guaranteed to not overlap.
1637 IRBuilder<> Builder(MatMul);
1638 Check0->getTerminator()->eraseFromParent();
1639 Builder.SetInsertPoint(Check0);
1640 Type *IntPtrTy = Builder.getIntPtrTy(Load->getModule()->getDataLayout());
1641 Value *StoreBegin = Builder.CreatePtrToInt(
1642 const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin");
1643 Value *StoreEnd = Builder.CreateAdd(
1644 StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()),
1645 "store.end", true, true);
1646 Value *LoadBegin = Builder.CreatePtrToInt(const_cast<Value *>(LoadLoc.Ptr),
1647 IntPtrTy, "load.begin");
1648 Builder.CreateCondBr(Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1,
1649 Fusion);
1650
1651 // Check if the store begins before the end of the load location. If the
1652 // condition holds, they alias, otherwise they are guaranteed to not
1653 // overlap.
1654 Check1->getTerminator()->eraseFromParent();
1655 Builder.SetInsertPoint(Check1, Check1->begin());
1656 Value *LoadEnd = Builder.CreateAdd(
1657 LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()),
1658 "load.end", true, true);
1659 Builder.CreateCondBr(Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy,
1660 Fusion);
1661
1662 // Copy load operand to new alloca.
1663 Builder.SetInsertPoint(Copy, Copy->begin());
1664 auto *VT = cast<FixedVectorType>(Load->getType());
1665 // Use an array type for the alloca, to avoid potentially huge alignment
1666 // requirements for large vector types.
1667 auto *ArrayTy = ArrayType::get(VT->getElementType(), VT->getNumElements());
1668 AllocaInst *Alloca =
1669 Builder.CreateAlloca(ArrayTy, Load->getPointerAddressSpace());
1670
1671 Builder.CreateMemCpy(Alloca, Alloca->getAlign(), Load->getPointerOperand(),
1672 Load->getAlign(), LoadLoc.Size.getValue());
1673 Builder.SetInsertPoint(Fusion, Fusion->begin());
1674 PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3);
1675 PHI->addIncoming(Load->getPointerOperand(), Check0);
1676 PHI->addIncoming(Load->getPointerOperand(), Check1);
1677 PHI->addIncoming(Alloca, Copy);
1678
1679 // Adjust DT.
1680 DTUpdates.push_back({DT->Insert, Check0, Check1});
1681 DTUpdates.push_back({DT->Insert, Check0, Fusion});
1682 DTUpdates.push_back({DT->Insert, Check1, Copy});
1683 DTUpdates.push_back({DT->Insert, Check1, Fusion});
1684 DT->applyUpdates(DTUpdates);
1685 return PHI;
1686 }
1687
1688 bool isFusionProfitable(CallInst *MatMul) {
1689 if (ForceFusion)
1690 return true;
1691
1692 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1693 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1694
1695 const unsigned R = LShape.NumRows;
1696 const unsigned C = RShape.NumColumns;
1697 const unsigned M = LShape.NumColumns;
1698 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1699
1700 const unsigned VF = std::max<unsigned>(
1702 .getFixedValue() /
1704 1U);
1705
1706 // Cost model for tiling
1707 //
1708 // For tiling to be beneficial, we need reuse either along the R or
1709 // the C axis. We vectorize along the R axis so that means at least
1710 // 3 elements.
1711 // TODO: Also consider cost of copying if operands alias.
1712 if (R <= VF && C == 1)
1713 return false;
1714 // Then we need enough elements to exceed the number of vector
1715 // registers we have. Note that this is an oversimplification since
1716 // fusing also takes some extra loads which may exceed the number of
1717 // reloads necessary.
1718 unsigned Op0Regs = (R + VF - 1) / VF * M;
1719 unsigned Op1Regs = (M + VF - 1) / VF * C;
1720 return Op0Regs + Op1Regs >
1722 }
1723
1724 MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) {
1725 MatrixTy Res;
1726 auto *ColumType = FixedVectorType::get(EltType, R);
1727 for (unsigned I = 0; I < C; ++I)
1728 Res.addVector(ConstantAggregateZero::get(ColumType));
1729 return Res;
1730 }
1731
1732 void createTiledLoops(CallInst *MatMul, Value *LPtr, ShapeInfo LShape,
1733 Value *RPtr, ShapeInfo RShape, StoreInst *Store) {
1734 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1735
1736 // Create the main tiling loop nest.
1737 TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns, TileSize);
1738 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
1739 Instruction *InsertI = cast<Instruction>(MatMul);
1740 BasicBlock *Start = InsertI->getParent();
1741 BasicBlock *End =
1742 SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue");
1743 IRBuilder<> Builder(MatMul);
1744 BasicBlock *InnerBody = TI.CreateTiledLoops(Start, End, Builder, DTU, *LI);
1745
1746 Type *TileVecTy =
1748 MatrixTy TileResult;
1749 // Insert in the inner loop header.
1750 Builder.SetInsertPoint(TI.KLoop.Header->getTerminator());
1751 // Create PHI nodes for the result columns to accumulate across iterations.
1752 SmallVector<PHINode *, 4> ColumnPhis;
1753 for (unsigned I = 0; I < TileSize; I++) {
1754 auto *Phi = Builder.CreatePHI(TileVecTy, 2, "result.vec." + Twine(I));
1755 Phi->addIncoming(ConstantAggregateZero::get(TileVecTy),
1756 TI.RowLoop.Header->getSingleSuccessor());
1757 TileResult.addVector(Phi);
1758 ColumnPhis.push_back(Phi);
1759 }
1760
1761 // Insert in the inner loop body, which computes
1762 // Res += Load(CurrentRow, K) * Load(K, CurrentColumn)
1763 Builder.SetInsertPoint(InnerBody->getTerminator());
1764 // Load tiles of the operands.
1765 MatrixTy A =
1766 loadMatrix(LPtr, {}, false, LShape, TI.RowLoop.Index, TI.KLoop.Index,
1767 {TileSize, TileSize}, EltType, Builder);
1768 MatrixTy B =
1769 loadMatrix(RPtr, {}, false, RShape, TI.KLoop.Index, TI.ColumnLoop.Index,
1770 {TileSize, TileSize}, EltType, Builder);
1771 emitMatrixMultiply(TileResult, A, B, Builder, true, false,
1772 getFastMathFlags(MatMul));
1773 // Store result after the inner loop is done.
1774 Builder.SetInsertPoint(TI.RowLoop.Latch->getTerminator());
1775 storeMatrix(TileResult, Store->getPointerOperand(), Store->getAlign(),
1776 Store->isVolatile(), {LShape.NumRows, RShape.NumColumns},
1777 TI.RowLoop.Index, TI.ColumnLoop.Index, EltType, Builder);
1778
1779 for (unsigned I = 0; I < TileResult.getNumVectors(); I++)
1780 ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.KLoop.Latch);
1781
1782 // Force unrolling of a few iterations of the inner loop, to make sure there
1783 // is enough work per iteration.
1784 // FIXME: The unroller should make this decision directly instead, but
1785 // currently the cost-model is not up to the task.
1786 unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns / TileSize);
1787 addStringMetadataToLoop(LI->getLoopFor(TI.KLoop.Header),
1788 "llvm.loop.unroll.count", InnerLoopUnrollCount);
1789 }
1790
1791 void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1,
1792 StoreInst *Store,
1793 SmallPtrSetImpl<Instruction *> &FusedInsts) {
1794 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
1795 "Tiling only supported for column-major matrixes at the moment!");
1796 if (!isFusionProfitable(MatMul))
1797 return;
1798
1799 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1800 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1801
1802 const unsigned R = LShape.NumRows;
1803 const unsigned C = RShape.NumColumns;
1804 const unsigned M = LShape.NumColumns;
1805 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1806
1807 Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul);
1808 Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul);
1809 Value *CPtr = Store->getPointerOperand();
1810
1811 if (TileUseLoops && (R % TileSize == 0 && C % TileSize == 0))
1812 createTiledLoops(MatMul, APtr, LShape, BPtr, RShape, Store);
1813 else {
1814 IRBuilder<> Builder(Store);
1815 for (unsigned J = 0; J < C; J += TileSize)
1816 for (unsigned I = 0; I < R; I += TileSize) {
1817 const unsigned TileR = std::min(R - I, unsigned(TileSize));
1818 const unsigned TileC = std::min(C - J, unsigned(TileSize));
1819 MatrixTy Res = getZeroMatrix(EltType, TileR, TileC);
1820
1821 for (unsigned K = 0; K < M; K += TileSize) {
1822 const unsigned TileM = std::min(M - K, unsigned(TileSize));
1823 MatrixTy A =
1824 loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(),
1825 LShape, Builder.getInt64(I), Builder.getInt64(K),
1826 {TileR, TileM}, EltType, Builder);
1827 MatrixTy B =
1828 loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(),
1829 RShape, Builder.getInt64(K), Builder.getInt64(J),
1830 {TileM, TileC}, EltType, Builder);
1831 emitMatrixMultiply(Res, A, B, Builder, true, false,
1832 getFastMathFlags(MatMul));
1833 }
1834 storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M},
1835 Builder.getInt64(I), Builder.getInt64(J), EltType,
1836 Builder);
1837 }
1838 }
1839
1840 // Mark eliminated instructions as fused and remove them.
1841 FusedInsts.insert(Store);
1842 FusedInsts.insert(MatMul);
1843 Store->eraseFromParent();
1844 MatMul->eraseFromParent();
1845 if (LoadOp0->hasNUses(0)) {
1846 FusedInsts.insert(LoadOp0);
1847 LoadOp0->eraseFromParent();
1848 }
1849 if (LoadOp1 != LoadOp0 && LoadOp1->hasNUses(0)) {
1850 FusedInsts.insert(LoadOp1);
1851 LoadOp1->eraseFromParent();
1852 }
1853 }
1854
1855 /// Try to lower matrix multiply chains by fusing operations.
1856 ///
1857 /// Call finalizeLowering on lowered instructions. Instructions that are
1858 /// completely eliminated by fusion are added to \p FusedInsts.
1859 void LowerMatrixMultiplyFused(CallInst *MatMul,
1860 SmallPtrSetImpl<Instruction *> &FusedInsts) {
1861 if (!FuseMatrix || !DT)
1862 return;
1863
1864 assert(AA && LI && "Analyses should be available");
1865
1866 Value *A = MatMul->getArgOperand(0);
1867 Value *B = MatMul->getArgOperand(1);
1868
1869 // We can fold the transpose into the operand that is used to fetch scalars.
1870 Value *T;
1871 if (MatrixLayout == MatrixLayoutTy::ColumnMajor
1872 ? match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))
1873 : match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))) {
1874 IRBuilder<> Builder(MatMul);
1875 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1876 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1877 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1878 const unsigned R = LShape.NumRows;
1879 const unsigned M = LShape.NumColumns;
1880 const unsigned C = RShape.NumColumns;
1881
1882 MatrixTy MA;
1883 MatrixTy MB;
1884
1885 Value *Transpose;
1886 if (MatrixLayout == MatrixLayoutTy::ColumnMajor) {
1887 MA = getMatrix(A, ShapeInfo(R, M), Builder);
1888 MB = getMatrix(T, ShapeInfo(C, M), Builder);
1889 Transpose = B;
1890 } else {
1891 MA = getMatrix(T, ShapeInfo(R, M), Builder);
1892 MB = getMatrix(B, ShapeInfo(C, M), Builder);
1893 Transpose = A;
1894 }
1895
1896 // Initialize the output
1897 MatrixTy Result(R, C, EltType);
1898
1899 emitMatrixMultiply(Result, MA, MB, Builder, false, true,
1900 getFastMathFlags(MatMul));
1901
1902 FusedInsts.insert(MatMul);
1903 if (Transpose->hasOneUse()) {
1904 FusedInsts.insert(cast<Instruction>(Transpose));
1905 ToRemove.push_back(cast<Instruction>(Transpose));
1906 // TODO: add a fake entry for the folded instruction so that this is
1907 // included in the expression in the remark.
1908 Inst2ColumnMatrix[Transpose] = MatrixTy(M, C, EltType);
1909 }
1910 finalizeLowering(MatMul, Result, Builder);
1911 return;
1912 }
1913
1914 if (!MatMul->hasOneUse() || MatrixLayout != MatrixLayoutTy::ColumnMajor)
1915 return;
1916
1917 // Lower {ld, ld} -> matmul -> st chains. No need to call finalizeLowering
1918 // since the single store user will be lowered as part of this.
1919 auto *LoadOp0 = dyn_cast<LoadInst>(A);
1920 auto *LoadOp1 = dyn_cast<LoadInst>(B);
1921 auto *Store = dyn_cast<StoreInst>(*MatMul->user_begin());
1922 if (LoadOp0 && LoadOp1 && Store) {
1923 // The store address must dominate the MatMul instruction, otherwise
1924 // we create invalid IR.
1925 SetVector<Value *> WorkList;
1926 WorkList.insert(Store->getOperand(1));
1928 for (unsigned I = 0; I != WorkList.size(); ++I) {
1929 Value *Current = WorkList[I];
1930 auto *CurrI = dyn_cast<Instruction>(Current);
1931 if (!CurrI)
1932 continue;
1933 if (isa<PHINode>(CurrI))
1934 return;
1935 if (DT->dominates(CurrI, MatMul))
1936 continue;
1937 if (CurrI->mayHaveSideEffects() || CurrI->mayReadFromMemory())
1938 return;
1939 ToHoist.push_back(CurrI);
1940 WorkList.insert(CurrI->op_begin(), CurrI->op_end());
1941 }
1942
1943 sort(ToHoist, [this](Instruction *A, Instruction *B) {
1944 return DT->dominates(A, B);
1945 });
1946 for (Instruction *I : ToHoist)
1947 I->moveBefore(MatMul);
1948
1949 emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts);
1950 return;
1951 }
1952 }
1953
1954 /// Lowers llvm.matrix.multiply.
1955 void LowerMultiply(CallInst *MatMul) {
1956 IRBuilder<> Builder(MatMul);
1957 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1958 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1959 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1960
1961 const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder);
1962 const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder);
1963 assert(Lhs.getElementType() == Rhs.getElementType() &&
1964 "Matrix multiply argument element types do not match.");
1965
1966 const unsigned R = LShape.NumRows;
1967 const unsigned C = RShape.NumColumns;
1968 assert(LShape.NumColumns == RShape.NumRows);
1969
1970 // Initialize the output
1971 MatrixTy Result(R, C, EltType);
1972 assert(Lhs.getElementType() == Result.getElementType() &&
1973 "Matrix multiply result element type does not match arguments.");
1974
1975 emitMatrixMultiply(Result, Lhs, Rhs, Builder, false, false,
1976 getFastMathFlags(MatMul));
1977 finalizeLowering(MatMul, Result, Builder);
1978 }
1979
1980 /// Lowers llvm.matrix.transpose.
1981 void LowerTranspose(CallInst *Inst) {
1982 MatrixTy Result;
1983 IRBuilder<> Builder(Inst);
1984 Value *InputVal = Inst->getArgOperand(0);
1985 VectorType *VectorTy = cast<VectorType>(InputVal->getType());
1986 ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2));
1987 MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);
1988
1989 const unsigned NewNumVecs =
1990 InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns;
1991 const unsigned NewNumElts =
1992 InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows;
1993
1994 for (unsigned I = 0; I < NewNumVecs; ++I) {
1995 // Build a single result vector. First initialize it.
1996 Value *ResultVector = PoisonValue::get(
1997 FixedVectorType::get(VectorTy->getElementType(), NewNumElts));
1998 // Go through the old elements and insert it into the resulting vector.
1999 for (auto J : enumerate(InputMatrix.vectors())) {
2000 Value *Elt = Builder.CreateExtractElement(J.value(), I);
2001 // Row and column indices are transposed.
2002 ResultVector =
2003 Builder.CreateInsertElement(ResultVector, Elt, J.index());
2004 }
2005 Result.addVector(ResultVector);
2006 }
2007
2008 // TODO: Improve estimate of operations needed for transposes. Currently we
2009 // just count the insertelement/extractelement instructions, but do not
2010 // account for later simplifications/combines.
2011 finalizeLowering(
2012 Inst,
2013 Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns)
2014 .addNumExposedTransposes(1),
2015 Builder);
2016 }
2017
2018 /// Lower load instructions, if shape information is available.
2019 bool VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) {
2020 auto I = ShapeMap.find(Inst);
2021 if (I == ShapeMap.end())
2022 return false;
2023
2024 LowerLoad(Inst, Ptr, Inst->getAlign(),
2025 Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
2026 I->second);
2027 return true;
2028 }
2029
2030 bool VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr,
2031 IRBuilder<> &Builder) {
2032 auto I = ShapeMap.find(StoredVal);
2033 if (I == ShapeMap.end())
2034 return false;
2035
2036 LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
2037 Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
2038 I->second);
2039 return true;
2040 }
2041
2042 /// Lower binary operators, if shape information is available.
2043 bool VisitBinaryOperator(BinaryOperator *Inst) {
2044 auto I = ShapeMap.find(Inst);
2045 if (I == ShapeMap.end())
2046 return false;
2047
2048 Value *Lhs = Inst->getOperand(0);
2049 Value *Rhs = Inst->getOperand(1);
2050
2051 IRBuilder<> Builder(Inst);
2052 ShapeInfo &Shape = I->second;
2053
2054 MatrixTy Result;
2055 MatrixTy A = getMatrix(Lhs, Shape, Builder);
2056 MatrixTy B = getMatrix(Rhs, Shape, Builder);
2057 assert(A.isColumnMajor() == B.isColumnMajor() &&
2058 Result.isColumnMajor() == A.isColumnMajor() &&
2059 "operands must agree on matrix layout");
2060
2061 Builder.setFastMathFlags(getFastMathFlags(Inst));
2062
2063 // Helper to perform binary op on vectors.
2064 auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) {
2065 switch (Inst->getOpcode()) {
2066 case Instruction::Add:
2067 return Builder.CreateAdd(LHS, RHS);
2068 case Instruction::Mul:
2069 return Builder.CreateMul(LHS, RHS);
2070 case Instruction::Sub:
2071 return Builder.CreateSub(LHS, RHS);
2072 case Instruction::FAdd:
2073 return Builder.CreateFAdd(LHS, RHS);
2074 case Instruction::FMul:
2075 return Builder.CreateFMul(LHS, RHS);
2076 case Instruction::FSub:
2077 return Builder.CreateFSub(LHS, RHS);
2078 default:
2079 llvm_unreachable("Unsupported binary operator for matrix");
2080 }
2081 };
2082
2083 for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
2084 Result.addVector(BuildVectorOp(A.getVector(I), B.getVector(I)));
2085
2086 finalizeLowering(Inst,
2087 Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
2088 Result.getNumVectors()),
2089 Builder);
2090 return true;
2091 }
2092
2093 /// Lower unary operators, if shape information is available.
2094 bool VisitUnaryOperator(UnaryOperator *Inst) {
2095 auto I = ShapeMap.find(Inst);
2096 if (I == ShapeMap.end())
2097 return false;
2098
2099 Value *Op = Inst->getOperand(0);
2100
2101 IRBuilder<> Builder(Inst);
2102 ShapeInfo &Shape = I->second;
2103
2104 MatrixTy Result;
2105 MatrixTy M = getMatrix(Op, Shape, Builder);
2106
2107 Builder.setFastMathFlags(getFastMathFlags(Inst));
2108
2109 // Helper to perform unary op on vectors.
2110 auto BuildVectorOp = [&Builder, Inst](Value *Op) {
2111 switch (Inst->getOpcode()) {
2112 case Instruction::FNeg:
2113 return Builder.CreateFNeg(Op);
2114 default:
2115 llvm_unreachable("Unsupported unary operator for matrix");
2116 }
2117 };
2118
2119 for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
2120 Result.addVector(BuildVectorOp(M.getVector(I)));
2121
2122 finalizeLowering(Inst,
2123 Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
2124 Result.getNumVectors()),
2125 Builder);
2126 return true;
2127 }
2128
2129 /// Helper to linearize a matrix expression tree into a string. Currently
2130 /// matrix expressions are linarized by starting at an expression leaf and
2131 /// linearizing bottom up.
2132 struct ExprLinearizer {
2133 unsigned LengthToBreak = 100;
2134 std::string Str;
2135 raw_string_ostream Stream;
2136 unsigned LineLength = 0;
2137 const DataLayout &DL;
2138
2139 /// Mapping from instructions to matrixes. It is used to identify
2140 /// matrix instructions.
2141 const MapVector<Value *, MatrixTy> &Inst2Matrix;
2142
2143 /// Mapping from values to the leaves of all expressions that the value is
2144 /// part of.
2146
2147 /// Set of matrix expressions in the scope of a given DISubprogram.
2148 const SmallSetVector<Value *, 32> &ExprsInSubprogram;
2149
2150 /// Leaf node of the expression to linearize.
2151 Value *Leaf;
2152
2153 /// Used to keep track of sub-expressions that get reused while linearizing
2154 /// the expression. Re-used sub-expressions are marked as (reused).
2155 SmallPtrSet<Value *, 8> ReusedExprs;
2156
2157 ExprLinearizer(const DataLayout &DL,
2158 const MapVector<Value *, MatrixTy> &Inst2Matrix,
2159 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
2160 const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2161 Value *Leaf)
2162 : Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared),
2163 ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {}
2164
2165 void indent(unsigned N) {
2166 LineLength += N;
2167 for (unsigned i = 0; i < N; i++)
2168 Stream << " ";
2169 }
2170
2171 void lineBreak() {
2172 Stream << "\n";
2173 LineLength = 0;
2174 }
2175
2176 void maybeIndent(unsigned Indent) {
2177 if (LineLength >= LengthToBreak)
2178 lineBreak();
2179
2180 if (LineLength == 0)
2181 indent(Indent);
2182 }
2183
2184 void write(StringRef S) {
2185 LineLength += S.size();
2186 Stream << S;
2187 }
2188
2189 Value *getUnderlyingObjectThroughLoads(Value *V) {
2190 if (Value *Ptr = getPointerOperand(V))
2191 return getUnderlyingObjectThroughLoads(Ptr);
2192 else if (V->getType()->isPointerTy())
2193 return getUnderlyingObject(V);
2194 return V;
2195 }
2196
2197 /// Returns true if \p V is a matrix value in the given subprogram.
2198 bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); }
2199
2200 /// If \p V is a matrix value, print its shape as NumRows x NumColumns to
2201 /// \p SS.
2202 void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) {
2203 auto M = Inst2Matrix.find(V);
2204 if (M == Inst2Matrix.end())
2205 SS << "unknown";
2206 else {
2207 SS << M->second.getNumRows();
2208 SS << "x";
2209 SS << M->second.getNumColumns();
2210 }
2211 }
2212
2213 /// Write the called function name. Handles calls to llvm.matrix.*
2214 /// specially: we write the name, followed by the dimensions of the input
2215 /// matrixes, followed by the scalar type name.
2216 void writeFnName(CallInst *CI) {
2217 if (!CI->getCalledFunction())
2218 write("<no called fn>");
2219 else {
2221 if (!Name.starts_with("llvm.matrix")) {
2222 write(Name);
2223 return;
2224 }
2225 auto *II = cast<IntrinsicInst>(CI);
2227 .drop_front(StringRef("llvm.matrix.").size()));
2228 write(".");
2229 std::string Tmp;
2231
2232 switch (II->getIntrinsicID()) {
2233 case Intrinsic::matrix_multiply:
2234 prettyPrintMatrixType(II->getOperand(0), SS);
2235 SS << ".";
2236 prettyPrintMatrixType(II->getOperand(1), SS);
2237 SS << "." << *II->getType()->getScalarType();
2238 break;
2239 case Intrinsic::matrix_transpose:
2240 prettyPrintMatrixType(II->getOperand(0), SS);
2241 SS << "." << *II->getType()->getScalarType();
2242 break;
2243 case Intrinsic::matrix_column_major_load:
2244 prettyPrintMatrixType(II, SS);
2245 SS << "." << *II->getType()->getScalarType();
2246 break;
2247 case Intrinsic::matrix_column_major_store:
2248 prettyPrintMatrixType(II->getOperand(0), SS);
2249 SS << "." << *II->getOperand(0)->getType()->getScalarType();
2250 break;
2251 default:
2252 llvm_unreachable("Unhandled case");
2253 }
2254 SS.flush();
2255 write(Tmp);
2256 }
2257 }
2258
2259 unsigned getNumShapeArgs(CallInst *CI) const {
2260 if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) {
2261 switch (II->getIntrinsicID()) {
2262 case Intrinsic::matrix_multiply:
2263 return 3;
2264 case Intrinsic::matrix_transpose:
2265 return 2;
2266 case Intrinsic::matrix_column_major_load:
2267 case Intrinsic::matrix_column_major_store:
2268 return 3;
2269 default:
2270 return 0;
2271 }
2272 }
2273 return 0;
2274 }
2275
2276 /// Special printing for values: for pointers, we print if they refer to an
2277 /// (function) external address or a stack address, for other values we
2278 /// either print the constant or "scalar"/"matrix" for other values.
2279 void write(Value *V) {
2280 V = getUnderlyingObjectThroughLoads(V);
2281 if (V->getType()->isPointerTy()) {
2282 if (isa<AllocaInst>(V)) {
2283 Stream << "stack addr";
2284 LineLength += StringRef("stack addr").size();
2285 } else {
2286 Stream << "addr";
2287 LineLength += StringRef("addr").size();
2288 }
2289 if (!V->getName().empty()) {
2290 Stream << " %" << V->getName() << "";
2291 LineLength += V->getName().size() + 2;
2292 }
2293 return;
2294 }
2295
2296 std::string Tmp;
2297 raw_string_ostream TmpStream(Tmp);
2298
2299 if (auto *CI = dyn_cast<ConstantInt>(V))
2300 TmpStream << CI->getValue();
2301 else if (isa<Constant>(V))
2302 TmpStream << "constant";
2303 else {
2304 if (isMatrix(V))
2305 TmpStream << "matrix";
2306 else
2307 TmpStream << "scalar";
2308 }
2309 TmpStream.flush();
2310 Tmp = std::string(StringRef(Tmp).trim());
2311 LineLength += Tmp.size();
2312 Stream << Tmp;
2313 }
2314
2315 /// Linearize expression \p Expr starting at an indentation of \p Indent.
2316 /// Expressions that are re-used multiple times are prefixed with (reused)
2317 /// at the re-used root instruction.
2318 void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused,
2319 bool ParentShared) {
2320 auto *I = cast<Instruction>(Expr);
2321 maybeIndent(Indent);
2323
2324 // Is Expr shared with other expression leaves?
2325 bool ExprShared = false;
2326
2327 // Deal with shared subtrees. Mark them as shared, if required.
2328 if (!ParentShared) {
2329 auto SI = Shared.find(Expr);
2330 assert(SI != Shared.end() && SI->second.count(Leaf));
2331
2332 for (Value *S : SI->second) {
2333 if (S == Leaf)
2334 continue;
2335 DebugLoc DL = cast<Instruction>(S)->getDebugLoc();
2336 write("shared with remark at line " + std::to_string(DL.getLine()) +
2337 " column " + std::to_string(DL.getCol()) + " (");
2338 }
2339 ExprShared = SI->second.size() > 1;
2340 }
2341
2342 bool Reused = !ReusedExprs.insert(Expr).second;
2343 if (Reused && !ParentReused)
2344 write("(reused) ");
2345
2346 if (auto *CI = dyn_cast<CallInst>(I)) {
2347 writeFnName(CI);
2348
2349 Ops.append(CI->arg_begin(), CI->arg_end() - getNumShapeArgs(CI));
2350 } else if (isa<BitCastInst>(Expr)) {
2351 // Special case bitcasts, which are used to materialize matrixes from
2352 // non-matrix ops.
2353 write("matrix");
2354 return;
2355 } else {
2356 Ops.append(I->value_op_begin(), I->value_op_end());
2357 write(std::string(I->getOpcodeName()));
2358 }
2359
2360 write(std::string("("));
2361
2362 unsigned NumOpsToBreak = 1;
2363 if (match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>()))
2364 NumOpsToBreak = 2;
2365
2366 for (Value *Op : Ops) {
2367 if (Ops.size() > NumOpsToBreak)
2368 lineBreak();
2369
2370 maybeIndent(Indent + 1);
2371 if (isMatrix(Op))
2372 linearizeExpr(Op, Indent + 1, Reused, ExprShared);
2373 else
2374 write(Op);
2375 if (Op != Ops.back())
2376 write(", ");
2377 }
2378
2379 write(")");
2380 }
2381
2382 const std::string &getResult() {
2383 Stream.flush();
2384 return Str;
2385 }
2386 };
2387
2388 /// Generate remarks for matrix operations in a function. To generate remarks
2389 /// for matrix expressions, the following approach is used:
2390 /// 1. Use the inlined-at debug information to group matrix operations to the
2391 /// DISubprograms they are contained in.
2392 /// 2. Collect leaves of matrix expressions (done in
2393 /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression
2394 // mapping. Leaves are lowered matrix instructions without other matrix
2395 // users (like stores) in the current subprogram.
2396 /// 3. For each leaf, create a remark containing a linearizied version of the
2397 /// matrix expression. The expression is linearized by a recursive
2398 /// bottom-up traversal of the matrix operands, starting at a leaf. Note
2399 /// that multiple leaves can share sub-expressions. Shared subexpressions
2400 /// are explicitly marked as shared().
2401 struct RemarkGenerator {
2402 const MapVector<Value *, MatrixTy> &Inst2Matrix;
2404 Function &Func;
2405 const DataLayout &DL;
2406
2407 RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix,
2409 : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func),
2410 DL(Func.getParent()->getDataLayout()) {}
2411
2412 /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are
2413 /// instructions in Inst2Matrix returning void or without any users in
2414 /// \p ExprsInSubprogram. Currently that should only include stores.
2416 getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) {
2418 for (auto *Expr : ExprsInSubprogram)
2419 if (Expr->getType()->isVoidTy() ||
2420 !any_of(Expr->users(), [&ExprsInSubprogram](User *U) {
2421 return ExprsInSubprogram.count(U);
2422 }))
2423 Leaves.push_back(Expr);
2424 return Leaves;
2425 }
2426
2427 /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf
2428 /// to all visited expressions in \p Shared. Limit the matrix operations to
2429 /// the ones in \p ExprsInSubprogram.
2430 void collectSharedInfo(Value *Leaf, Value *V,
2431 const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2433
2434 if (!ExprsInSubprogram.count(V))
2435 return;
2436
2437 auto I = Shared.insert({V, {}});
2438 I.first->second.insert(Leaf);
2439
2440 for (Value *Op : cast<Instruction>(V)->operand_values())
2441 collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared);
2442 }
2443
2444 /// Calculate the number of exclusive and shared op counts for expression
2445 /// starting at \p V. Expressions used multiple times are counted once.
2446 /// Limit the matrix operations to the ones in \p ExprsInSubprogram.
2447 std::pair<OpInfoTy, OpInfoTy>
2448 sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs,
2449 const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2450 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const {
2451 if (!ExprsInSubprogram.count(Root))
2452 return {};
2453
2454 // Already counted this expression. Stop.
2455 if (!ReusedExprs.insert(Root).second)
2456 return {};
2457
2458 OpInfoTy SharedCount;
2459 OpInfoTy Count;
2460
2461 auto I = Shared.find(Root);
2462 auto CM = Inst2Matrix.find(Root);
2463 if (I->second.size() == 1)
2464 Count = CM->second.getOpInfo();
2465 else
2466 SharedCount = CM->second.getOpInfo();
2467
2468 for (Value *Op : cast<Instruction>(Root)->operand_values()) {
2469 auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared);
2470 Count += C.first;
2471 SharedCount += C.second;
2472 }
2473 return {Count, SharedCount};
2474 }
2475
2476 void emitRemarks() {
2478 return;
2479
2480 // Map matrix operations to their containting subprograms, by traversing
2481 // the inlinedAt chain. If the function does not have a DISubprogram, we
2482 // only map them to the containing function.
2484 for (const auto &KV : Inst2Matrix) {
2485 if (Func.getSubprogram()) {
2486 auto *I = cast<Instruction>(KV.first);
2487 DILocation *Context = I->getDebugLoc();
2488 while (Context) {
2489 auto I =
2490 Subprog2Exprs.insert({getSubprogram(Context->getScope()), {}});
2491 I.first->second.push_back(KV.first);
2492 Context = DebugLoc(Context).getInlinedAt();
2493 }
2494 } else {
2495 auto I = Subprog2Exprs.insert({nullptr, {}});
2496 I.first->second.push_back(KV.first);
2497 }
2498 }
2499 for (auto &KV : Subprog2Exprs) {
2500 SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(),
2501 KV.second.end());
2502 auto Leaves = getExpressionLeaves(ExprsInSubprogram);
2503
2505 for (Value *Leaf : Leaves)
2506 collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared);
2507
2508 // Generate remarks for each leaf.
2509 for (auto *L : Leaves) {
2510
2511 DebugLoc Loc = cast<Instruction>(L)->getDebugLoc();
2512 DILocation *Context = cast<Instruction>(L)->getDebugLoc();
2513 while (Context) {
2514 if (getSubprogram(Context->getScope()) == KV.first) {
2515 Loc = Context;
2516 break;
2517 }
2518 Context = DebugLoc(Context).getInlinedAt();
2519 }
2520
2521 SmallPtrSet<Value *, 8> ReusedExprs;
2522 OpInfoTy Counts, SharedCounts;
2523 std::tie(Counts, SharedCounts) =
2524 sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared);
2525
2526 OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc,
2527 cast<Instruction>(L)->getParent());
2528
2529 Rem << "Lowered with ";
2530 Rem << ore::NV("NumStores", Counts.NumStores) << " stores, "
2531 << ore::NV("NumLoads", Counts.NumLoads) << " loads, "
2532 << ore::NV("NumComputeOps", Counts.NumComputeOps)
2533 << " compute ops, "
2534 << ore::NV("NumExposedTransposes", Counts.NumExposedTransposes)
2535 << " exposed transposes";
2536
2537 if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 ||
2538 SharedCounts.NumComputeOps > 0) {
2539 Rem << ",\nadditionally "
2540 << ore::NV("NumStores", SharedCounts.NumStores) << " stores, "
2541 << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, "
2542 << ore::NV("NumFPOps", SharedCounts.NumComputeOps)
2543 << " compute ops"
2544 << " are shared with other expressions";
2545 }
2546
2547 Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL));
2548 ORE.emit(Rem);
2549 }
2550 }
2551 }
2552
2553 std::string
2554 linearize(Value *L,
2555 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
2556 const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2557 const DataLayout &DL) {
2558 ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L);
2559 Lin.linearizeExpr(L, 0, false, false);
2560 return Lin.getResult();
2561 }
2562 };
2563};
2564} // namespace
2565
2568 auto &TTI = AM.getResult<TargetIRAnalysis>(F);
2569 OptimizationRemarkEmitter *ORE = nullptr;
2570 AAResults *AA = nullptr;
2571 DominatorTree *DT = nullptr;
2572 LoopInfo *LI = nullptr;
2573
2574 if (!Minimal) {
2576 AA = &AM.getResult<AAManager>(F);
2578 LI = &AM.getResult<LoopAnalysis>(F);
2579 }
2580
2581 LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE);
2582 if (LMT.Visit()) {
2584 if (!Minimal) {
2585 PA.preserve<LoopAnalysis>();
2587 }
2588 return PA;
2589 }
2590 return PreservedAnalyses::all();
2591}
2592
2594 raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
2596 OS, MapClassName2PassName);
2597 OS << '<';
2598 if (Minimal)
2599 OS << "minimal";
2600 OS << '>';
2601}
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Rewrite undef for PHI
ReachingDefAnalysis InstSet & ToRemove
static const Function * getParent(const Value *V)
BitTracker BT
Definition: BitTracker.cpp:73
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
#define clEnumValN(ENUMVAL, FLAGNAME, DESC)
Definition: CommandLine.h:693
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
#define LLVM_DEBUG(X)
Definition: Debug.h:101
std::string Name
bool End
Definition: ELF_riscv.cpp:480
Hexagon Common GEP
hexagon Hexagon specific predictive commoning for HVX vectors
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
iv users
Definition: IVUsers.cpp:48
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Definition: Lint.cpp:531
Live Register Matrix
static DISubprogram * getSubprogram(DIScope *Scope)
Helper function to either return Scope, if it is a subprogram or the attached subprogram for a local ...
static cl::opt< bool > ForceFusion("force-fuse-matrix", cl::init(false), cl::Hidden, cl::desc("Force matrix instruction fusion even if not profitable."))
static void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II, BasicBlock &BB)
Erase V from BB and move \II forward to avoid invalidating iterators.
static cl::opt< bool > VerifyShapeInfo("verify-matrix-shapes", cl::Hidden, cl::desc("Enable/disable matrix shape verification."), cl::init(false))
static bool isSplat(Value *V)
Return true if V is a splat of a value (which is used when multiplying a matrix with a scalar).
static cl::opt< bool > TileUseLoops("fuse-matrix-use-loops", cl::init(false), cl::Hidden, cl::desc("Generate loop nest for tiling."))
static cl::opt< bool > FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden, cl::desc("Enable/disable fusing matrix instructions."))
auto m_AnyAdd(const LTy &L, const RTy &R)
Match any add operation (fp or integer).
static cl::opt< bool > AllowContractEnabled("matrix-allow-contract", cl::init(false), cl::Hidden, cl::desc("Allow the use of FMAs if available and profitable. This may " "result in different results, due to less rounding error."))
auto m_AnyMul(const LTy &L, const RTy &R)
Match any mul operation (fp or integer).
static cl::opt< bool > PrintAfterTransposeOpt("matrix-print-after-transpose-opt", cl::init(false))
#define DEBUG_TYPE
static cl::opt< unsigned > TileSize("fuse-matrix-tile-size", cl::init(4), cl::Hidden, cl::desc("Tile size for matrix instruction fusion using square-shaped tiles."))
static cl::opt< MatrixLayoutTy > MatrixLayout("matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor), cl::desc("Sets the default matrix layout"), cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major", "Use column-major layout"), clEnumValN(MatrixLayoutTy::RowMajor, "row-major", "Use row-major layout")))
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
#define T1
LLVMContext & Context
PowerPC Reduce CR logical Operation
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
static unsigned getFastMathFlags(const MachineInstr &I)
static Value * extractVector(IRBuilderTy &IRB, Value *V, unsigned BeginIndex, unsigned EndIndex, const Twine &Name)
Definition: SROA.cpp:2551
static Value * insertVector(IRBuilderTy &IRB, Value *Old, Value *V, unsigned BeginIndex, const Twine &Name)
Definition: SROA.cpp:2573
raw_pwrite_stream & OS
This file defines the SmallSet class.
This file defines the SmallVector class.
static SymbolRef::Type getType(const Symbol *Sym)
Definition: TapiFile.cpp:40
static const int BlockSize
Definition: TarWriter.cpp:33
This pass exposes codegen information to IR-level passes.
static std::optional< unsigned > getOpcode(ArrayRef< VPValue * > Values)
Returns the opcode of Values or ~0 if they do not all agree.
Definition: VPlanSLP.cpp:191
static SDValue LowerStore(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)
static SDValue LowerLoad(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)
Value * RHS
Value * LHS
BinaryOperator * Mul
A manager for alias analyses.
bool isNoAlias(const MemoryLocation &LocA, const MemoryLocation &LocB)
A trivial helper function to check to see if the specified pointers are no-alias.
an instruction to allocate memory on the stack
Definition: Instructions.h:59
Align getAlign() const
Return the alignment of the memory that is being allocated by the instruction.
Definition: Instructions.h:132
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:348
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition: PassManager.h:500
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: ArrayRef.h:41
LLVM Basic Block Representation.
Definition: BasicBlock.h:60
iterator begin()
Instruction iterator methods.
Definition: BasicBlock.h:429
reverse_iterator rbegin()
Definition: BasicBlock.h:445
InstListType::reverse_iterator reverse_iterator
Definition: BasicBlock.h:166
reverse_iterator rend()
Definition: BasicBlock.h:447
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.h:220
BinaryOps getOpcode() const
Definition: InstrTypes.h:491
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
Definition: InstrTypes.h:1703
User::op_iterator arg_begin()
Return the iterator pointing to the beginning of the argument list.
Definition: InstrTypes.h:1623
MaybeAlign getParamAlign(unsigned ArgNo) const
Extract the alignment for a call or parameter (0=unknown).
Definition: InstrTypes.h:2053
Value * getArgOperand(unsigned i) const
Definition: InstrTypes.h:1648
User::op_iterator arg_end()
Return the iterator pointing to the end of the argument list.
Definition: InstrTypes.h:1629
This class represents a function call, abstracting a target machine's calling convention.
static ConstantAggregateZero * get(Type *Ty)
Definition: Constants.cpp:1663
This is the shared class of boolean and integer constants.
Definition: Constants.h:79
DISubprogram * getSubprogram() const
Get the subprogram for this scope.
Debug location.
Base class for scope-like contexts.
Subprogram description.
This class represents an Operation in the Expression.
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:110
A debug info location.
Definition: DebugLoc.h:33
DILocation * getInlinedAt() const
Definition: DebugLoc.cpp:39
Analysis pass which computes a DominatorTree.
Definition: Dominators.h:275
void applyUpdates(ArrayRef< UpdateType > Updates)
Inform the dominator tree about a sequence of CFG edge insertions and deletions and perform a batch u...
static constexpr UpdateKind Delete
static constexpr UpdateKind Insert
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:162
bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
Definition: Dominators.cpp:122
Convenience struct for specifying and reasoning about fast-math flags.
Definition: FMF.h:20
void setAllowContract(bool B=true)
Definition: FMF.h:91
bool allowReassoc() const
Flag queries.
Definition: FMF.h:65
bool allowContract() const
Definition: FMF.h:70
Class to represent fixed width SIMD vectors.
Definition: DerivedTypes.h:539
static FixedVectorType * get(Type *ElementType, unsigned NumElts)
Definition: Type.cpp:692
Intrinsic::ID getIntrinsicID() const LLVM_READONLY
getIntrinsicID - This method returns the ID number of the specified function, or Intrinsic::not_intri...
Definition: Function.h:230
bool isIntrinsic() const
isIntrinsic - Returns true if the function's name starts with "llvm.".
Definition: Function.h:235
CallInst * CreateFAddReduce(Value *Acc, Value *Src)
Create a sequential vector fadd reduction intrinsic of the source vector.
Definition: IRBuilder.cpp:417
Value * CreateICmpULT(Value *LHS, Value *RHS, const Twine &Name="")
Definition: IRBuilder.h:2240
Value * CreateFSub(Value *L, Value *R, const Twine &Name="", MDNode *FPMD=nullptr)
Definition: IRBuilder.h:1554
Value * CreateInsertElement(Type *VecTy, Value *NewElt, Value *Idx, const Twine &Name="")
Definition: IRBuilder.h:2455
AllocaInst * CreateAlloca(Type *Ty, unsigned AddrSpace, Value *ArraySize=nullptr, const Twine &Name="")
Definition: IRBuilder.h:1772
Value * CreateExtractElement(Value *Vec, Value *Idx, const Twine &Name="")
Definition: IRBuilder.h:2443
LoadInst * CreateAlignedLoad(Type *Ty, Value *Ptr, MaybeAlign Align, const char *Name)
Definition: IRBuilder.h:1806
Value * CreateFAdd(Value *L, Value *R, const Twine &Name="", MDNode *FPMD=nullptr)
Definition: IRBuilder.h:1527
Value * CreateVectorSplat(unsigned NumElts, Value *V, const Twine &Name="")
Return a vector value that contains.
Definition: IRBuilder.cpp:1212
CallInst * CreateAddReduce(Value *Src)
Create a vector int add reduction intrinsic of the source vector.
Definition: IRBuilder.cpp:433
IntegerType * getIntPtrTy(const DataLayout &DL, unsigned AddrSpace=0)
Fetch the type of an integer with size at least as big as that of a pointer in the given address spac...
Definition: IRBuilder.h:569
void setFastMathFlags(FastMathFlags NewFMF)
Set the fast-math flags to be used with generated fp-math operators.
Definition: IRBuilder.h:305
ConstantInt * getInt64(uint64_t C)
Get a constant 64-bit value.
Definition: IRBuilder.h:485
PHINode * CreatePHI(Type *Ty, unsigned NumReservedValues, const Twine &Name="")
Definition: IRBuilder.h:2380
Value * CreateSub(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Definition: IRBuilder.h:1338
ConstantInt * getIntN(unsigned N, uint64_t C)
Get a constant N-bit value, zero extended or truncated from a 64-bit value.
Definition: IRBuilder.h:491
BranchInst * CreateCondBr(Value *Cond, BasicBlock *True, BasicBlock *False, MDNode *BranchWeights=nullptr, MDNode *Unpredictable=nullptr)
Create a conditional 'br Cond, TrueDest, FalseDest' instruction.
Definition: IRBuilder.h:1114
LoadInst * CreateLoad(Type *Ty, Value *Ptr, const char *Name)
Provided to resolve 'CreateLoad(Ty, Ptr, "...")' correctly, instead of converting the string to 'bool...
Definition: IRBuilder.h:1789
Value * CreateShuffleVector(Value *V1, Value *V2, Value *Mask, const Twine &Name="")
Definition: IRBuilder.h:2477
Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Definition: IRBuilder.h:1321
Value * CreatePtrToInt(Value *V, Type *DestTy, const Twine &Name="")
Definition: IRBuilder.h:2100
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
Definition: IRBuilder.h:180
StoreInst * CreateAlignedStore(Value *Val, Value *Ptr, MaybeAlign Align, bool isVolatile=false)
Definition: IRBuilder.h:1825
CallInst * CreateCall(FunctionType *FTy, Value *Callee, ArrayRef< Value * > Args=std::nullopt, const Twine &Name="", MDNode *FPMathTag=nullptr)
Definition: IRBuilder.h:2395
Value * CreateGEP(Type *Ty, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &Name="", bool IsInBounds=false)
Definition: IRBuilder.h:1865
Value * CreateFMul(Value *L, Value *R, const Twine &Name="", MDNode *FPMD=nullptr)
Definition: IRBuilder.h:1581
Value * CreateFNeg(Value *V, const Twine &Name="", MDNode *FPMathTag=nullptr)
Definition: IRBuilder.h:1729
CallInst * CreateMemCpy(Value *Dst, MaybeAlign DstAlign, Value *Src, MaybeAlign SrcAlign, uint64_t Size, bool isVolatile=false, MDNode *TBAATag=nullptr, MDNode *TBAAStructTag=nullptr, MDNode *ScopeTag=nullptr, MDNode *NoAliasTag=nullptr)
Create and insert a memcpy between the specified pointers.
Definition: IRBuilder.h:653
Value * CreateMul(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Definition: IRBuilder.h:1355
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2649
static InstructionCost getInvalid(CostType Val=0)
void setFastMathFlags(FastMathFlags FMF)
Convenience function for setting multiple fast-math flags on this instruction, which must be an opera...
const BasicBlock * getParent() const
Definition: Instruction.h:150
InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
FastMathFlags getFastMathFlags() const LLVM_READONLY
Convenience function for getting all the fast-math flags, which must be an operator which supports th...
A wrapper class for inspecting calls to intrinsic functions.
Definition: IntrinsicInst.h:47
Intrinsic::ID getIntrinsicID() const
Return the intrinsic ID of this intrinsic.
Definition: IntrinsicInst.h:54
An instruction for reading from memory.
Definition: Instructions.h:184
bool isVolatile() const
Return true if this is a load from a volatile memory location.
Definition: Instructions.h:230
Align getAlign() const
Return the alignment of the access that is being performed.
Definition: Instructions.h:236
TypeSize getValue() const
Analysis pass that exposes the LoopInfo for a function.
Definition: LoopInfo.h:566
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void printPipeline(raw_ostream &OS, function_ref< StringRef(StringRef)> MapClassName2PassName)
This class implements a map that also provides access to all stored values in a deterministic order.
Definition: MapVector.h:36
iterator end()
Definition: MapVector.h:71
iterator find(const KeyT &Key)
Definition: MapVector.h:167
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: MapVector.h:141
CallInst * CreateMatrixTranspose(Value *Matrix, unsigned Rows, unsigned Columns, const Twine &Name="")
Create a llvm.matrix.transpose call, transposing Matrix with Rows rows and Columns columns.
CallInst * CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows, unsigned LHSColumns, unsigned RHSColumns, const Twine &Name="")
Create a llvm.matrix.multiply call, multiplying matrixes LHS and RHS.
Representation for a specific memory location.
static MemoryLocation get(const LoadInst *LI)
Return a location with information about the memory reference by the given instruction.
LocationSize Size
The maximum size of the location, in address-units, or UnknownSize if the size is not known.
const Value * Ptr
The address of the start of the location.
The optimization diagnostic interface.
bool allowExtraAnalysis(StringRef PassName) const
Whether we allow for extra compile-time budget to perform more analysis to produce fewer false positi...
void emit(DiagnosticInfoOptimizationBase &OptDiag)
Output the remark via the diagnostic handler and to the optimization record file.
Diagnostic information for applied optimization remarks.
static PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
Definition: Constants.cpp:1827
A set of analyses that are preserved following a run of a transformation pass.
Definition: Analysis.h:109
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: Analysis.h:115
void preserve()
Mark an analysis as preserved.
Definition: Analysis.h:129
A vector that has set insertion semantics.
Definition: SetVector.h:57
size_type size() const
Determine the number of elements in the SetVector.
Definition: SetVector.h:98
size_type count(const key_type &key) const
Count the number of elements of a given key in the SetVector.
Definition: SetVector.h:264
bool insert(const value_type &X)
Insert a new element into the SetVector.
Definition: SetVector.h:162
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
Definition: SmallPtrSet.h:321
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
Definition: SmallPtrSet.h:360
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:342
bool contains(ConstPtrType Ptr) const
Definition: SmallPtrSet.h:366
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
Definition: SmallPtrSet.h:427
A SetVector that performs no allocations if smaller than a certain size.
Definition: SetVector.h:370
SmallSet - This maintains a set of unique values, optimizing for the case when the set is small (less...
Definition: SmallSet.h:135
bool empty() const
Definition: SmallSet.h:159
bool erase(const T &V)
Definition: SmallSet.h:207
std::pair< const_iterator, bool > insert(const T &V)
insert - Insert an element into the set if it isn't already there.
Definition: SmallSet.h:179
bool empty() const
Definition: SmallVector.h:94
size_t size() const
Definition: SmallVector.h:91
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: SmallVector.h:586
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
Definition: SmallVector.h:696
void push_back(const T &Elt)
Definition: SmallVector.h:426
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1209
An instruction for storing to memory.
Definition: Instructions.h:317
Align getAlign() const
Definition: Instructions.h:369
bool isVolatile() const
Return true if this is a store to a volatile memory location.
Definition: Instructions.h:361
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:50
StringRef drop_front(size_t N=1) const
Return a StringRef equal to 'this' but with the first N elements dropped.
Definition: StringRef.h:605
constexpr size_t size() const
size - Get the string size.
Definition: StringRef.h:137
Analysis pass providing the TargetTransformInfo.
This pass provides access to the codegen interfaces that are needed for IR-level transformations.
TypeSize getRegisterBitWidth(RegisterKind K) const
InstructionCost getMemoryOpCost(unsigned Opcode, Type *Src, Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput, OperandValueInfo OpdInfo={OK_AnyValue, OP_None}, const Instruction *I=nullptr) const
InstructionCost getArithmeticReductionCost(unsigned Opcode, VectorType *Ty, std::optional< FastMathFlags > FMF, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput) const
Calculate the cost of vector reduction intrinsics.
unsigned getRegisterClassForType(bool Vector, Type *Ty=nullptr) const
@ TCK_RecipThroughput
Reciprocal throughput.
InstructionCost getArithmeticInstrCost(unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput, TTI::OperandValueInfo Opd1Info={TTI::OK_AnyValue, TTI::OP_None}, TTI::OperandValueInfo Opd2Info={TTI::OK_AnyValue, TTI::OP_None}, ArrayRef< const Value * > Args=ArrayRef< const Value * >(), const Instruction *CxtI=nullptr) const
This is an approximation of reciprocal throughput of a math/logic op.
InstructionCost getShuffleCost(ShuffleKind Kind, VectorType *Tp, ArrayRef< int > Mask=std::nullopt, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput, int Index=0, VectorType *SubTp=nullptr, ArrayRef< const Value * > Args=std::nullopt) const
unsigned getNumberOfRegisters(unsigned ClassID) const
@ SK_Splice
Concatenates elements from the first input vector with elements of the second input vector.
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
Definition: Twine.h:81
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
bool isVoidTy() const
Return true if this is 'void'.
Definition: Type.h:140
Type * getScalarType() const
If this is a vector type, return the element type, otherwise return 'this'.
Definition: Type.h:348
UnaryOps getOpcode() const
Definition: InstrTypes.h:203
A Use represents the edge between a Value definition and its users.
Definition: Use.h:43
Value * getOperand(unsigned i) const
Definition: User.h:169
See the file comment.
Definition: ValueMap.h:84
size_type count(const KeyT &Val) const
Return 1 if the specified key is in the map, 0 otherwise.
Definition: ValueMap.h:151
iterator find(const KeyT &Val)
Definition: ValueMap.h:155
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: ValueMap.h:172
iterator end()
Definition: ValueMap.h:135
bool erase(const KeyT &Val)
Definition: ValueMap.h:190
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
user_iterator user_begin()
Definition: Value.h:397
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition: Value.h:434
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:534
iterator_range< user_iterator > users()
Definition: Value.h:421
bool hasNUses(unsigned N) const
Return true if this Value has exactly N uses.
Definition: Value.cpp:149
iterator_range< use_iterator > uses()
Definition: Value.h:376
StringRef getName() const
Return a constant reference to the value's name.
Definition: Value.cpp:309
Type * getElementType() const
Definition: DerivedTypes.h:436
constexpr ScalarTy getFixedValue() const
Definition: TypeSize.h:187
An efficient, type-erasing, non-owning reference to a callable.
A range adaptor for a pair of iterators.
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition: raw_ostream.h:52
A raw_ostream that writes to an std::string.
Definition: raw_ostream.h:660
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
Definition: BitmaskEnum.h:121
@ C
The default llvm calling convention, compatible with C.
Definition: CallingConv.h:34
StringRef getBaseName(ID id)
Return the LLVM name for an intrinsic, without encoded types for overloading, such as "llvm....
Definition: Function.cpp:1008
Function * getDeclaration(Module *M, ID id, ArrayRef< Type * > Tys=std::nullopt)
Create or insert an LLVM Function declaration for an intrinsic, and return it.
Definition: Function.cpp:1447
TwoOps_match< ValueOpTy, PointerOpTy, Instruction::Store > m_Store(const ValueOpTy &ValueOp, const PointerOpTy &PointerOp)
Matches StoreInst.
specific_intval< false > m_SpecificInt(APInt V)
Match a specific integer value or vector with all elements equal to the value.
Definition: PatternMatch.h:862
BinaryOp_match< LHS, RHS, Instruction::Add > m_Add(const LHS &L, const RHS &R)
Definition: PatternMatch.h:982
class_match< BinaryOperator > m_BinOp()
Match an arbitrary binary operation and ignore it.
Definition: PatternMatch.h:84
BinaryOp_match< LHS, RHS, Instruction::FMul > m_FMul(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
Definition: PatternMatch.h:147
BinaryOp_match< LHS, RHS, Instruction::FAdd > m_FAdd(const LHS &L, const RHS &R)
Definition: PatternMatch.h:988
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
OneUse_match< T > m_OneUse(const T &SubPattern)
Definition: PatternMatch.h:67
OneOps_match< OpTy, Instruction::Load > m_Load(const OpTy &Op)
Matches LoadInst.
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
Definition: PatternMatch.h:76
match_combine_or< LTy, RTy > m_CombineOr(const LTy &L, const RTy &R)
Combine two pattern matchers matching L || R.
Definition: PatternMatch.h:218
@ SS
Definition: X86.h:207
ValuesClass values(OptsTy... Options)
Helper to build a ValuesClass by forwarding a variable number of arguments as an initializer list to ...
Definition: CommandLine.h:718
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:450
ElementType
The element type of an SRV or UAV resource.
Definition: DXILABI.h:68
DiagnosticInfoOptimizationBase::Argument NV
NodeAddr< PhiNode * > Phi
Definition: RDFGraph.h:390
NodeAddr< FuncNode * > Func
Definition: RDFGraph.h:393
const_iterator begin(StringRef path, Style style=Style::native)
Get begin iterator over path.
Definition: Path.cpp:228
const_iterator end(StringRef path)
Get end iterator over path.
Definition: Path.cpp:237
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
@ Offset
Definition: DWP.cpp:456
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
Definition: STLExtras.h:1689
auto enumerate(FirstRange &&First, RestRanges &&...Rest)
Given two or more input ranges, returns a new range whose values are are tuples (A,...
Definition: STLExtras.h:2386
auto successors(const MachineBasicBlock *BB)
bool operator!=(uint64_t V1, const APInt &V2)
Definition: APInt.h:2036
iterator_range< T > make_range(T x, T y)
Convenience function for iterating over sub-ranges.
std::string & operator+=(std::string &buffer, StringRef string)
Definition: StringRef.h:899
const Value * getUnderlyingObject(const Value *V, unsigned MaxLookup=6)
This method strips off any GEP address adjustments and pointer casts from the specified value,...
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition: STLExtras.h:665
Value * concatenateVectors(IRBuilderBase &Builder, ArrayRef< Value * > Vecs)
Concatenate a list of vectors.
bool operator==(const AddressRangeValuePair &LHS, const AddressRangeValuePair &RHS)
const Value * getPointerOperand(const Value *V)
A helper function that returns the pointer operand of a load, store or GEP instruction.
void addStringMetadataToLoop(Loop *TheLoop, const char *MDString, unsigned V=0)
Set input string into loop metadata by keeping other values intact.
Definition: LoopUtils.cpp:214
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1738
auto reverse(ContainerTy &&C)
Definition: STLExtras.h:428
Error write(MCStreamer &Out, ArrayRef< std::string > Inputs, OnCuIndexOverflow OverflowOptValue)
Definition: DWP.cpp:601
void sort(IteratorTy Start, IteratorTy End)
Definition: STLExtras.h:1656
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
void report_fatal_error(Error Err, bool gen_crash_diag=true)
Report a serious error, calling any installed error handler.
Definition: Error.cpp:156
raw_fd_ostream & errs()
This returns a reference to a raw_ostream for standard error.
@ FMulAdd
Sum of float products with llvm.fmuladd(a * b + sum).
@ Add
Sum of integers.
DWARFExpression::Operation Op
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition: Casting.h:565
BasicBlock * SplitBlock(BasicBlock *Old, BasicBlock::iterator SplitPt, DominatorTree *DT, LoopInfo *LI=nullptr, MemorySSAUpdater *MSSAU=nullptr, const Twine &BBName="", bool Before=false)
Split the specified block at the specified instruction.
Align commonAlignment(Align A, uint64_t Offset)
Returns the alignment that satisfies both alignments.
Definition: Alignment.h:212
llvm::SmallVector< int, 16 > createSequentialMask(unsigned Start, unsigned NumInts, unsigned NumUndefs)
Create a sequential shuffle mask.
#define N
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition: Alignment.h:39
This struct is a compact representation of a valid (power of two) or undefined (0) alignment.
Definition: Alignment.h:117
A CRTP mix-in to automatically provide informational APIs needed for passes.
Definition: PassManager.h:91
A helper struct to create IR loop nests for tiling in IR of the following form: for ColumnLoop....
Definition: MatrixUtils.h:31