LLVM 18.0.0git
LowerMatrixIntrinsics.cpp
Go to the documentation of this file.
1//===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Lower matrix intrinsics to vector operations.
10//
11// TODO:
12// * Improve fusion:
13// * Support more cases, e.g. multiply-add, multiply-sub, operands/results
14// transposed.
15// * Improve cost-modeling, e.g. choose different number of rows/columns
16// columns for tiles, consider cost of copies on alias.
17//
18//===----------------------------------------------------------------------===//
19
30#include "llvm/IR/CFG.h"
31#include "llvm/IR/DataLayout.h"
33#include "llvm/IR/Function.h"
34#include "llvm/IR/IRBuilder.h"
41#include "llvm/Support/Debug.h"
45
46#include <cmath>
47
48using namespace llvm;
49using namespace PatternMatch;
50
51#define DEBUG_TYPE "lower-matrix-intrinsics"
52
53static cl::opt<bool>
54 FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden,
55 cl::desc("Enable/disable fusing matrix instructions."));
56// TODO: Allow and use non-square tiles.
58 "fuse-matrix-tile-size", cl::init(4), cl::Hidden,
60 "Tile size for matrix instruction fusion using square-shaped tiles."));
61static cl::opt<bool> TileUseLoops("fuse-matrix-use-loops", cl::init(false),
63 cl::desc("Generate loop nest for tiling."));
65 "force-fuse-matrix", cl::init(false), cl::Hidden,
66 cl::desc("Force matrix instruction fusion even if not profitable."));
68 "matrix-allow-contract", cl::init(false), cl::Hidden,
69 cl::desc("Allow the use of FMAs if available and profitable. This may "
70 "result in different results, due to less rounding error."));
71
72static cl::opt<bool>
73 VerifyShapeInfo("verify-matrix-shapes", cl::Hidden,
74 cl::desc("Enable/disable matrix shape verification."),
75 cl::init(false));
76
78
80 "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor),
81 cl::desc("Sets the default matrix layout"),
83 "Use column-major layout"),
85 "Use row-major layout")));
86
87static cl::opt<bool> PrintAfterTransposeOpt("matrix-print-after-transpose-opt",
88 cl::init(false));
89
90/// Helper function to either return Scope, if it is a subprogram or the
91/// attached subprogram for a local scope.
93 if (auto *Subprogram = dyn_cast<DISubprogram>(Scope))
94 return Subprogram;
95 return cast<DILocalScope>(Scope)->getSubprogram();
96}
97
98/// Erase \p V from \p BB and move \II forward to avoid invalidating
99/// iterators.
101 BasicBlock &BB) {
102 auto *Inst = cast<Instruction>(V);
103 // Still used, don't erase.
104 if (!Inst->use_empty())
105 return;
106 if (II != BB.rend() && Inst == &*II)
107 ++II;
108 Inst->eraseFromParent();
109}
110
111/// Return true if V is a splat of a value (which is used when multiplying a
112/// matrix with a scalar).
113static bool isSplat(Value *V) {
114 if (auto *SV = dyn_cast<ShuffleVectorInst>(V))
115 return SV->isZeroEltSplat();
116 return false;
117}
118
119/// Match any mul operation (fp or integer).
120template <typename LTy, typename RTy>
121auto m_AnyMul(const LTy &L, const RTy &R) {
122 return m_CombineOr(m_Mul(L, R), m_FMul(L, R));
123}
124
125/// Match any add operation (fp or integer).
126template <typename LTy, typename RTy>
127auto m_AnyAdd(const LTy &L, const RTy &R) {
128 return m_CombineOr(m_Add(L, R), m_FAdd(L, R));
129}
130
131namespace {
132
133// Given an element pointer \p BasePtr to the start of a (sub) matrix, compute
134// the start address of vector \p VecIdx with type (\p EltType x \p NumElements)
135// assuming \p Stride elements between start two consecutive vectors.
136// \p Stride must be >= \p NumElements.
137// For column-major matrixes, the function computes the address of a column
138// vectors and \p NumElements must be set to the number of elements in a column
139// (= number of rows of the matrix). For row-major matrixes, the function
140// computes the address of a row vector and \p NumElements must be set to the
141// number of elements in a column (= number of columns of the matrix).
142//
143// Consider a 4x4 matrix in column-mjaor layout like below
144//
145// 0 1 2 3
146// 0 v_0_0 v_0_1 v_0_2 v_0_3
147// 1 v_1_0 v_1_1 v_1_2 v_1_3
148// 2 v_2_0 v_2_1 v_2_2 v_2_3
149// 3 v_3_0 v_3_1 v_3_2 v_3_3
150
151// To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1,
152// we need a pointer to the first element of the submatrix as base pointer.
153// Then we can use computeVectorAddr to compute the addresses for the columns
154// of the sub-matrix.
155//
156// Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..)
157// -> just returns Base
158// Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..)
159// -> returns Base + (1 * 4)
160// Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..)
161// -> returns Base + (2 * 4)
162//
163// The graphic below illustrates the number of elements in a column (marked
164// with |) and the number of skipped elements (marked with }).
165//
166// v_0_0 v_0_1 {v_0_2 {v_0_3
167// Base Col 1 Col 2
168// | | |
169// v_1_0 |v_1_1 |v_1_2 |v_1_3
170// v_2_0 |v_2_1 |v_2_2 |v_2_3
171// v_3_0 {v_3_1 {v_3_2 v_3_3
172//
173Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride,
174 unsigned NumElements, Type *EltType,
175 IRBuilder<> &Builder) {
176
177 assert((!isa<ConstantInt>(Stride) ||
178 cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) &&
179 "Stride must be >= the number of elements in the result vector.");
180
181 // Compute the start of the vector with index VecIdx as VecIdx * Stride.
182 Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start");
183
184 // Get pointer to the start of the selected vector. Skip GEP creation,
185 // if we select vector 0.
186 if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->isZero())
187 VecStart = BasePtr;
188 else
189 VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep");
190
191 return VecStart;
192}
193
194/// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
195///
196/// Currently, the lowering for each matrix intrinsic is done as follows:
197/// 1. Propagate the shape information from intrinsics to connected
198/// instructions.
199/// 2. Lower instructions with shape information (assuming column-major layout).
200/// The lowering works similarly using row-major layout.
201/// 2.1. Get column vectors for each argument. If we already lowered the
202/// definition of an argument, use the produced column vectors directly.
203/// If not, split the operand vector containing an embedded matrix into
204/// a set of column vectors,
205/// 2.2. Lower the instruction in terms of column major operations, which
206/// yields a set of column vectors containing result matrix. Note that we
207/// lower all instructions that have shape information. Besides the
208/// intrinsics, this includes stores for example.
209/// 2.3. Update uses of the lowered instruction. If we have shape information
210/// for a user, there is nothing to do, as we will look up the result
211/// column matrix when lowering the user. For other uses, we embed the
212/// result matrix in a flat vector and update the use.
213/// 2.4. Cache the result column matrix for the instruction we lowered
214/// 3. After we lowered all instructions in a function, remove the now
215/// obsolete instructions.
216///
217class LowerMatrixIntrinsics {
218 Function &Func;
219 const DataLayout &DL;
221 AliasAnalysis *AA;
222 DominatorTree *DT;
223 LoopInfo *LI;
225
226 /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation.
227 struct OpInfoTy {
228 /// Number of stores emitted to generate this matrix.
229 unsigned NumStores = 0;
230 /// Number of loads emitted to generate this matrix.
231 unsigned NumLoads = 0;
232 /// Number of compute operations emitted to generate this matrix.
233 unsigned NumComputeOps = 0;
234 /// Most of the time transposes can be fused with matrix multiplies or can
235 /// be folded away via algebraic simplifications. This is the number of
236 /// transposes that we failed to make "free" via such optimizations.
237 unsigned NumExposedTransposes = 0;
238
239 OpInfoTy &operator+=(const OpInfoTy &RHS) {
240 NumStores += RHS.NumStores;
241 NumLoads += RHS.NumLoads;
242 NumComputeOps += RHS.NumComputeOps;
243 NumExposedTransposes += RHS.NumExposedTransposes;
244 return *this;
245 }
246 };
247
248 /// Wrapper class representing a matrix as a set of vectors, either in row or
249 /// column major layout. All vectors must have the same vector type.
250 class MatrixTy {
252
253 OpInfoTy OpInfo;
254
255 bool IsColumnMajor = true;
256
257 public:
258 MatrixTy() : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
259 MatrixTy(ArrayRef<Value *> Vectors)
260 : Vectors(Vectors.begin(), Vectors.end()),
261 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
262 MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy)
263 : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {
264
265 unsigned D = isColumnMajor() ? NumColumns : NumRows;
266 for (unsigned J = 0; J < D; ++J)
268 EltTy, isColumnMajor() ? NumRows : NumColumns)));
269 }
270
271 Value *getVector(unsigned i) const { return Vectors[i]; }
272 Value *getColumn(unsigned i) const {
273 assert(isColumnMajor() && "only supported for column-major matrixes");
274 return Vectors[i];
275 }
276 Value *getRow(unsigned i) const {
277 assert(!isColumnMajor() && "only supported for row-major matrixes");
278 return Vectors[i];
279 }
280
281 void setVector(unsigned i, Value *V) { Vectors[i] = V; }
282
283 Type *getElementType() const { return getVectorTy()->getElementType(); }
284
285 unsigned getNumVectors() const {
286 if (isColumnMajor())
287 return getNumColumns();
288 return getNumRows();
289 }
290
291 unsigned getNumColumns() const {
292 if (isColumnMajor())
293 return Vectors.size();
294 else {
295 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
296 return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
297 }
298 }
299 unsigned getNumRows() const {
300 if (isColumnMajor()) {
301 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
302 return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
303 } else
304 return Vectors.size();
305 }
306
307 void addVector(Value *V) { Vectors.push_back(V); }
308 VectorType *getColumnTy() {
309 assert(isColumnMajor() && "only supported for column-major matrixes");
310 return getVectorTy();
311 }
312
313 VectorType *getVectorTy() const {
314 return cast<VectorType>(Vectors[0]->getType());
315 }
316
318 assert(isColumnMajor() &&
319 "columns() only supported for column-major matrixes");
320 return make_range(Vectors.begin(), Vectors.end());
321 }
322
324 return make_range(Vectors.begin(), Vectors.end());
325 }
326
327 /// Embed the vectors of the matrix into a flat vector by concatenating
328 /// them.
329 Value *embedInVector(IRBuilder<> &Builder) const {
330 return Vectors.size() == 1 ? Vectors[0]
331 : concatenateVectors(Builder, Vectors);
332 }
333
334 MatrixTy &addNumLoads(unsigned N) {
335 OpInfo.NumLoads += N;
336 return *this;
337 }
338
339 void setNumLoads(unsigned N) { OpInfo.NumLoads = N; }
340
341 MatrixTy &addNumStores(unsigned N) {
342 OpInfo.NumStores += N;
343 return *this;
344 }
345
346 MatrixTy &addNumExposedTransposes(unsigned N) {
347 OpInfo.NumExposedTransposes += N;
348 return *this;
349 }
350
351 MatrixTy &addNumComputeOps(unsigned N) {
352 OpInfo.NumComputeOps += N;
353 return *this;
354 }
355
356 unsigned getNumStores() const { return OpInfo.NumStores; }
357 unsigned getNumLoads() const { return OpInfo.NumLoads; }
358 unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; }
359
360 const OpInfoTy &getOpInfo() const { return OpInfo; }
361
362 bool isColumnMajor() const { return IsColumnMajor; }
363
364 unsigned getStride() const {
365 if (isColumnMajor())
366 return getNumRows();
367 return getNumColumns();
368 }
369
370 /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the
371 /// matrix is column-major, the result vector is extracted from a column
372 /// vector, otherwise from a row vector.
373 Value *extractVector(unsigned I, unsigned J, unsigned NumElts,
374 IRBuilder<> &Builder) const {
375 Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I);
376 assert(cast<FixedVectorType>(Vec->getType())->getNumElements() >=
377 NumElts &&
378 "Extracted vector will contain poison values");
379 return Builder.CreateShuffleVector(
380 Vec, createSequentialMask(isColumnMajor() ? I : J, NumElts, 0),
381 "block");
382 }
383 };
384
385 struct ShapeInfo {
386 unsigned NumRows;
387 unsigned NumColumns;
388
389 bool IsColumnMajor;
390
391 ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0)
392 : NumRows(NumRows), NumColumns(NumColumns),
393 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
394
395 ShapeInfo(Value *NumRows, Value *NumColumns)
396 : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(),
397 cast<ConstantInt>(NumColumns)->getZExtValue()) {}
398
399 bool operator==(const ShapeInfo &other) {
400 return NumRows == other.NumRows && NumColumns == other.NumColumns;
401 }
402 bool operator!=(const ShapeInfo &other) { return !(*this == other); }
403
404 /// Returns true if shape-information is defined, meaning both dimensions
405 /// are != 0.
406 operator bool() const {
407 assert(NumRows == 0 || NumColumns != 0);
408 return NumRows != 0;
409 }
410
411 unsigned getStride() const {
412 if (IsColumnMajor)
413 return NumRows;
414 return NumColumns;
415 }
416
417 unsigned getNumVectors() const {
418 if (IsColumnMajor)
419 return NumColumns;
420 return NumRows;
421 }
422
423 /// Returns the transposed shape.
424 ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); }
425 };
426
427 /// Maps instructions to their shape information. The shape information
428 /// describes the shape to be used while lowering. This matches the shape of
429 /// the result value of the instruction, with the only exceptions being store
430 /// instructions and the matrix_column_major_store intrinsics. For those, the
431 /// shape information indicates that those instructions should be lowered
432 /// using shape information as well. A ValueMap is used so that when
433 /// sub-passes like optimizeTransposes performs RAUW the map stays
434 /// up-to-date.
436
437 /// List of instructions to remove. While lowering, we are not replacing all
438 /// users of a lowered instruction, if shape information is available and
439 /// those need to be removed after we finished lowering.
441
442 /// Map from instructions to their produced column matrix.
443 MapVector<Value *, MatrixTy> Inst2ColumnMatrix;
444
445private:
446 static FastMathFlags getFastMathFlags(Instruction *Inst) {
447 FastMathFlags FMF;
448
449 if (isa<FPMathOperator>(*Inst))
450 FMF = Inst->getFastMathFlags();
451
453
454 return FMF;
455 }
456
457public:
458 LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI,
461 : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), AA(AA), DT(DT),
462 LI(LI), ORE(ORE) {}
463
464 unsigned getNumOps(Type *VT) {
465 assert(isa<VectorType>(VT) && "Expected vector type");
466 return getNumOps(VT->getScalarType(),
467 cast<FixedVectorType>(VT)->getNumElements());
468 }
469
470 /// Is this the minimal version executed in the backend pipelines.
471 bool isMinimal() const {
472 return !DT;
473 }
474
475 /// Return the estimated number of vector ops required for an operation on
476 /// \p VT * N.
477 unsigned getNumOps(Type *ST, unsigned N) {
478 return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedValue() /
481 .getFixedValue()));
482 }
483
484 /// Return the set of vectors that a matrix value is lowered to.
485 ///
486 /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise
487 /// split the flat vector \p MatrixVal containing a matrix with shape \p SI
488 /// into vectors.
489 MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI,
490 IRBuilder<> &Builder) {
491 VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType());
492 assert(VType && "MatrixVal must be a vector type");
493 assert(cast<FixedVectorType>(VType)->getNumElements() ==
494 SI.NumRows * SI.NumColumns &&
495 "The vector size must match the number of matrix elements");
496
497 // Check if we lowered MatrixVal using shape information. In that case,
498 // return the existing matrix, if it matches the requested shape
499 // information. If there is a mis-match, embed the result in a flat
500 // vector and split it later.
501 auto Found = Inst2ColumnMatrix.find(MatrixVal);
502 if (Found != Inst2ColumnMatrix.end()) {
503 MatrixTy &M = Found->second;
504 // Return the found matrix, if its shape matches the requested shape
505 // information
506 if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns())
507 return M;
508
509 MatrixVal = M.embedInVector(Builder);
510 }
511
512 // Otherwise split MatrixVal.
513 SmallVector<Value *, 16> SplitVecs;
514 for (unsigned MaskStart = 0;
515 MaskStart < cast<FixedVectorType>(VType)->getNumElements();
516 MaskStart += SI.getStride()) {
517 Value *V = Builder.CreateShuffleVector(
518 MatrixVal, createSequentialMask(MaskStart, SI.getStride(), 0),
519 "split");
520 SplitVecs.push_back(V);
521 }
522
523 return {SplitVecs};
524 }
525
526 /// If \p V already has a known shape return false. Otherwise set the shape
527 /// for instructions that support it.
528 bool setShapeInfo(Value *V, ShapeInfo Shape) {
529 assert(Shape && "Shape not set");
530 if (isa<UndefValue>(V) || !supportsShapeInfo(V))
531 return false;
532
533 auto SIter = ShapeMap.find(V);
534 if (SIter != ShapeMap.end()) {
535 if (VerifyShapeInfo && (SIter->second.NumRows != Shape.NumRows ||
536 SIter->second.NumColumns != Shape.NumColumns)) {
537 errs() << "Conflicting shapes (" << SIter->second.NumRows << "x"
538 << SIter->second.NumColumns << " vs " << Shape.NumRows << "x"
539 << Shape.NumColumns << ") for " << *V << "\n";
541 "Matrix shape verification failed, compilation aborted!");
542 }
543
544 LLVM_DEBUG(dbgs() << " not overriding existing shape: "
545 << SIter->second.NumRows << " "
546 << SIter->second.NumColumns << " for " << *V << "\n");
547 return false;
548 }
549
550 ShapeMap.insert({V, Shape});
551 LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns
552 << " for " << *V << "\n");
553 return true;
554 }
555
556 bool isUniformShape(Value *V) {
557 Instruction *I = dyn_cast<Instruction>(V);
558 if (!I)
559 return true;
560
561 switch (I->getOpcode()) {
562 case Instruction::FAdd:
563 case Instruction::FSub:
564 case Instruction::FMul: // Scalar multiply.
565 case Instruction::FNeg:
566 case Instruction::Add:
567 case Instruction::Mul:
568 case Instruction::Sub:
569 return true;
570 default:
571 return false;
572 }
573 }
574
575 /// Returns true if shape information can be used for \p V. The supported
576 /// instructions must match the instructions that can be lowered by this pass.
577 bool supportsShapeInfo(Value *V) {
578 Instruction *Inst = dyn_cast<Instruction>(V);
579 if (!Inst)
580 return false;
581
582 IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst);
583 if (II)
584 switch (II->getIntrinsicID()) {
585 case Intrinsic::matrix_multiply:
586 case Intrinsic::matrix_transpose:
587 case Intrinsic::matrix_column_major_load:
588 case Intrinsic::matrix_column_major_store:
589 return true;
590 default:
591 return false;
592 }
593 return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
594 }
595
596 /// Propagate the shape information of instructions to their users.
597 /// The work list contains instructions for which we can compute the shape,
598 /// either based on the information provided by matrix intrinsics or known
599 /// shapes of operands.
601 propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) {
603 // Pop an element for which we guaranteed to have at least one of the
604 // operand shapes. Add the shape for this and then add users to the work
605 // list.
606 LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n");
607 while (!WorkList.empty()) {
608 Instruction *Inst = WorkList.pop_back_val();
609
610 // New entry, set the value and insert operands
611 bool Propagate = false;
612
613 Value *MatrixA;
614 Value *MatrixB;
615 Value *M;
616 Value *N;
617 Value *K;
618 if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>(
619 m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
620 m_Value(N), m_Value(K)))) {
621 Propagate = setShapeInfo(Inst, {M, K});
622 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>(
623 m_Value(MatrixA), m_Value(M), m_Value(N)))) {
624 // Flip dimensions.
625 Propagate = setShapeInfo(Inst, {N, M});
626 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>(
627 m_Value(MatrixA), m_Value(), m_Value(),
628 m_Value(), m_Value(M), m_Value(N)))) {
629 Propagate = setShapeInfo(Inst, {N, M});
630 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>(
631 m_Value(), m_Value(), m_Value(), m_Value(M),
632 m_Value(N)))) {
633 Propagate = setShapeInfo(Inst, {M, N});
634 } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) {
635 auto OpShape = ShapeMap.find(MatrixA);
636 if (OpShape != ShapeMap.end())
637 setShapeInfo(Inst, OpShape->second);
638 continue;
639 } else if (isUniformShape(Inst)) {
640 // Find the first operand that has a known shape and use that.
641 for (auto &Op : Inst->operands()) {
642 auto OpShape = ShapeMap.find(Op.get());
643 if (OpShape != ShapeMap.end()) {
644 Propagate |= setShapeInfo(Inst, OpShape->second);
645 break;
646 }
647 }
648 }
649
650 if (Propagate) {
651 NewWorkList.push_back(Inst);
652 for (auto *User : Inst->users())
653 if (ShapeMap.count(User) == 0)
654 WorkList.push_back(cast<Instruction>(User));
655 }
656 }
657
658 return NewWorkList;
659 }
660
661 /// Propagate the shape to operands of instructions with shape information.
662 /// \p Worklist contains the instruction for which we already know the shape.
664 propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) {
666
667 auto pushInstruction = [](Value *V,
669 Instruction *I = dyn_cast<Instruction>(V);
670 if (I)
671 WorkList.push_back(I);
672 };
673 // Pop an element with known shape. Traverse the operands, if their shape
674 // derives from the result shape and is unknown, add it and add them to the
675 // worklist.
676 LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n");
677 while (!WorkList.empty()) {
678 Value *V = WorkList.pop_back_val();
679
680 size_t BeforeProcessingV = WorkList.size();
681 if (!isa<Instruction>(V))
682 continue;
683
684 Value *MatrixA;
685 Value *MatrixB;
686 Value *M;
687 Value *N;
688 Value *K;
689 if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>(
690 m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
691 m_Value(N), m_Value(K)))) {
692 if (setShapeInfo(MatrixA, {M, N}))
693 pushInstruction(MatrixA, WorkList);
694
695 if (setShapeInfo(MatrixB, {N, K}))
696 pushInstruction(MatrixB, WorkList);
697
698 } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>(
699 m_Value(MatrixA), m_Value(M), m_Value(N)))) {
700 // Flip dimensions.
701 if (setShapeInfo(MatrixA, {M, N}))
702 pushInstruction(MatrixA, WorkList);
703 } else if (match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>(
704 m_Value(MatrixA), m_Value(), m_Value(), m_Value(),
705 m_Value(M), m_Value(N)))) {
706 if (setShapeInfo(MatrixA, {M, N})) {
707 pushInstruction(MatrixA, WorkList);
708 }
709 } else if (isa<LoadInst>(V) ||
710 match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) {
711 // Nothing to do, no matrix input.
712 } else if (isa<StoreInst>(V)) {
713 // Nothing to do. We forward-propagated to this so we would just
714 // backward propagate to an instruction with an already known shape.
715 } else if (isUniformShape(V)) {
716 // Propagate to all operands.
717 ShapeInfo Shape = ShapeMap[V];
718 for (Use &U : cast<Instruction>(V)->operands()) {
719 if (setShapeInfo(U.get(), Shape))
720 pushInstruction(U.get(), WorkList);
721 }
722 }
723 // After we discovered new shape info for new instructions in the
724 // worklist, we use their users as seeds for the next round of forward
725 // propagation.
726 for (size_t I = BeforeProcessingV; I != WorkList.size(); I++)
727 for (User *U : WorkList[I]->users())
728 if (isa<Instruction>(U) && V != U)
729 NewWorkList.push_back(cast<Instruction>(U));
730 }
731 return NewWorkList;
732 }
733
734 /// (Op0 op Op1)^T -> Op0^T op Op1^T
735 /// Transpose \p Op0 and \p Op1 of shape \p Shape0 and \p Shape1, then use
736 /// them on both sides of \p Operation.
737 Instruction *distributeTransposes(
738 Value *Op0, ShapeInfo Shape0, Value *Op1, ShapeInfo Shape1,
739 MatrixBuilder &Builder,
740 function_ref<Instruction *(Value *, ShapeInfo, Value *, ShapeInfo)>
741 Operation) {
742 Value *T0 = Builder.CreateMatrixTranspose(
743 Op0, Shape0.NumRows, Shape0.NumColumns, Op0->getName() + "_t");
744 // We are being run after shape prop, add shape for newly created
745 // instructions so that we lower them later.
746 setShapeInfo(T0, Shape0.t());
747 Value *T1 = Builder.CreateMatrixTranspose(
748 Op1, Shape1.NumRows, Shape1.NumColumns, Op1->getName() + "_t");
749 setShapeInfo(T1, Shape1.t());
750 return Operation(T0, Shape0.t(), T1, Shape1.t());
751 }
752
753 void updateShapeAndReplaceAllUsesWith(Instruction &Old, Value *New) {
754 // We need to remove Old from the ShapeMap otherwise RAUW will replace it
755 // with New. We should only add New it it supportsShapeInfo so we insert
756 // it conditionally instead.
757 auto S = ShapeMap.find(&Old);
758 if (S != ShapeMap.end()) {
759 ShapeMap.erase(S);
760 if (supportsShapeInfo(New))
761 ShapeMap.insert({New, S->second});
762 }
763 Old.replaceAllUsesWith(New);
764 }
765
766 /// Sink a top-level transpose inside matmuls and adds.
767 /// This creates and erases instructions as needed, and returns the newly
768 /// created instruction while updating the iterator to avoid invalidation. If
769 /// this returns nullptr, no new instruction was created.
771 BasicBlock &BB = *I.getParent();
772 IRBuilder<> IB(&I);
774
775 Value *TA, *TAMA, *TAMB;
776 ConstantInt *R, *K, *C;
777 if (!match(&I, m_Intrinsic<Intrinsic::matrix_transpose>(
779 return nullptr;
780
781 // Transpose of a transpose is a nop
782 Value *TATA;
783 if (match(TA, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TATA)))) {
784 updateShapeAndReplaceAllUsesWith(I, TATA);
785 eraseFromParentAndMove(&I, II, BB);
786 eraseFromParentAndMove(TA, II, BB);
787 return nullptr;
788 }
789
790 // k^T -> k
791 if (isSplat(TA)) {
792 updateShapeAndReplaceAllUsesWith(I, TA);
793 eraseFromParentAndMove(&I, II, BB);
794 return nullptr;
795 }
796
797 // (A * B)^t -> B^t * A^t
798 // RxK KxC CxK KxR
799 if (match(TA, m_Intrinsic<Intrinsic::matrix_multiply>(
800 m_Value(TAMA), m_Value(TAMB), m_ConstantInt(R),
802 auto NewInst = distributeTransposes(
803 TAMB, {K, C}, TAMA, {R, K}, Builder,
804 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
805 return Builder.CreateMatrixMultiply(T0, T1, Shape0.NumRows,
806 Shape0.NumColumns,
807 Shape1.NumColumns, "mmul");
808 });
809 updateShapeAndReplaceAllUsesWith(I, NewInst);
810 eraseFromParentAndMove(&I, II, BB);
811 eraseFromParentAndMove(TA, II, BB);
812 return NewInst;
813 }
814
815 // Same as above, but with a mul, which occurs when multiplied
816 // with a scalar.
817 // (A * k)^t -> A^t * k
818 // R x C RxC
819 if (match(TA, m_AnyMul(m_Value(TAMA), m_Value(TAMB))) &&
820 (isSplat(TAMA) || isSplat(TAMB))) {
821 IRBuilder<> LocalBuilder(&I);
822 // We know that the transposed operand is of shape RxC.
823 // An when multiplied with a scalar, the shape is preserved.
824 auto NewInst = distributeTransposes(
825 TAMA, {R, C}, TAMB, {R, C}, Builder,
826 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
827 bool IsFP = I.getType()->isFPOrFPVectorTy();
828 auto *Mul = IsFP ? LocalBuilder.CreateFMul(T0, T1, "mmul")
829 : LocalBuilder.CreateMul(T0, T1, "mmul");
830 auto *Result = cast<Instruction>(Mul);
831 setShapeInfo(Result, Shape0);
832 return Result;
833 });
834 updateShapeAndReplaceAllUsesWith(I, NewInst);
835 eraseFromParentAndMove(&I, II, BB);
836 eraseFromParentAndMove(TA, II, BB);
837 return NewInst;
838 }
839
840 // (A + B)^t -> A^t + B^t
841 // RxC RxC CxR CxR
842 if (match(TA, m_AnyAdd(m_Value(TAMA), m_Value(TAMB)))) {
843 IRBuilder<> LocalBuilder(&I);
844 auto NewInst = distributeTransposes(
845 TAMA, {R, C}, TAMB, {R, C}, Builder,
846 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
847 bool IsFP = I.getType()->isFPOrFPVectorTy();
848 auto *Add = IsFP ? LocalBuilder.CreateFAdd(T0, T1, "madd")
849 : LocalBuilder.CreateAdd(T0, T1, "madd");
850
851 auto *Result = cast<Instruction>(Add);
852 setShapeInfo(Result, Shape0);
853 return Result;
854 });
855 updateShapeAndReplaceAllUsesWith(I, NewInst);
856 eraseFromParentAndMove(&I, II, BB);
857 eraseFromParentAndMove(TA, II, BB);
858 return NewInst;
859 }
860
861 return nullptr;
862 }
863
864 void liftTranspose(Instruction &I) {
865 // Erase dead Instructions after lifting transposes from binops.
866 auto CleanupBinOp = [](Instruction &T, Value *A, Value *B) {
867 if (T.use_empty())
868 T.eraseFromParent();
869 if (A->use_empty())
870 cast<Instruction>(A)->eraseFromParent();
871 if (A != B && B->use_empty())
872 cast<Instruction>(B)->eraseFromParent();
873 };
874
875 Value *A, *B, *AT, *BT;
876 ConstantInt *R, *K, *C;
877 // A^t * B ^t -> (B * A)^t
878 if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>(
881 match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(AT))) &&
882 match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value((BT))))) {
883 IRBuilder<> IB(&I);
885 Value *M = Builder.CreateMatrixMultiply(
886 BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue());
887 setShapeInfo(M, {C, R});
888 Instruction *NewInst = Builder.CreateMatrixTranspose(M, C->getZExtValue(),
889 R->getZExtValue());
890 updateShapeAndReplaceAllUsesWith(I, NewInst);
891 CleanupBinOp(I, A, B);
892 }
893 // A^t + B ^t -> (A + B)^t
894 else if (match(&I, m_FAdd(m_Value(A), m_Value(B))) &&
895 match(A, m_Intrinsic<Intrinsic::matrix_transpose>(
896 m_Value(AT), m_ConstantInt(R), m_ConstantInt(C))) &&
897 match(B, m_Intrinsic<Intrinsic::matrix_transpose>(
900 Value *Add = cast<Instruction>(Builder.CreateFAdd(AT, BT, "mfadd"));
901 setShapeInfo(Add, {C, R});
902 MatrixBuilder MBuilder(Builder);
903 Instruction *NewInst = MBuilder.CreateMatrixTranspose(
904 Add, C->getZExtValue(), R->getZExtValue(), "mfadd_t");
905 updateShapeAndReplaceAllUsesWith(I, NewInst);
906 CleanupBinOp(I, A, B);
907 }
908 }
909
910 /// Try moving transposes in order to fold them away or into multiplies.
911 void optimizeTransposes() {
912 // First sink all transposes inside matmuls and adds, hoping that we end up
913 // with NN, NT or TN variants.
914 for (BasicBlock &BB : reverse(Func)) {
915 for (auto II = BB.rbegin(); II != BB.rend();) {
916 Instruction &I = *II;
917 // We may remove II. By default continue on the next/prev instruction.
918 ++II;
919 if (Instruction *NewInst = sinkTranspose(I, II))
920 II = std::next(BasicBlock::reverse_iterator(NewInst));
921 }
922 }
923
924 // If we have a TT matmul or a TT add, lift the transpose. We may be able
925 // to fold into consuming multiply or add.
926 for (BasicBlock &BB : Func) {
928 liftTranspose(I);
929 }
930 }
931 }
932
933 bool Visit() {
935
936 // Initially only the shape of matrix intrinsics is known.
937 // Initialize the work list with ops carrying shape information.
938 for (BasicBlock &BB : Func)
939 for (Instruction &Inst : BB) {
940 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst);
941 if (!II)
942 continue;
943
944 switch (II->getIntrinsicID()) {
945 case Intrinsic::matrix_multiply:
946 case Intrinsic::matrix_transpose:
947 case Intrinsic::matrix_column_major_load:
948 case Intrinsic::matrix_column_major_store:
949 WorkList.push_back(&Inst);
950 break;
951 default:
952 break;
953 }
954 }
955
956 // Avoid unnecessary work if there are no matrix intrinsics in the function.
957 if (WorkList.empty())
958 return false;
959
960 // Propagate shapes until nothing changes any longer.
961 while (!WorkList.empty()) {
962 WorkList = propagateShapeForward(WorkList);
963 WorkList = propagateShapeBackward(WorkList);
964 }
965
966 if (!isMinimal()) {
967 optimizeTransposes();
969 dbgs() << "Dump after matrix transpose optimization:\n";
970 Func.print(dbgs());
971 }
972 }
973
974 bool Changed = false;
975 SmallVector<CallInst *, 16> MaybeFusableInsts;
977
978 // First, collect all instructions with shape information and candidates for
979 // fusion (currently only matrix multiplies).
981 for (auto *BB : RPOT)
982 for (Instruction &I : *BB) {
983 if (ShapeMap.find(&I) == ShapeMap.end())
984 continue;
985 if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>()))
986 MaybeFusableInsts.push_back(cast<CallInst>(&I));
987 MatrixInsts.push_back(&I);
988 }
989
990 // Second, try to lower any dot products
992 for (CallInst *CI : MaybeFusableInsts)
993 lowerDotProduct(CI, FusedInsts, getFastMathFlags(CI));
994
995 // Third, try to fuse candidates.
996 for (CallInst *CI : MaybeFusableInsts)
997 LowerMatrixMultiplyFused(CI, FusedInsts);
998
999 Changed = !FusedInsts.empty();
1000
1001 // Fourth, lower remaining instructions with shape information.
1002 for (Instruction *Inst : MatrixInsts) {
1003 if (FusedInsts.count(Inst))
1004 continue;
1005
1006 IRBuilder<> Builder(Inst);
1007
1008 if (CallInst *CInst = dyn_cast<CallInst>(Inst))
1009 Changed |= VisitCallInst(CInst);
1010
1011 Value *Op1;
1012 Value *Op2;
1013 if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
1014 Changed |= VisitBinaryOperator(BinOp);
1015 if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
1016 Changed |= VisitUnaryOperator(UnOp);
1017 if (match(Inst, m_Load(m_Value(Op1))))
1018 Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder);
1019 else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
1020 Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);
1021 }
1022
1023 if (ORE) {
1024 RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func);
1025 RemarkGen.emitRemarks();
1026 }
1027
1028 // Delete the instructions backwards, as it has a reduced likelihood of
1029 // having to update as many def-use and use-def chains.
1030 //
1031 // Because we add to ToRemove during fusion we can't guarantee that defs
1032 // are before uses. Change uses to poison temporarily as these should get
1033 // removed as well.
1034 //
1035 // For verification, we keep track of where we changed uses to poison in
1036 // PoisonedInsts and then check that we in fact remove them.
1037 SmallSet<Instruction *, 16> PoisonedInsts;
1038 for (auto *Inst : reverse(ToRemove)) {
1039 for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
1040 if (auto *Poisoned = dyn_cast<Instruction>(U.getUser()))
1041 PoisonedInsts.insert(Poisoned);
1042 U.set(PoisonValue::get(Inst->getType()));
1043 }
1044 Inst->eraseFromParent();
1045 PoisonedInsts.erase(Inst);
1046 }
1047 if (!PoisonedInsts.empty()) {
1048 // If we didn't remove all poisoned instructions, it's a hard error.
1049 dbgs() << "Poisoned but present instructions:\n";
1050 for (auto *I : PoisonedInsts)
1051 dbgs() << *I << "\n";
1052 llvm_unreachable("Poisoned but instruction not removed");
1053 }
1054
1055 return Changed;
1056 }
1057
1058 /// Replace intrinsic calls
1059 bool VisitCallInst(CallInst *Inst) {
1060 if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic())
1061 return false;
1062
1063 switch (Inst->getCalledFunction()->getIntrinsicID()) {
1064 case Intrinsic::matrix_multiply:
1065 LowerMultiply(Inst);
1066 break;
1067 case Intrinsic::matrix_transpose:
1068 LowerTranspose(Inst);
1069 break;
1070 case Intrinsic::matrix_column_major_load:
1071 LowerColumnMajorLoad(Inst);
1072 break;
1073 case Intrinsic::matrix_column_major_store:
1074 LowerColumnMajorStore(Inst);
1075 break;
1076 default:
1077 return false;
1078 }
1079 return true;
1080 }
1081
1082 /// Compute the alignment for a column/row \p Idx with \p Stride between them.
1083 /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a
1084 /// ConstantInt, reduce the initial alignment based on the byte offset. For
1085 /// non-ConstantInt strides, return the common alignment of the initial
1086 /// alignment and the element size in bytes.
1087 Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy,
1088 MaybeAlign A) const {
1089 Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy);
1090 if (Idx == 0)
1091 return InitialAlign;
1092
1093 TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy);
1094 if (auto *ConstStride = dyn_cast<ConstantInt>(Stride)) {
1095 uint64_t StrideInBytes =
1096 ConstStride->getZExtValue() * ElementSizeInBits / 8;
1097 return commonAlignment(InitialAlign, Idx * StrideInBytes);
1098 }
1099 return commonAlignment(InitialAlign, ElementSizeInBits / 8);
1100 }
1101
1102 /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between
1103 /// vectors.
1104 MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride,
1105 bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) {
1106 auto *VType = cast<VectorType>(Ty);
1107 Type *EltTy = VType->getElementType();
1108 Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride());
1109 Value *EltPtr = Ptr;
1110 MatrixTy Result;
1111 for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) {
1112 Value *GEP = computeVectorAddr(
1113 EltPtr, Builder.getIntN(Stride->getType()->getScalarSizeInBits(), I),
1114 Stride, Shape.getStride(), EltTy, Builder);
1115 Value *Vector = Builder.CreateAlignedLoad(
1116 VecTy, GEP, getAlignForIndex(I, Stride, EltTy, MAlign),
1117 IsVolatile, "col.load");
1118
1119 Result.addVector(Vector);
1120 }
1121 return Result.addNumLoads(getNumOps(Result.getVectorTy()) *
1122 Result.getNumVectors());
1123 }
1124
1125 /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix,
1126 /// starting at \p MatrixPtr[I][J].
1127 MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile,
1128 ShapeInfo MatrixShape, Value *I, Value *J,
1129 ShapeInfo ResultShape, Type *EltTy,
1130 IRBuilder<> &Builder) {
1131
1132 Value *Offset = Builder.CreateAdd(
1133 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
1134
1135 Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset);
1136 auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows *
1137 ResultShape.NumColumns);
1138
1139 return loadMatrix(TileTy, TileStart, Align,
1140 Builder.getInt64(MatrixShape.getStride()), IsVolatile,
1141 ResultShape, Builder);
1142 }
1143
1144 /// Lower a load instruction with shape information.
1145 void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride,
1146 bool IsVolatile, ShapeInfo Shape) {
1147 IRBuilder<> Builder(Inst);
1148 finalizeLowering(Inst,
1149 loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile,
1150 Shape, Builder),
1151 Builder);
1152 }
1153
1154 /// Lowers llvm.matrix.column.major.load.
1155 ///
1156 /// The intrinsic loads a matrix from memory using a stride between columns.
1157 void LowerColumnMajorLoad(CallInst *Inst) {
1158 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
1159 "Intrinsic only supports column-major layout!");
1160 Value *Ptr = Inst->getArgOperand(0);
1161 Value *Stride = Inst->getArgOperand(1);
1162 LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride,
1163 cast<ConstantInt>(Inst->getArgOperand(2))->isOne(),
1164 {Inst->getArgOperand(3), Inst->getArgOperand(4)});
1165 }
1166
1167 /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p
1168 /// MatrixPtr[I][J].
1169 void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr,
1170 MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape,
1171 Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) {
1172 Value *Offset = Builder.CreateAdd(
1173 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
1174
1175 Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset);
1176 auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() *
1177 StoreVal.getNumColumns());
1178
1179 storeMatrix(TileTy, StoreVal, TileStart, MAlign,
1180 Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder);
1181 }
1182
1183 /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between
1184 /// vectors.
1185 MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr,
1186 MaybeAlign MAlign, Value *Stride, bool IsVolatile,
1187 IRBuilder<> &Builder) {
1188 auto VType = cast<VectorType>(Ty);
1189 Value *EltPtr = Ptr;
1190 for (auto Vec : enumerate(StoreVal.vectors())) {
1191 Value *GEP = computeVectorAddr(
1192 EltPtr,
1193 Builder.getIntN(Stride->getType()->getScalarSizeInBits(),
1194 Vec.index()),
1195 Stride, StoreVal.getStride(), VType->getElementType(), Builder);
1196 Builder.CreateAlignedStore(Vec.value(), GEP,
1197 getAlignForIndex(Vec.index(), Stride,
1198 VType->getElementType(),
1199 MAlign),
1200 IsVolatile);
1201 }
1202 return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) *
1203 StoreVal.getNumVectors());
1204 }
1205
1206 /// Lower a store instruction with shape information.
1208 Value *Stride, bool IsVolatile, ShapeInfo Shape) {
1209 IRBuilder<> Builder(Inst);
1210 auto StoreVal = getMatrix(Matrix, Shape, Builder);
1211 finalizeLowering(Inst,
1212 storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride,
1213 IsVolatile, Builder),
1214 Builder);
1215 }
1216
1217 /// Lowers llvm.matrix.column.major.store.
1218 ///
1219 /// The intrinsic store a matrix back memory using a stride between columns.
1220 void LowerColumnMajorStore(CallInst *Inst) {
1221 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
1222 "Intrinsic only supports column-major layout!");
1223 Value *Matrix = Inst->getArgOperand(0);
1224 Value *Ptr = Inst->getArgOperand(1);
1225 Value *Stride = Inst->getArgOperand(2);
1226 LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride,
1227 cast<ConstantInt>(Inst->getArgOperand(3))->isOne(),
1228 {Inst->getArgOperand(4), Inst->getArgOperand(5)});
1229 }
1230
1231 // Set elements I..I+NumElts-1 to Block
1232 Value *insertVector(Value *Col, unsigned I, Value *Block,
1233 IRBuilder<> &Builder) {
1234
1235 // First, bring Block to the same size as Col
1236 unsigned BlockNumElts =
1237 cast<FixedVectorType>(Block->getType())->getNumElements();
1238 unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements();
1239 assert(NumElts >= BlockNumElts && "Too few elements for current block");
1240
1241 Block = Builder.CreateShuffleVector(
1242 Block, createSequentialMask(0, BlockNumElts, NumElts - BlockNumElts));
1243
1244 // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7,
1245 // 8, 4, 5, 6
1247 unsigned i;
1248 for (i = 0; i < I; i++)
1249 Mask.push_back(i);
1250
1251 unsigned VecNumElts =
1252 cast<FixedVectorType>(Col->getType())->getNumElements();
1253 for (; i < I + BlockNumElts; i++)
1254 Mask.push_back(i - I + VecNumElts);
1255
1256 for (; i < VecNumElts; i++)
1257 Mask.push_back(i);
1258
1259 return Builder.CreateShuffleVector(Col, Block, Mask);
1260 }
1261
1262 Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp,
1263 IRBuilder<> &Builder, bool AllowContraction,
1264 unsigned &NumComputeOps) {
1265 NumComputeOps += getNumOps(A->getType());
1266 if (!Sum)
1267 return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B);
1268
1269 if (UseFPOp) {
1270 if (AllowContraction) {
1271 // Use fmuladd for floating point operations and let the backend decide
1272 // if that's profitable.
1274 Func.getParent(), Intrinsic::fmuladd, A->getType());
1275 return Builder.CreateCall(FMulAdd, {A, B, Sum});
1276 }
1277 NumComputeOps += getNumOps(A->getType());
1278 Value *Mul = Builder.CreateFMul(A, B);
1279 return Builder.CreateFAdd(Sum, Mul);
1280 }
1281
1282 NumComputeOps += getNumOps(A->getType());
1283 Value *Mul = Builder.CreateMul(A, B);
1284 return Builder.CreateAdd(Sum, Mul);
1285 }
1286
1287 /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
1288 /// users with shape information, there's nothing to do: they will use the
1289 /// cached value when they are lowered. For other users, \p Matrix is
1290 /// flattened and the uses are updated to use it. Also marks \p Inst for
1291 /// deletion.
1292 void finalizeLowering(Instruction *Inst, MatrixTy Matrix,
1293 IRBuilder<> &Builder) {
1294 auto inserted = Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix));
1295 (void)inserted;
1296 assert(inserted.second && "multiple matrix lowering mapping");
1297
1298 ToRemove.push_back(Inst);
1299 Value *Flattened = nullptr;
1300 for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
1301 if (ShapeMap.find(U.getUser()) == ShapeMap.end()) {
1302 if (!Flattened)
1303 Flattened = Matrix.embedInVector(Builder);
1304 U.set(Flattened);
1305 }
1306 }
1307 }
1308
1309 /// Special case for MatMul lowering. Prevents scalar loads of row-major
1310 /// vectors Lowers to vector reduction add instead of sequential add if
1311 /// reassocation is enabled.
1312 void lowerDotProduct(CallInst *MatMul,
1314 FastMathFlags FMF) {
1315 if (FusedInsts.contains(MatMul) ||
1316 MatrixLayout != MatrixLayoutTy::ColumnMajor)
1317 return;
1318 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1319 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1320
1321 if (LShape.NumRows != 1 || RShape.NumColumns != 1) // not a dot product
1322 return;
1323
1324 Value *LHS = MatMul->getArgOperand(0);
1325 Value *RHS = MatMul->getArgOperand(1);
1326
1327 Type *ElementType = cast<VectorType>(LHS->getType())->getElementType();
1328 bool IsIntVec = ElementType->isIntegerTy();
1329
1330 // Floating point reductions require reassocation.
1331 if (!IsIntVec && !FMF.allowReassoc())
1332 return;
1333
1334 auto CanBeFlattened = [this](Value *Op) {
1335 if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end())
1336 return true;
1337 return match(
1339 m_Load(m_Value()),
1340 m_CombineOr(m_Intrinsic<Intrinsic::matrix_transpose>(),
1341 m_Intrinsic<Intrinsic::matrix_column_major_load>(
1342 m_Value(), m_SpecificInt(1))))));
1343 };
1344 // Returns the cost benefit of using \p Op with the dot product lowering. If
1345 // the returned cost is < 0, the argument is cheaper to use in the
1346 // dot-product lowering.
1347 auto GetCostForArg = [this, &CanBeFlattened](Value *Op, unsigned N) {
1348 if (!isa<Instruction>(Op))
1349 return InstructionCost(0);
1350
1351 FixedVectorType *VecTy = cast<FixedVectorType>(Op->getType());
1352 Type *EltTy = VecTy->getElementType();
1353
1354 if (!CanBeFlattened(Op)) {
1355 InstructionCost EmbedCost(0);
1356 // Roughly estimate the cost for embedding the columns into a vector.
1357 for (unsigned I = 1; I < N; ++I)
1358 EmbedCost -=
1360 std::nullopt, TTI::TCK_RecipThroughput);
1361 return EmbedCost;
1362 }
1363
1364 if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end()) {
1365 InstructionCost OriginalCost =
1366 TTI.getArithmeticInstrCost(cast<Instruction>(Op)->getOpcode(),
1367 EltTy) *
1368 N;
1370 cast<Instruction>(Op)->getOpcode(), VecTy);
1371 return NewCost - OriginalCost;
1372 }
1373
1374 if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>())) {
1375 // The transpose can be skipped for the dot product lowering, roughly
1376 // estimate the savings as the cost of embedding the columns in a
1377 // vector.
1378 InstructionCost EmbedCost(0);
1379 for (unsigned I = 1; I < N; ++I)
1380 EmbedCost +=
1382 std::nullopt, TTI::TCK_RecipThroughput);
1383 return EmbedCost;
1384 }
1385
1386 // Costs for loads.
1387 if (N == 1)
1388 return InstructionCost(0);
1389
1390 return TTI.getMemoryOpCost(Instruction::Load, VecTy, Align(1), 0) -
1391 N * TTI.getMemoryOpCost(Instruction::Load, EltTy, Align(1), 0);
1392 };
1393 auto LHSCost = GetCostForArg(LHS, LShape.NumColumns);
1394
1395 // We compare the costs of a vector.reduce.add to sequential add.
1396 int AddOpCode = IsIntVec ? Instruction::Add : Instruction::FAdd;
1397 int MulOpCode = IsIntVec ? Instruction::Mul : Instruction::FMul;
1398 InstructionCost ReductionCost =
1400 AddOpCode, cast<VectorType>(LHS->getType()),
1401 IsIntVec ? std::nullopt : std::optional(FMF)) +
1402 TTI.getArithmeticInstrCost(MulOpCode, LHS->getType());
1403 InstructionCost SequentialAddCost =
1404 TTI.getArithmeticInstrCost(AddOpCode, ElementType) *
1405 (LShape.NumColumns - 1) +
1406 TTI.getArithmeticInstrCost(MulOpCode, ElementType) *
1407 (LShape.NumColumns);
1408 if ((LHSCost + ReductionCost - SequentialAddCost) > InstructionCost(0))
1409 return;
1410
1411 FusedInsts.insert(MatMul);
1412 IRBuilder<> Builder(MatMul);
1413 auto FlattenArg = [&Builder, &FusedInsts, &CanBeFlattened,
1414 this](Value *Op) -> Value * {
1415 // Matmul must be the only user of loads because we don't use LowerLoad
1416 // for row vectors (LowerLoad results in scalar loads and shufflevectors
1417 // instead of single vector load).
1418 if (!CanBeFlattened(Op))
1419 return Op;
1420
1421 if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end()) {
1422 ShapeMap[Op] = ShapeMap[Op].t();
1423 return Op;
1424 }
1425
1426 FusedInsts.insert(cast<Instruction>(Op));
1427 // If vector uses the builtin load, lower to a LoadInst
1428 Value *Arg;
1429 if (match(Op, m_Intrinsic<Intrinsic::matrix_column_major_load>(
1430 m_Value(Arg)))) {
1431 auto *NewLoad = Builder.CreateLoad(Op->getType(), Arg);
1432 Op->replaceAllUsesWith(NewLoad);
1433 cast<Instruction>(Op)->eraseFromParent();
1434 return NewLoad;
1435 } else if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>(
1436 m_Value(Arg)))) {
1437 ToRemove.push_back(cast<Instruction>(Op));
1438 return Arg;
1439 }
1440
1441 return Op;
1442 };
1443 LHS = FlattenArg(LHS);
1444
1445 // Insert mul/fmul and llvm.vector.reduce.fadd
1446 Value *Mul =
1447 IsIntVec ? Builder.CreateMul(LHS, RHS) : Builder.CreateFMul(LHS, RHS);
1448
1449 Value *Result;
1450 if (IsIntVec)
1451 Result = Builder.CreateAddReduce(Mul);
1452 else {
1453 Result = Builder.CreateFAddReduce(
1454 ConstantFP::get(cast<VectorType>(LHS->getType())->getElementType(),
1455 0.0),
1456 Mul);
1457 cast<Instruction>(Result)->setFastMathFlags(FMF);
1458 }
1459
1460 // pack scalar back into a matrix and then replace matmul inst
1461 Result = Builder.CreateInsertElement(PoisonValue::get(MatMul->getType()),
1462 Result, uint64_t(0));
1463 MatMul->replaceAllUsesWith(Result);
1464 FusedInsts.insert(MatMul);
1465 ToRemove.push_back(MatMul);
1466 }
1467
1468 /// Compute \p Result += \p A * \p B for input matrices with left-associating
1469 /// addition.
1470 ///
1471 /// We can fold a transpose into the operand that is used to extract scalars.
1472 /// This is the first operands with row-major and the second with
1473 /// column-major. If \p IsScalarMatrixTransposed we assume the appropriate
1474 /// operand is transposed.
1475 void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A,
1476 const MatrixTy &B, IRBuilder<> &Builder, bool IsTiled,
1477 bool IsScalarMatrixTransposed, FastMathFlags FMF) {
1478 const unsigned VF = std::max<unsigned>(
1480 .getFixedValue() /
1481 Result.getElementType()->getPrimitiveSizeInBits().getFixedValue(),
1482 1U);
1483 unsigned R = Result.getNumRows();
1484 unsigned C = Result.getNumColumns();
1485 unsigned M = A.getNumColumns();
1486
1487 bool IsFP = Result.getElementType()->isFloatingPointTy();
1488 assert(A.isColumnMajor() == B.isColumnMajor() &&
1489 Result.isColumnMajor() == A.isColumnMajor() &&
1490 "operands must agree on matrix layout");
1491 unsigned NumComputeOps = 0;
1492
1493 Builder.setFastMathFlags(FMF);
1494
1495 if (A.isColumnMajor()) {
1496 // Multiply columns from the first operand with scalars from the second
1497 // operand. Then move along the K axes and accumulate the columns. With
1498 // this the adds can be vectorized without reassociation.
1499 for (unsigned J = 0; J < C; ++J) {
1500 unsigned BlockSize = VF;
1501 // If Result is zero, we don't need to accumulate in the K==0 iteration.
1502 bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J));
1503
1504 for (unsigned I = 0; I < R; I += BlockSize) {
1505 // Gradually lower the vectorization factor to cover the remainder.
1506 while (I + BlockSize > R)
1507 BlockSize /= 2;
1508
1509 Value *Sum = IsTiled ? Result.extractVector(I, J, BlockSize, Builder)
1510 : nullptr;
1511 for (unsigned K = 0; K < M; ++K) {
1512 Value *L = A.extractVector(I, K, BlockSize, Builder);
1513 Value *RH = Builder.CreateExtractElement(
1514 B.getColumn(IsScalarMatrixTransposed ? K : J),
1515 IsScalarMatrixTransposed ? J : K);
1516 Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat");
1517 Sum =
1518 createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat,
1519 IsFP, Builder, FMF.allowContract(), NumComputeOps);
1520 }
1521 Result.setVector(J,
1522 insertVector(Result.getVector(J), I, Sum, Builder));
1523 }
1524 }
1525 } else {
1526 // Multiply rows from the second operand with scalars from the first
1527 // operand. Then move along the K axes and accumulate the rows. With this
1528 // the adds can be vectorized without reassociation.
1529 for (unsigned I = 0; I < R; ++I) {
1530 unsigned BlockSize = VF;
1531 bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I));
1532 for (unsigned J = 0; J < C; J += BlockSize) {
1533 // Gradually lower the vectorization factor to cover the remainder.
1534 while (J + BlockSize > C)
1535 BlockSize /= 2;
1536
1537 Value *Sum = nullptr;
1538 for (unsigned K = 0; K < M; ++K) {
1539 Value *R = B.extractVector(K, J, BlockSize, Builder);
1540 Value *LH = Builder.CreateExtractElement(
1541 A.getVector(IsScalarMatrixTransposed ? K : I),
1542 IsScalarMatrixTransposed ? I : K);
1543 Value *Splat = Builder.CreateVectorSplat(BlockSize, LH, "splat");
1544 Sum =
1545 createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R,
1546 IsFP, Builder, FMF.allowContract(), NumComputeOps);
1547 }
1548 Result.setVector(I,
1549 insertVector(Result.getVector(I), J, Sum, Builder));
1550 }
1551 }
1552 }
1553 Result.addNumComputeOps(NumComputeOps);
1554 }
1555
1556 /// Ensure that the memory in \p Load does not alias \p Store by potentially
1557 /// copying it to a new location. This new or otherwise the original location
1558 /// is returned.
1559 Value *getNonAliasingPointer(LoadInst *Load, StoreInst *Store,
1560 CallInst *MatMul) {
1561 MemoryLocation StoreLoc = MemoryLocation::get(Store);
1562 MemoryLocation LoadLoc = MemoryLocation::get(Load);
1563
1564 // If we can statically determine noalias we're good.
1565 if (AA->isNoAlias(LoadLoc, StoreLoc))
1566 return Load->getPointerOperand();
1567
1568 // Create code to check if the memory locations of the Load and Store
1569 // overlap and if they do, copy Load's operand to a new buffer.
1570
1571 // First, create new blocks for 2n part of the check and the copy.
1572 BasicBlock *Check0 = MatMul->getParent();
1573 // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a
1574 // DT. Manually collect dominator tree updates, to avoid unnecessary work,
1575 // as we adjust Check0 and Check1's branches.
1577 for (BasicBlock *Succ : successors(Check0))
1578 DTUpdates.push_back({DT->Delete, Check0, Succ});
1579
1580 BasicBlock *Check1 =
1581 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1582 nullptr, "alias_cont");
1583 BasicBlock *Copy =
1584 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1585 nullptr, "copy");
1586 BasicBlock *Fusion =
1587 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1588 nullptr, "no_alias");
1589
1590 // Check if the loaded memory location begins before the end of the store
1591 // location. If the condition holds, they might overlap, otherwise they are
1592 // guaranteed to not overlap.
1593 IRBuilder<> Builder(MatMul);
1594 Check0->getTerminator()->eraseFromParent();
1595 Builder.SetInsertPoint(Check0);
1596 Type *IntPtrTy = Builder.getIntPtrTy(Load->getModule()->getDataLayout());
1597 Value *StoreBegin = Builder.CreatePtrToInt(
1598 const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin");
1599 Value *StoreEnd = Builder.CreateAdd(
1600 StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()),
1601 "store.end", true, true);
1602 Value *LoadBegin = Builder.CreatePtrToInt(const_cast<Value *>(LoadLoc.Ptr),
1603 IntPtrTy, "load.begin");
1604 Builder.CreateCondBr(Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1,
1605 Fusion);
1606
1607 // Check if the store begins before the end of the load location. If the
1608 // condition holds, they alias, otherwise they are guaranteed to not
1609 // overlap.
1610 Check1->getTerminator()->eraseFromParent();
1611 Builder.SetInsertPoint(Check1, Check1->begin());
1612 Value *LoadEnd = Builder.CreateAdd(
1613 LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()),
1614 "load.end", true, true);
1615 Builder.CreateCondBr(Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy,
1616 Fusion);
1617
1618 // Copy load operand to new alloca.
1619 Builder.SetInsertPoint(Copy, Copy->begin());
1620 auto *VT = cast<FixedVectorType>(Load->getType());
1621 // Use an array type for the alloca, to avoid potentially huge alignment
1622 // requirements for large vector types.
1623 auto *ArrayTy = ArrayType::get(VT->getElementType(), VT->getNumElements());
1624 AllocaInst *Alloca =
1625 Builder.CreateAlloca(ArrayTy, Load->getPointerAddressSpace());
1626
1627 Builder.CreateMemCpy(Alloca, Alloca->getAlign(), Load->getPointerOperand(),
1628 Load->getAlign(), LoadLoc.Size.getValue());
1629 Builder.SetInsertPoint(Fusion, Fusion->begin());
1630 PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3);
1631 PHI->addIncoming(Load->getPointerOperand(), Check0);
1632 PHI->addIncoming(Load->getPointerOperand(), Check1);
1633 PHI->addIncoming(Alloca, Copy);
1634
1635 // Adjust DT.
1636 DTUpdates.push_back({DT->Insert, Check0, Check1});
1637 DTUpdates.push_back({DT->Insert, Check0, Fusion});
1638 DTUpdates.push_back({DT->Insert, Check1, Copy});
1639 DTUpdates.push_back({DT->Insert, Check1, Fusion});
1640 DT->applyUpdates(DTUpdates);
1641 return PHI;
1642 }
1643
1644 bool isFusionProfitable(CallInst *MatMul) {
1645 if (ForceFusion)
1646 return true;
1647
1648 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1649 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1650
1651 const unsigned R = LShape.NumRows;
1652 const unsigned C = RShape.NumColumns;
1653 const unsigned M = LShape.NumColumns;
1654 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1655
1656 const unsigned VF = std::max<unsigned>(
1658 .getFixedValue() /
1660 1U);
1661
1662 // Cost model for tiling
1663 //
1664 // For tiling to be beneficial, we need reuse either along the R or
1665 // the C axis. We vectorize along the R axis so that means at least
1666 // 3 elements.
1667 // TODO: Also consider cost of copying if operands alias.
1668 if (R <= VF && C == 1)
1669 return false;
1670 // Then we need enough elements to exceed the number of vector
1671 // registers we have. Note that this is an oversimplification since
1672 // fusing also takes some extra loads which may exceed the number of
1673 // reloads necessary.
1674 unsigned Op0Regs = (R + VF - 1) / VF * M;
1675 unsigned Op1Regs = (M + VF - 1) / VF * C;
1676 return Op0Regs + Op1Regs >
1678 }
1679
1680 MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) {
1681 MatrixTy Res;
1682 auto *ColumType = FixedVectorType::get(EltType, R);
1683 for (unsigned I = 0; I < C; ++I)
1684 Res.addVector(ConstantAggregateZero::get(ColumType));
1685 return Res;
1686 }
1687
1688 void createTiledLoops(CallInst *MatMul, Value *LPtr, ShapeInfo LShape,
1689 Value *RPtr, ShapeInfo RShape, StoreInst *Store) {
1690 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1691
1692 // Create the main tiling loop nest.
1693 TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns, TileSize);
1694 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
1695 Instruction *InsertI = cast<Instruction>(MatMul);
1696 BasicBlock *Start = InsertI->getParent();
1697 BasicBlock *End =
1698 SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue");
1699 IRBuilder<> Builder(MatMul);
1700 BasicBlock *InnerBody = TI.CreateTiledLoops(Start, End, Builder, DTU, *LI);
1701
1702 Type *TileVecTy =
1704 MatrixTy TileResult;
1705 // Insert in the inner loop header.
1706 Builder.SetInsertPoint(TI.KLoop.Header->getTerminator());
1707 // Create PHI nodes for the result columns to accumulate across iterations.
1708 SmallVector<PHINode *, 4> ColumnPhis;
1709 for (unsigned I = 0; I < TileSize; I++) {
1710 auto *Phi = Builder.CreatePHI(TileVecTy, 2, "result.vec." + Twine(I));
1711 Phi->addIncoming(ConstantAggregateZero::get(TileVecTy),
1712 TI.RowLoop.Header->getSingleSuccessor());
1713 TileResult.addVector(Phi);
1714 ColumnPhis.push_back(Phi);
1715 }
1716
1717 // Insert in the inner loop body, which computes
1718 // Res += Load(CurrentRow, K) * Load(K, CurrentColumn)
1719 Builder.SetInsertPoint(InnerBody->getTerminator());
1720 // Load tiles of the operands.
1721 MatrixTy A =
1722 loadMatrix(LPtr, {}, false, LShape, TI.RowLoop.Index, TI.KLoop.Index,
1723 {TileSize, TileSize}, EltType, Builder);
1724 MatrixTy B =
1725 loadMatrix(RPtr, {}, false, RShape, TI.KLoop.Index, TI.ColumnLoop.Index,
1726 {TileSize, TileSize}, EltType, Builder);
1727 emitMatrixMultiply(TileResult, A, B, Builder, true, false,
1728 getFastMathFlags(MatMul));
1729 // Store result after the inner loop is done.
1730 Builder.SetInsertPoint(TI.RowLoop.Latch->getTerminator());
1731 storeMatrix(TileResult, Store->getPointerOperand(), Store->getAlign(),
1732 Store->isVolatile(), {LShape.NumRows, RShape.NumColumns},
1733 TI.RowLoop.Index, TI.ColumnLoop.Index, EltType, Builder);
1734
1735 for (unsigned I = 0; I < TileResult.getNumVectors(); I++)
1736 ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.KLoop.Latch);
1737
1738 // Force unrolling of a few iterations of the inner loop, to make sure there
1739 // is enough work per iteration.
1740 // FIXME: The unroller should make this decision directly instead, but
1741 // currently the cost-model is not up to the task.
1742 unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns / TileSize);
1743 addStringMetadataToLoop(LI->getLoopFor(TI.KLoop.Header),
1744 "llvm.loop.unroll.count", InnerLoopUnrollCount);
1745 }
1746
1747 void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1,
1748 StoreInst *Store,
1749 SmallPtrSetImpl<Instruction *> &FusedInsts) {
1750 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
1751 "Tiling only supported for column-major matrixes at the moment!");
1752 if (!isFusionProfitable(MatMul))
1753 return;
1754
1755 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1756 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1757
1758 const unsigned R = LShape.NumRows;
1759 const unsigned C = RShape.NumColumns;
1760 const unsigned M = LShape.NumColumns;
1761 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1762
1763 Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul);
1764 Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul);
1765 Value *CPtr = Store->getPointerOperand();
1766
1767 if (TileUseLoops && (R % TileSize == 0 && C % TileSize == 0))
1768 createTiledLoops(MatMul, APtr, LShape, BPtr, RShape, Store);
1769 else {
1770 IRBuilder<> Builder(Store);
1771 for (unsigned J = 0; J < C; J += TileSize)
1772 for (unsigned I = 0; I < R; I += TileSize) {
1773 const unsigned TileR = std::min(R - I, unsigned(TileSize));
1774 const unsigned TileC = std::min(C - J, unsigned(TileSize));
1775 MatrixTy Res = getZeroMatrix(EltType, TileR, TileC);
1776
1777 for (unsigned K = 0; K < M; K += TileSize) {
1778 const unsigned TileM = std::min(M - K, unsigned(TileSize));
1779 MatrixTy A =
1780 loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(),
1781 LShape, Builder.getInt64(I), Builder.getInt64(K),
1782 {TileR, TileM}, EltType, Builder);
1783 MatrixTy B =
1784 loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(),
1785 RShape, Builder.getInt64(K), Builder.getInt64(J),
1786 {TileM, TileC}, EltType, Builder);
1787 emitMatrixMultiply(Res, A, B, Builder, true, false,
1788 getFastMathFlags(MatMul));
1789 }
1790 storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M},
1791 Builder.getInt64(I), Builder.getInt64(J), EltType,
1792 Builder);
1793 }
1794 }
1795
1796 // Mark eliminated instructions as fused and remove them.
1797 FusedInsts.insert(Store);
1798 FusedInsts.insert(MatMul);
1799 Store->eraseFromParent();
1800 MatMul->eraseFromParent();
1801 if (LoadOp0->hasNUses(0)) {
1802 FusedInsts.insert(LoadOp0);
1803 LoadOp0->eraseFromParent();
1804 }
1805 if (LoadOp1 != LoadOp0 && LoadOp1->hasNUses(0)) {
1806 FusedInsts.insert(LoadOp1);
1807 LoadOp1->eraseFromParent();
1808 }
1809 }
1810
1811 /// Try to lower matrix multiply chains by fusing operations.
1812 ///
1813 /// Call finalizeLowering on lowered instructions. Instructions that are
1814 /// completely eliminated by fusion are added to \p FusedInsts.
1815 void LowerMatrixMultiplyFused(CallInst *MatMul,
1816 SmallPtrSetImpl<Instruction *> &FusedInsts) {
1817 if (!FuseMatrix || !DT)
1818 return;
1819
1820 assert(AA && LI && "Analyses should be available");
1821
1822 Value *A = MatMul->getArgOperand(0);
1823 Value *B = MatMul->getArgOperand(1);
1824
1825 // We can fold the transpose into the operand that is used to fetch scalars.
1826 Value *T;
1827 if (MatrixLayout == MatrixLayoutTy::ColumnMajor
1828 ? match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))
1829 : match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))) {
1830 IRBuilder<> Builder(MatMul);
1831 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1832 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1833 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1834 const unsigned R = LShape.NumRows;
1835 const unsigned M = LShape.NumColumns;
1836 const unsigned C = RShape.NumColumns;
1837
1838 MatrixTy MA;
1839 MatrixTy MB;
1840
1841 Value *Transpose;
1842 if (MatrixLayout == MatrixLayoutTy::ColumnMajor) {
1843 MA = getMatrix(A, ShapeInfo(R, M), Builder);
1844 MB = getMatrix(T, ShapeInfo(C, M), Builder);
1845 Transpose = B;
1846 } else {
1847 MA = getMatrix(T, ShapeInfo(R, M), Builder);
1848 MB = getMatrix(B, ShapeInfo(C, M), Builder);
1849 Transpose = A;
1850 }
1851
1852 // Initialize the output
1853 MatrixTy Result(R, C, EltType);
1854
1855 emitMatrixMultiply(Result, MA, MB, Builder, false, true,
1856 getFastMathFlags(MatMul));
1857
1858 FusedInsts.insert(MatMul);
1859 if (Transpose->hasOneUse()) {
1860 FusedInsts.insert(cast<Instruction>(Transpose));
1861 ToRemove.push_back(cast<Instruction>(Transpose));
1862 // TODO: add a fake entry for the folded instruction so that this is
1863 // included in the expression in the remark.
1864 Inst2ColumnMatrix[Transpose] = MatrixTy(M, C, EltType);
1865 }
1866 finalizeLowering(MatMul, Result, Builder);
1867 return;
1868 }
1869
1870 if (!MatMul->hasOneUse() || MatrixLayout != MatrixLayoutTy::ColumnMajor)
1871 return;
1872
1873 // Lower {ld, ld} -> matmul -> st chains. No need to call finalizeLowering
1874 // since the single store user will be lowered as part of this.
1875 auto *LoadOp0 = dyn_cast<LoadInst>(A);
1876 auto *LoadOp1 = dyn_cast<LoadInst>(B);
1877 auto *Store = dyn_cast<StoreInst>(*MatMul->user_begin());
1878 if (LoadOp0 && LoadOp1 && Store) {
1879 // The store address must dominate the MatMul instruction, otherwise
1880 // we create invalid IR.
1881 SetVector<Value *> WorkList;
1882 WorkList.insert(Store->getOperand(1));
1884 for (unsigned I = 0; I != WorkList.size(); ++I) {
1885 Value *Current = WorkList[I];
1886 auto *CurrI = dyn_cast<Instruction>(Current);
1887 if (!CurrI)
1888 continue;
1889 if (isa<PHINode>(CurrI))
1890 return;
1891 if (DT->dominates(CurrI, MatMul))
1892 continue;
1893 if (CurrI->mayHaveSideEffects() || CurrI->mayReadFromMemory())
1894 return;
1895 ToHoist.push_back(CurrI);
1896 WorkList.insert(CurrI->op_begin(), CurrI->op_end());
1897 }
1898
1899 sort(ToHoist, [this](Instruction *A, Instruction *B) {
1900 return DT->dominates(A, B);
1901 });
1902 for (Instruction *I : ToHoist)
1903 I->moveBefore(MatMul);
1904
1905 emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts);
1906 return;
1907 }
1908 }
1909
1910 /// Lowers llvm.matrix.multiply.
1911 void LowerMultiply(CallInst *MatMul) {
1912 IRBuilder<> Builder(MatMul);
1913 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1914 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1915 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1916
1917 const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder);
1918 const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder);
1919 assert(Lhs.getElementType() == Rhs.getElementType() &&
1920 "Matrix multiply argument element types do not match.");
1921
1922 const unsigned R = LShape.NumRows;
1923 const unsigned C = RShape.NumColumns;
1924 assert(LShape.NumColumns == RShape.NumRows);
1925
1926 // Initialize the output
1927 MatrixTy Result(R, C, EltType);
1928 assert(Lhs.getElementType() == Result.getElementType() &&
1929 "Matrix multiply result element type does not match arguments.");
1930
1931 emitMatrixMultiply(Result, Lhs, Rhs, Builder, false, false,
1932 getFastMathFlags(MatMul));
1933 finalizeLowering(MatMul, Result, Builder);
1934 }
1935
1936 /// Lowers llvm.matrix.transpose.
1937 void LowerTranspose(CallInst *Inst) {
1938 MatrixTy Result;
1939 IRBuilder<> Builder(Inst);
1940 Value *InputVal = Inst->getArgOperand(0);
1941 VectorType *VectorTy = cast<VectorType>(InputVal->getType());
1942 ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2));
1943 MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);
1944
1945 const unsigned NewNumVecs =
1946 InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns;
1947 const unsigned NewNumElts =
1948 InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows;
1949
1950 for (unsigned I = 0; I < NewNumVecs; ++I) {
1951 // Build a single result vector. First initialize it.
1952 Value *ResultVector = PoisonValue::get(
1953 FixedVectorType::get(VectorTy->getElementType(), NewNumElts));
1954 // Go through the old elements and insert it into the resulting vector.
1955 for (auto J : enumerate(InputMatrix.vectors())) {
1956 Value *Elt = Builder.CreateExtractElement(J.value(), I);
1957 // Row and column indices are transposed.
1958 ResultVector =
1959 Builder.CreateInsertElement(ResultVector, Elt, J.index());
1960 }
1961 Result.addVector(ResultVector);
1962 }
1963
1964 // TODO: Improve estimate of operations needed for transposes. Currently we
1965 // just count the insertelement/extractelement instructions, but do not
1966 // account for later simplifications/combines.
1967 finalizeLowering(
1968 Inst,
1969 Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns)
1970 .addNumExposedTransposes(1),
1971 Builder);
1972 }
1973
1974 /// Lower load instructions, if shape information is available.
1975 bool VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) {
1976 auto I = ShapeMap.find(Inst);
1977 if (I == ShapeMap.end())
1978 return false;
1979
1980 LowerLoad(Inst, Ptr, Inst->getAlign(),
1981 Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
1982 I->second);
1983 return true;
1984 }
1985
1986 bool VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr,
1987 IRBuilder<> &Builder) {
1988 auto I = ShapeMap.find(StoredVal);
1989 if (I == ShapeMap.end())
1990 return false;
1991
1992 LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
1993 Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
1994 I->second);
1995 return true;
1996 }
1997
1998 /// Lower binary operators, if shape information is available.
1999 bool VisitBinaryOperator(BinaryOperator *Inst) {
2000 auto I = ShapeMap.find(Inst);
2001 if (I == ShapeMap.end())
2002 return false;
2003
2004 Value *Lhs = Inst->getOperand(0);
2005 Value *Rhs = Inst->getOperand(1);
2006
2007 IRBuilder<> Builder(Inst);
2008 ShapeInfo &Shape = I->second;
2009
2010 MatrixTy Result;
2011 MatrixTy A = getMatrix(Lhs, Shape, Builder);
2012 MatrixTy B = getMatrix(Rhs, Shape, Builder);
2013 assert(A.isColumnMajor() == B.isColumnMajor() &&
2014 Result.isColumnMajor() == A.isColumnMajor() &&
2015 "operands must agree on matrix layout");
2016
2017 Builder.setFastMathFlags(getFastMathFlags(Inst));
2018
2019 // Helper to perform binary op on vectors.
2020 auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) {
2021 switch (Inst->getOpcode()) {
2022 case Instruction::Add:
2023 return Builder.CreateAdd(LHS, RHS);
2024 case Instruction::Mul:
2025 return Builder.CreateMul(LHS, RHS);
2026 case Instruction::Sub:
2027 return Builder.CreateSub(LHS, RHS);
2028 case Instruction::FAdd:
2029 return Builder.CreateFAdd(LHS, RHS);
2030 case Instruction::FMul:
2031 return Builder.CreateFMul(LHS, RHS);
2032 case Instruction::FSub:
2033 return Builder.CreateFSub(LHS, RHS);
2034 default:
2035 llvm_unreachable("Unsupported binary operator for matrix");
2036 }
2037 };
2038
2039 for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
2040 Result.addVector(BuildVectorOp(A.getVector(I), B.getVector(I)));
2041
2042 finalizeLowering(Inst,
2043 Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
2044 Result.getNumVectors()),
2045 Builder);
2046 return true;
2047 }
2048
2049 /// Lower unary operators, if shape information is available.
2050 bool VisitUnaryOperator(UnaryOperator *Inst) {
2051 auto I = ShapeMap.find(Inst);
2052 if (I == ShapeMap.end())
2053 return false;
2054
2055 Value *Op = Inst->getOperand(0);
2056
2057 IRBuilder<> Builder(Inst);
2058 ShapeInfo &Shape = I->second;
2059
2060 MatrixTy Result;
2061 MatrixTy M = getMatrix(Op, Shape, Builder);
2062
2063 Builder.setFastMathFlags(getFastMathFlags(Inst));
2064
2065 // Helper to perform unary op on vectors.
2066 auto BuildVectorOp = [&Builder, Inst](Value *Op) {
2067 switch (Inst->getOpcode()) {
2068 case Instruction::FNeg:
2069 return Builder.CreateFNeg(Op);
2070 default:
2071 llvm_unreachable("Unsupported unary operator for matrix");
2072 }
2073 };
2074
2075 for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
2076 Result.addVector(BuildVectorOp(M.getVector(I)));
2077
2078 finalizeLowering(Inst,
2079 Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
2080 Result.getNumVectors()),
2081 Builder);
2082 return true;
2083 }
2084
2085 /// Helper to linearize a matrix expression tree into a string. Currently
2086 /// matrix expressions are linarized by starting at an expression leaf and
2087 /// linearizing bottom up.
2088 struct ExprLinearizer {
2089 unsigned LengthToBreak = 100;
2090 std::string Str;
2091 raw_string_ostream Stream;
2092 unsigned LineLength = 0;
2093 const DataLayout &DL;
2094
2095 /// Mapping from instructions to matrixes. It is used to identify
2096 /// matrix instructions.
2097 const MapVector<Value *, MatrixTy> &Inst2Matrix;
2098
2099 /// Mapping from values to the leaves of all expressions that the value is
2100 /// part of.
2102
2103 /// Set of matrix expressions in the scope of a given DISubprogram.
2104 const SmallSetVector<Value *, 32> &ExprsInSubprogram;
2105
2106 /// Leaf node of the expression to linearize.
2107 Value *Leaf;
2108
2109 /// Used to keep track of sub-expressions that get reused while linearizing
2110 /// the expression. Re-used sub-expressions are marked as (reused).
2111 SmallPtrSet<Value *, 8> ReusedExprs;
2112
2113 ExprLinearizer(const DataLayout &DL,
2114 const MapVector<Value *, MatrixTy> &Inst2Matrix,
2115 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
2116 const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2117 Value *Leaf)
2118 : Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared),
2119 ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {}
2120
2121 void indent(unsigned N) {
2122 LineLength += N;
2123 for (unsigned i = 0; i < N; i++)
2124 Stream << " ";
2125 }
2126
2127 void lineBreak() {
2128 Stream << "\n";
2129 LineLength = 0;
2130 }
2131
2132 void maybeIndent(unsigned Indent) {
2133 if (LineLength >= LengthToBreak)
2134 lineBreak();
2135
2136 if (LineLength == 0)
2137 indent(Indent);
2138 }
2139
2140 void write(StringRef S) {
2141 LineLength += S.size();
2142 Stream << S;
2143 }
2144
2145 Value *getUnderlyingObjectThroughLoads(Value *V) {
2146 if (Value *Ptr = getPointerOperand(V))
2147 return getUnderlyingObjectThroughLoads(Ptr);
2148 else if (V->getType()->isPointerTy())
2149 return getUnderlyingObject(V);
2150 return V;
2151 }
2152
2153 /// Returns true if \p V is a matrix value in the given subprogram.
2154 bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); }
2155
2156 /// If \p V is a matrix value, print its shape as NumRows x NumColumns to
2157 /// \p SS.
2158 void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) {
2159 auto M = Inst2Matrix.find(V);
2160 if (M == Inst2Matrix.end())
2161 SS << "unknown";
2162 else {
2163 SS << M->second.getNumRows();
2164 SS << "x";
2165 SS << M->second.getNumColumns();
2166 }
2167 }
2168
2169 /// Write the called function name. Handles calls to llvm.matrix.*
2170 /// specially: we write the name, followed by the dimensions of the input
2171 /// matrixes, followed by the scalar type name.
2172 void writeFnName(CallInst *CI) {
2173 if (!CI->getCalledFunction())
2174 write("<no called fn>");
2175 else {
2177 if (!Name.startswith("llvm.matrix")) {
2178 write(Name);
2179 return;
2180 }
2181 auto *II = cast<IntrinsicInst>(CI);
2183 .drop_front(StringRef("llvm.matrix.").size()));
2184 write(".");
2185 std::string Tmp;
2187
2188 switch (II->getIntrinsicID()) {
2189 case Intrinsic::matrix_multiply:
2190 prettyPrintMatrixType(II->getOperand(0), SS);
2191 SS << ".";
2192 prettyPrintMatrixType(II->getOperand(1), SS);
2193 SS << "." << *II->getType()->getScalarType();
2194 break;
2195 case Intrinsic::matrix_transpose:
2196 prettyPrintMatrixType(II->getOperand(0), SS);
2197 SS << "." << *II->getType()->getScalarType();
2198 break;
2199 case Intrinsic::matrix_column_major_load:
2200 prettyPrintMatrixType(II, SS);
2201 SS << "." << *II->getType()->getScalarType();
2202 break;
2203 case Intrinsic::matrix_column_major_store:
2204 prettyPrintMatrixType(II->getOperand(0), SS);
2205 SS << "." << *II->getOperand(0)->getType()->getScalarType();
2206 break;
2207 default:
2208 llvm_unreachable("Unhandled case");
2209 }
2210 SS.flush();
2211 write(Tmp);
2212 }
2213 }
2214
2215 unsigned getNumShapeArgs(CallInst *CI) const {
2216 if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) {
2217 switch (II->getIntrinsicID()) {
2218 case Intrinsic::matrix_multiply:
2219 return 3;
2220 case Intrinsic::matrix_transpose:
2221 return 2;
2222 case Intrinsic::matrix_column_major_load:
2223 case Intrinsic::matrix_column_major_store:
2224 return 3;
2225 default:
2226 return 0;
2227 }
2228 }
2229 return 0;
2230 }
2231
2232 /// Special printing for values: for pointers, we print if they refer to an
2233 /// (function) external address or a stack address, for other values we
2234 /// either print the constant or "scalar"/"matrix" for other values.
2235 void write(Value *V) {
2236 V = getUnderlyingObjectThroughLoads(V);
2237 if (V->getType()->isPointerTy()) {
2238 if (isa<AllocaInst>(V)) {
2239 Stream << "stack addr";
2240 LineLength += StringRef("stack addr").size();
2241 } else {
2242 Stream << "addr";
2243 LineLength += StringRef("addr").size();
2244 }
2245 if (!V->getName().empty()) {
2246 Stream << " %" << V->getName() << "";
2247 LineLength += V->getName().size() + 2;
2248 }
2249 return;
2250 }
2251
2252 std::string Tmp;
2253 raw_string_ostream TmpStream(Tmp);
2254
2255 if (auto *CI = dyn_cast<ConstantInt>(V))
2256 TmpStream << CI->getValue();
2257 else if (isa<Constant>(V))
2258 TmpStream << "constant";
2259 else {
2260 if (isMatrix(V))
2261 TmpStream << "matrix";
2262 else
2263 TmpStream << "scalar";
2264 }
2265 TmpStream.flush();
2266 Tmp = std::string(StringRef(Tmp).trim());
2267 LineLength += Tmp.size();
2268 Stream << Tmp;
2269 }
2270
2271 /// Linearize expression \p Expr starting at an indentation of \p Indent.
2272 /// Expressions that are re-used multiple times are prefixed with (reused)
2273 /// at the re-used root instruction.
2274 void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused,
2275 bool ParentShared) {
2276 auto *I = cast<Instruction>(Expr);
2277 maybeIndent(Indent);
2279
2280 // Is Expr shared with other expression leaves?
2281 bool ExprShared = false;
2282
2283 // Deal with shared subtrees. Mark them as shared, if required.
2284 if (!ParentShared) {
2285 auto SI = Shared.find(Expr);
2286 assert(SI != Shared.end() && SI->second.count(Leaf));
2287
2288 for (Value *S : SI->second) {
2289 if (S == Leaf)
2290 continue;
2291 DebugLoc DL = cast<Instruction>(S)->getDebugLoc();
2292 write("shared with remark at line " + std::to_string(DL.getLine()) +
2293 " column " + std::to_string(DL.getCol()) + " (");
2294 }
2295 ExprShared = SI->second.size() > 1;
2296 }
2297
2298 bool Reused = !ReusedExprs.insert(Expr).second;
2299 if (Reused && !ParentReused)
2300 write("(reused) ");
2301
2302 if (auto *CI = dyn_cast<CallInst>(I)) {
2303 writeFnName(CI);
2304
2305 Ops.append(CI->arg_begin(), CI->arg_end() - getNumShapeArgs(CI));
2306 } else if (isa<BitCastInst>(Expr)) {
2307 // Special case bitcasts, which are used to materialize matrixes from
2308 // non-matrix ops.
2309 write("matrix");
2310 return;
2311 } else {
2312 Ops.append(I->value_op_begin(), I->value_op_end());
2313 write(std::string(I->getOpcodeName()));
2314 }
2315
2316 write(std::string("("));
2317
2318 unsigned NumOpsToBreak = 1;
2319 if (match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>()))
2320 NumOpsToBreak = 2;
2321
2322 for (Value *Op : Ops) {
2323 if (Ops.size() > NumOpsToBreak)
2324 lineBreak();
2325
2326 maybeIndent(Indent + 1);
2327 if (isMatrix(Op))
2328 linearizeExpr(Op, Indent + 1, Reused, ExprShared);
2329 else
2330 write(Op);
2331 if (Op != Ops.back())
2332 write(", ");
2333 }
2334
2335 write(")");
2336 }
2337
2338 const std::string &getResult() {
2339 Stream.flush();
2340 return Str;
2341 }
2342 };
2343
2344 /// Generate remarks for matrix operations in a function. To generate remarks
2345 /// for matrix expressions, the following approach is used:
2346 /// 1. Use the inlined-at debug information to group matrix operations to the
2347 /// DISubprograms they are contained in.
2348 /// 2. Collect leaves of matrix expressions (done in
2349 /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression
2350 // mapping. Leaves are lowered matrix instructions without other matrix
2351 // users (like stores) in the current subprogram.
2352 /// 3. For each leaf, create a remark containing a linearizied version of the
2353 /// matrix expression. The expression is linearized by a recursive
2354 /// bottom-up traversal of the matrix operands, starting at a leaf. Note
2355 /// that multiple leaves can share sub-expressions. Shared subexpressions
2356 /// are explicitly marked as shared().
2357 struct RemarkGenerator {
2358 const MapVector<Value *, MatrixTy> &Inst2Matrix;
2360 Function &Func;
2361 const DataLayout &DL;
2362
2363 RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix,
2365 : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func),
2366 DL(Func.getParent()->getDataLayout()) {}
2367
2368 /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are
2369 /// instructions in Inst2Matrix returning void or without any users in
2370 /// \p ExprsInSubprogram. Currently that should only include stores.
2372 getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) {
2374 for (auto *Expr : ExprsInSubprogram)
2375 if (Expr->getType()->isVoidTy() ||
2376 !any_of(Expr->users(), [&ExprsInSubprogram](User *U) {
2377 return ExprsInSubprogram.count(U);
2378 }))
2379 Leaves.push_back(Expr);
2380 return Leaves;
2381 }
2382
2383 /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf
2384 /// to all visited expressions in \p Shared. Limit the matrix operations to
2385 /// the ones in \p ExprsInSubprogram.
2386 void collectSharedInfo(Value *Leaf, Value *V,
2387 const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2389
2390 if (!ExprsInSubprogram.count(V))
2391 return;
2392
2393 auto I = Shared.insert({V, {}});
2394 I.first->second.insert(Leaf);
2395
2396 for (Value *Op : cast<Instruction>(V)->operand_values())
2397 collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared);
2398 }
2399
2400 /// Calculate the number of exclusive and shared op counts for expression
2401 /// starting at \p V. Expressions used multiple times are counted once.
2402 /// Limit the matrix operations to the ones in \p ExprsInSubprogram.
2403 std::pair<OpInfoTy, OpInfoTy>
2404 sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs,
2405 const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2406 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const {
2407 if (!ExprsInSubprogram.count(Root))
2408 return {};
2409
2410 // Already counted this expression. Stop.
2411 if (!ReusedExprs.insert(Root).second)
2412 return {};
2413
2414 OpInfoTy SharedCount;
2415 OpInfoTy Count;
2416
2417 auto I = Shared.find(Root);
2418 auto CM = Inst2Matrix.find(Root);
2419 if (I->second.size() == 1)
2420 Count = CM->second.getOpInfo();
2421 else
2422 SharedCount = CM->second.getOpInfo();
2423
2424 for (Value *Op : cast<Instruction>(Root)->operand_values()) {
2425 auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared);
2426 Count += C.first;
2427 SharedCount += C.second;
2428 }
2429 return {Count, SharedCount};
2430 }
2431
2432 void emitRemarks() {
2434 return;
2435
2436 // Map matrix operations to their containting subprograms, by traversing
2437 // the inlinedAt chain. If the function does not have a DISubprogram, we
2438 // only map them to the containing function.
2440 for (const auto &KV : Inst2Matrix) {
2441 if (Func.getSubprogram()) {
2442 auto *I = cast<Instruction>(KV.first);
2443 DILocation *Context = I->getDebugLoc();
2444 while (Context) {
2445 auto I =
2446 Subprog2Exprs.insert({getSubprogram(Context->getScope()), {}});
2447 I.first->second.push_back(KV.first);
2448 Context = DebugLoc(Context).getInlinedAt();
2449 }
2450 } else {
2451 auto I = Subprog2Exprs.insert({nullptr, {}});
2452 I.first->second.push_back(KV.first);
2453 }
2454 }
2455 for (auto &KV : Subprog2Exprs) {
2456 SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(),
2457 KV.second.end());
2458 auto Leaves = getExpressionLeaves(ExprsInSubprogram);
2459
2461 for (Value *Leaf : Leaves)
2462 collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared);
2463
2464 // Generate remarks for each leaf.
2465 for (auto *L : Leaves) {
2466
2467 DebugLoc Loc = cast<Instruction>(L)->getDebugLoc();
2468 DILocation *Context = cast<Instruction>(L)->getDebugLoc();
2469 while (Context) {
2470 if (getSubprogram(Context->getScope()) == KV.first) {
2471 Loc = Context;
2472 break;
2473 }
2474 Context = DebugLoc(Context).getInlinedAt();
2475 }
2476
2477 SmallPtrSet<Value *, 8> ReusedExprs;
2478 OpInfoTy Counts, SharedCounts;
2479 std::tie(Counts, SharedCounts) =
2480 sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared);
2481
2482 OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc,
2483 cast<Instruction>(L)->getParent());
2484
2485 Rem << "Lowered with ";
2486 Rem << ore::NV("NumStores", Counts.NumStores) << " stores, "
2487 << ore::NV("NumLoads", Counts.NumLoads) << " loads, "
2488 << ore::NV("NumComputeOps", Counts.NumComputeOps)
2489 << " compute ops, "
2490 << ore::NV("NumExposedTransposes", Counts.NumExposedTransposes)
2491 << " exposed transposes";
2492
2493 if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 ||
2494 SharedCounts.NumComputeOps > 0) {
2495 Rem << ",\nadditionally "
2496 << ore::NV("NumStores", SharedCounts.NumStores) << " stores, "
2497 << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, "
2498 << ore::NV("NumFPOps", SharedCounts.NumComputeOps)
2499 << " compute ops"
2500 << " are shared with other expressions";
2501 }
2502
2503 Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL));
2504 ORE.emit(Rem);
2505 }
2506 }
2507 }
2508
2509 std::string
2510 linearize(Value *L,
2511 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
2512 const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2513 const DataLayout &DL) {
2514 ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L);
2515 Lin.linearizeExpr(L, 0, false, false);
2516 return Lin.getResult();
2517 }
2518 };
2519};
2520} // namespace
2521
2524 auto &TTI = AM.getResult<TargetIRAnalysis>(F);
2525 OptimizationRemarkEmitter *ORE = nullptr;
2526 AAResults *AA = nullptr;
2527 DominatorTree *DT = nullptr;
2528 LoopInfo *LI = nullptr;
2529
2530 if (!Minimal) {
2532 AA = &AM.getResult<AAManager>(F);
2534 LI = &AM.getResult<LoopAnalysis>(F);
2535 }
2536
2537 LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE);
2538 if (LMT.Visit()) {
2540 if (!Minimal) {
2541 PA.preserve<LoopAnalysis>();
2543 }
2544 return PA;
2545 }
2546 return PreservedAnalyses::all();
2547}
2548
2550 raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
2552 OS, MapClassName2PassName);
2553 OS << '<';
2554 if (Minimal)
2555 OS << "minimal";
2556 OS << '>';
2557}
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Rewrite undef for PHI
ReachingDefAnalysis InstSet & ToRemove
assume Assume Builder
static const Function * getParent(const Value *V)
BitTracker BT
Definition: BitTracker.cpp:73
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
#define clEnumValN(ENUMVAL, FLAGNAME, DESC)
Definition: CommandLine.h:680
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
#define LLVM_DEBUG(X)
Definition: Debug.h:101
std::string Name
bool End
Definition: ELF_riscv.cpp:469
Hexagon Common GEP
hexagon Hexagon specific predictive commoning for HVX vectors
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
iv users
Definition: IVUsers.cpp:48
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Definition: Lint.cpp:526
Live Register Matrix
static DISubprogram * getSubprogram(DIScope *Scope)
Helper function to either return Scope, if it is a subprogram or the attached subprogram for a local ...
static cl::opt< bool > ForceFusion("force-fuse-matrix", cl::init(false), cl::Hidden, cl::desc("Force matrix instruction fusion even if not profitable."))
static void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II, BasicBlock &BB)
Erase V from BB and move \II forward to avoid invalidating iterators.
static cl::opt< bool > VerifyShapeInfo("verify-matrix-shapes", cl::Hidden, cl::desc("Enable/disable matrix shape verification."), cl::init(false))
static bool isSplat(Value *V)
Return true if V is a splat of a value (which is used when multiplying a matrix with a scalar).
static cl::opt< bool > TileUseLoops("fuse-matrix-use-loops", cl::init(false), cl::Hidden, cl::desc("Generate loop nest for tiling."))
static cl::opt< bool > FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden, cl::desc("Enable/disable fusing matrix instructions."))
auto m_AnyAdd(const LTy &L, const RTy &R)
Match any add operation (fp or integer).
static cl::opt< bool > AllowContractEnabled("matrix-allow-contract", cl::init(false), cl::Hidden, cl::desc("Allow the use of FMAs if available and profitable. This may " "result in different results, due to less rounding error."))
auto m_AnyMul(const LTy &L, const RTy &R)
Match any mul operation (fp or integer).
static cl::opt< bool > PrintAfterTransposeOpt("matrix-print-after-transpose-opt", cl::init(false))
#define DEBUG_TYPE
static cl::opt< unsigned > TileSize("fuse-matrix-tile-size", cl::init(4), cl::Hidden, cl::desc("Tile size for matrix instruction fusion using square-shaped tiles."))
static cl::opt< MatrixLayoutTy > MatrixLayout("matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor), cl::desc("Sets the default matrix layout"), cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major", "Use column-major layout"), clEnumValN(MatrixLayoutTy::RowMajor, "row-major", "Use row-major layout")))
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
#define T1
LLVMContext & Context
PowerPC Reduce CR logical Operation
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
static Value * extractVector(IRBuilderTy &IRB, Value *V, unsigned BeginIndex, unsigned EndIndex, const Twine &Name)
Definition: SROA.cpp:2369
static Value * insertVector(IRBuilderTy &IRB, Value *Old, Value *V, unsigned BeginIndex, const Twine &Name)
Definition: SROA.cpp:2391
raw_pwrite_stream & OS
This file defines the SmallVector class.
static SymbolRef::Type getType(const Symbol *Sym)
Definition: TapiFile.cpp:40
static const int BlockSize
Definition: TarWriter.cpp:33
This pass exposes codegen information to IR-level passes.
static std::optional< unsigned > getOpcode(ArrayRef< VPValue * > Values)
Returns the opcode of Values or ~0 if they do not all agree.
Definition: VPlanSLP.cpp:191
static SDValue LowerStore(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)
static SDValue LowerLoad(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)
Value * RHS
Value * LHS
BinaryOperator * Mul
A manager for alias analyses.
bool isNoAlias(const MemoryLocation &LocA, const MemoryLocation &LocB)
A trivial helper function to check to see if the specified pointers are no-alias.
an instruction to allocate memory on the stack
Definition: Instructions.h:58
Align getAlign() const
Return the alignment of the memory that is being allocated by the instruction.
Definition: Instructions.h:125
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:620
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition: PassManager.h:774
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: ArrayRef.h:41
LLVM Basic Block Representation.
Definition: BasicBlock.h:56
iterator begin()
Instruction iterator methods.
Definition: BasicBlock.h:335
reverse_iterator rbegin()
Definition: BasicBlock.h:340
InstListType::reverse_iterator reverse_iterator
Definition: BasicBlock.h:89
reverse_iterator rend()
Definition: BasicBlock.h:342
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.h:127
BinaryOps getOpcode() const
Definition: InstrTypes.h:391
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
Definition: InstrTypes.h:1412
User::op_iterator arg_begin()
Return the iterator pointing to the beginning of the argument list.
Definition: InstrTypes.h:1332
MaybeAlign getParamAlign(unsigned ArgNo) const
Extract the alignment for a call or parameter (0=unknown).
Definition: InstrTypes.h:1762
Value * getArgOperand(unsigned i) const
Definition: InstrTypes.h:1357
User::op_iterator arg_end()
Return the iterator pointing to the end of the argument list.
Definition: InstrTypes.h:1338
This class represents a function call, abstracting a target machine's calling convention.
static ConstantAggregateZero * get(Type *Ty)
Definition: Constants.cpp:1579
static Constant * get(Type *Ty, double V)
This returns a ConstantFP, or a vector containing a splat of a ConstantFP, for the specified value in...
Definition: Constants.cpp:927
This is the shared class of boolean and integer constants.
Definition: Constants.h:78
static Constant * get(Type *Ty, uint64_t V, bool IsSigned=false)
If Ty is a vector type, return a Constant with a splat of the given value.
Definition: Constants.cpp:888
DISubprogram * getSubprogram() const
Get the subprogram for this scope.
Debug location.
Base class for scope-like contexts.
Subprogram description.
This class represents an Operation in the Expression.
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:110
A debug info location.
Definition: DebugLoc.h:33
DILocation * getInlinedAt() const
Definition: DebugLoc.cpp:39
Analysis pass which computes a DominatorTree.
Definition: Dominators.h:279
void applyUpdates(ArrayRef< UpdateType > Updates)
Inform the dominator tree about a sequence of CFG edge insertions and deletions and perform a batch u...
static constexpr UpdateKind Delete
static constexpr UpdateKind Insert
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:166
bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
Definition: Dominators.cpp:122
Convenience struct for specifying and reasoning about fast-math flags.
Definition: FMF.h:20
void setAllowContract(bool B=true)
Definition: FMF.h:91
bool allowReassoc() const
Flag queries.
Definition: FMF.h:65
bool allowContract() const
Definition: FMF.h:70
Class to represent fixed width SIMD vectors.
Definition: DerivedTypes.h:536
static FixedVectorType * get(Type *ElementType, unsigned NumElts)
Definition: Type.cpp:693
Intrinsic::ID getIntrinsicID() const LLVM_READONLY
getIntrinsicID - This method returns the ID number of the specified function, or Intrinsic::not_intri...
Definition: Function.h:206
bool isIntrinsic() const
isIntrinsic - Returns true if the function's name starts with "llvm.".
Definition: Function.h:211
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2625
const BasicBlock * getParent() const
Definition: Instruction.h:90
FastMathFlags getFastMathFlags() const LLVM_READONLY
Convenience function for getting all the fast-math flags, which must be an operator which supports th...
SymbolTableList< Instruction >::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
Definition: Instruction.cpp:83
A wrapper class for inspecting calls to intrinsic functions.
Definition: IntrinsicInst.h:47
Intrinsic::ID getIntrinsicID() const
Return the intrinsic ID of this intrinsic.
Definition: IntrinsicInst.h:54
An instruction for reading from memory.
Definition: Instructions.h:177
bool isVolatile() const
Return true if this is a load from a volatile memory location.
Definition: Instructions.h:214
Align getAlign() const
Return the alignment of the access that is being performed.
Definition: Instructions.h:220
uint64_t getValue() const
Analysis pass that exposes the LoopInfo for a function.
Definition: LoopInfo.h:569
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void printPipeline(raw_ostream &OS, function_ref< StringRef(StringRef)> MapClassName2PassName)
This class implements a map that also provides access to all stored values in a deterministic order.
Definition: MapVector.h:36
iterator end()
Definition: MapVector.h:71
iterator find(const KeyT &Key)
Definition: MapVector.h:146
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: MapVector.h:117
Representation for a specific memory location.
static MemoryLocation get(const LoadInst *LI)
Return a location with information about the memory reference by the given instruction.
LocationSize Size
The maximum size of the location, in address-units, or UnknownSize if the size is not known.
const Value * Ptr
The address of the start of the location.
The optimization diagnostic interface.
bool allowExtraAnalysis(StringRef PassName) const
Whether we allow for extra compile-time budget to perform more analysis to produce fewer false positi...
void emit(DiagnosticInfoOptimizationBase &OptDiag)
Output the remark via the diagnostic handler and to the optimization record file.
Diagnostic information for applied optimization remarks.
static PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
Definition: Constants.cpp:1743
A set of analyses that are preserved following a run of a transformation pass.
Definition: PassManager.h:152
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: PassManager.h:158
void preserve()
Mark an analysis as preserved.
Definition: PassManager.h:173
A vector that has set insertion semantics.
Definition: SetVector.h:57
size_type size() const
Determine the number of elements in the SetVector.
Definition: SetVector.h:98
size_type count(const key_type &key) const
Count the number of elements of a given key in the SetVector.
Definition: SetVector.h:264
bool insert(const value_type &X)
Insert a new element into the SetVector.
Definition: SetVector.h:162
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
Definition: SmallPtrSet.h:345
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
Definition: SmallPtrSet.h:384
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:366
bool contains(ConstPtrType Ptr) const
Definition: SmallPtrSet.h:390
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
Definition: SmallPtrSet.h:451
A SetVector that performs no allocations if smaller than a certain size.
Definition: SetVector.h:370
SmallSet - This maintains a set of unique values, optimizing for the case when the set is small (less...
Definition: SmallSet.h:135
bool empty() const
Definition: SmallSet.h:159
bool erase(const T &V)
Definition: SmallSet.h:207
std::pair< const_iterator, bool > insert(const T &V)
insert - Insert an element into the set if it isn't already there.
Definition: SmallSet.h:179
bool empty() const
Definition: SmallVector.h:94
size_t size() const
Definition: SmallVector.h:91
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: SmallVector.h:577
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
Definition: SmallVector.h:687
void push_back(const T &Elt)
Definition: SmallVector.h:416
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1200
An instruction for storing to memory.
Definition: Instructions.h:301
Align getAlign() const
Definition: Instructions.h:345
bool isVolatile() const
Return true if this is a store to a volatile memory location.
Definition: Instructions.h:337
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:50
StringRef drop_front(size_t N=1) const
Return a StringRef equal to 'this' but with the first N elements dropped.
Definition: StringRef.h:613
constexpr size_t size() const
size - Get the string size.
Definition: StringRef.h:137
Analysis pass providing the TargetTransformInfo.
This pass provides access to the codegen interfaces that are needed for IR-level transformations.
TypeSize getRegisterBitWidth(RegisterKind K) const
InstructionCost getMemoryOpCost(unsigned Opcode, Type *Src, Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput, OperandValueInfo OpdInfo={OK_AnyValue, OP_None}, const Instruction *I=nullptr) const
InstructionCost getArithmeticReductionCost(unsigned Opcode, VectorType *Ty, std::optional< FastMathFlags > FMF, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput) const
Calculate the cost of vector reduction intrinsics.
unsigned getRegisterClassForType(bool Vector, Type *Ty=nullptr) const
@ TCK_RecipThroughput
Reciprocal throughput.
InstructionCost getArithmeticInstrCost(unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput, TTI::OperandValueInfo Opd1Info={TTI::OK_AnyValue, TTI::OP_None}, TTI::OperandValueInfo Opd2Info={TTI::OK_AnyValue, TTI::OP_None}, ArrayRef< const Value * > Args=ArrayRef< const Value * >(), const Instruction *CxtI=nullptr) const
This is an approximation of reciprocal throughput of a math/logic op.
InstructionCost getShuffleCost(ShuffleKind Kind, VectorType *Tp, ArrayRef< int > Mask=std::nullopt, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput, int Index=0, VectorType *SubTp=nullptr, ArrayRef< const Value * > Args=std::nullopt) const
unsigned getNumberOfRegisters(unsigned ClassID) const
@ SK_Splice
Concatenates elements from the first input vector with elements of the second input vector.
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
Definition: Twine.h:81
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
bool isIntegerTy() const
True if this is an instance of IntegerType.
Definition: Type.h:228
TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
bool isVoidTy() const
Return true if this is 'void'.
Definition: Type.h:140
Type * getScalarType() const
If this is a vector type, return the element type, otherwise return 'this'.
Definition: Type.h:348
UnaryOps getOpcode() const
Definition: InstrTypes.h:170
A Use represents the edge between a Value definition and its users.
Definition: Use.h:43
op_range operands()
Definition: User.h:242
Value * getOperand(unsigned i) const
Definition: User.h:169
See the file comment.
Definition: ValueMap.h:84
size_type count(const KeyT &Val) const
Return 1 if the specified key is in the map, 0 otherwise.
Definition: ValueMap.h:151
iterator find(const KeyT &Val)
Definition: ValueMap.h:155
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: ValueMap.h:172
iterator end()
Definition: ValueMap.h:135
bool erase(const KeyT &Val)
Definition: ValueMap.h:190
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
user_iterator user_begin()
Definition: Value.h:397
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition: Value.h:434
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:535
iterator_range< user_iterator > users()
Definition: Value.h:421
bool hasNUses(unsigned N) const
Return true if this Value has exactly N uses.
Definition: Value.cpp:149
iterator_range< use_iterator > uses()
Definition: Value.h:376
StringRef getName() const
Return a constant reference to the value's name.
Definition: Value.cpp:309
Type * getElementType() const
Definition: DerivedTypes.h:433
constexpr ScalarTy getFixedValue() const
Definition: TypeSize.h:182
An efficient, type-erasing, non-owning reference to a callable.
A range adaptor for a pair of iterators.
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition: raw_ostream.h:52
A raw_ostream that writes to an std::string.
Definition: raw_ostream.h:642
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
Definition: BitmaskEnum.h:119
@ C
The default llvm calling convention, compatible with C.
Definition: CallingConv.h:34
StringRef getBaseName(ID id)
Return the LLVM name for an intrinsic, without encoded types for overloading, such as "llvm....
Definition: Function.cpp:983
Function * getDeclaration(Module *M, ID id, ArrayRef< Type * > Tys=std::nullopt)
Create or insert an LLVM Function declaration for an intrinsic, and return it.
Definition: Function.cpp:1422
TwoOps_match< ValueOpTy, PointerOpTy, Instruction::Store > m_Store(const ValueOpTy &ValueOp, const PointerOpTy &PointerOp)
Matches StoreInst.
specific_intval< false > m_SpecificInt(APInt V)
Match a specific integer value or vector with all elements equal to the value.
Definition: PatternMatch.h:862
BinaryOp_match< LHS, RHS, Instruction::Add > m_Add(const LHS &L, const RHS &R)
Definition: PatternMatch.h:982
class_match< BinaryOperator > m_BinOp()
Match an arbitrary binary operation and ignore it.
Definition: PatternMatch.h:84
BinaryOp_match< LHS, RHS, Instruction::FMul > m_FMul(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
Definition: PatternMatch.h:147
BinaryOp_match< LHS, RHS, Instruction::FAdd > m_FAdd(const LHS &L, const RHS &R)
Definition: PatternMatch.h:988
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
OneUse_match< T > m_OneUse(const T &SubPattern)
Definition: PatternMatch.h:67
OneOps_match< OpTy, Instruction::Load > m_Load(const OpTy &Op)
Matches LoadInst.
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
Definition: PatternMatch.h:76
match_combine_or< LTy, RTy > m_CombineOr(const LTy &L, const RTy &R)
Combine two pattern matchers matching L || R.
Definition: PatternMatch.h:218
@ SS
Definition: X86.h:208
ValuesClass values(OptsTy... Options)
Helper to build a ValuesClass by forwarding a variable number of arguments as an initializer list to ...
Definition: CommandLine.h:705
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:445
DiagnosticInfoOptimizationBase::Argument NV
NodeAddr< PhiNode * > Phi
Definition: RDFGraph.h:390
NodeAddr< FuncNode * > Func
Definition: RDFGraph.h:393
const_iterator begin(StringRef path, Style style=Style::native)
Get begin iterator over path.
Definition: Path.cpp:228
const_iterator end(StringRef path)
Get end iterator over path.
Definition: Path.cpp:237
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
@ Offset
Definition: DWP.cpp:440
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
Definition: STLExtras.h:1685
auto enumerate(FirstRange &&First, RestRanges &&...Rest)
Given two or more input ranges, returns a new range whose values are are tuples (A,...
Definition: STLExtras.h:2338
auto successors(const MachineBasicBlock *BB)
bool operator!=(uint64_t V1, const APInt &V2)
Definition: APInt.h:2036
iterator_range< T > make_range(T x, T y)
Convenience function for iterating over sub-ranges.
std::string & operator+=(std::string &buffer, StringRef string)
Definition: StringRef.h:907
Error write(MCStreamer &Out, ArrayRef< std::string > Inputs, bool ContinueOnCuIndexOverflow)
Definition: DWP.cpp:585
const Value * getUnderlyingObject(const Value *V, unsigned MaxLookup=6)
This method strips off any GEP address adjustments and pointer casts from the specified value,...
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition: STLExtras.h:666
Value * concatenateVectors(IRBuilderBase &Builder, ArrayRef< Value * > Vecs)
Concatenate a list of vectors.
bool operator==(const AddressRangeValuePair &LHS, const AddressRangeValuePair &RHS)
const Value * getPointerOperand(const Value *V)
A helper function that returns the pointer operand of a load, store or GEP instruction.
void addStringMetadataToLoop(Loop *TheLoop, const char *MDString, unsigned V=0)
Set input string into loop metadata by keeping other values intact.
Definition: LoopUtils.cpp:214
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1734
auto reverse(ContainerTy &&C)
Definition: STLExtras.h:429
void sort(IteratorTy Start, IteratorTy End)
Definition: STLExtras.h:1652
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
void report_fatal_error(Error Err, bool gen_crash_diag=true)
Report a serious error, calling any installed error handler.
Definition: Error.cpp:156
raw_fd_ostream & errs()
This returns a reference to a raw_ostream for standard error.
@ FMulAdd
Sum of float products with llvm.fmuladd(a * b + sum).
@ Add
Sum of integers.
DWARFExpression::Operation Op
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition: Casting.h:565
BasicBlock * SplitBlock(BasicBlock *Old, BasicBlock::iterator SplitPt, DominatorTree *DT, LoopInfo *LI=nullptr, MemorySSAUpdater *MSSAU=nullptr, const Twine &BBName="", bool Before=false)
Split the specified block at the specified instruction.
Align commonAlignment(Align A, uint64_t Offset)
Returns the alignment that satisfies both alignments.
Definition: Alignment.h:212
llvm::SmallVector< int, 16 > createSequentialMask(unsigned Start, unsigned NumInts, unsigned NumUndefs)
Create a sequential shuffle mask.
#define N
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition: Alignment.h:39
This struct is a compact representation of a valid (power of two) or undefined (0) alignment.
Definition: Alignment.h:117
A CRTP mix-in to automatically provide informational APIs needed for passes.
Definition: PassManager.h:371
A helper struct to create IR loop nests for tiling in IR of the following form: for ColumnLoop....
Definition: MatrixUtils.h:31