LLVM 17.0.0git
LowerMatrixIntrinsics.cpp
Go to the documentation of this file.
1//===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Lower matrix intrinsics to vector operations.
10//
11// TODO:
12// * Improve fusion:
13// * Support more cases, e.g. multiply-add, multiply-sub, operands/results
14// transposed.
15// * Improve cost-modeling, e.g. choose different number of rows/columns
16// columns for tiles, consider cost of copies on alias.
17//
18//===----------------------------------------------------------------------===//
19
30#include "llvm/IR/CFG.h"
31#include "llvm/IR/DataLayout.h"
33#include "llvm/IR/Function.h"
34#include "llvm/IR/IRBuilder.h"
40#include "llvm/Pass.h"
43#include "llvm/Support/Debug.h"
48
49#include <cmath>
50
51using namespace llvm;
52using namespace PatternMatch;
53
54#define DEBUG_TYPE "lower-matrix-intrinsics"
55
56static cl::opt<bool>
57 FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden,
58 cl::desc("Enable/disable fusing matrix instructions."));
59// TODO: Allow and use non-square tiles.
61 "fuse-matrix-tile-size", cl::init(4), cl::Hidden,
63 "Tile size for matrix instruction fusion using square-shaped tiles."));
64static cl::opt<bool> TileUseLoops("fuse-matrix-use-loops", cl::init(false),
66 cl::desc("Generate loop nest for tiling."));
68 "force-fuse-matrix", cl::init(false), cl::Hidden,
69 cl::desc("Force matrix instruction fusion even if not profitable."));
71 "matrix-allow-contract", cl::init(false), cl::Hidden,
72 cl::desc("Allow the use of FMAs if available and profitable. This may "
73 "result in different results, due to less rounding error."));
74
76
78 "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor),
79 cl::desc("Sets the default matrix layout"),
81 "Use column-major layout"),
83 "Use row-major layout")));
84
85static cl::opt<bool> PrintAfterTransposeOpt("matrix-print-after-transpose-opt",
86 cl::init(false));
87
88/// Helper function to either return Scope, if it is a subprogram or the
89/// attached subprogram for a local scope.
91 if (auto *Subprogram = dyn_cast<DISubprogram>(Scope))
92 return Subprogram;
93 return cast<DILocalScope>(Scope)->getSubprogram();
94}
95
96/// Erase \p V from \p BB and move \II forward to avoid invalidating
97/// iterators.
99 BasicBlock &BB) {
100 auto *Inst = cast<Instruction>(V);
101 // Still used, don't erase.
102 if (!Inst->use_empty())
103 return;
104 if (II != BB.rend() && Inst == &*II)
105 ++II;
106 Inst->eraseFromParent();
107}
108
109/// Return true if V is a splat of a value (which is used when multiplying a
110/// matrix with a scalar).
111static bool isSplat(Value *V) {
112 if (auto *SV = dyn_cast<ShuffleVectorInst>(V))
113 return SV->isZeroEltSplat();
114 return false;
115}
116
117/// Match any mul operation (fp or integer).
118template <typename LTy, typename RTy>
119auto m_AnyMul(const LTy &L, const RTy &R) {
120 return m_CombineOr(m_Mul(L, R), m_FMul(L, R));
121}
122
123/// Match any add operation (fp or integer).
124template <typename LTy, typename RTy>
125auto m_AnyAdd(const LTy &L, const RTy &R) {
126 return m_CombineOr(m_Add(L, R), m_FAdd(L, R));
127}
128
129namespace {
130
131// Given an element pointer \p BasePtr to the start of a (sub) matrix, compute
132// the start address of vector \p VecIdx with type (\p EltType x \p NumElements)
133// assuming \p Stride elements between start two consecutive vectors.
134// \p Stride must be >= \p NumElements.
135// For column-major matrixes, the function computes the address of a column
136// vectors and \p NumElements must be set to the number of elements in a column
137// (= number of rows of the matrix). For row-major matrixes, the function
138// computes the address of a row vector and \p NumElements must be set to the
139// number of elements in a column (= number of columns of the matrix).
140//
141// Consider a 4x4 matrix in column-mjaor layout like below
142//
143// 0 1 2 3
144// 0 v_0_0 v_0_1 v_0_2 v_0_3
145// 1 v_1_0 v_1_1 v_1_2 v_1_3
146// 2 v_2_0 v_2_1 v_2_2 v_2_3
147// 3 v_3_0 v_3_1 v_3_2 v_3_3
148
149// To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1,
150// we need a pointer to the first element of the submatrix as base pointer.
151// Then we can use computeVectorAddr to compute the addresses for the columns
152// of the sub-matrix.
153//
154// Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..)
155// -> just returns Base
156// Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..)
157// -> returns Base + (1 * 4)
158// Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..)
159// -> returns Base + (2 * 4)
160//
161// The graphic below illustrates the number of elements in a column (marked
162// with |) and the number of skipped elements (marked with }).
163//
164// v_0_0 v_0_1 {v_0_2 {v_0_3
165// Base Col 1 Col 2
166// | | |
167// v_1_0 |v_1_1 |v_1_2 |v_1_3
168// v_2_0 |v_2_1 |v_2_2 |v_2_3
169// v_3_0 {v_3_1 {v_3_2 v_3_3
170//
171Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride,
172 unsigned NumElements, Type *EltType,
173 IRBuilder<> &Builder) {
174
175 assert((!isa<ConstantInt>(Stride) ||
176 cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) &&
177 "Stride must be >= the number of elements in the result vector.");
178 unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
179
180 // Compute the start of the vector with index VecIdx as VecIdx * Stride.
181 Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start");
182
183 // Get pointer to the start of the selected vector. Skip GEP creation,
184 // if we select vector 0.
185 if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->isZero())
186 VecStart = BasePtr;
187 else
188 VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep");
189
190 // Cast elementwise vector start pointer to a pointer to a vector
191 // (EltType x NumElements)*.
192 auto *VecType = FixedVectorType::get(EltType, NumElements);
193 Type *VecPtrType = PointerType::get(VecType, AS);
194 return Builder.CreatePointerCast(VecStart, VecPtrType, "vec.cast");
195}
196
197/// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
198///
199/// Currently, the lowering for each matrix intrinsic is done as follows:
200/// 1. Propagate the shape information from intrinsics to connected
201/// instructions.
202/// 2. Lower instructions with shape information (assuming column-major layout).
203/// The lowering works similarly using row-major layout.
204/// 2.1. Get column vectors for each argument. If we already lowered the
205/// definition of an argument, use the produced column vectors directly.
206/// If not, split the operand vector containing an embedded matrix into
207/// a set of column vectors,
208/// 2.2. Lower the instruction in terms of column major operations, which
209/// yields a set of column vectors containing result matrix. Note that we
210/// lower all instructions that have shape information. Besides the
211/// intrinsics, this includes stores for example.
212/// 2.3. Update uses of the lowered instruction. If we have shape information
213/// for a user, there is nothing to do, as we will look up the result
214/// column matrix when lowering the user. For other uses, we embed the
215/// result matrix in a flat vector and update the use.
216/// 2.4. Cache the result column matrix for the instruction we lowered
217/// 3. After we lowered all instructions in a function, remove the now
218/// obsolete instructions.
219///
220class LowerMatrixIntrinsics {
221 Function &Func;
222 const DataLayout &DL;
224 AliasAnalysis *AA;
225 DominatorTree *DT;
226 LoopInfo *LI;
228
229 /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation.
230 struct OpInfoTy {
231 /// Number of stores emitted to generate this matrix.
232 unsigned NumStores = 0;
233 /// Number of loads emitted to generate this matrix.
234 unsigned NumLoads = 0;
235 /// Number of compute operations emitted to generate this matrix.
236 unsigned NumComputeOps = 0;
237 /// Most of the time transposes can be fused with matrix multiplies or can
238 /// be folded away via algebraic simplifications. This is the number of
239 /// transposes that we failed to make "free" via such optimizations.
240 unsigned NumExposedTransposes = 0;
241
242 OpInfoTy &operator+=(const OpInfoTy &RHS) {
243 NumStores += RHS.NumStores;
244 NumLoads += RHS.NumLoads;
245 NumComputeOps += RHS.NumComputeOps;
246 NumExposedTransposes += RHS.NumExposedTransposes;
247 return *this;
248 }
249 };
250
251 /// Wrapper class representing a matrix as a set of vectors, either in row or
252 /// column major layout. All vectors must have the same vector type.
253 class MatrixTy {
255
256 OpInfoTy OpInfo;
257
258 bool IsColumnMajor = true;
259
260 public:
261 MatrixTy() : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
262 MatrixTy(ArrayRef<Value *> Vectors)
263 : Vectors(Vectors.begin(), Vectors.end()),
264 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
265 MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy)
266 : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {
267
268 unsigned D = isColumnMajor() ? NumColumns : NumRows;
269 for (unsigned J = 0; J < D; ++J)
271 EltTy, isColumnMajor() ? NumRows : NumColumns)));
272 }
273
274 Value *getVector(unsigned i) const { return Vectors[i]; }
275 Value *getColumn(unsigned i) const {
276 assert(isColumnMajor() && "only supported for column-major matrixes");
277 return Vectors[i];
278 }
279 Value *getRow(unsigned i) const {
280 assert(!isColumnMajor() && "only supported for row-major matrixes");
281 return Vectors[i];
282 }
283
284 void setVector(unsigned i, Value *V) { Vectors[i] = V; }
285
286 Type *getElementType() const { return getVectorTy()->getElementType(); }
287
288 unsigned getNumVectors() const {
289 if (isColumnMajor())
290 return getNumColumns();
291 return getNumRows();
292 }
293
294 unsigned getNumColumns() const {
295 if (isColumnMajor())
296 return Vectors.size();
297 else {
298 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
299 return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
300 }
301 }
302 unsigned getNumRows() const {
303 if (isColumnMajor()) {
304 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
305 return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
306 } else
307 return Vectors.size();
308 }
309
310 void addVector(Value *V) { Vectors.push_back(V); }
311 VectorType *getColumnTy() {
312 assert(isColumnMajor() && "only supported for column-major matrixes");
313 return getVectorTy();
314 }
315
316 VectorType *getVectorTy() const {
317 return cast<VectorType>(Vectors[0]->getType());
318 }
319
321 assert(isColumnMajor() &&
322 "columns() only supported for column-major matrixes");
323 return make_range(Vectors.begin(), Vectors.end());
324 }
325
327 return make_range(Vectors.begin(), Vectors.end());
328 }
329
330 /// Embed the vectors of the matrix into a flat vector by concatenating
331 /// them.
332 Value *embedInVector(IRBuilder<> &Builder) const {
333 return Vectors.size() == 1 ? Vectors[0]
334 : concatenateVectors(Builder, Vectors);
335 }
336
337 MatrixTy &addNumLoads(unsigned N) {
338 OpInfo.NumLoads += N;
339 return *this;
340 }
341
342 void setNumLoads(unsigned N) { OpInfo.NumLoads = N; }
343
344 MatrixTy &addNumStores(unsigned N) {
345 OpInfo.NumStores += N;
346 return *this;
347 }
348
349 MatrixTy &addNumExposedTransposes(unsigned N) {
350 OpInfo.NumExposedTransposes += N;
351 return *this;
352 }
353
354 MatrixTy &addNumComputeOps(unsigned N) {
355 OpInfo.NumComputeOps += N;
356 return *this;
357 }
358
359 unsigned getNumStores() const { return OpInfo.NumStores; }
360 unsigned getNumLoads() const { return OpInfo.NumLoads; }
361 unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; }
362
363 const OpInfoTy &getOpInfo() const { return OpInfo; }
364
365 bool isColumnMajor() const { return IsColumnMajor; }
366
367 unsigned getStride() const {
368 if (isColumnMajor())
369 return getNumRows();
370 return getNumColumns();
371 }
372
373 /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the
374 /// matrix is column-major, the result vector is extracted from a column
375 /// vector, otherwise from a row vector.
376 Value *extractVector(unsigned I, unsigned J, unsigned NumElts,
377 IRBuilder<> &Builder) const {
378 Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I);
379 assert(cast<FixedVectorType>(Vec->getType())->getNumElements() >=
380 NumElts &&
381 "Extracted vector will contain poison values");
382 return Builder.CreateShuffleVector(
383 Vec, createSequentialMask(isColumnMajor() ? I : J, NumElts, 0),
384 "block");
385 }
386 };
387
388 struct ShapeInfo {
389 unsigned NumRows;
390 unsigned NumColumns;
391
392 bool IsColumnMajor;
393
394 ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0)
395 : NumRows(NumRows), NumColumns(NumColumns),
396 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
397
398 ShapeInfo(Value *NumRows, Value *NumColumns)
399 : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(),
400 cast<ConstantInt>(NumColumns)->getZExtValue()) {}
401
402 bool operator==(const ShapeInfo &other) {
403 return NumRows == other.NumRows && NumColumns == other.NumColumns;
404 }
405 bool operator!=(const ShapeInfo &other) { return !(*this == other); }
406
407 /// Returns true if shape-information is defined, meaning both dimensions
408 /// are != 0.
409 operator bool() const {
410 assert(NumRows == 0 || NumColumns != 0);
411 return NumRows != 0;
412 }
413
414 unsigned getStride() const {
415 if (IsColumnMajor)
416 return NumRows;
417 return NumColumns;
418 }
419
420 unsigned getNumVectors() const {
421 if (IsColumnMajor)
422 return NumColumns;
423 return NumRows;
424 }
425
426 /// Returns the transposed shape.
427 ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); }
428 };
429
430 /// Maps instructions to their shape information. The shape information
431 /// describes the shape to be used while lowering. This matches the shape of
432 /// the result value of the instruction, with the only exceptions being store
433 /// instructions and the matrix_column_major_store intrinsics. For those, the
434 /// shape information indicates that those instructions should be lowered
435 /// using shape information as well. A ValueMap is used so that when
436 /// sub-passes like optimizeTransposes performs RAUW the map stays
437 /// up-to-date.
439
440 /// List of instructions to remove. While lowering, we are not replacing all
441 /// users of a lowered instruction, if shape information is available and
442 /// those need to be removed after we finished lowering.
444
445 /// Map from instructions to their produced column matrix.
446 MapVector<Value *, MatrixTy> Inst2ColumnMatrix;
447
448private:
449 static FastMathFlags getFastMathFlags(Instruction *Inst) {
450 FastMathFlags FMF;
451
452 if (isa<FPMathOperator>(*Inst))
453 FMF = Inst->getFastMathFlags();
454
456
457 return FMF;
458 }
459
460public:
461 LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI,
464 : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), AA(AA), DT(DT),
465 LI(LI), ORE(ORE) {}
466
467 unsigned getNumOps(Type *VT) {
468 assert(isa<VectorType>(VT) && "Expected vector type");
469 return getNumOps(VT->getScalarType(),
470 cast<FixedVectorType>(VT)->getNumElements());
471 }
472
473 /// Is this the minimal version executed in the backend pipelines.
474 bool isMinimal() const {
475 return !DT;
476 }
477
478 /// Return the estimated number of vector ops required for an operation on
479 /// \p VT * N.
480 unsigned getNumOps(Type *ST, unsigned N) {
481 return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedValue() /
484 .getFixedValue()));
485 }
486
487 /// Return the set of vectors that a matrix value is lowered to.
488 ///
489 /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise
490 /// split the flat vector \p MatrixVal containing a matrix with shape \p SI
491 /// into vectors.
492 MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI,
493 IRBuilder<> &Builder) {
494 VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType());
495 assert(VType && "MatrixVal must be a vector type");
496 assert(cast<FixedVectorType>(VType)->getNumElements() ==
497 SI.NumRows * SI.NumColumns &&
498 "The vector size must match the number of matrix elements");
499
500 // Check if we lowered MatrixVal using shape information. In that case,
501 // return the existing matrix, if it matches the requested shape
502 // information. If there is a mis-match, embed the result in a flat
503 // vector and split it later.
504 auto Found = Inst2ColumnMatrix.find(MatrixVal);
505 if (Found != Inst2ColumnMatrix.end()) {
506 MatrixTy &M = Found->second;
507 // Return the found matrix, if its shape matches the requested shape
508 // information
509 if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns())
510 return M;
511
512 MatrixVal = M.embedInVector(Builder);
513 }
514
515 // Otherwise split MatrixVal.
516 SmallVector<Value *, 16> SplitVecs;
517 for (unsigned MaskStart = 0;
518 MaskStart < cast<FixedVectorType>(VType)->getNumElements();
519 MaskStart += SI.getStride()) {
520 Value *V = Builder.CreateShuffleVector(
521 MatrixVal, createSequentialMask(MaskStart, SI.getStride(), 0),
522 "split");
523 SplitVecs.push_back(V);
524 }
525
526 return {SplitVecs};
527 }
528
529 /// If \p V already has a known shape return false. Otherwise set the shape
530 /// for instructions that support it.
531 bool setShapeInfo(Value *V, ShapeInfo Shape) {
532 assert(Shape && "Shape not set");
533 if (isa<UndefValue>(V) || !supportsShapeInfo(V))
534 return false;
535
536 auto SIter = ShapeMap.find(V);
537 if (SIter != ShapeMap.end()) {
538 LLVM_DEBUG(dbgs() << " not overriding existing shape: "
539 << SIter->second.NumRows << " "
540 << SIter->second.NumColumns << " for " << *V << "\n");
541 return false;
542 }
543
544 ShapeMap.insert({V, Shape});
545 LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns
546 << " for " << *V << "\n");
547 return true;
548 }
549
550 bool isUniformShape(Value *V) {
551 Instruction *I = dyn_cast<Instruction>(V);
552 if (!I)
553 return true;
554
555 switch (I->getOpcode()) {
556 case Instruction::FAdd:
557 case Instruction::FSub:
558 case Instruction::FMul: // Scalar multiply.
559 case Instruction::FNeg:
560 case Instruction::Add:
561 case Instruction::Mul:
562 case Instruction::Sub:
563 return true;
564 default:
565 return false;
566 }
567 }
568
569 /// Returns true if shape information can be used for \p V. The supported
570 /// instructions must match the instructions that can be lowered by this pass.
571 bool supportsShapeInfo(Value *V) {
572 Instruction *Inst = dyn_cast<Instruction>(V);
573 if (!Inst)
574 return false;
575
576 IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst);
577 if (II)
578 switch (II->getIntrinsicID()) {
579 case Intrinsic::matrix_multiply:
580 case Intrinsic::matrix_transpose:
581 case Intrinsic::matrix_column_major_load:
582 case Intrinsic::matrix_column_major_store:
583 return true;
584 default:
585 return false;
586 }
587 return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
588 }
589
590 /// Propagate the shape information of instructions to their users.
591 /// The work list contains instructions for which we can compute the shape,
592 /// either based on the information provided by matrix intrinsics or known
593 /// shapes of operands.
595 propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) {
597 // Pop an element for which we guaranteed to have at least one of the
598 // operand shapes. Add the shape for this and then add users to the work
599 // list.
600 LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n");
601 while (!WorkList.empty()) {
602 Instruction *Inst = WorkList.pop_back_val();
603
604 // New entry, set the value and insert operands
605 bool Propagate = false;
606
607 Value *MatrixA;
608 Value *MatrixB;
609 Value *M;
610 Value *N;
611 Value *K;
612 if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>(
613 m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
614 m_Value(N), m_Value(K)))) {
615 Propagate = setShapeInfo(Inst, {M, K});
616 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>(
617 m_Value(MatrixA), m_Value(M), m_Value(N)))) {
618 // Flip dimensions.
619 Propagate = setShapeInfo(Inst, {N, M});
620 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>(
621 m_Value(MatrixA), m_Value(), m_Value(),
622 m_Value(), m_Value(M), m_Value(N)))) {
623 Propagate = setShapeInfo(Inst, {N, M});
624 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>(
625 m_Value(), m_Value(), m_Value(), m_Value(M),
626 m_Value(N)))) {
627 Propagate = setShapeInfo(Inst, {M, N});
628 } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) {
629 auto OpShape = ShapeMap.find(MatrixA);
630 if (OpShape != ShapeMap.end())
631 setShapeInfo(Inst, OpShape->second);
632 continue;
633 } else if (isUniformShape(Inst)) {
634 // Find the first operand that has a known shape and use that.
635 for (auto &Op : Inst->operands()) {
636 auto OpShape = ShapeMap.find(Op.get());
637 if (OpShape != ShapeMap.end()) {
638 Propagate |= setShapeInfo(Inst, OpShape->second);
639 break;
640 }
641 }
642 }
643
644 if (Propagate) {
645 NewWorkList.push_back(Inst);
646 for (auto *User : Inst->users())
647 if (ShapeMap.count(User) == 0)
648 WorkList.push_back(cast<Instruction>(User));
649 }
650 }
651
652 return NewWorkList;
653 }
654
655 /// Propagate the shape to operands of instructions with shape information.
656 /// \p Worklist contains the instruction for which we already know the shape.
658 propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) {
660
661 auto pushInstruction = [](Value *V,
663 Instruction *I = dyn_cast<Instruction>(V);
664 if (I)
665 WorkList.push_back(I);
666 };
667 // Pop an element with known shape. Traverse the operands, if their shape
668 // derives from the result shape and is unknown, add it and add them to the
669 // worklist.
670 LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n");
671 while (!WorkList.empty()) {
672 Value *V = WorkList.pop_back_val();
673
674 size_t BeforeProcessingV = WorkList.size();
675 if (!isa<Instruction>(V))
676 continue;
677
678 Value *MatrixA;
679 Value *MatrixB;
680 Value *M;
681 Value *N;
682 Value *K;
683 if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>(
684 m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
685 m_Value(N), m_Value(K)))) {
686 if (setShapeInfo(MatrixA, {M, N}))
687 pushInstruction(MatrixA, WorkList);
688
689 if (setShapeInfo(MatrixB, {N, K}))
690 pushInstruction(MatrixB, WorkList);
691
692 } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>(
693 m_Value(MatrixA), m_Value(M), m_Value(N)))) {
694 // Flip dimensions.
695 if (setShapeInfo(MatrixA, {M, N}))
696 pushInstruction(MatrixA, WorkList);
697 } else if (match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>(
698 m_Value(MatrixA), m_Value(), m_Value(), m_Value(),
699 m_Value(M), m_Value(N)))) {
700 if (setShapeInfo(MatrixA, {M, N})) {
701 pushInstruction(MatrixA, WorkList);
702 }
703 } else if (isa<LoadInst>(V) ||
704 match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) {
705 // Nothing to do, no matrix input.
706 } else if (isa<StoreInst>(V)) {
707 // Nothing to do. We forward-propagated to this so we would just
708 // backward propagate to an instruction with an already known shape.
709 } else if (isUniformShape(V)) {
710 // Propagate to all operands.
711 ShapeInfo Shape = ShapeMap[V];
712 for (Use &U : cast<Instruction>(V)->operands()) {
713 if (setShapeInfo(U.get(), Shape))
714 pushInstruction(U.get(), WorkList);
715 }
716 }
717 // After we discovered new shape info for new instructions in the
718 // worklist, we use their users as seeds for the next round of forward
719 // propagation.
720 for (size_t I = BeforeProcessingV; I != WorkList.size(); I++)
721 for (User *U : WorkList[I]->users())
722 if (isa<Instruction>(U) && V != U)
723 NewWorkList.push_back(cast<Instruction>(U));
724 }
725 return NewWorkList;
726 }
727
728 /// (Op0 op Op1)^T -> Op0^T op Op1^T
729 /// Transpose \p Op0 and \p Op1 of shape \p Shape0 and \p Shape1, then use
730 /// them on both sides of \p Operation.
731 Instruction *distributeTransposes(
732 Value *Op0, ShapeInfo Shape0, Value *Op1, ShapeInfo Shape1,
733 MatrixBuilder &Builder,
734 function_ref<Instruction *(Value *, ShapeInfo, Value *, ShapeInfo)>
735 Operation) {
736 Value *T0 = Builder.CreateMatrixTranspose(
737 Op0, Shape0.NumRows, Shape0.NumColumns, Op0->getName() + "_t");
738 // We are being run after shape prop, add shape for newly created
739 // instructions so that we lower them later.
740 setShapeInfo(T0, Shape0.t());
741 Value *T1 = Builder.CreateMatrixTranspose(
742 Op1, Shape1.NumRows, Shape1.NumColumns, Op1->getName() + "_t");
743 setShapeInfo(T1, Shape1.t());
744 return Operation(T0, Shape0.t(), T1, Shape1.t());
745 }
746
747 void updateShapeAndReplaceAllUsesWith(Instruction &Old, Value *New) {
748 // We need to remove Old from the ShapeMap otherwise RAUW will replace it
749 // with New. We should only add New it it supportsShapeInfo so we insert
750 // it conditionally instead.
751 auto S = ShapeMap.find(&Old);
752 if (S != ShapeMap.end()) {
753 ShapeMap.erase(S);
754 if (supportsShapeInfo(New))
755 ShapeMap.insert({New, S->second});
756 }
757 Old.replaceAllUsesWith(New);
758 }
759
760 /// Sink a top-level transpose inside matmuls and adds.
761 /// This creates and erases instructions as needed, and returns the newly
762 /// created instruction while updating the iterator to avoid invalidation. If
763 /// this returns nullptr, no new instruction was created.
765 BasicBlock &BB = *I.getParent();
766 IRBuilder<> IB(&I);
768
769 Value *TA, *TAMA, *TAMB;
770 ConstantInt *R, *K, *C;
771 if (!match(&I, m_Intrinsic<Intrinsic::matrix_transpose>(
773 return nullptr;
774
775 // Transpose of a transpose is a nop
776 Value *TATA;
777 if (match(TA, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TATA)))) {
778 updateShapeAndReplaceAllUsesWith(I, TATA);
779 eraseFromParentAndMove(&I, II, BB);
780 eraseFromParentAndMove(TA, II, BB);
781 return nullptr;
782 }
783
784 // k^T -> k
785 if (isSplat(TA)) {
786 updateShapeAndReplaceAllUsesWith(I, TA);
787 eraseFromParentAndMove(&I, II, BB);
788 return nullptr;
789 }
790
791 // (A * B)^t -> B^t * A^t
792 // RxK KxC CxK KxR
793 if (match(TA, m_Intrinsic<Intrinsic::matrix_multiply>(
794 m_Value(TAMA), m_Value(TAMB), m_ConstantInt(R),
796 auto NewInst = distributeTransposes(
797 TAMB, {K, C}, TAMA, {R, K}, Builder,
798 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
799 return Builder.CreateMatrixMultiply(T0, T1, Shape0.NumRows,
800 Shape0.NumColumns,
801 Shape1.NumColumns, "mmul");
802 });
803 updateShapeAndReplaceAllUsesWith(I, NewInst);
804 eraseFromParentAndMove(&I, II, BB);
805 eraseFromParentAndMove(TA, II, BB);
806 return NewInst;
807 }
808
809 // Same as above, but with a mul, which occurs when multiplied
810 // with a scalar.
811 // (A * k)^t -> A^t * k
812 // R x C RxC
813 if (match(TA, m_AnyMul(m_Value(TAMA), m_Value(TAMB))) &&
814 (isSplat(TAMA) || isSplat(TAMB))) {
815 IRBuilder<> LocalBuilder(&I);
816 // We know that the transposed operand is of shape RxC.
817 // An when multiplied with a scalar, the shape is preserved.
818 auto NewInst = distributeTransposes(
819 TAMA, {R, C}, TAMB, {R, C}, Builder,
820 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
821 bool IsFP = I.getType()->isFPOrFPVectorTy();
822 auto *Mul = IsFP ? LocalBuilder.CreateFMul(T0, T1, "mmul")
823 : LocalBuilder.CreateMul(T0, T1, "mmul");
824 auto *Result = cast<Instruction>(Mul);
825 setShapeInfo(Result, Shape0);
826 return Result;
827 });
828 updateShapeAndReplaceAllUsesWith(I, NewInst);
829 eraseFromParentAndMove(&I, II, BB);
830 eraseFromParentAndMove(TA, II, BB);
831 return NewInst;
832 }
833
834 // (A + B)^t -> A^t + B^t
835 // RxC RxC CxR CxR
836 if (match(TA, m_AnyAdd(m_Value(TAMA), m_Value(TAMB)))) {
837 IRBuilder<> LocalBuilder(&I);
838 auto NewInst = distributeTransposes(
839 TAMA, {R, C}, TAMB, {R, C}, Builder,
840 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
841 auto *FAdd =
842 cast<Instruction>(LocalBuilder.CreateFAdd(T0, T1, "mfadd"));
843 setShapeInfo(FAdd, Shape0);
844 return FAdd;
845 });
846 updateShapeAndReplaceAllUsesWith(I, NewInst);
847 eraseFromParentAndMove(&I, II, BB);
848 eraseFromParentAndMove(TA, II, BB);
849 return NewInst;
850 }
851
852 return nullptr;
853 }
854
855 void liftTranspose(Instruction &I) {
856 // Erase dead Instructions after lifting transposes from binops.
857 auto CleanupBinOp = [](Instruction &T, Value *A, Value *B) {
858 if (T.use_empty())
859 T.eraseFromParent();
860 if (A->use_empty())
861 cast<Instruction>(A)->eraseFromParent();
862 if (A != B && B->use_empty())
863 cast<Instruction>(B)->eraseFromParent();
864 };
865
866 Value *A, *B, *AT, *BT;
867 ConstantInt *R, *K, *C;
868 // A^t * B ^t -> (B * A)^t
869 if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>(
872 match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(AT))) &&
873 match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value((BT))))) {
874 IRBuilder<> IB(&I);
876 Value *M = Builder.CreateMatrixMultiply(
877 BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue());
878 setShapeInfo(M, {C, R});
879 Instruction *NewInst = Builder.CreateMatrixTranspose(M, C->getZExtValue(),
880 R->getZExtValue());
881 updateShapeAndReplaceAllUsesWith(I, NewInst);
882 CleanupBinOp(I, A, B);
883 }
884 // A^t + B ^t -> (A + B)^t
885 else if (match(&I, m_FAdd(m_Value(A), m_Value(B))) &&
886 match(A, m_Intrinsic<Intrinsic::matrix_transpose>(
887 m_Value(AT), m_ConstantInt(R), m_ConstantInt(C))) &&
888 match(B, m_Intrinsic<Intrinsic::matrix_transpose>(
891 Value *Add = cast<Instruction>(Builder.CreateFAdd(AT, BT, "mfadd"));
892 setShapeInfo(Add, {C, R});
893 MatrixBuilder MBuilder(Builder);
894 Instruction *NewInst = MBuilder.CreateMatrixTranspose(
895 Add, C->getZExtValue(), R->getZExtValue(), "mfadd_t");
896 updateShapeAndReplaceAllUsesWith(I, NewInst);
897 CleanupBinOp(I, A, B);
898 }
899 }
900
901 /// Try moving transposes in order to fold them away or into multiplies.
902 void optimizeTransposes() {
903 // First sink all transposes inside matmuls and adds, hoping that we end up
904 // with NN, NT or TN variants.
905 for (BasicBlock &BB : reverse(Func)) {
906 for (auto II = BB.rbegin(); II != BB.rend();) {
907 Instruction &I = *II;
908 // We may remove II. By default continue on the next/prev instruction.
909 ++II;
910 if (Instruction *NewInst = sinkTranspose(I, II))
911 II = std::next(BasicBlock::reverse_iterator(NewInst));
912 }
913 }
914
915 // If we have a TT matmul or a TT add, lift the transpose. We may be able
916 // to fold into consuming multiply or add.
917 for (BasicBlock &BB : Func) {
919 liftTranspose(I);
920 }
921 }
922 }
923
924 bool Visit() {
926
927 // Initially only the shape of matrix intrinsics is known.
928 // Initialize the work list with ops carrying shape information.
929 for (BasicBlock &BB : Func)
930 for (Instruction &Inst : BB) {
931 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst);
932 if (!II)
933 continue;
934
935 switch (II->getIntrinsicID()) {
936 case Intrinsic::matrix_multiply:
937 case Intrinsic::matrix_transpose:
938 case Intrinsic::matrix_column_major_load:
939 case Intrinsic::matrix_column_major_store:
940 WorkList.push_back(&Inst);
941 break;
942 default:
943 break;
944 }
945 }
946
947 // Avoid unnecessary work if there are no matrix intrinsics in the function.
948 if (WorkList.empty())
949 return false;
950
951 // Propagate shapes until nothing changes any longer.
952 while (!WorkList.empty()) {
953 WorkList = propagateShapeForward(WorkList);
954 WorkList = propagateShapeBackward(WorkList);
955 }
956
957 if (!isMinimal()) {
958 optimizeTransposes();
960 dbgs() << "Dump after matrix transpose optimization:\n";
961 Func.print(dbgs());
962 }
963 }
964
965 bool Changed = false;
966 SmallVector<CallInst *, 16> MaybeFusableInsts;
968
969 // First, collect all instructions with shape information and candidates for
970 // fusion (currently only matrix multiplies).
972 for (auto *BB : RPOT)
973 for (Instruction &I : *BB) {
974 if (ShapeMap.find(&I) == ShapeMap.end())
975 continue;
976 if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>()))
977 MaybeFusableInsts.push_back(cast<CallInst>(&I));
978 MatrixInsts.push_back(&I);
979 }
980
981 // Second, try to fuse candidates.
983 for (CallInst *CI : MaybeFusableInsts)
984 LowerMatrixMultiplyFused(CI, FusedInsts);
985 Changed = !FusedInsts.empty();
986
987 // Third, lower remaining instructions with shape information.
988 for (Instruction *Inst : MatrixInsts) {
989 if (FusedInsts.count(Inst))
990 continue;
991
992 IRBuilder<> Builder(Inst);
993
994 if (CallInst *CInst = dyn_cast<CallInst>(Inst))
995 Changed |= VisitCallInst(CInst);
996
997 Value *Op1;
998 Value *Op2;
999 if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
1000 Changed |= VisitBinaryOperator(BinOp);
1001 if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
1002 Changed |= VisitUnaryOperator(UnOp);
1003 if (match(Inst, m_Load(m_Value(Op1))))
1004 Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder);
1005 else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
1006 Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);
1007 }
1008
1009 if (ORE) {
1010 RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func);
1011 RemarkGen.emitRemarks();
1012 }
1013
1014 // Delete the instructions backwards, as it has a reduced likelihood of
1015 // having to update as many def-use and use-def chains.
1016 //
1017 // Because we add to ToRemove during fusion we can't guarantee that defs
1018 // are before uses. Change uses to poison temporarily as these should get
1019 // removed as well.
1020 //
1021 // For verification, we keep track of where we changed uses to poison in
1022 // PoisonedInsts and then check that we in fact remove them.
1023 SmallSet<Instruction *, 16> PoisonedInsts;
1024 for (auto *Inst : reverse(ToRemove)) {
1025 for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
1026 if (auto *Poisoned = dyn_cast<Instruction>(U.getUser()))
1027 PoisonedInsts.insert(Poisoned);
1028 U.set(PoisonValue::get(Inst->getType()));
1029 }
1030 Inst->eraseFromParent();
1031 PoisonedInsts.erase(Inst);
1032 }
1033 if (!PoisonedInsts.empty()) {
1034 // If we didn't remove all poisoned instructions, it's a hard error.
1035 dbgs() << "Poisoned but present instructions:\n";
1036 for (auto *I : PoisonedInsts)
1037 dbgs() << *I << "\n";
1038 llvm_unreachable("Poisoned but instruction not removed");
1039 }
1040
1041 return Changed;
1042 }
1043
1044 /// Turns \p BasePtr into an elementwise pointer to \p EltType.
1045 Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) {
1046 unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
1047 Type *EltPtrType = PointerType::get(EltType, AS);
1048 return Builder.CreatePointerCast(BasePtr, EltPtrType);
1049 }
1050
1051 /// Replace intrinsic calls
1052 bool VisitCallInst(CallInst *Inst) {
1053 if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic())
1054 return false;
1055
1056 switch (Inst->getCalledFunction()->getIntrinsicID()) {
1057 case Intrinsic::matrix_multiply:
1058 LowerMultiply(Inst);
1059 break;
1060 case Intrinsic::matrix_transpose:
1061 LowerTranspose(Inst);
1062 break;
1063 case Intrinsic::matrix_column_major_load:
1064 LowerColumnMajorLoad(Inst);
1065 break;
1066 case Intrinsic::matrix_column_major_store:
1067 LowerColumnMajorStore(Inst);
1068 break;
1069 default:
1070 return false;
1071 }
1072 return true;
1073 }
1074
1075 /// Compute the alignment for a column/row \p Idx with \p Stride between them.
1076 /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a
1077 /// ConstantInt, reduce the initial alignment based on the byte offset. For
1078 /// non-ConstantInt strides, return the common alignment of the initial
1079 /// alignment and the element size in bytes.
1080 Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy,
1081 MaybeAlign A) const {
1082 Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy);
1083 if (Idx == 0)
1084 return InitialAlign;
1085
1086 TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy);
1087 if (auto *ConstStride = dyn_cast<ConstantInt>(Stride)) {
1088 uint64_t StrideInBytes =
1089 ConstStride->getZExtValue() * ElementSizeInBits / 8;
1090 return commonAlignment(InitialAlign, Idx * StrideInBytes);
1091 }
1092 return commonAlignment(InitialAlign, ElementSizeInBits / 8);
1093 }
1094
1095 /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between
1096 /// vectors.
1097 MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride,
1098 bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) {
1099 auto *VType = cast<VectorType>(Ty);
1100 Type *EltTy = VType->getElementType();
1101 Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride());
1102 Value *EltPtr = createElementPtr(Ptr, EltTy, Builder);
1103 MatrixTy Result;
1104 for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) {
1105 Value *GEP = computeVectorAddr(
1106 EltPtr, Builder.getIntN(Stride->getType()->getScalarSizeInBits(), I),
1107 Stride, Shape.getStride(), EltTy, Builder);
1108 Value *Vector = Builder.CreateAlignedLoad(
1109 VecTy, GEP, getAlignForIndex(I, Stride, EltTy, MAlign),
1110 IsVolatile, "col.load");
1111
1112 Result.addVector(Vector);
1113 }
1114 return Result.addNumLoads(getNumOps(Result.getVectorTy()) *
1115 Result.getNumVectors());
1116 }
1117
1118 /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix,
1119 /// starting at \p MatrixPtr[I][J].
1120 MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile,
1121 ShapeInfo MatrixShape, Value *I, Value *J,
1122 ShapeInfo ResultShape, Type *EltTy,
1123 IRBuilder<> &Builder) {
1124
1125 Value *Offset = Builder.CreateAdd(
1126 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
1127
1128 unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace();
1129 Value *EltPtr =
1130 Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS));
1131 Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset);
1132 auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows *
1133 ResultShape.NumColumns);
1134 Type *TilePtrTy = PointerType::get(TileTy, AS);
1135 Value *TilePtr =
1136 Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
1137
1138 return loadMatrix(TileTy, TilePtr, Align,
1139 Builder.getInt64(MatrixShape.getStride()), IsVolatile,
1140 ResultShape, Builder);
1141 }
1142
1143 /// Lower a load instruction with shape information.
1144 void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride,
1145 bool IsVolatile, ShapeInfo Shape) {
1146 IRBuilder<> Builder(Inst);
1147 finalizeLowering(Inst,
1148 loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile,
1149 Shape, Builder),
1150 Builder);
1151 }
1152
1153 /// Lowers llvm.matrix.column.major.load.
1154 ///
1155 /// The intrinsic loads a matrix from memory using a stride between columns.
1156 void LowerColumnMajorLoad(CallInst *Inst) {
1157 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
1158 "Intrinsic only supports column-major layout!");
1159 Value *Ptr = Inst->getArgOperand(0);
1160 Value *Stride = Inst->getArgOperand(1);
1161 LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride,
1162 cast<ConstantInt>(Inst->getArgOperand(2))->isOne(),
1163 {Inst->getArgOperand(3), Inst->getArgOperand(4)});
1164 }
1165
1166 /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p
1167 /// MatrixPtr[I][J].
1168 void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr,
1169 MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape,
1170 Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) {
1171 Value *Offset = Builder.CreateAdd(
1172 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
1173
1174 unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace();
1175 Value *EltPtr =
1176 Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS));
1177 Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset);
1178 auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() *
1179 StoreVal.getNumColumns());
1180 Type *TilePtrTy = PointerType::get(TileTy, AS);
1181 Value *TilePtr =
1182 Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
1183
1184 storeMatrix(TileTy, StoreVal, TilePtr, MAlign,
1185 Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder);
1186 }
1187
1188 /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between
1189 /// vectors.
1190 MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr,
1191 MaybeAlign MAlign, Value *Stride, bool IsVolatile,
1192 IRBuilder<> &Builder) {
1193 auto VType = cast<VectorType>(Ty);
1194 Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
1195 for (auto Vec : enumerate(StoreVal.vectors())) {
1196 Value *GEP = computeVectorAddr(
1197 EltPtr,
1198 Builder.getIntN(Stride->getType()->getScalarSizeInBits(),
1199 Vec.index()),
1200 Stride, StoreVal.getStride(), VType->getElementType(), Builder);
1201 Builder.CreateAlignedStore(Vec.value(), GEP,
1202 getAlignForIndex(Vec.index(), Stride,
1203 VType->getElementType(),
1204 MAlign),
1205 IsVolatile);
1206 }
1207 return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) *
1208 StoreVal.getNumVectors());
1209 }
1210
1211 /// Lower a store instruction with shape information.
1213 Value *Stride, bool IsVolatile, ShapeInfo Shape) {
1214 IRBuilder<> Builder(Inst);
1215 auto StoreVal = getMatrix(Matrix, Shape, Builder);
1216 finalizeLowering(Inst,
1217 storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride,
1218 IsVolatile, Builder),
1219 Builder);
1220 }
1221
1222 /// Lowers llvm.matrix.column.major.store.
1223 ///
1224 /// The intrinsic store a matrix back memory using a stride between columns.
1225 void LowerColumnMajorStore(CallInst *Inst) {
1226 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
1227 "Intrinsic only supports column-major layout!");
1228 Value *Matrix = Inst->getArgOperand(0);
1229 Value *Ptr = Inst->getArgOperand(1);
1230 Value *Stride = Inst->getArgOperand(2);
1231 LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride,
1232 cast<ConstantInt>(Inst->getArgOperand(3))->isOne(),
1233 {Inst->getArgOperand(4), Inst->getArgOperand(5)});
1234 }
1235
1236 // Set elements I..I+NumElts-1 to Block
1237 Value *insertVector(Value *Col, unsigned I, Value *Block,
1238 IRBuilder<> &Builder) {
1239
1240 // First, bring Block to the same size as Col
1241 unsigned BlockNumElts =
1242 cast<FixedVectorType>(Block->getType())->getNumElements();
1243 unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements();
1244 assert(NumElts >= BlockNumElts && "Too few elements for current block");
1245
1246 Block = Builder.CreateShuffleVector(
1247 Block, createSequentialMask(0, BlockNumElts, NumElts - BlockNumElts));
1248
1249 // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7,
1250 // 8, 4, 5, 6
1252 unsigned i;
1253 for (i = 0; i < I; i++)
1254 Mask.push_back(i);
1255
1256 unsigned VecNumElts =
1257 cast<FixedVectorType>(Col->getType())->getNumElements();
1258 for (; i < I + BlockNumElts; i++)
1259 Mask.push_back(i - I + VecNumElts);
1260
1261 for (; i < VecNumElts; i++)
1262 Mask.push_back(i);
1263
1264 return Builder.CreateShuffleVector(Col, Block, Mask);
1265 }
1266
1267 Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp,
1268 IRBuilder<> &Builder, bool AllowContraction,
1269 unsigned &NumComputeOps) {
1270 NumComputeOps += getNumOps(A->getType());
1271 if (!Sum)
1272 return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B);
1273
1274 if (UseFPOp) {
1275 if (AllowContraction) {
1276 // Use fmuladd for floating point operations and let the backend decide
1277 // if that's profitable.
1279 Func.getParent(), Intrinsic::fmuladd, A->getType());
1280 return Builder.CreateCall(FMulAdd, {A, B, Sum});
1281 }
1282 NumComputeOps += getNumOps(A->getType());
1283 Value *Mul = Builder.CreateFMul(A, B);
1284 return Builder.CreateFAdd(Sum, Mul);
1285 }
1286
1287 NumComputeOps += getNumOps(A->getType());
1288 Value *Mul = Builder.CreateMul(A, B);
1289 return Builder.CreateAdd(Sum, Mul);
1290 }
1291
1292 /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
1293 /// users with shape information, there's nothing to do: they will use the
1294 /// cached value when they are lowered. For other users, \p Matrix is
1295 /// flattened and the uses are updated to use it. Also marks \p Inst for
1296 /// deletion.
1297 void finalizeLowering(Instruction *Inst, MatrixTy Matrix,
1298 IRBuilder<> &Builder) {
1299 auto inserted = Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix));
1300 (void)inserted;
1301 assert(inserted.second && "multiple matrix lowering mapping");
1302
1303 ToRemove.push_back(Inst);
1304 Value *Flattened = nullptr;
1305 for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
1306 if (ShapeMap.find(U.getUser()) == ShapeMap.end()) {
1307 if (!Flattened)
1308 Flattened = Matrix.embedInVector(Builder);
1309 U.set(Flattened);
1310 }
1311 }
1312 }
1313
1314 /// Compute \p Result += \p A * \p B for input matrices with left-associating
1315 /// addition.
1316 ///
1317 /// We can fold a transpose into the operand that is used to extract scalars.
1318 /// This is the first operands with row-major and the second with
1319 /// column-major. If \p IsScalarMatrixTransposed we assume the appropriate
1320 /// operand is transposed.
1321 void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A,
1322 const MatrixTy &B, IRBuilder<> &Builder, bool IsTiled,
1323 bool IsScalarMatrixTransposed, FastMathFlags FMF) {
1324 const unsigned VF = std::max<unsigned>(
1326 .getFixedValue() /
1327 Result.getElementType()->getPrimitiveSizeInBits().getFixedValue(),
1328 1U);
1329 unsigned R = Result.getNumRows();
1330 unsigned C = Result.getNumColumns();
1331 unsigned M = A.getNumColumns();
1332
1333 bool IsFP = Result.getElementType()->isFloatingPointTy();
1334 assert(A.isColumnMajor() == B.isColumnMajor() &&
1335 Result.isColumnMajor() == A.isColumnMajor() &&
1336 "operands must agree on matrix layout");
1337 unsigned NumComputeOps = 0;
1338
1339 Builder.setFastMathFlags(FMF);
1340
1341 if (A.isColumnMajor()) {
1342 // Multiply columns from the first operand with scalars from the second
1343 // operand. Then move along the K axes and accumulate the columns. With
1344 // this the adds can be vectorized without reassociation.
1345 for (unsigned J = 0; J < C; ++J) {
1346 unsigned BlockSize = VF;
1347 // If Result is zero, we don't need to accumulate in the K==0 iteration.
1348 bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J));
1349
1350 for (unsigned I = 0; I < R; I += BlockSize) {
1351 // Gradually lower the vectorization factor to cover the remainder.
1352 while (I + BlockSize > R)
1353 BlockSize /= 2;
1354
1355 Value *Sum = IsTiled ? Result.extractVector(I, J, BlockSize, Builder)
1356 : nullptr;
1357 for (unsigned K = 0; K < M; ++K) {
1358 Value *L = A.extractVector(I, K, BlockSize, Builder);
1359 Value *RH = Builder.CreateExtractElement(
1360 B.getColumn(IsScalarMatrixTransposed ? K : J),
1361 IsScalarMatrixTransposed ? J : K);
1362 Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat");
1363 Sum =
1364 createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat,
1365 IsFP, Builder, FMF.allowContract(), NumComputeOps);
1366 }
1367 Result.setVector(J,
1368 insertVector(Result.getVector(J), I, Sum, Builder));
1369 }
1370 }
1371 } else {
1372 // Multiply rows from the second operand with scalars from the first
1373 // operand. Then move along the K axes and accumulate the rows. With this
1374 // the adds can be vectorized without reassociation.
1375 for (unsigned I = 0; I < R; ++I) {
1376 unsigned BlockSize = VF;
1377 bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I));
1378 for (unsigned J = 0; J < C; J += BlockSize) {
1379 // Gradually lower the vectorization factor to cover the remainder.
1380 while (J + BlockSize > C)
1381 BlockSize /= 2;
1382
1383 Value *Sum = nullptr;
1384 for (unsigned K = 0; K < M; ++K) {
1385 Value *R = B.extractVector(K, J, BlockSize, Builder);
1386 Value *LH = Builder.CreateExtractElement(
1387 A.getVector(IsScalarMatrixTransposed ? K : I),
1388 IsScalarMatrixTransposed ? I : K);
1389 Value *Splat = Builder.CreateVectorSplat(BlockSize, LH, "splat");
1390 Sum =
1391 createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R,
1392 IsFP, Builder, FMF.allowContract(), NumComputeOps);
1393 }
1394 Result.setVector(I,
1395 insertVector(Result.getVector(I), J, Sum, Builder));
1396 }
1397 }
1398 }
1399 Result.addNumComputeOps(NumComputeOps);
1400 }
1401
1402 /// Ensure that the memory in \p Load does not alias \p Store by potentially
1403 /// copying it to a new location. This new or otherwise the original location
1404 /// is returned.
1405 Value *getNonAliasingPointer(LoadInst *Load, StoreInst *Store,
1406 CallInst *MatMul) {
1407 MemoryLocation StoreLoc = MemoryLocation::get(Store);
1408 MemoryLocation LoadLoc = MemoryLocation::get(Load);
1409
1410 // If we can statically determine noalias we're good.
1411 if (AA->isNoAlias(LoadLoc, StoreLoc))
1412 return Load->getPointerOperand();
1413
1414 // Create code to check if the memory locations of the Load and Store
1415 // overlap and if they do, copy Load's operand to a new buffer.
1416
1417 // First, create new blocks for 2n part of the check and the copy.
1418 BasicBlock *Check0 = MatMul->getParent();
1419 // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a
1420 // DT. Manually collect dominator tree updates, to avoid unnecessary work,
1421 // as we adjust Check0 and Check1's branches.
1423 for (BasicBlock *Succ : successors(Check0))
1424 DTUpdates.push_back({DT->Delete, Check0, Succ});
1425
1426 BasicBlock *Check1 =
1427 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1428 nullptr, "alias_cont");
1429 BasicBlock *Copy =
1430 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1431 nullptr, "copy");
1432 BasicBlock *Fusion =
1433 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1434 nullptr, "no_alias");
1435
1436 // Check if the loaded memory location begins before the end of the store
1437 // location. If the condition holds, they might overlap, otherwise they are
1438 // guaranteed to not overlap.
1439 IRBuilder<> Builder(MatMul);
1440 Check0->getTerminator()->eraseFromParent();
1441 Builder.SetInsertPoint(Check0);
1442 Type *IntPtrTy = Builder.getIntPtrTy(Load->getModule()->getDataLayout());
1443 Value *StoreBegin = Builder.CreatePtrToInt(
1444 const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin");
1445 Value *StoreEnd = Builder.CreateAdd(
1446 StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()),
1447 "store.end", true, true);
1448 Value *LoadBegin = Builder.CreatePtrToInt(const_cast<Value *>(LoadLoc.Ptr),
1449 IntPtrTy, "load.begin");
1450 Builder.CreateCondBr(Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1,
1451 Fusion);
1452
1453 // Check if the store begins before the end of the load location. If the
1454 // condition holds, they alias, otherwise they are guaranteed to not
1455 // overlap.
1456 Check1->getTerminator()->eraseFromParent();
1457 Builder.SetInsertPoint(Check1, Check1->begin());
1458 Value *LoadEnd = Builder.CreateAdd(
1459 LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()),
1460 "load.end", true, true);
1461 Builder.CreateCondBr(Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy,
1462 Fusion);
1463
1464 // Copy load operand to new alloca.
1465 Builder.SetInsertPoint(Copy, Copy->begin());
1466 auto *VT = cast<FixedVectorType>(Load->getType());
1467 // Use an array type for the alloca, to avoid potentially huge alignment
1468 // requirements for large vector types.
1469 auto *ArrayTy = ArrayType::get(VT->getElementType(), VT->getNumElements());
1470 AllocaInst *Alloca =
1471 Builder.CreateAlloca(ArrayTy, Load->getPointerAddressSpace());
1472 Value *BC = Builder.CreateBitCast(Alloca, VT->getPointerTo());
1473
1474 Builder.CreateMemCpy(BC, Alloca->getAlign(), Load->getPointerOperand(),
1475 Load->getAlign(), LoadLoc.Size.getValue());
1476 Builder.SetInsertPoint(Fusion, Fusion->begin());
1477 PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3);
1478 PHI->addIncoming(Load->getPointerOperand(), Check0);
1479 PHI->addIncoming(Load->getPointerOperand(), Check1);
1480 PHI->addIncoming(BC, Copy);
1481
1482 // Adjust DT.
1483 DTUpdates.push_back({DT->Insert, Check0, Check1});
1484 DTUpdates.push_back({DT->Insert, Check0, Fusion});
1485 DTUpdates.push_back({DT->Insert, Check1, Copy});
1486 DTUpdates.push_back({DT->Insert, Check1, Fusion});
1487 DT->applyUpdates(DTUpdates);
1488 return PHI;
1489 }
1490
1491 bool isFusionProfitable(CallInst *MatMul) {
1492 if (ForceFusion)
1493 return true;
1494
1495 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1496 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1497
1498 const unsigned R = LShape.NumRows;
1499 const unsigned C = RShape.NumColumns;
1500 const unsigned M = LShape.NumColumns;
1501 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1502
1503 const unsigned VF = std::max<unsigned>(
1505 .getFixedValue() /
1507 1U);
1508
1509 // Cost model for tiling
1510 //
1511 // For tiling to be beneficial, we need reuse either along the R or
1512 // the C axis. We vectorize along the R axis so that means at least
1513 // 3 elements.
1514 // TODO: Also consider cost of copying if operands alias.
1515 if (R <= VF && C == 1)
1516 return false;
1517 // Then we need enough elements to exceed the number of vector
1518 // registers we have. Note that this is an oversimplification since
1519 // fusing also takes some extra loads which may exceed the number of
1520 // reloads necessary.
1521 unsigned Op0Regs = (R + VF - 1) / VF * M;
1522 unsigned Op1Regs = (M + VF - 1) / VF * C;
1523 return Op0Regs + Op1Regs >
1525 }
1526
1527 MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) {
1528 MatrixTy Res;
1529 auto *ColumType = FixedVectorType::get(EltType, R);
1530 for (unsigned I = 0; I < C; ++I)
1531 Res.addVector(ConstantAggregateZero::get(ColumType));
1532 return Res;
1533 }
1534
1535 void createTiledLoops(CallInst *MatMul, Value *LPtr, ShapeInfo LShape,
1536 Value *RPtr, ShapeInfo RShape, StoreInst *Store) {
1537 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1538
1539 // Create the main tiling loop nest.
1540 TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns, TileSize);
1541 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
1542 Instruction *InsertI = cast<Instruction>(MatMul);
1543 BasicBlock *Start = InsertI->getParent();
1544 BasicBlock *End =
1545 SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue");
1546 IRBuilder<> Builder(MatMul);
1547 BasicBlock *InnerBody = TI.CreateTiledLoops(Start, End, Builder, DTU, *LI);
1548
1549 Type *TileVecTy =
1551 MatrixTy TileResult;
1552 // Insert in the inner loop header.
1553 Builder.SetInsertPoint(TI.KLoop.Header->getTerminator());
1554 // Create PHI nodes for the result columns to accumulate across iterations.
1555 SmallVector<PHINode *, 4> ColumnPhis;
1556 for (unsigned I = 0; I < TileSize; I++) {
1557 auto *Phi = Builder.CreatePHI(TileVecTy, 2, "result.vec." + Twine(I));
1558 Phi->addIncoming(ConstantAggregateZero::get(TileVecTy),
1559 TI.RowLoop.Header->getSingleSuccessor());
1560 TileResult.addVector(Phi);
1561 ColumnPhis.push_back(Phi);
1562 }
1563
1564 // Insert in the inner loop body, which computes
1565 // Res += Load(CurrentRow, K) * Load(K, CurrentColumn)
1566 Builder.SetInsertPoint(InnerBody->getTerminator());
1567 // Load tiles of the operands.
1568 MatrixTy A =
1569 loadMatrix(LPtr, {}, false, LShape, TI.RowLoop.Index, TI.KLoop.Index,
1570 {TileSize, TileSize}, EltType, Builder);
1571 MatrixTy B =
1572 loadMatrix(RPtr, {}, false, RShape, TI.KLoop.Index, TI.ColumnLoop.Index,
1573 {TileSize, TileSize}, EltType, Builder);
1574 emitMatrixMultiply(TileResult, A, B, Builder, true, false,
1575 getFastMathFlags(MatMul));
1576 // Store result after the inner loop is done.
1577 Builder.SetInsertPoint(TI.RowLoop.Latch->getTerminator());
1578 storeMatrix(TileResult, Store->getPointerOperand(), Store->getAlign(),
1579 Store->isVolatile(), {LShape.NumRows, RShape.NumColumns},
1580 TI.RowLoop.Index, TI.ColumnLoop.Index, EltType, Builder);
1581
1582 for (unsigned I = 0; I < TileResult.getNumVectors(); I++)
1583 ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.KLoop.Latch);
1584
1585 // Force unrolling of a few iterations of the inner loop, to make sure there
1586 // is enough work per iteration.
1587 // FIXME: The unroller should make this decision directly instead, but
1588 // currently the cost-model is not up to the task.
1589 unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns / TileSize);
1590 addStringMetadataToLoop(LI->getLoopFor(TI.KLoop.Header),
1591 "llvm.loop.unroll.count", InnerLoopUnrollCount);
1592 }
1593
1594 void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1,
1595 StoreInst *Store,
1596 SmallPtrSetImpl<Instruction *> &FusedInsts) {
1597 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
1598 "Tiling only supported for column-major matrixes at the moment!");
1599 if (!isFusionProfitable(MatMul))
1600 return;
1601
1602 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1603 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1604
1605 const unsigned R = LShape.NumRows;
1606 const unsigned C = RShape.NumColumns;
1607 const unsigned M = LShape.NumColumns;
1608 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1609
1610 Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul);
1611 Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul);
1612 Value *CPtr = Store->getPointerOperand();
1613
1614 if (TileUseLoops && (R % TileSize == 0 && C % TileSize == 0))
1615 createTiledLoops(MatMul, APtr, LShape, BPtr, RShape, Store);
1616 else {
1617 IRBuilder<> Builder(Store);
1618 for (unsigned J = 0; J < C; J += TileSize)
1619 for (unsigned I = 0; I < R; I += TileSize) {
1620 const unsigned TileR = std::min(R - I, unsigned(TileSize));
1621 const unsigned TileC = std::min(C - J, unsigned(TileSize));
1622 MatrixTy Res = getZeroMatrix(EltType, TileR, TileC);
1623
1624 for (unsigned K = 0; K < M; K += TileSize) {
1625 const unsigned TileM = std::min(M - K, unsigned(TileSize));
1626 MatrixTy A =
1627 loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(),
1628 LShape, Builder.getInt64(I), Builder.getInt64(K),
1629 {TileR, TileM}, EltType, Builder);
1630 MatrixTy B =
1631 loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(),
1632 RShape, Builder.getInt64(K), Builder.getInt64(J),
1633 {TileM, TileC}, EltType, Builder);
1634 emitMatrixMultiply(Res, A, B, Builder, true, false,
1635 getFastMathFlags(MatMul));
1636 }
1637 storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M},
1638 Builder.getInt64(I), Builder.getInt64(J), EltType,
1639 Builder);
1640 }
1641 }
1642
1643 // Mark eliminated instructions as fused and remove them.
1644 FusedInsts.insert(Store);
1645 FusedInsts.insert(MatMul);
1646 Store->eraseFromParent();
1647 MatMul->eraseFromParent();
1648 if (LoadOp0->hasNUses(0)) {
1649 FusedInsts.insert(LoadOp0);
1650 LoadOp0->eraseFromParent();
1651 }
1652 if (LoadOp1 != LoadOp0 && LoadOp1->hasNUses(0)) {
1653 FusedInsts.insert(LoadOp1);
1654 LoadOp1->eraseFromParent();
1655 }
1656 }
1657
1658 /// Try to lower matrix multiply chains by fusing operations.
1659 ///
1660 /// Call finalizeLowering on lowered instructions. Instructions that are
1661 /// completely eliminated by fusion are added to \p FusedInsts.
1662 void LowerMatrixMultiplyFused(CallInst *MatMul,
1663 SmallPtrSetImpl<Instruction *> &FusedInsts) {
1664 if (!FuseMatrix || !DT)
1665 return;
1666
1667 assert(AA && LI && "Analyses should be available");
1668
1669 Value *A = MatMul->getArgOperand(0);
1670 Value *B = MatMul->getArgOperand(1);
1671
1672 // We can fold the transpose into the operand that is used to fetch scalars.
1673 Value *T;
1674 if (MatrixLayout == MatrixLayoutTy::ColumnMajor
1675 ? match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))
1676 : match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))) {
1677 IRBuilder<> Builder(MatMul);
1678 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1679 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1680 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1681 const unsigned R = LShape.NumRows;
1682 const unsigned M = LShape.NumColumns;
1683 const unsigned C = RShape.NumColumns;
1684
1685 MatrixTy MA;
1686 MatrixTy MB;
1687
1688 Value *Transpose;
1689 if (MatrixLayout == MatrixLayoutTy::ColumnMajor) {
1690 MA = getMatrix(A, ShapeInfo(R, M), Builder);
1691 MB = getMatrix(T, ShapeInfo(C, M), Builder);
1692 Transpose = B;
1693 } else {
1694 MA = getMatrix(T, ShapeInfo(R, M), Builder);
1695 MB = getMatrix(B, ShapeInfo(C, M), Builder);
1696 Transpose = A;
1697 }
1698
1699 // Initialize the output
1700 MatrixTy Result(R, C, EltType);
1701
1702 emitMatrixMultiply(Result, MA, MB, Builder, false, true,
1703 getFastMathFlags(MatMul));
1704
1705 FusedInsts.insert(MatMul);
1706 if (Transpose->hasOneUse()) {
1707 FusedInsts.insert(cast<Instruction>(Transpose));
1708 ToRemove.push_back(cast<Instruction>(Transpose));
1709 // TODO: add a fake entry for the folded instruction so that this is
1710 // included in the expression in the remark.
1711 Inst2ColumnMatrix[Transpose] = MatrixTy(M, C, EltType);
1712 }
1713 finalizeLowering(MatMul, Result, Builder);
1714 return;
1715 }
1716
1717 if (!MatMul->hasOneUse() || MatrixLayout != MatrixLayoutTy::ColumnMajor)
1718 return;
1719
1720 // Lower {ld, ld} -> matmul -> st chains. No need to call finalizeLowering
1721 // since the single store user will be lowered as part of this.
1722 auto *LoadOp0 = dyn_cast<LoadInst>(A);
1723 auto *LoadOp1 = dyn_cast<LoadInst>(B);
1724 auto *Store = dyn_cast<StoreInst>(*MatMul->user_begin());
1725 if (LoadOp0 && LoadOp1 && Store) {
1726 // The store address must dominate the MatMul instruction, otherwise
1727 // we create invalid IR.
1728 SetVector<Value *> WorkList;
1729 WorkList.insert(Store->getOperand(1));
1731 for (unsigned I = 0; I != WorkList.size(); ++I) {
1732 Value *Current = WorkList[I];
1733 auto *CurrI = dyn_cast<Instruction>(Current);
1734 if (!CurrI)
1735 continue;
1736 if (isa<PHINode>(CurrI))
1737 return;
1738 if (DT->dominates(CurrI, MatMul))
1739 continue;
1740 if (CurrI->mayHaveSideEffects() || CurrI->mayReadFromMemory())
1741 return;
1742 ToHoist.push_back(CurrI);
1743 WorkList.insert(CurrI->op_begin(), CurrI->op_end());
1744 }
1745
1746 sort(ToHoist, [this](Instruction *A, Instruction *B) {
1747 return DT->dominates(A, B);
1748 });
1749 for (Instruction *I : ToHoist)
1750 I->moveBefore(MatMul);
1751
1752 emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts);
1753 return;
1754 }
1755 }
1756
1757 /// Lowers llvm.matrix.multiply.
1758 void LowerMultiply(CallInst *MatMul) {
1759 IRBuilder<> Builder(MatMul);
1760 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1761 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1762 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1763
1764 const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder);
1765 const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder);
1766 assert(Lhs.getElementType() == Rhs.getElementType() &&
1767 "Matrix multiply argument element types do not match.");
1768
1769 const unsigned R = LShape.NumRows;
1770 const unsigned C = RShape.NumColumns;
1771 assert(LShape.NumColumns == RShape.NumRows);
1772
1773 // Initialize the output
1774 MatrixTy Result(R, C, EltType);
1775 assert(Lhs.getElementType() == Result.getElementType() &&
1776 "Matrix multiply result element type does not match arguments.");
1777
1778 emitMatrixMultiply(Result, Lhs, Rhs, Builder, false, false,
1779 getFastMathFlags(MatMul));
1780 finalizeLowering(MatMul, Result, Builder);
1781 }
1782
1783 /// Lowers llvm.matrix.transpose.
1784 void LowerTranspose(CallInst *Inst) {
1785 MatrixTy Result;
1786 IRBuilder<> Builder(Inst);
1787 Value *InputVal = Inst->getArgOperand(0);
1788 VectorType *VectorTy = cast<VectorType>(InputVal->getType());
1789 ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2));
1790 MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);
1791
1792 const unsigned NewNumVecs =
1793 InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns;
1794 const unsigned NewNumElts =
1795 InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows;
1796
1797 for (unsigned I = 0; I < NewNumVecs; ++I) {
1798 // Build a single result vector. First initialize it.
1799 Value *ResultVector = PoisonValue::get(
1800 FixedVectorType::get(VectorTy->getElementType(), NewNumElts));
1801 // Go through the old elements and insert it into the resulting vector.
1802 for (auto J : enumerate(InputMatrix.vectors())) {
1803 Value *Elt = Builder.CreateExtractElement(J.value(), I);
1804 // Row and column indices are transposed.
1805 ResultVector =
1806 Builder.CreateInsertElement(ResultVector, Elt, J.index());
1807 }
1808 Result.addVector(ResultVector);
1809 }
1810
1811 // TODO: Improve estimate of operations needed for transposes. Currently we
1812 // just count the insertelement/extractelement instructions, but do not
1813 // account for later simplifications/combines.
1814 finalizeLowering(
1815 Inst,
1816 Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns)
1817 .addNumExposedTransposes(1),
1818 Builder);
1819 }
1820
1821 /// Lower load instructions, if shape information is available.
1822 bool VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) {
1823 auto I = ShapeMap.find(Inst);
1824 if (I == ShapeMap.end())
1825 return false;
1826
1827 LowerLoad(Inst, Ptr, Inst->getAlign(),
1828 Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
1829 I->second);
1830 return true;
1831 }
1832
1833 bool VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr,
1834 IRBuilder<> &Builder) {
1835 auto I = ShapeMap.find(StoredVal);
1836 if (I == ShapeMap.end())
1837 return false;
1838
1839 LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
1840 Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
1841 I->second);
1842 return true;
1843 }
1844
1845 /// Lower binary operators, if shape information is available.
1846 bool VisitBinaryOperator(BinaryOperator *Inst) {
1847 auto I = ShapeMap.find(Inst);
1848 if (I == ShapeMap.end())
1849 return false;
1850
1851 Value *Lhs = Inst->getOperand(0);
1852 Value *Rhs = Inst->getOperand(1);
1853
1854 IRBuilder<> Builder(Inst);
1855 ShapeInfo &Shape = I->second;
1856
1857 MatrixTy Result;
1858 MatrixTy A = getMatrix(Lhs, Shape, Builder);
1859 MatrixTy B = getMatrix(Rhs, Shape, Builder);
1860 assert(A.isColumnMajor() == B.isColumnMajor() &&
1861 Result.isColumnMajor() == A.isColumnMajor() &&
1862 "operands must agree on matrix layout");
1863
1864 Builder.setFastMathFlags(getFastMathFlags(Inst));
1865
1866 // Helper to perform binary op on vectors.
1867 auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) {
1868 switch (Inst->getOpcode()) {
1869 case Instruction::Add:
1870 return Builder.CreateAdd(LHS, RHS);
1871 case Instruction::Mul:
1872 return Builder.CreateMul(LHS, RHS);
1873 case Instruction::Sub:
1874 return Builder.CreateSub(LHS, RHS);
1875 case Instruction::FAdd:
1876 return Builder.CreateFAdd(LHS, RHS);
1877 case Instruction::FMul:
1878 return Builder.CreateFMul(LHS, RHS);
1879 case Instruction::FSub:
1880 return Builder.CreateFSub(LHS, RHS);
1881 default:
1882 llvm_unreachable("Unsupported binary operator for matrix");
1883 }
1884 };
1885
1886 for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
1887 Result.addVector(BuildVectorOp(A.getVector(I), B.getVector(I)));
1888
1889 finalizeLowering(Inst,
1890 Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
1891 Result.getNumVectors()),
1892 Builder);
1893 return true;
1894 }
1895
1896 /// Lower unary operators, if shape information is available.
1897 bool VisitUnaryOperator(UnaryOperator *Inst) {
1898 auto I = ShapeMap.find(Inst);
1899 if (I == ShapeMap.end())
1900 return false;
1901
1902 Value *Op = Inst->getOperand(0);
1903
1904 IRBuilder<> Builder(Inst);
1905 ShapeInfo &Shape = I->second;
1906
1907 MatrixTy Result;
1908 MatrixTy M = getMatrix(Op, Shape, Builder);
1909
1910 Builder.setFastMathFlags(getFastMathFlags(Inst));
1911
1912 // Helper to perform unary op on vectors.
1913 auto BuildVectorOp = [&Builder, Inst](Value *Op) {
1914 switch (Inst->getOpcode()) {
1915 case Instruction::FNeg:
1916 return Builder.CreateFNeg(Op);
1917 default:
1918 llvm_unreachable("Unsupported unary operator for matrix");
1919 }
1920 };
1921
1922 for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
1923 Result.addVector(BuildVectorOp(M.getVector(I)));
1924
1925 finalizeLowering(Inst,
1926 Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
1927 Result.getNumVectors()),
1928 Builder);
1929 return true;
1930 }
1931
1932 /// Helper to linearize a matrix expression tree into a string. Currently
1933 /// matrix expressions are linarized by starting at an expression leaf and
1934 /// linearizing bottom up.
1935 struct ExprLinearizer {
1936 unsigned LengthToBreak = 100;
1937 std::string Str;
1938 raw_string_ostream Stream;
1939 unsigned LineLength = 0;
1940 const DataLayout &DL;
1941
1942 /// Mapping from instructions to matrixes. It is used to identify
1943 /// matrix instructions.
1944 const MapVector<Value *, MatrixTy> &Inst2Matrix;
1945
1946 /// Mapping from values to the leaves of all expressions that the value is
1947 /// part of.
1949
1950 /// Set of matrix expressions in the scope of a given DISubprogram.
1951 const SmallSetVector<Value *, 32> &ExprsInSubprogram;
1952
1953 /// Leaf node of the expression to linearize.
1954 Value *Leaf;
1955
1956 /// Used to keep track of sub-expressions that get reused while linearizing
1957 /// the expression. Re-used sub-expressions are marked as (reused).
1958 SmallPtrSet<Value *, 8> ReusedExprs;
1959
1960 ExprLinearizer(const DataLayout &DL,
1961 const MapVector<Value *, MatrixTy> &Inst2Matrix,
1962 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
1963 const SmallSetVector<Value *, 32> &ExprsInSubprogram,
1964 Value *Leaf)
1965 : Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared),
1966 ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {}
1967
1968 void indent(unsigned N) {
1969 LineLength += N;
1970 for (unsigned i = 0; i < N; i++)
1971 Stream << " ";
1972 }
1973
1974 void lineBreak() {
1975 Stream << "\n";
1976 LineLength = 0;
1977 }
1978
1979 void maybeIndent(unsigned Indent) {
1980 if (LineLength >= LengthToBreak)
1981 lineBreak();
1982
1983 if (LineLength == 0)
1984 indent(Indent);
1985 }
1986
1987 void write(StringRef S) {
1988 LineLength += S.size();
1989 Stream << S;
1990 }
1991
1992 Value *getUnderlyingObjectThroughLoads(Value *V) {
1993 if (Value *Ptr = getPointerOperand(V))
1994 return getUnderlyingObjectThroughLoads(Ptr);
1995 else if (V->getType()->isPointerTy())
1996 return getUnderlyingObject(V);
1997 return V;
1998 }
1999
2000 /// Returns true if \p V is a matrix value in the given subprogram.
2001 bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); }
2002
2003 /// If \p V is a matrix value, print its shape as as NumRows x NumColumns to
2004 /// \p SS.
2005 void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) {
2006 auto M = Inst2Matrix.find(V);
2007 if (M == Inst2Matrix.end())
2008 SS << "unknown";
2009 else {
2010 SS << M->second.getNumRows();
2011 SS << "x";
2012 SS << M->second.getNumColumns();
2013 }
2014 }
2015
2016 /// Write the called function name. Handles calls to llvm.matrix.*
2017 /// specially: we write the name, followed by the dimensions of the input
2018 /// matrixes, followed by the scalar type name.
2019 void writeFnName(CallInst *CI) {
2020 if (!CI->getCalledFunction())
2021 write("<no called fn>");
2022 else {
2024 if (!Name.startswith("llvm.matrix")) {
2025 write(Name);
2026 return;
2027 }
2028 auto *II = cast<IntrinsicInst>(CI);
2030 .drop_front(StringRef("llvm.matrix.").size()));
2031 write(".");
2032 std::string Tmp;
2034
2035 switch (II->getIntrinsicID()) {
2036 case Intrinsic::matrix_multiply:
2037 prettyPrintMatrixType(II->getOperand(0), SS);
2038 SS << ".";
2039 prettyPrintMatrixType(II->getOperand(1), SS);
2040 SS << "." << *II->getType()->getScalarType();
2041 break;
2042 case Intrinsic::matrix_transpose:
2043 prettyPrintMatrixType(II->getOperand(0), SS);
2044 SS << "." << *II->getType()->getScalarType();
2045 break;
2046 case Intrinsic::matrix_column_major_load:
2047 prettyPrintMatrixType(II, SS);
2048 SS << "." << *II->getType()->getScalarType();
2049 break;
2050 case Intrinsic::matrix_column_major_store:
2051 prettyPrintMatrixType(II->getOperand(0), SS);
2052 SS << "." << *II->getOperand(0)->getType()->getScalarType();
2053 break;
2054 default:
2055 llvm_unreachable("Unhandled case");
2056 }
2057 SS.flush();
2058 write(Tmp);
2059 }
2060 }
2061
2062 unsigned getNumShapeArgs(CallInst *CI) const {
2063 if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) {
2064 switch (II->getIntrinsicID()) {
2065 case Intrinsic::matrix_multiply:
2066 return 3;
2067 case Intrinsic::matrix_transpose:
2068 return 2;
2069 case Intrinsic::matrix_column_major_load:
2070 case Intrinsic::matrix_column_major_store:
2071 return 3;
2072 default:
2073 return 0;
2074 }
2075 }
2076 return 0;
2077 }
2078
2079 /// Special printing for values: for pointers, we print if they refer to an
2080 /// (function) external address or a stack address, for other values we
2081 /// either print the constant or "scalar"/"matrix" for other values.
2082 void write(Value *V) {
2083 V = getUnderlyingObjectThroughLoads(V);
2084 if (V->getType()->isPointerTy()) {
2085 if (isa<AllocaInst>(V)) {
2086 Stream << "stack addr";
2087 LineLength += StringRef("stack addr").size();
2088 } else {
2089 Stream << "addr";
2090 LineLength += StringRef("addr").size();
2091 }
2092 if (!V->getName().empty()) {
2093 Stream << " %" << V->getName() << "";
2094 LineLength += V->getName().size() + 2;
2095 }
2096 return;
2097 }
2098
2099 std::string Tmp;
2100 raw_string_ostream TmpStream(Tmp);
2101
2102 if (auto *CI = dyn_cast<ConstantInt>(V))
2103 TmpStream << CI->getValue();
2104 else if (isa<Constant>(V))
2105 TmpStream << "constant";
2106 else {
2107 if (isMatrix(V))
2108 TmpStream << "matrix";
2109 else
2110 TmpStream << "scalar";
2111 }
2112 TmpStream.flush();
2113 Tmp = std::string(StringRef(Tmp).trim());
2114 LineLength += Tmp.size();
2115 Stream << Tmp;
2116 }
2117
2118 /// Linearize expression \p Expr starting at an indentation of \p Indent.
2119 /// Expressions that are re-used multiple times are prefixed with (reused)
2120 /// at the re-used root instruction.
2121 void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused,
2122 bool ParentShared) {
2123 auto *I = cast<Instruction>(Expr);
2124 maybeIndent(Indent);
2126
2127 // Is Expr shared with other expression leaves?
2128 bool ExprShared = false;
2129
2130 // Deal with shared subtrees. Mark them as shared, if required.
2131 if (!ParentShared) {
2132 auto SI = Shared.find(Expr);
2133 assert(SI != Shared.end() && SI->second.count(Leaf));
2134
2135 for (Value *S : SI->second) {
2136 if (S == Leaf)
2137 continue;
2138 DebugLoc DL = cast<Instruction>(S)->getDebugLoc();
2139 write("shared with remark at line " + std::to_string(DL.getLine()) +
2140 " column " + std::to_string(DL.getCol()) + " (");
2141 }
2142 ExprShared = SI->second.size() > 1;
2143 }
2144
2145 bool Reused = !ReusedExprs.insert(Expr).second;
2146 if (Reused && !ParentReused)
2147 write("(reused) ");
2148
2149 if (auto *CI = dyn_cast<CallInst>(I)) {
2150 writeFnName(CI);
2151
2152 Ops.append(CI->arg_begin(), CI->arg_end() - getNumShapeArgs(CI));
2153 } else if (isa<BitCastInst>(Expr)) {
2154 // Special case bitcasts, which are used to materialize matrixes from
2155 // non-matrix ops.
2156 write("matrix");
2157 return;
2158 } else {
2159 Ops.append(I->value_op_begin(), I->value_op_end());
2160 write(std::string(I->getOpcodeName()));
2161 }
2162
2163 write(std::string("("));
2164
2165 unsigned NumOpsToBreak = 1;
2166 if (match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>()))
2167 NumOpsToBreak = 2;
2168
2169 for (Value *Op : Ops) {
2170 if (Ops.size() > NumOpsToBreak)
2171 lineBreak();
2172
2173 maybeIndent(Indent + 1);
2174 if (isMatrix(Op))
2175 linearizeExpr(Op, Indent + 1, Reused, ExprShared);
2176 else
2177 write(Op);
2178 if (Op != Ops.back())
2179 write(", ");
2180 }
2181
2182 write(")");
2183 }
2184
2185 const std::string &getResult() {
2186 Stream.flush();
2187 return Str;
2188 }
2189 };
2190
2191 /// Generate remarks for matrix operations in a function. To generate remarks
2192 /// for matrix expressions, the following approach is used:
2193 /// 1. Use the inlined-at debug information to group matrix operations to the
2194 /// DISubprograms they are contained in.
2195 /// 2. Collect leaves of matrix expressions (done in
2196 /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression
2197 // mapping. Leaves are lowered matrix instructions without other matrix
2198 // users (like stores) in the current subprogram.
2199 /// 3. For each leaf, create a remark containing a linearizied version of the
2200 /// matrix expression. The expression is linearized by a recursive
2201 /// bottom-up traversal of the matrix operands, starting at a leaf. Note
2202 /// that multiple leaves can share sub-expressions. Shared subexpressions
2203 /// are explicitly marked as shared().
2204 struct RemarkGenerator {
2205 const MapVector<Value *, MatrixTy> &Inst2Matrix;
2207 Function &Func;
2208 const DataLayout &DL;
2209
2210 RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix,
2212 : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func),
2213 DL(Func.getParent()->getDataLayout()) {}
2214
2215 /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are
2216 /// instructions in Inst2Matrix returning void or without any users in
2217 /// \p ExprsInSubprogram. Currently that should only include stores.
2219 getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) {
2221 for (auto *Expr : ExprsInSubprogram)
2222 if (Expr->getType()->isVoidTy() ||
2223 !any_of(Expr->users(), [&ExprsInSubprogram](User *U) {
2224 return ExprsInSubprogram.count(U);
2225 }))
2226 Leaves.push_back(Expr);
2227 return Leaves;
2228 }
2229
2230 /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf
2231 /// to all visited expressions in \p Shared. Limit the matrix operations to
2232 /// the ones in \p ExprsInSubprogram.
2233 void collectSharedInfo(Value *Leaf, Value *V,
2234 const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2236
2237 if (!ExprsInSubprogram.count(V))
2238 return;
2239
2240 auto I = Shared.insert({V, {}});
2241 I.first->second.insert(Leaf);
2242
2243 for (Value *Op : cast<Instruction>(V)->operand_values())
2244 collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared);
2245 }
2246
2247 /// Calculate the number of exclusive and shared op counts for expression
2248 /// starting at \p V. Expressions used multiple times are counted once.
2249 /// Limit the matrix operations to the ones in \p ExprsInSubprogram.
2250 std::pair<OpInfoTy, OpInfoTy>
2251 sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs,
2252 const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2253 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const {
2254 if (!ExprsInSubprogram.count(Root))
2255 return {};
2256
2257 // Already counted this expression. Stop.
2258 if (!ReusedExprs.insert(Root).second)
2259 return {};
2260
2261 OpInfoTy SharedCount;
2262 OpInfoTy Count;
2263
2264 auto I = Shared.find(Root);
2265 auto CM = Inst2Matrix.find(Root);
2266 if (I->second.size() == 1)
2267 Count = CM->second.getOpInfo();
2268 else
2269 SharedCount = CM->second.getOpInfo();
2270
2271 for (Value *Op : cast<Instruction>(Root)->operand_values()) {
2272 auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared);
2273 Count += C.first;
2274 SharedCount += C.second;
2275 }
2276 return {Count, SharedCount};
2277 }
2278
2279 void emitRemarks() {
2281 return;
2282
2283 // Map matrix operations to their containting subprograms, by traversing
2284 // the inlinedAt chain. If the function does not have a DISubprogram, we
2285 // only map them to the containing function.
2287 for (const auto &KV : Inst2Matrix) {
2288 if (Func.getSubprogram()) {
2289 auto *I = cast<Instruction>(KV.first);
2290 DILocation *Context = I->getDebugLoc();
2291 while (Context) {
2292 auto I =
2293 Subprog2Exprs.insert({getSubprogram(Context->getScope()), {}});
2294 I.first->second.push_back(KV.first);
2295 Context = DebugLoc(Context).getInlinedAt();
2296 }
2297 } else {
2298 auto I = Subprog2Exprs.insert({nullptr, {}});
2299 I.first->second.push_back(KV.first);
2300 }
2301 }
2302 for (auto &KV : Subprog2Exprs) {
2303 SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(),
2304 KV.second.end());
2305 auto Leaves = getExpressionLeaves(ExprsInSubprogram);
2306
2308 for (Value *Leaf : Leaves)
2309 collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared);
2310
2311 // Generate remarks for each leaf.
2312 for (auto *L : Leaves) {
2313
2314 DebugLoc Loc = cast<Instruction>(L)->getDebugLoc();
2315 DILocation *Context = cast<Instruction>(L)->getDebugLoc();
2316 while (Context) {
2317 if (getSubprogram(Context->getScope()) == KV.first) {
2318 Loc = Context;
2319 break;
2320 }
2321 Context = DebugLoc(Context).getInlinedAt();
2322 }
2323
2324 SmallPtrSet<Value *, 8> ReusedExprs;
2325 OpInfoTy Counts, SharedCounts;
2326 std::tie(Counts, SharedCounts) =
2327 sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared);
2328
2329 OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc,
2330 cast<Instruction>(L)->getParent());
2331
2332 Rem << "Lowered with ";
2333 Rem << ore::NV("NumStores", Counts.NumStores) << " stores, "
2334 << ore::NV("NumLoads", Counts.NumLoads) << " loads, "
2335 << ore::NV("NumComputeOps", Counts.NumComputeOps)
2336 << " compute ops, "
2337 << ore::NV("NumExposedTransposes", Counts.NumExposedTransposes)
2338 << " exposed transposes";
2339
2340 if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 ||
2341 SharedCounts.NumComputeOps > 0) {
2342 Rem << ",\nadditionally "
2343 << ore::NV("NumStores", SharedCounts.NumStores) << " stores, "
2344 << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, "
2345 << ore::NV("NumFPOps", SharedCounts.NumComputeOps)
2346 << " compute ops"
2347 << " are shared with other expressions";
2348 }
2349
2350 Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL));
2351 ORE.emit(Rem);
2352 }
2353 }
2354 }
2355
2356 std::string
2357 linearize(Value *L,
2358 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
2359 const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2360 const DataLayout &DL) {
2361 ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L);
2362 Lin.linearizeExpr(L, 0, false, false);
2363 return Lin.getResult();
2364 }
2365 };
2366};
2367} // namespace
2368
2371 auto &TTI = AM.getResult<TargetIRAnalysis>(F);
2372 OptimizationRemarkEmitter *ORE = nullptr;
2373 AAResults *AA = nullptr;
2374 DominatorTree *DT = nullptr;
2375 LoopInfo *LI = nullptr;
2376
2377 if (!Minimal) {
2379 AA = &AM.getResult<AAManager>(F);
2381 LI = &AM.getResult<LoopAnalysis>(F);
2382 }
2383
2384 LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE);
2385 if (LMT.Visit()) {
2387 if (!Minimal) {
2388 PA.preserve<LoopAnalysis>();
2390 }
2391 return PA;
2392 }
2393 return PreservedAnalyses::all();
2394}
2395
2397 raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
2399 OS, MapClassName2PassName);
2400 OS << '<';
2401 if (Minimal)
2402 OS << "minimal";
2403 OS << '>';
2404}
2405
2406namespace {
2407
2408class LowerMatrixIntrinsicsLegacyPass : public FunctionPass {
2409public:
2410 static char ID;
2411
2412 LowerMatrixIntrinsicsLegacyPass() : FunctionPass(ID) {
2415 }
2416
2417 bool runOnFunction(Function &F) override {
2418 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
2419 auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
2420 auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
2421 auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
2422 auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
2423 LowerMatrixIntrinsics LMT(F, TTI, &AA, &DT, &LI, &ORE);
2424 bool C = LMT.Visit();
2425 return C;
2426 }
2427
2428 void getAnalysisUsage(AnalysisUsage &AU) const override {
2436 }
2437};
2438} // namespace
2439
2440static const char pass_name[] = "Lower the matrix intrinsics";
2441char LowerMatrixIntrinsicsLegacyPass::ID = 0;
2442INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,
2443 false, false)
2448INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,
2450
2452 return new LowerMatrixIntrinsicsLegacyPass();
2453}
2454
2455namespace {
2456
2457/// A lightweight version of the matrix lowering pass that only requires TTI.
2458/// Advanced features that require DT, AA or ORE like tiling are disabled. This
2459/// is used to lower matrix intrinsics if the main lowering pass is not run, for
2460/// example with -O0.
2461class LowerMatrixIntrinsicsMinimalLegacyPass : public FunctionPass {
2462public:
2463 static char ID;
2464
2465 LowerMatrixIntrinsicsMinimalLegacyPass() : FunctionPass(ID) {
2468 }
2469
2470 bool runOnFunction(Function &F) override {
2471 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
2472 LowerMatrixIntrinsics LMT(F, TTI, nullptr, nullptr, nullptr, nullptr);
2473 bool C = LMT.Visit();
2474 return C;
2475 }
2476
2477 void getAnalysisUsage(AnalysisUsage &AU) const override {
2479 AU.setPreservesCFG();
2480 }
2481};
2482} // namespace
2483
2484static const char pass_name_minimal[] = "Lower the matrix intrinsics (minimal)";
2485char LowerMatrixIntrinsicsMinimalLegacyPass::ID = 0;
2486INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsMinimalLegacyPass,
2487 "lower-matrix-intrinsics-minimal", pass_name_minimal,
2488 false, false)
2489INITIALIZE_PASS_END(LowerMatrixIntrinsicsMinimalLegacyPass,
2491 false)
2492
2494 return new LowerMatrixIntrinsicsMinimalLegacyPass();
2495}
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Rewrite undef for PHI
ReachingDefAnalysis InstSet & ToRemove
assume Assume Builder
static const Function * getParent(const Value *V)
BitTracker BT
Definition: BitTracker.cpp:73
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
#define clEnumValN(ENUMVAL, FLAGNAME, DESC)
Definition: CommandLine.h:678
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
#define LLVM_DEBUG(X)
Definition: Debug.h:101
std::string Name
static bool runOnFunction(Function &F, bool PostInlining)
expand Expand reduction intrinsics
#define DEBUG_TYPE
Hexagon Common GEP
hexagon Hexagon specific predictive commoning for HVX vectors
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
iv users
Definition: IVUsers.cpp:48
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Definition: Lint.cpp:524
Live Register Matrix
static DISubprogram * getSubprogram(DIScope *Scope)
Helper function to either return Scope, if it is a subprogram or the attached subprogram for a local ...
static cl::opt< bool > ForceFusion("force-fuse-matrix", cl::init(false), cl::Hidden, cl::desc("Force matrix instruction fusion even if not profitable."))
static void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II, BasicBlock &BB)
Erase V from BB and move \II forward to avoid invalidating iterators.
static bool isSplat(Value *V)
Return true if V is a splat of a value (which is used when multiplying a matrix with a scalar).
static cl::opt< bool > TileUseLoops("fuse-matrix-use-loops", cl::init(false), cl::Hidden, cl::desc("Generate loop nest for tiling."))
static cl::opt< bool > FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden, cl::desc("Enable/disable fusing matrix instructions."))
auto m_AnyAdd(const LTy &L, const RTy &R)
Match any add operation (fp or integer).
lower matrix intrinsics minimal
static const char pass_name[]
static cl::opt< bool > AllowContractEnabled("matrix-allow-contract", cl::init(false), cl::Hidden, cl::desc("Allow the use of FMAs if available and profitable. This may " "result in different results, due to less rounding error."))
static const char pass_name_minimal[]
auto m_AnyMul(const LTy &L, const RTy &R)
Match any mul operation (fp or integer).
static cl::opt< bool > PrintAfterTransposeOpt("matrix-print-after-transpose-opt", cl::init(false))
#define DEBUG_TYPE
static cl::opt< unsigned > TileSize("fuse-matrix-tile-size", cl::init(4), cl::Hidden, cl::desc("Tile size for matrix instruction fusion using square-shaped tiles."))
static cl::opt< MatrixLayoutTy > MatrixLayout("matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor), cl::desc("Sets the default matrix layout"), cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major", "Use column-major layout"), clEnumValN(MatrixLayoutTy::RowMajor, "row-major", "Use row-major layout")))
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
#define T1
LLVMContext & Context
return ToRemove size() > 0
PowerPC Reduce CR logical Operation
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition: PassSupport.h:55
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:59
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:52
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
@ SI
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
static Value * extractVector(IRBuilderTy &IRB, Value *V, unsigned BeginIndex, unsigned EndIndex, const Twine &Name)
Definition: SROA.cpp:2326
static Value * insertVector(IRBuilderTy &IRB, Value *Old, Value *V, unsigned BeginIndex, const Twine &Name)
Definition: SROA.cpp:2348
raw_pwrite_stream & OS
This file defines the SmallVector class.
static SymbolRef::Type getType(const Symbol *Sym)
Definition: TapiFile.cpp:40
static const int BlockSize
Definition: TarWriter.cpp:33
This pass exposes codegen information to IR-level passes.
static SDValue LowerStore(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)
static SDValue LowerLoad(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)
Value * RHS
Value * LHS
BinaryOperator * Mul
A manager for alias analyses.
A wrapper pass to provide the legacy pass manager access to a suitably prepared AAResults object.
bool isNoAlias(const MemoryLocation &LocA, const MemoryLocation &LocB)
A trivial helper function to check to see if the specified pointers are no-alias.
an instruction to allocate memory on the stack
Definition: Instructions.h:58
Align getAlign() const
Return the alignment of the memory that is being allocated by the instruction.
Definition: Instructions.h:125
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:620
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition: PassManager.h:774
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition: Pass.cpp:265
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: ArrayRef.h:41
LLVM Basic Block Representation.
Definition: BasicBlock.h:56
iterator begin()
Instruction iterator methods.
Definition: BasicBlock.h:314
reverse_iterator rbegin()
Definition: BasicBlock.h:319
InstListType::reverse_iterator reverse_iterator
Definition: BasicBlock.h:89
reverse_iterator rend()
Definition: BasicBlock.h:321
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.h:127
BinaryOps getOpcode() const
Definition: InstrTypes.h:391
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
Definition: InstrTypes.h:1408
User::op_iterator arg_begin()
Return the iterator pointing to the beginning of the argument list.
Definition: InstrTypes.h:1328
MaybeAlign getParamAlign(unsigned ArgNo) const
Extract the alignment for a call or parameter (0=unknown).
Definition: InstrTypes.h:1754
Value * getArgOperand(unsigned i) const
Definition: InstrTypes.h:1353
User::op_iterator arg_end()
Return the iterator pointing to the end of the argument list.
Definition: InstrTypes.h:1334
This class represents a function call, abstracting a target machine's calling convention.
static ConstantAggregateZero * get(Type *Ty)
Definition: Constants.cpp:1594
This is the shared class of boolean and integer constants.
Definition: Constants.h:78
static Constant * get(Type *Ty, uint64_t V, bool IsSigned=false)
If Ty is a vector type, return a Constant with a splat of the given value.
Definition: Constants.cpp:888
DISubprogram * getSubprogram() const
Get the subprogram for this scope.
Debug location.
Base class for scope-like contexts.
Subprogram description.
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:110
A debug info location.
Definition: DebugLoc.h:33
DILocation * getInlinedAt() const
Definition: DebugLoc.cpp:39
Analysis pass which computes a DominatorTree.
Definition: Dominators.h:279
void applyUpdates(ArrayRef< UpdateType > Updates)
Inform the dominator tree about a sequence of CFG edge insertions and deletions and perform a batch u...
static constexpr UpdateKind Delete
static constexpr UpdateKind Insert
Legacy analysis pass which computes a DominatorTree.
Definition: Dominators.h:314
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:166
bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
Definition: Dominators.cpp:122
Convenience struct for specifying and reasoning about fast-math flags.
Definition: FMF.h:21
void setAllowContract(bool B=true)
Definition: FMF.h:92
bool allowContract() const
Definition: FMF.h:71
static FixedVectorType * get(Type *ElementType, unsigned NumElts)
Definition: Type.cpp:704
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:308
Intrinsic::ID getIntrinsicID() const LLVM_READONLY
getIntrinsicID - This method returns the ID number of the specified function, or Intrinsic::not_intri...
Definition: Function.h:204
bool isIntrinsic() const
isIntrinsic - Returns true if the function's name starts with "llvm.".
Definition: Function.h:209
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2558
const BasicBlock * getParent() const
Definition: Instruction.h:90
FastMathFlags getFastMathFlags() const LLVM_READONLY
Convenience function for getting all the fast-math flags, which must be an operator which supports th...
SymbolTableList< Instruction >::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
Definition: Instruction.cpp:82
A wrapper class for inspecting calls to intrinsic functions.
Definition: IntrinsicInst.h:47
Intrinsic::ID getIntrinsicID() const
Return the intrinsic ID of this intrinsic.
Definition: IntrinsicInst.h:54
An instruction for reading from memory.
Definition: Instructions.h:177
bool isVolatile() const
Return true if this is a load from a volatile memory location.
Definition: Instructions.h:214
Align getAlign() const
Return the alignment of the access that is being performed.
Definition: Instructions.h:220
uint64_t getValue() const
Analysis pass that exposes the LoopInfo for a function.
Definition: LoopInfo.h:1268
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
Definition: LoopInfo.h:992
The legacy pass manager's analysis pass to compute loop information.
Definition: LoopInfo.h:1293
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void printPipeline(raw_ostream &OS, function_ref< StringRef(StringRef)> MapClassName2PassName)
This class implements a map that also provides access to all stored values in a deterministic order.
Definition: MapVector.h:37
iterator end()
Definition: MapVector.h:72
iterator find(const KeyT &Key)
Definition: MapVector.h:147
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: MapVector.h:118
Representation for a specific memory location.
static MemoryLocation get(const LoadInst *LI)
Return a location with information about the memory reference by the given instruction.
LocationSize Size
The maximum size of the location, in address-units, or UnknownSize if the size is not known.
const Value * Ptr
The address of the start of the location.
OptimizationRemarkEmitter legacy analysis pass.
The optimization diagnostic interface.
bool allowExtraAnalysis(StringRef PassName) const
Whether we allow for extra compile-time budget to perform more analysis to produce fewer false positi...
void emit(DiagnosticInfoOptimizationBase &OptDiag)
Output the remark via the diagnostic handler and to the optimization record file.
Diagnostic information for applied optimization remarks.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
Pass interface - Implemented by all 'passes'.
Definition: Pass.h:91
static PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
Definition: Constants.cpp:1758
A set of analyses that are preserved following a run of a transformation pass.
Definition: PassManager.h:152
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: PassManager.h:158
void preserve()
Mark an analysis as preserved.
Definition: PassManager.h:173
A vector that has set insertion semantics.
Definition: SetVector.h:40
size_type size() const
Determine the number of elements in the SetVector.
Definition: SetVector.h:77
size_type count(const key_type &key) const
Count the number of elements of a given key in the SetVector.
Definition: SetVector.h:208
bool insert(const value_type &X)
Insert a new element into the SetVector.
Definition: SetVector.h:141
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
Definition: SmallPtrSet.h:344
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
Definition: SmallPtrSet.h:383
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:365
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
Definition: SmallPtrSet.h:450
A SetVector that performs no allocations if smaller than a certain size.
Definition: SetVector.h:301
SmallSet - This maintains a set of unique values, optimizing for the case when the set is small (less...
Definition: SmallSet.h:135
bool empty() const
Definition: SmallSet.h:157
bool erase(const T &V)
Definition: SmallSet.h:205
std::pair< const_iterator, bool > insert(const T &V)
insert - Insert an element into the set if it isn't already there.
Definition: SmallSet.h:177
bool empty() const
Definition: SmallVector.h:94
size_t size() const
Definition: SmallVector.h:91
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: SmallVector.h:577
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
Definition: SmallVector.h:687
void push_back(const T &Elt)
Definition: SmallVector.h:416
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1200
An instruction for storing to memory.
Definition: Instructions.h:301
Align getAlign() const
Definition: Instructions.h:345
bool isVolatile() const
Return true if this is a store to a volatile memory location.
Definition: Instructions.h:337
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:50
StringRef drop_front(size_t N=1) const
Return a StringRef equal to 'this' but with the first N elements dropped.
Definition: StringRef.h:596
constexpr size_t size() const
size - Get the string size.
Definition: StringRef.h:137
Analysis pass providing the TargetTransformInfo.
Wrapper pass for TargetTransformInfo.
This pass provides access to the codegen interfaces that are needed for IR-level transformations.
TypeSize getRegisterBitWidth(RegisterKind K) const
unsigned getRegisterClassForType(bool Vector, Type *Ty=nullptr) const
unsigned getNumberOfRegisters(unsigned ClassID) const
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
Definition: Twine.h:81
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
PointerType * getPointerTo(unsigned AddrSpace=0) const
Return a pointer to the current type.
unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
bool isVoidTy() const
Return true if this is 'void'.
Definition: Type.h:140
Type * getScalarType() const
If this is a vector type, return the element type, otherwise return 'this'.
Definition: Type.h:350
UnaryOps getOpcode() const
Definition: InstrTypes.h:170
static UndefValue * get(Type *T)
Static factory methods - Return an 'undef' object of the specified type.
Definition: Constants.cpp:1739
A Use represents the edge between a Value definition and its users.
Definition: Use.h:43
op_range operands()
Definition: User.h:242
Value * getOperand(unsigned i) const
Definition: User.h:169
See the file comment.
Definition: ValueMap.h:84
size_type count(const KeyT &Val) const
Return 1 if the specified key is in the map, 0 otherwise.
Definition: ValueMap.h:151
iterator find(const KeyT &Val)
Definition: ValueMap.h:155
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: ValueMap.h:172
iterator end()
Definition: ValueMap.h:135
bool erase(const KeyT &Val)
Definition: ValueMap.h:190
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
user_iterator user_begin()
Definition: Value.h:397
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition: Value.h:434
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:532
iterator_range< user_iterator > users()
Definition: Value.h:421
bool hasNUses(unsigned N) const
Return true if this Value has exactly N uses.
Definition: Value.cpp:148
iterator_range< use_iterator > uses()
Definition: Value.h:376
StringRef getName() const
Return a constant reference to the value's name.
Definition: Value.cpp:308
constexpr ScalarTy getFixedValue() const
Definition: TypeSize.h:182
An efficient, type-erasing, non-owning reference to a callable.
A range adaptor for a pair of iterators.
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition: raw_ostream.h:52
A raw_ostream that writes to an std::string.
Definition: raw_ostream.h:642
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
Definition: BitmaskEnum.h:119
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition: CallingConv.h:34
StringRef getBaseName(ID id)
Return the LLVM name for an intrinsic, without encoded types for overloading, such as "llvm....
Definition: Function.cpp:974
Function * getDeclaration(Module *M, ID id, ArrayRef< Type * > Tys=std::nullopt)
Create or insert an LLVM Function declaration for an intrinsic, and return it.
Definition: Function.cpp:1506
TwoOps_match< ValueOpTy, PointerOpTy, Instruction::Store > m_Store(const ValueOpTy &ValueOp, const PointerOpTy &PointerOp)
Matches StoreInst.
BinaryOp_match< LHS, RHS, Instruction::Add > m_Add(const LHS &L, const RHS &R)
Definition: PatternMatch.h:979
BinaryOp_match< LHS, RHS, Instruction::FMul > m_FMul(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
Definition: PatternMatch.h:147
BinaryOp_match< LHS, RHS, Instruction::FAdd > m_FAdd(const LHS &L, const RHS &R)
Definition: PatternMatch.h:985
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
OneOps_match< OpTy, Instruction::Load > m_Load(const OpTy &Op)
Matches LoadInst.
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
Definition: PatternMatch.h:76
match_combine_or< LTy, RTy > m_CombineOr(const LTy &L, const RTy &R)
Combine two pattern matchers matching L || R.
Definition: PatternMatch.h:218
@ SS
Definition: X86.h:209
ValuesClass values(OptsTy... Options)
Helper to build a ValuesClass by forwarding a variable number of arguments as an initializer list to ...
Definition: CommandLine.h:703
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:445
DiagnosticInfoOptimizationBase::Argument NV
const_iterator begin(StringRef path, Style style=Style::native)
Get begin iterator over path.
Definition: Path.cpp:226
const_iterator end(StringRef path)
Get end iterator over path.
Definition: Path.cpp:235
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
@ Offset
Definition: DWP.cpp:406
Error write(MCStreamer &Out, ArrayRef< std::string > Inputs)
Definition: DWP.cpp:551
auto enumerate(FirstRange &&First, RestRanges &&...Rest)
Given two or more input ranges, returns a new range whose values are are tuples (A,...
Definition: STLExtras.h:2387
auto successors(const MachineBasicBlock *BB)
bool operator!=(uint64_t V1, const APInt &V2)
Definition: APInt.h:2047
iterator_range< T > make_range(T x, T y)
Convenience function for iterating over sub-ranges.
std::string & operator+=(std::string &buffer, StringRef string)
Definition: StringRef.h:890
const Value * getUnderlyingObject(const Value *V, unsigned MaxLookup=6)
This method strips off any GEP address adjustments and pointer casts from the specified value,...
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition: STLExtras.h:732
Value * concatenateVectors(IRBuilderBase &Builder, ArrayRef< Value * > Vecs)
Concatenate a list of vectors.
void initializeLowerMatrixIntrinsicsLegacyPassPass(PassRegistry &)
bool operator==(const AddressRangeValuePair &LHS, const AddressRangeValuePair &RHS)
const Value * getPointerOperand(const Value *V)
A helper function that returns the pointer operand of a load, store or GEP instruction.
void addStringMetadataToLoop(Loop *TheLoop, const char *MDString, unsigned V=0)
Set input string into loop metadata by keeping other values intact.
Definition: LoopUtils.cpp:214
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1789
auto reverse(ContainerTy &&C)
Definition: STLExtras.h:495
void sort(IteratorTy Start, IteratorTy End)
Definition: STLExtras.h:1730
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
Pass * createLowerMatrixIntrinsicsPass()
void initializeLowerMatrixIntrinsicsMinimalLegacyPassPass(PassRegistry &)
@ FMulAdd
Fused multiply-add of floats (a * b + c).
@ Add
Sum of integers.
@ FAdd
Sum of floats.
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition: Casting.h:566
Align commonAlignment(Align A, uint64_t Offset)
Returns the alignment that satisfies both alignments.
Definition: Alignment.h:212
BasicBlock * SplitBlock(BasicBlock *Old, Instruction *SplitPt, DominatorTree *DT, LoopInfo *LI=nullptr, MemorySSAUpdater *MSSAU=nullptr, const Twine &BBName="", bool Before=false)
Split the specified block at the specified instruction.
Pass * createLowerMatrixIntrinsicsMinimalPass()
llvm::SmallVector< int, 16 > createSequentialMask(unsigned Start, unsigned NumInts, unsigned NumUndefs)
Create a sequential shuffle mask.
#define N
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition: Alignment.h:39
This struct is a compact representation of a valid (power of two) or undefined (0) alignment.
Definition: Alignment.h:117
A CRTP mix-in to automatically provide informational APIs needed for passes.
Definition: PassManager.h:371
A helper struct to create IR loop nests for tiling in IR of the following form: for ColumnLoop....
Definition: MatrixUtils.h:31