LLVM  16.0.0git
LowerMatrixIntrinsics.cpp
Go to the documentation of this file.
1 //===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Lower matrix intrinsics to vector operations.
10 //
11 // TODO:
12 // * Improve fusion:
13 // * Support more cases, e.g. multiply-add, multiply-sub, operands/results
14 // transposed.
15 // * Improve cost-modeling, e.g. choose different number of rows/columns
16 // columns for tiles, consider cost of copies on alias.
17 //
18 //===----------------------------------------------------------------------===//
19 
22 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/Analysis/LoopInfo.h"
30 #include "llvm/IR/CFG.h"
31 #include "llvm/IR/DataLayout.h"
33 #include "llvm/IR/Function.h"
34 #include "llvm/IR/IRBuilder.h"
35 #include "llvm/IR/Instructions.h"
36 #include "llvm/IR/IntrinsicInst.h"
37 #include "llvm/IR/MatrixBuilder.h"
38 #include "llvm/IR/PatternMatch.h"
39 #include "llvm/InitializePasses.h"
40 #include "llvm/Pass.h"
41 #include "llvm/Support/Alignment.h"
43 #include "llvm/Support/Debug.h"
44 #include "llvm/Transforms/Scalar.h"
48 
49 #include <cmath>
50 
51 using namespace llvm;
52 using namespace PatternMatch;
53 
54 #define DEBUG_TYPE "lower-matrix-intrinsics"
55 
56 static cl::opt<bool>
57  FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden,
58  cl::desc("Enable/disable fusing matrix instructions."));
59 // TODO: Allow and use non-square tiles.
61  "fuse-matrix-tile-size", cl::init(4), cl::Hidden,
62  cl::desc(
63  "Tile size for matrix instruction fusion using square-shaped tiles."));
64 static cl::opt<bool> TileUseLoops("fuse-matrix-use-loops", cl::init(false),
65  cl::Hidden,
66  cl::desc("Generate loop nest for tiling."));
68  "force-fuse-matrix", cl::init(false), cl::Hidden,
69  cl::desc("Force matrix instruction fusion even if not profitable."));
71  "matrix-allow-contract", cl::init(false), cl::Hidden,
72  cl::desc("Allow the use of FMAs if available and profitable. This may "
73  "result in different results, due to less rounding error."));
74 
76 
78  "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor),
79  cl::desc("Sets the default matrix layout"),
81  "Use column-major layout"),
83  "Use row-major layout")));
84 
85 static cl::opt<bool> PrintAfterTransposeOpt("matrix-print-after-transpose-opt",
86  cl::init(false));
87 
88 /// Helper function to either return Scope, if it is a subprogram or the
89 /// attached subprogram for a local scope.
91  if (auto *Subprogram = dyn_cast<DISubprogram>(Scope))
92  return Subprogram;
93  return cast<DILocalScope>(Scope)->getSubprogram();
94 }
95 
96 /// Return true if V is a splat of a value (which is used when multiplying a
97 /// matrix with a scalar).
98 static bool isSplat(Value *V) {
99  if (auto *SV = dyn_cast<ShuffleVectorInst>(V))
100  return SV->isZeroEltSplat();
101  return false;
102 }
103 
104 /// Match any mul operation (fp or integer).
105 template <typename LTy, typename RTy>
106 auto m_AnyMul(const LTy &L, const RTy &R) {
107  return m_CombineOr(m_Mul(L, R), m_FMul(L, R));
108 }
109 
110 namespace {
111 
112 // Given an element pointer \p BasePtr to the start of a (sub) matrix, compute
113 // the start address of vector \p VecIdx with type (\p EltType x \p NumElements)
114 // assuming \p Stride elements between start two consecutive vectors.
115 // \p Stride must be >= \p NumElements.
116 // For column-major matrixes, the function computes the address of a column
117 // vectors and \p NumElements must be set to the number of elements in a column
118 // (= number of rows of the matrix). For row-major matrixes, the function
119 // computes the address of a row vector and \p NumElements must be set to the
120 // number of elements in a column (= number of columns of the matrix).
121 //
122 // Consider a 4x4 matrix in column-mjaor layout like below
123 //
124 // 0 1 2 3
125 // 0 v_0_0 v_0_1 v_0_2 v_0_3
126 // 1 v_1_0 v_1_1 v_1_2 v_1_3
127 // 2 v_2_0 v_2_1 v_2_2 v_2_3
128 // 3 v_3_0 v_3_1 v_3_2 v_3_3
129 
130 // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1,
131 // we need a pointer to the first element of the submatrix as base pointer.
132 // Then we can use computeVectorAddr to compute the addresses for the columns
133 // of the sub-matrix.
134 //
135 // Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..)
136 // -> just returns Base
137 // Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..)
138 // -> returns Base + (1 * 4)
139 // Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..)
140 // -> returns Base + (2 * 4)
141 //
142 // The graphic below illustrates the number of elements in a column (marked
143 // with |) and the number of skipped elements (marked with }).
144 //
145 // v_0_0 v_0_1 {v_0_2 {v_0_3
146 // Base Col 1 Col 2
147 // | | |
148 // v_1_0 |v_1_1 |v_1_2 |v_1_3
149 // v_2_0 |v_2_1 |v_2_2 |v_2_3
150 // v_3_0 {v_3_1 {v_3_2 v_3_3
151 //
152 Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride,
153  unsigned NumElements, Type *EltType,
154  IRBuilder<> &Builder) {
155 
156  assert((!isa<ConstantInt>(Stride) ||
157  cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) &&
158  "Stride must be >= the number of elements in the result vector.");
159  unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
160 
161  // Compute the start of the vector with index VecIdx as VecIdx * Stride.
162  Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start");
163 
164  // Get pointer to the start of the selected vector. Skip GEP creation,
165  // if we select vector 0.
166  if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->isZero())
167  VecStart = BasePtr;
168  else
169  VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep");
170 
171  // Cast elementwise vector start pointer to a pointer to a vector
172  // (EltType x NumElements)*.
173  auto *VecType = FixedVectorType::get(EltType, NumElements);
174  Type *VecPtrType = PointerType::get(VecType, AS);
175  return Builder.CreatePointerCast(VecStart, VecPtrType, "vec.cast");
176 }
177 
178 /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
179 ///
180 /// Currently, the lowering for each matrix intrinsic is done as follows:
181 /// 1. Propagate the shape information from intrinsics to connected
182 /// instructions.
183 /// 2. Lower instructions with shape information (assuming column-major layout).
184 /// The lowering works similarly using row-major layout.
185 /// 2.1. Get column vectors for each argument. If we already lowered the
186 /// definition of an argument, use the produced column vectors directly.
187 /// If not, split the operand vector containing an embedded matrix into
188 /// a set of column vectors,
189 /// 2.2. Lower the instruction in terms of column major operations, which
190 /// yields a set of column vectors containing result matrix. Note that we
191 /// lower all instructions that have shape information. Besides the
192 /// intrinsics, this includes stores for example.
193 /// 2.3. Update uses of the lowered instruction. If we have shape information
194 /// for a user, there is nothing to do, as we will look up the result
195 /// column matrix when lowering the user. For other uses, we embed the
196 /// result matrix in a flat vector and update the use.
197 /// 2.4. Cache the result column matrix for the instruction we lowered
198 /// 3. After we lowered all instructions in a function, remove the now
199 /// obsolete instructions.
200 ///
201 class LowerMatrixIntrinsics {
202  Function &Func;
203  const DataLayout &DL;
204  const TargetTransformInfo &TTI;
205  AliasAnalysis *AA;
206  DominatorTree *DT;
207  LoopInfo *LI;
209 
210  /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation.
211  struct OpInfoTy {
212  /// Number of stores emitted to generate this matrix.
213  unsigned NumStores = 0;
214  /// Number of loads emitted to generate this matrix.
215  unsigned NumLoads = 0;
216  /// Number of compute operations emitted to generate this matrix.
217  unsigned NumComputeOps = 0;
218  /// Most of the time transposes can be fused with matrix multiplies or can
219  /// be folded away via algebraic simplifications. This is the number of
220  /// transposes that we failed to make "free" via such optimizations.
221  unsigned NumExposedTransposes = 0;
222 
223  OpInfoTy &operator+=(const OpInfoTy &RHS) {
224  NumStores += RHS.NumStores;
225  NumLoads += RHS.NumLoads;
226  NumComputeOps += RHS.NumComputeOps;
227  NumExposedTransposes += RHS.NumExposedTransposes;
228  return *this;
229  }
230  };
231 
232  /// Wrapper class representing a matrix as a set of vectors, either in row or
233  /// column major layout. All vectors must have the same vector type.
234  class MatrixTy {
235  SmallVector<Value *, 16> Vectors;
236 
237  OpInfoTy OpInfo;
238 
239  bool IsColumnMajor = true;
240 
241  public:
242  MatrixTy() : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
243  MatrixTy(ArrayRef<Value *> Vectors)
244  : Vectors(Vectors.begin(), Vectors.end()),
245  IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
246  MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy)
247  : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {
248 
249  unsigned D = isColumnMajor() ? NumColumns : NumRows;
250  for (unsigned J = 0; J < D; ++J)
252  EltTy, isColumnMajor() ? NumRows : NumColumns)));
253  }
254 
255  Value *getVector(unsigned i) const { return Vectors[i]; }
256  Value *getColumn(unsigned i) const {
257  assert(isColumnMajor() && "only supported for column-major matrixes");
258  return Vectors[i];
259  }
260  Value *getRow(unsigned i) const {
261  assert(!isColumnMajor() && "only supported for row-major matrixes");
262  return Vectors[i];
263  }
264 
265  void setVector(unsigned i, Value *V) { Vectors[i] = V; }
266 
267  Type *getElementType() const { return getVectorTy()->getElementType(); }
268 
269  unsigned getNumVectors() const {
270  if (isColumnMajor())
271  return getNumColumns();
272  return getNumRows();
273  }
274 
275  unsigned getNumColumns() const {
276  if (isColumnMajor())
277  return Vectors.size();
278  else {
279  assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
280  return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
281  }
282  }
283  unsigned getNumRows() const {
284  if (isColumnMajor()) {
285  assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
286  return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
287  } else
288  return Vectors.size();
289  }
290 
291  void addVector(Value *V) { Vectors.push_back(V); }
292  VectorType *getColumnTy() {
293  assert(isColumnMajor() && "only supported for column-major matrixes");
294  return getVectorTy();
295  }
296 
297  VectorType *getVectorTy() const {
298  return cast<VectorType>(Vectors[0]->getType());
299  }
300 
302  assert(isColumnMajor() &&
303  "columns() only supported for column-major matrixes");
304  return make_range(Vectors.begin(), Vectors.end());
305  }
306 
308  return make_range(Vectors.begin(), Vectors.end());
309  }
310 
311  /// Embed the vectors of the matrix into a flat vector by concatenating
312  /// them.
313  Value *embedInVector(IRBuilder<> &Builder) const {
314  return Vectors.size() == 1 ? Vectors[0]
315  : concatenateVectors(Builder, Vectors);
316  }
317 
318  MatrixTy &addNumLoads(unsigned N) {
319  OpInfo.NumLoads += N;
320  return *this;
321  }
322 
323  void setNumLoads(unsigned N) { OpInfo.NumLoads = N; }
324 
325  MatrixTy &addNumStores(unsigned N) {
326  OpInfo.NumStores += N;
327  return *this;
328  }
329 
330  MatrixTy &addNumExposedTransposes(unsigned N) {
331  OpInfo.NumExposedTransposes += N;
332  return *this;
333  }
334 
335  MatrixTy &addNumComputeOps(unsigned N) {
336  OpInfo.NumComputeOps += N;
337  return *this;
338  }
339 
340  unsigned getNumStores() const { return OpInfo.NumStores; }
341  unsigned getNumLoads() const { return OpInfo.NumLoads; }
342  unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; }
343 
344  const OpInfoTy &getOpInfo() const { return OpInfo; }
345 
346  bool isColumnMajor() const { return IsColumnMajor; }
347 
348  unsigned getStride() const {
349  if (isColumnMajor())
350  return getNumRows();
351  return getNumColumns();
352  }
353 
354  /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the
355  /// matrix is column-major, the result vector is extracted from a column
356  /// vector, otherwise from a row vector.
357  Value *extractVector(unsigned I, unsigned J, unsigned NumElts,
358  IRBuilder<> &Builder) const {
359  Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I);
360  assert(cast<FixedVectorType>(Vec->getType())->getNumElements() >=
361  NumElts &&
362  "Extracted vector will contain poison values");
363  return Builder.CreateShuffleVector(
364  Vec, createSequentialMask(isColumnMajor() ? I : J, NumElts, 0),
365  "block");
366  }
367  };
368 
369  struct ShapeInfo {
370  unsigned NumRows;
371  unsigned NumColumns;
372 
373  bool IsColumnMajor;
374 
375  ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0)
376  : NumRows(NumRows), NumColumns(NumColumns),
377  IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
378 
379  ShapeInfo(Value *NumRows, Value *NumColumns)
380  : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(),
381  cast<ConstantInt>(NumColumns)->getZExtValue()) {}
382 
383  bool operator==(const ShapeInfo &other) {
384  return NumRows == other.NumRows && NumColumns == other.NumColumns;
385  }
386  bool operator!=(const ShapeInfo &other) { return !(*this == other); }
387 
388  /// Returns true if shape-information is defined, meaning both dimensions
389  /// are != 0.
390  operator bool() const {
391  assert(NumRows == 0 || NumColumns != 0);
392  return NumRows != 0;
393  }
394 
395  unsigned getStride() const {
396  if (IsColumnMajor)
397  return NumRows;
398  return NumColumns;
399  }
400 
401  unsigned getNumVectors() const {
402  if (IsColumnMajor)
403  return NumColumns;
404  return NumRows;
405  }
406 
407  /// Returns the transposed shape.
408  ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); }
409  };
410 
411  /// Maps instructions to their shape information. The shape information
412  /// describes the shape to be used while lowering. This matches the shape of
413  /// the result value of the instruction, with the only exceptions being store
414  /// instructions and the matrix_column_major_store intrinsics. For those, the
415  /// shape information indicates that those instructions should be lowered
416  /// using shape information as well. A ValueMap is used so that when
417  /// sub-passes like optimizeTransposes performs RAUW the map stays
418  /// up-to-date.
420 
421  /// List of instructions to remove. While lowering, we are not replacing all
422  /// users of a lowered instruction, if shape information is available and
423  /// those need to be removed after we finished lowering.
425 
426  /// Map from instructions to their produced column matrix.
427  MapVector<Value *, MatrixTy> Inst2ColumnMatrix;
428 
429 private:
430  static FastMathFlags getFastMathFlags(Instruction *Inst) {
431  FastMathFlags FMF;
432 
433  if (isa<FPMathOperator>(*Inst))
434  FMF = Inst->getFastMathFlags();
435 
437 
438  return FMF;
439  }
440 
441 public:
442  LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI,
443  AliasAnalysis *AA, DominatorTree *DT, LoopInfo *LI,
445  : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), AA(AA), DT(DT),
446  LI(LI), ORE(ORE) {}
447 
448  unsigned getNumOps(Type *VT) {
449  assert(isa<VectorType>(VT) && "Expected vector type");
450  return getNumOps(VT->getScalarType(),
451  cast<FixedVectorType>(VT)->getNumElements());
452  }
453 
454  /// Is this the minimal version executed in the backend pipelines.
455  bool isMinimal() const {
456  return !DT;
457  }
458 
459  /// Return the estimated number of vector ops required for an operation on
460  /// \p VT * N.
461  unsigned getNumOps(Type *ST, unsigned N) {
462  return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedSize() /
463  double(TTI.getRegisterBitWidth(
465  .getFixedSize()));
466  }
467 
468  /// Return the set of vectors that a matrix value is lowered to.
469  ///
470  /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise
471  /// split the flat vector \p MatrixVal containing a matrix with shape \p SI
472  /// into vectors.
473  MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI,
474  IRBuilder<> &Builder) {
475  VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType());
476  assert(VType && "MatrixVal must be a vector type");
477  assert(cast<FixedVectorType>(VType)->getNumElements() ==
478  SI.NumRows * SI.NumColumns &&
479  "The vector size must match the number of matrix elements");
480 
481  // Check if we lowered MatrixVal using shape information. In that case,
482  // return the existing matrix, if it matches the requested shape
483  // information. If there is a mis-match, embed the result in a flat
484  // vector and split it later.
485  auto Found = Inst2ColumnMatrix.find(MatrixVal);
486  if (Found != Inst2ColumnMatrix.end()) {
487  MatrixTy &M = Found->second;
488  // Return the found matrix, if its shape matches the requested shape
489  // information
490  if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns())
491  return M;
492 
493  MatrixVal = M.embedInVector(Builder);
494  }
495 
496  // Otherwise split MatrixVal.
497  SmallVector<Value *, 16> SplitVecs;
498  for (unsigned MaskStart = 0;
499  MaskStart < cast<FixedVectorType>(VType)->getNumElements();
500  MaskStart += SI.getStride()) {
501  Value *V = Builder.CreateShuffleVector(
502  MatrixVal, createSequentialMask(MaskStart, SI.getStride(), 0),
503  "split");
504  SplitVecs.push_back(V);
505  }
506 
507  return {SplitVecs};
508  }
509 
510  /// If \p V already has a known shape return false. Otherwise set the shape
511  /// for instructions that support it.
512  bool setShapeInfo(Value *V, ShapeInfo Shape) {
513  assert(Shape && "Shape not set");
514  if (isa<UndefValue>(V) || !supportsShapeInfo(V))
515  return false;
516 
517  auto SIter = ShapeMap.find(V);
518  if (SIter != ShapeMap.end()) {
519  LLVM_DEBUG(dbgs() << " not overriding existing shape: "
520  << SIter->second.NumRows << " "
521  << SIter->second.NumColumns << " for " << *V << "\n");
522  return false;
523  }
524 
525  ShapeMap.insert({V, Shape});
526  LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns
527  << " for " << *V << "\n");
528  return true;
529  }
530 
531  bool isUniformShape(Value *V) {
532  Instruction *I = dyn_cast<Instruction>(V);
533  if (!I)
534  return true;
535 
536  switch (I->getOpcode()) {
537  case Instruction::FAdd:
538  case Instruction::FSub:
539  case Instruction::FMul: // Scalar multiply.
540  case Instruction::FNeg:
541  case Instruction::Add:
542  case Instruction::Mul:
543  case Instruction::Sub:
544  return true;
545  default:
546  return false;
547  }
548  }
549 
550  /// Returns true if shape information can be used for \p V. The supported
551  /// instructions must match the instructions that can be lowered by this pass.
552  bool supportsShapeInfo(Value *V) {
553  Instruction *Inst = dyn_cast<Instruction>(V);
554  if (!Inst)
555  return false;
556 
557  IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst);
558  if (II)
559  switch (II->getIntrinsicID()) {
560  case Intrinsic::matrix_multiply:
561  case Intrinsic::matrix_transpose:
562  case Intrinsic::matrix_column_major_load:
563  case Intrinsic::matrix_column_major_store:
564  return true;
565  default:
566  return false;
567  }
568  return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
569  }
570 
571  /// Propagate the shape information of instructions to their users.
572  /// The work list contains instructions for which we can compute the shape,
573  /// either based on the information provided by matrix intrinsics or known
574  /// shapes of operands.
576  propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) {
577  SmallVector<Instruction *, 32> NewWorkList;
578  // Pop an element for which we guaranteed to have at least one of the
579  // operand shapes. Add the shape for this and then add users to the work
580  // list.
581  LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n");
582  while (!WorkList.empty()) {
583  Instruction *Inst = WorkList.pop_back_val();
584 
585  // New entry, set the value and insert operands
586  bool Propagate = false;
587 
588  Value *MatrixA;
589  Value *MatrixB;
590  Value *M;
591  Value *N;
592  Value *K;
593  if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>(
594  m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
595  m_Value(N), m_Value(K)))) {
596  Propagate = setShapeInfo(Inst, {M, K});
597  } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>(
598  m_Value(MatrixA), m_Value(M), m_Value(N)))) {
599  // Flip dimensions.
600  Propagate = setShapeInfo(Inst, {N, M});
601  } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>(
602  m_Value(MatrixA), m_Value(), m_Value(),
603  m_Value(), m_Value(M), m_Value(N)))) {
604  Propagate = setShapeInfo(Inst, {N, M});
605  } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>(
606  m_Value(), m_Value(), m_Value(), m_Value(M),
607  m_Value(N)))) {
608  Propagate = setShapeInfo(Inst, {M, N});
609  } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) {
610  auto OpShape = ShapeMap.find(MatrixA);
611  if (OpShape != ShapeMap.end())
612  setShapeInfo(Inst, OpShape->second);
613  continue;
614  } else if (isUniformShape(Inst)) {
615  // Find the first operand that has a known shape and use that.
616  for (auto &Op : Inst->operands()) {
617  auto OpShape = ShapeMap.find(Op.get());
618  if (OpShape != ShapeMap.end()) {
619  Propagate |= setShapeInfo(Inst, OpShape->second);
620  break;
621  }
622  }
623  }
624 
625  if (Propagate) {
626  NewWorkList.push_back(Inst);
627  for (auto *User : Inst->users())
628  if (ShapeMap.count(User) == 0)
629  WorkList.push_back(cast<Instruction>(User));
630  }
631  }
632 
633  return NewWorkList;
634  }
635 
636  /// Propagate the shape to operands of instructions with shape information.
637  /// \p Worklist contains the instruction for which we already know the shape.
639  propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) {
640  SmallVector<Instruction *, 32> NewWorkList;
641 
642  auto pushInstruction = [](Value *V,
643  SmallVectorImpl<Instruction *> &WorkList) {
644  Instruction *I = dyn_cast<Instruction>(V);
645  if (I)
646  WorkList.push_back(I);
647  };
648  // Pop an element with known shape. Traverse the operands, if their shape
649  // derives from the result shape and is unknown, add it and add them to the
650  // worklist.
651  LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n");
652  while (!WorkList.empty()) {
653  Value *V = WorkList.pop_back_val();
654 
655  size_t BeforeProcessingV = WorkList.size();
656  if (!isa<Instruction>(V))
657  continue;
658 
659  Value *MatrixA;
660  Value *MatrixB;
661  Value *M;
662  Value *N;
663  Value *K;
664  if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>(
665  m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
666  m_Value(N), m_Value(K)))) {
667  if (setShapeInfo(MatrixA, {M, N}))
668  pushInstruction(MatrixA, WorkList);
669 
670  if (setShapeInfo(MatrixB, {N, K}))
671  pushInstruction(MatrixB, WorkList);
672 
673  } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>(
674  m_Value(MatrixA), m_Value(M), m_Value(N)))) {
675  // Flip dimensions.
676  if (setShapeInfo(MatrixA, {M, N}))
677  pushInstruction(MatrixA, WorkList);
678  } else if (match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>(
679  m_Value(MatrixA), m_Value(), m_Value(), m_Value(),
680  m_Value(M), m_Value(N)))) {
681  if (setShapeInfo(MatrixA, {M, N})) {
682  pushInstruction(MatrixA, WorkList);
683  }
684  } else if (isa<LoadInst>(V) ||
685  match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) {
686  // Nothing to do, no matrix input.
687  } else if (isa<StoreInst>(V)) {
688  // Nothing to do. We forward-propagated to this so we would just
689  // backward propagate to an instruction with an already known shape.
690  } else if (isUniformShape(V)) {
691  // Propagate to all operands.
692  ShapeInfo Shape = ShapeMap[V];
693  for (Use &U : cast<Instruction>(V)->operands()) {
694  if (setShapeInfo(U.get(), Shape))
695  pushInstruction(U.get(), WorkList);
696  }
697  }
698  // After we discovered new shape info for new instructions in the
699  // worklist, we use their users as seeds for the next round of forward
700  // propagation.
701  for (size_t I = BeforeProcessingV; I != WorkList.size(); I++)
702  for (User *U : WorkList[I]->users())
703  if (isa<Instruction>(U) && V != U)
704  NewWorkList.push_back(cast<Instruction>(U));
705  }
706  return NewWorkList;
707  }
708 
709  /// (Op0 op Op1)^T -> Op0^T op Op1^T
710  /// Transpose \p Op0 and \p Op1 of shape \p Shape0 and \p Shape1, then use
711  /// them on both sides of \p Operation.
712  Instruction *distributeTransposes(
713  Value *Op0, ShapeInfo Shape0, Value *Op1, ShapeInfo Shape1,
715  function_ref<Instruction *(Value *, ShapeInfo, Value *, ShapeInfo)>
716  Operation) {
717  Value *T0 = Builder.CreateMatrixTranspose(
718  Op0, Shape0.NumRows, Shape0.NumColumns, Op0->getName() + "_t");
719  // We are being run after shape prop, add shape for newly created
720  // instructions so that we lower them later.
721  setShapeInfo(T0, Shape0.t());
722  Value *T1 = Builder.CreateMatrixTranspose(
723  Op1, Shape1.NumRows, Shape1.NumColumns, Op1->getName() + "_t");
724  setShapeInfo(T1, Shape1.t());
725  return Operation(T0, Shape0.t(), T1, Shape1.t());
726  }
727 
728  /// Try moving transposes in order to fold them away or into multiplies.
729  void optimizeTransposes() {
730  auto ReplaceAllUsesWith = [this](Instruction &Old, Value *New) {
731  // We need to remove Old from the ShapeMap otherwise RAUW will replace it
732  // with New. We should only add New it it supportsShapeInfo so we insert
733  // it conditionally instead.
734  auto S = ShapeMap.find(&Old);
735  if (S != ShapeMap.end()) {
736  ShapeMap.erase(S);
737  if (supportsShapeInfo(New))
738  ShapeMap.insert({New, S->second});
739  }
740  Old.replaceAllUsesWith(New);
741  };
742 
743  // First sink all transposes inside matmuls, hoping that we end up with NN,
744  // NT or TN variants.
745  for (BasicBlock &BB : reverse(Func)) {
746  for (auto II = BB.rbegin(); II != BB.rend();) {
747  Instruction &I = *II;
748  // We may remove II. By default continue on the next/prev instruction.
749  ++II;
750  // If we were to erase II, move again.
751  auto EraseFromParent = [&II, &BB](Value *V) {
752  auto *Inst = cast<Instruction>(V);
753  if (Inst->use_empty()) {
754  if (II != BB.rend() && Inst == &*II) {
755  ++II;
756  }
757  Inst->eraseFromParent();
758  }
759  };
760 
761  // If we're creating a new instruction, continue from there.
762  Instruction *NewInst = nullptr;
763 
764  IRBuilder<> IB(&I);
766 
767  Value *TA, *TAMA, *TAMB;
768  ConstantInt *R, *K, *C;
769  if (match(&I, m_Intrinsic<Intrinsic::matrix_transpose>(
771  // Transpose of a transpose is a nop
772  Value *TATA;
773  if (match(TA,
774  m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TATA)))) {
775  ReplaceAllUsesWith(I, TATA);
776  EraseFromParent(&I);
777  EraseFromParent(TA);
778  }
779  // k^T -> k
780  else if (isSplat(TA)) {
781  ReplaceAllUsesWith(I, TA);
782  EraseFromParent(&I);
783  }
784  // (A * B)^t -> B^t * A^t
785  // RxK KxC CxK KxR
786  else if (match(TA, m_Intrinsic<Intrinsic::matrix_multiply>(
787  m_Value(TAMA), m_Value(TAMB), m_ConstantInt(R),
788  m_ConstantInt(K), m_ConstantInt(C)))) {
789  NewInst = distributeTransposes(
790  TAMB, {K, C}, TAMA, {R, K}, Builder,
791  [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
792  return Builder.CreateMatrixMultiply(
793  T0, T1, Shape0.NumRows, Shape0.NumColumns,
794  Shape1.NumColumns, "mmul");
795  });
796  ReplaceAllUsesWith(I, NewInst);
797  EraseFromParent(&I);
798  EraseFromParent(TA);
799  // Same as above, but with a mul, which occurs when multiplied
800  // with a scalar.
801  // (A * k)^t -> A^t * k
802  // R x C RxC
803  } else if (match(TA, m_AnyMul(m_Value(TAMA), m_Value(TAMB))) &&
804  (isSplat(TAMA) || isSplat(TAMB))) {
805  IRBuilder<> LocalBuilder(&I);
806  // We know that the transposed operand is of shape RxC.
807  // An when multiplied with a scalar, the shape is preserved.
808  NewInst = distributeTransposes(
809  TAMA, {R, C}, TAMB, {R, C}, Builder,
810  [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
811  bool IsFP = I.getType()->isFPOrFPVectorTy();
812  auto *Mul = IsFP ? LocalBuilder.CreateFMul(T0, T1, "mmul")
813  : LocalBuilder.CreateMul(T0, T1, "mmul");
814  auto *Result = cast<Instruction>(Mul);
815  setShapeInfo(Result, Shape0);
816  return Result;
817  });
818  ReplaceAllUsesWith(I, NewInst);
819  EraseFromParent(&I);
820  EraseFromParent(TA);
821  }
822  }
823 
824  // If we replaced I with a new instruction, continue from there.
825  if (NewInst)
826  II = std::next(BasicBlock::reverse_iterator(NewInst));
827  }
828  }
829 
830  // If we have a TT matmul, lift the transpose. We may be able to fold into
831  // consuming multiply.
832  for (BasicBlock &BB : Func) {
834  Value *A, *B, *AT, *BT;
835  ConstantInt *R, *K, *C;
836  // A^t * B ^t -> (B * A)^t
837  if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>(
838  m_Value(A), m_Value(B), m_ConstantInt(R),
839  m_ConstantInt(K), m_ConstantInt(C))) &&
840  match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(AT))) &&
841  match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value((BT))))) {
842  IRBuilder<> IB(&I);
844  Value *M = Builder.CreateMatrixMultiply(
845  BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue());
846  setShapeInfo(M, {C, R});
847  Instruction *NewInst = Builder.CreateMatrixTranspose(
848  M, C->getZExtValue(), R->getZExtValue());
849  ReplaceAllUsesWith(I, NewInst);
850  if (I.use_empty())
851  I.eraseFromParent();
852  if (A->use_empty())
853  cast<Instruction>(A)->eraseFromParent();
854  if (A != B && B->use_empty())
855  cast<Instruction>(B)->eraseFromParent();
856  }
857  }
858  }
859  }
860 
861  bool Visit() {
863 
864  // Initially only the shape of matrix intrinsics is known.
865  // Initialize the work list with ops carrying shape information.
866  for (BasicBlock &BB : Func)
867  for (Instruction &Inst : BB) {
868  IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst);
869  if (!II)
870  continue;
871 
872  switch (II->getIntrinsicID()) {
873  case Intrinsic::matrix_multiply:
874  case Intrinsic::matrix_transpose:
875  case Intrinsic::matrix_column_major_load:
876  case Intrinsic::matrix_column_major_store:
877  WorkList.push_back(&Inst);
878  break;
879  default:
880  break;
881  }
882  }
883 
884  // Avoid unnecessary work if there are no matrix intrinsics in the function.
885  if (WorkList.empty())
886  return false;
887 
888  // Propagate shapes until nothing changes any longer.
889  while (!WorkList.empty()) {
890  WorkList = propagateShapeForward(WorkList);
891  WorkList = propagateShapeBackward(WorkList);
892  }
893 
894  if (!isMinimal()) {
895  optimizeTransposes();
897  dbgs() << "Dump after matrix transpose optimization:\n";
898  Func.print(dbgs());
899  }
900  }
901 
902  bool Changed = false;
903  SmallVector<CallInst *, 16> MaybeFusableInsts;
904  SmallVector<Instruction *, 16> MatrixInsts;
905 
906  // First, collect all instructions with shape information and candidates for
907  // fusion (currently only matrix multiplies).
909  for (auto *BB : RPOT)
910  for (Instruction &I : *BB) {
911  if (ShapeMap.find(&I) == ShapeMap.end())
912  continue;
913  if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>()))
914  MaybeFusableInsts.push_back(cast<CallInst>(&I));
915  MatrixInsts.push_back(&I);
916  }
917 
918  // Second, try to fuse candidates.
920  for (CallInst *CI : MaybeFusableInsts)
921  LowerMatrixMultiplyFused(CI, FusedInsts);
922  Changed = !FusedInsts.empty();
923 
924  // Third, lower remaining instructions with shape information.
925  for (Instruction *Inst : MatrixInsts) {
926  if (FusedInsts.count(Inst))
927  continue;
928 
929  IRBuilder<> Builder(Inst);
930 
931  if (CallInst *CInst = dyn_cast<CallInst>(Inst))
932  Changed |= VisitCallInst(CInst);
933 
934  Value *Op1;
935  Value *Op2;
936  if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
937  Changed |= VisitBinaryOperator(BinOp);
938  if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
939  Changed |= VisitUnaryOperator(UnOp);
940  if (match(Inst, m_Load(m_Value(Op1))))
941  Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder);
942  else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
943  Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);
944  }
945 
946  if (ORE) {
947  RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func);
948  RemarkGen.emitRemarks();
949  }
950 
951  // Delete the instructions backwards, as it has a reduced likelihood of
952  // having to update as many def-use and use-def chains.
953  //
954  // Because we add to ToRemove during fusion we can't guarantee that defs
955  // are before uses. Change uses to poison temporarily as these should get
956  // removed as well.
957  //
958  // For verification, we keep track of where we changed uses to poison in
959  // PoisonedInsts and then check that we in fact remove them.
960  SmallSet<Instruction *, 16> PoisonedInsts;
961  for (auto *Inst : reverse(ToRemove)) {
962  for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
963  if (auto *Poisoned = dyn_cast<Instruction>(U.getUser()))
964  PoisonedInsts.insert(Poisoned);
965  U.set(PoisonValue::get(Inst->getType()));
966  }
967  Inst->eraseFromParent();
968  PoisonedInsts.erase(Inst);
969  }
970  if (!PoisonedInsts.empty()) {
971  // If we didn't remove all poisoned instructions, it's a hard error.
972  dbgs() << "Poisoned but present instructions:\n";
973  for (auto *I : PoisonedInsts)
974  dbgs() << *I << "\n";
975  llvm_unreachable("Poisoned but instruction not removed");
976  }
977 
978  return Changed;
979  }
980 
981  /// Turns \p BasePtr into an elementwise pointer to \p EltType.
982  Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) {
983  unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
984  Type *EltPtrType = PointerType::get(EltType, AS);
985  return Builder.CreatePointerCast(BasePtr, EltPtrType);
986  }
987 
988  /// Replace intrinsic calls
989  bool VisitCallInst(CallInst *Inst) {
990  if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic())
991  return false;
992 
993  switch (Inst->getCalledFunction()->getIntrinsicID()) {
994  case Intrinsic::matrix_multiply:
995  LowerMultiply(Inst);
996  break;
997  case Intrinsic::matrix_transpose:
998  LowerTranspose(Inst);
999  break;
1000  case Intrinsic::matrix_column_major_load:
1001  LowerColumnMajorLoad(Inst);
1002  break;
1003  case Intrinsic::matrix_column_major_store:
1004  LowerColumnMajorStore(Inst);
1005  break;
1006  default:
1007  return false;
1008  }
1009  return true;
1010  }
1011 
1012  /// Compute the alignment for a column/row \p Idx with \p Stride between them.
1013  /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a
1014  /// ConstantInt, reduce the initial alignment based on the byte offset. For
1015  /// non-ConstantInt strides, return the common alignment of the initial
1016  /// alignment and the element size in bytes.
1017  Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy,
1018  MaybeAlign A) const {
1019  Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy);
1020  if (Idx == 0)
1021  return InitialAlign;
1022 
1023  TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy);
1024  if (auto *ConstStride = dyn_cast<ConstantInt>(Stride)) {
1025  uint64_t StrideInBytes =
1026  ConstStride->getZExtValue() * ElementSizeInBits / 8;
1027  return commonAlignment(InitialAlign, Idx * StrideInBytes);
1028  }
1029  return commonAlignment(InitialAlign, ElementSizeInBits / 8);
1030  }
1031 
1032  /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between
1033  /// vectors.
1034  MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride,
1035  bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) {
1036  auto *VType = cast<VectorType>(Ty);
1037  Type *EltTy = VType->getElementType();
1038  Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride());
1039  Value *EltPtr = createElementPtr(Ptr, EltTy, Builder);
1040  MatrixTy Result;
1041  for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) {
1042  Value *GEP = computeVectorAddr(
1043  EltPtr, Builder.getIntN(Stride->getType()->getScalarSizeInBits(), I),
1044  Stride, Shape.getStride(), EltTy, Builder);
1045  Value *Vector = Builder.CreateAlignedLoad(
1046  VecTy, GEP, getAlignForIndex(I, Stride, EltTy, MAlign),
1047  IsVolatile, "col.load");
1048 
1049  Result.addVector(Vector);
1050  }
1051  return Result.addNumLoads(getNumOps(Result.getVectorTy()) *
1052  Result.getNumVectors());
1053  }
1054 
1055  /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix,
1056  /// starting at \p MatrixPtr[I][J].
1057  MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile,
1058  ShapeInfo MatrixShape, Value *I, Value *J,
1059  ShapeInfo ResultShape, Type *EltTy,
1060  IRBuilder<> &Builder) {
1061 
1062  Value *Offset = Builder.CreateAdd(
1063  Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
1064 
1065  unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace();
1066  Value *EltPtr =
1067  Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS));
1068  Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset);
1069  auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows *
1070  ResultShape.NumColumns);
1071  Type *TilePtrTy = PointerType::get(TileTy, AS);
1072  Value *TilePtr =
1073  Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
1074 
1075  return loadMatrix(TileTy, TilePtr, Align,
1076  Builder.getInt64(MatrixShape.getStride()), IsVolatile,
1077  ResultShape, Builder);
1078  }
1079 
1080  /// Lower a load instruction with shape information.
1081  void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride,
1082  bool IsVolatile, ShapeInfo Shape) {
1083  IRBuilder<> Builder(Inst);
1084  finalizeLowering(Inst,
1085  loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile,
1086  Shape, Builder),
1087  Builder);
1088  }
1089 
1090  /// Lowers llvm.matrix.column.major.load.
1091  ///
1092  /// The intrinsic loads a matrix from memory using a stride between columns.
1093  void LowerColumnMajorLoad(CallInst *Inst) {
1095  "Intrinsic only supports column-major layout!");
1096  Value *Ptr = Inst->getArgOperand(0);
1097  Value *Stride = Inst->getArgOperand(1);
1098  LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride,
1099  cast<ConstantInt>(Inst->getArgOperand(2))->isOne(),
1100  {Inst->getArgOperand(3), Inst->getArgOperand(4)});
1101  }
1102 
1103  /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p
1104  /// MatrixPtr[I][J].
1105  void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr,
1106  MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape,
1107  Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) {
1108  Value *Offset = Builder.CreateAdd(
1109  Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
1110 
1111  unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace();
1112  Value *EltPtr =
1113  Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS));
1114  Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset);
1115  auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() *
1116  StoreVal.getNumColumns());
1117  Type *TilePtrTy = PointerType::get(TileTy, AS);
1118  Value *TilePtr =
1119  Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
1120 
1121  storeMatrix(TileTy, StoreVal, TilePtr, MAlign,
1122  Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder);
1123  }
1124 
1125  /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between
1126  /// vectors.
1127  MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr,
1128  MaybeAlign MAlign, Value *Stride, bool IsVolatile,
1129  IRBuilder<> &Builder) {
1130  auto VType = cast<VectorType>(Ty);
1131  Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
1132  for (auto Vec : enumerate(StoreVal.vectors())) {
1133  Value *GEP = computeVectorAddr(
1134  EltPtr,
1135  Builder.getIntN(Stride->getType()->getScalarSizeInBits(),
1136  Vec.index()),
1137  Stride, StoreVal.getStride(), VType->getElementType(), Builder);
1138  Builder.CreateAlignedStore(Vec.value(), GEP,
1139  getAlignForIndex(Vec.index(), Stride,
1140  VType->getElementType(),
1141  MAlign),
1142  IsVolatile);
1143  }
1144  return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) *
1145  StoreVal.getNumVectors());
1146  }
1147 
1148  /// Lower a store instruction with shape information.
1149  void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A,
1150  Value *Stride, bool IsVolatile, ShapeInfo Shape) {
1151  IRBuilder<> Builder(Inst);
1152  auto StoreVal = getMatrix(Matrix, Shape, Builder);
1153  finalizeLowering(Inst,
1154  storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride,
1155  IsVolatile, Builder),
1156  Builder);
1157  }
1158 
1159  /// Lowers llvm.matrix.column.major.store.
1160  ///
1161  /// The intrinsic store a matrix back memory using a stride between columns.
1162  void LowerColumnMajorStore(CallInst *Inst) {
1164  "Intrinsic only supports column-major layout!");
1165  Value *Matrix = Inst->getArgOperand(0);
1166  Value *Ptr = Inst->getArgOperand(1);
1167  Value *Stride = Inst->getArgOperand(2);
1168  LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride,
1169  cast<ConstantInt>(Inst->getArgOperand(3))->isOne(),
1170  {Inst->getArgOperand(4), Inst->getArgOperand(5)});
1171  }
1172 
1173  // Set elements I..I+NumElts-1 to Block
1174  Value *insertVector(Value *Col, unsigned I, Value *Block,
1175  IRBuilder<> &Builder) {
1176 
1177  // First, bring Block to the same size as Col
1178  unsigned BlockNumElts =
1179  cast<FixedVectorType>(Block->getType())->getNumElements();
1180  unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements();
1181  assert(NumElts >= BlockNumElts && "Too few elements for current block");
1182 
1183  Block = Builder.CreateShuffleVector(
1184  Block, createSequentialMask(0, BlockNumElts, NumElts - BlockNumElts));
1185 
1186  // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7,
1187  // 8, 4, 5, 6
1189  unsigned i;
1190  for (i = 0; i < I; i++)
1191  Mask.push_back(i);
1192 
1193  unsigned VecNumElts =
1194  cast<FixedVectorType>(Col->getType())->getNumElements();
1195  for (; i < I + BlockNumElts; i++)
1196  Mask.push_back(i - I + VecNumElts);
1197 
1198  for (; i < VecNumElts; i++)
1199  Mask.push_back(i);
1200 
1201  return Builder.CreateShuffleVector(Col, Block, Mask);
1202  }
1203 
1204  Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp,
1205  IRBuilder<> &Builder, bool AllowContraction,
1206  unsigned &NumComputeOps) {
1207  NumComputeOps += getNumOps(A->getType());
1208  if (!Sum)
1209  return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B);
1210 
1211  if (UseFPOp) {
1212  if (AllowContraction) {
1213  // Use fmuladd for floating point operations and let the backend decide
1214  // if that's profitable.
1216  Func.getParent(), Intrinsic::fmuladd, A->getType());
1217  return Builder.CreateCall(FMulAdd, {A, B, Sum});
1218  }
1219  NumComputeOps += getNumOps(A->getType());
1220  Value *Mul = Builder.CreateFMul(A, B);
1221  return Builder.CreateFAdd(Sum, Mul);
1222  }
1223 
1224  NumComputeOps += getNumOps(A->getType());
1225  Value *Mul = Builder.CreateMul(A, B);
1226  return Builder.CreateAdd(Sum, Mul);
1227  }
1228 
1229  /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
1230  /// users with shape information, there's nothing to do: they will use the
1231  /// cached value when they are lowered. For other users, \p Matrix is
1232  /// flattened and the uses are updated to use it. Also marks \p Inst for
1233  /// deletion.
1234  void finalizeLowering(Instruction *Inst, MatrixTy Matrix,
1235  IRBuilder<> &Builder) {
1236  auto inserted = Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix));
1237  (void)inserted;
1238  assert(inserted.second && "multiple matrix lowering mapping");
1239 
1240  ToRemove.push_back(Inst);
1241  Value *Flattened = nullptr;
1242  for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
1243  if (ShapeMap.find(U.getUser()) == ShapeMap.end()) {
1244  if (!Flattened)
1245  Flattened = Matrix.embedInVector(Builder);
1246  U.set(Flattened);
1247  }
1248  }
1249  }
1250 
1251  /// Compute \p Result += \p A * \p B for input matrices with left-associating
1252  /// addition.
1253  ///
1254  /// We can fold a transpose into the operand that is used to extract scalars.
1255  /// This is the first operands with row-major and the second with
1256  /// column-major. If \p IsScalarMatrixTransposed we assume the appropriate
1257  /// operand is transposed.
1258  void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A,
1259  const MatrixTy &B, IRBuilder<> &Builder, bool IsTiled,
1260  bool IsScalarMatrixTransposed, FastMathFlags FMF) {
1261  const unsigned VF = std::max<unsigned>(
1263  .getFixedSize() /
1264  Result.getElementType()->getPrimitiveSizeInBits().getFixedSize(),
1265  1U);
1266  unsigned R = Result.getNumRows();
1267  unsigned C = Result.getNumColumns();
1268  unsigned M = A.getNumColumns();
1269 
1270  bool IsFP = Result.getElementType()->isFloatingPointTy();
1271  assert(A.isColumnMajor() == B.isColumnMajor() &&
1272  Result.isColumnMajor() == A.isColumnMajor() &&
1273  "operands must agree on matrix layout");
1274  unsigned NumComputeOps = 0;
1275 
1276  Builder.setFastMathFlags(FMF);
1277 
1278  if (A.isColumnMajor()) {
1279  // Multiply columns from the first operand with scalars from the second
1280  // operand. Then move along the K axes and accumulate the columns. With
1281  // this the adds can be vectorized without reassociation.
1282  for (unsigned J = 0; J < C; ++J) {
1283  unsigned BlockSize = VF;
1284  // If Result is zero, we don't need to accumulate in the K==0 iteration.
1285  bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J));
1286 
1287  for (unsigned I = 0; I < R; I += BlockSize) {
1288  // Gradually lower the vectorization factor to cover the remainder.
1289  while (I + BlockSize > R)
1290  BlockSize /= 2;
1291 
1292  Value *Sum = IsTiled ? Result.extractVector(I, J, BlockSize, Builder)
1293  : nullptr;
1294  for (unsigned K = 0; K < M; ++K) {
1295  Value *L = A.extractVector(I, K, BlockSize, Builder);
1296  Value *RH = Builder.CreateExtractElement(
1297  B.getColumn(IsScalarMatrixTransposed ? K : J),
1298  IsScalarMatrixTransposed ? J : K);
1299  Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat");
1300  Sum =
1301  createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat,
1302  IsFP, Builder, FMF.allowContract(), NumComputeOps);
1303  }
1304  Result.setVector(J,
1305  insertVector(Result.getVector(J), I, Sum, Builder));
1306  }
1307  }
1308  } else {
1309  // Multiply rows from the second operand with scalars from the first
1310  // operand. Then move along the K axes and accumulate the rows. With this
1311  // the adds can be vectorized without reassociation.
1312  for (unsigned I = 0; I < R; ++I) {
1313  unsigned BlockSize = VF;
1314  bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I));
1315  for (unsigned J = 0; J < C; J += BlockSize) {
1316  // Gradually lower the vectorization factor to cover the remainder.
1317  while (J + BlockSize > C)
1318  BlockSize /= 2;
1319 
1320  Value *Sum = nullptr;
1321  for (unsigned K = 0; K < M; ++K) {
1322  Value *R = B.extractVector(K, J, BlockSize, Builder);
1323  Value *LH = Builder.CreateExtractElement(
1324  A.getVector(IsScalarMatrixTransposed ? K : I),
1325  IsScalarMatrixTransposed ? I : K);
1326  Value *Splat = Builder.CreateVectorSplat(BlockSize, LH, "splat");
1327  Sum =
1328  createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R,
1329  IsFP, Builder, FMF.allowContract(), NumComputeOps);
1330  }
1331  Result.setVector(I,
1332  insertVector(Result.getVector(I), J, Sum, Builder));
1333  }
1334  }
1335  }
1336  Result.addNumComputeOps(NumComputeOps);
1337  }
1338 
1339  /// Ensure that the memory in \p Load does not alias \p Store by potentially
1340  /// copying it to a new location. This new or otherwise the original location
1341  /// is returned.
1342  Value *getNonAliasingPointer(LoadInst *Load, StoreInst *Store,
1343  CallInst *MatMul) {
1346 
1347  // If we can statically determine noalias we're good.
1348  if (AA->isNoAlias(LoadLoc, StoreLoc))
1349  return Load->getPointerOperand();
1350 
1351  // Create code to check if the memory locations of the Load and Store
1352  // overlap and if they do, copy Load's operand to a new buffer.
1353 
1354  // First, create new blocks for 2n part of the check and the copy.
1355  BasicBlock *Check0 = MatMul->getParent();
1356  // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a
1357  // DT. Manually collect dominator tree updates, to avoid unnecessary work,
1358  // as we adjust Check0 and Check1's branches.
1360  for (BasicBlock *Succ : successors(Check0))
1361  DTUpdates.push_back({DT->Delete, Check0, Succ});
1362 
1363  BasicBlock *Check1 =
1364  SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1365  nullptr, "alias_cont");
1366  BasicBlock *Copy =
1367  SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1368  nullptr, "copy");
1369  BasicBlock *Fusion =
1370  SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1371  nullptr, "no_alias");
1372 
1373  // Check if the loaded memory location begins before the end of the store
1374  // location. If the condition holds, they might overlap, otherwise they are
1375  // guaranteed to not overlap.
1376  IRBuilder<> Builder(MatMul);
1377  Check0->getTerminator()->eraseFromParent();
1378  Builder.SetInsertPoint(Check0);
1379  Type *IntPtrTy = Builder.getIntPtrTy(Load->getModule()->getDataLayout());
1380  Value *StoreBegin = Builder.CreatePtrToInt(
1381  const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin");
1382  Value *StoreEnd = Builder.CreateAdd(
1383  StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()),
1384  "store.end", true, true);
1385  Value *LoadBegin = Builder.CreatePtrToInt(const_cast<Value *>(LoadLoc.Ptr),
1386  IntPtrTy, "load.begin");
1387  Builder.CreateCondBr(Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1,
1388  Fusion);
1389 
1390  // Check if the store begins before the end of the load location. If the
1391  // condition holds, they alias, otherwise they are guaranteed to not
1392  // overlap.
1393  Check1->getTerminator()->eraseFromParent();
1394  Builder.SetInsertPoint(Check1, Check1->begin());
1395  Value *LoadEnd = Builder.CreateAdd(
1396  LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()),
1397  "load.end", true, true);
1398  Builder.CreateCondBr(Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy,
1399  Fusion);
1400 
1401  // Copy load operand to new alloca.
1402  Builder.SetInsertPoint(Copy, Copy->begin());
1403  auto *VT = cast<FixedVectorType>(Load->getType());
1404  // Use an array type for the alloca, to avoid potentially huge alignment
1405  // requirements for large vector types.
1406  auto *ArrayTy = ArrayType::get(VT->getElementType(), VT->getNumElements());
1407  AllocaInst *Alloca =
1408  Builder.CreateAlloca(ArrayTy, Load->getPointerAddressSpace());
1409  Value *BC = Builder.CreateBitCast(Alloca, VT->getPointerTo());
1410 
1411  Builder.CreateMemCpy(BC, Alloca->getAlign(), Load->getPointerOperand(),
1412  Load->getAlign(), LoadLoc.Size.getValue());
1413  Builder.SetInsertPoint(Fusion, Fusion->begin());
1414  PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3);
1415  PHI->addIncoming(Load->getPointerOperand(), Check0);
1416  PHI->addIncoming(Load->getPointerOperand(), Check1);
1417  PHI->addIncoming(BC, Copy);
1418 
1419  // Adjust DT.
1420  DTUpdates.push_back({DT->Insert, Check0, Check1});
1421  DTUpdates.push_back({DT->Insert, Check0, Fusion});
1422  DTUpdates.push_back({DT->Insert, Check1, Copy});
1423  DTUpdates.push_back({DT->Insert, Check1, Fusion});
1424  DT->applyUpdates(DTUpdates);
1425  return PHI;
1426  }
1427 
1428  bool isFusionProfitable(CallInst *MatMul) {
1429  if (ForceFusion)
1430  return true;
1431 
1432  ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1433  ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1434 
1435  const unsigned R = LShape.NumRows;
1436  const unsigned C = RShape.NumColumns;
1437  const unsigned M = LShape.NumColumns;
1438  auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1439 
1440  const unsigned VF = std::max<unsigned>(
1442  .getFixedSize() /
1443  EltType->getPrimitiveSizeInBits().getFixedSize(),
1444  1U);
1445 
1446  // Cost model for tiling
1447  //
1448  // For tiling to be beneficial, we need reuse either along the R or
1449  // the C axis. We vectorize along the R axis so that means at least
1450  // 3 elements.
1451  // TODO: Also consider cost of copying if operands alias.
1452  if (R <= VF && C == 1)
1453  return false;
1454  // Then we need enough elements to exceed the number of vector
1455  // registers we have. Note that this is an oversimplification since
1456  // fusing also takes some extra loads which may exceed the number of
1457  // reloads necessary.
1458  unsigned Op0Regs = (R + VF - 1) / VF * M;
1459  unsigned Op1Regs = (M + VF - 1) / VF * C;
1460  return Op0Regs + Op1Regs >
1462  }
1463 
1464  MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) {
1465  MatrixTy Res;
1466  auto *ColumType = FixedVectorType::get(EltType, R);
1467  for (unsigned I = 0; I < C; ++I)
1468  Res.addVector(ConstantAggregateZero::get(ColumType));
1469  return Res;
1470  }
1471 
1472  void createTiledLoops(CallInst *MatMul, Value *LPtr, ShapeInfo LShape,
1473  Value *RPtr, ShapeInfo RShape, StoreInst *Store) {
1474  auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1475 
1476  // Create the main tiling loop nest.
1477  TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns, TileSize);
1479  Instruction *InsertI = cast<Instruction>(MatMul);
1480  BasicBlock *Start = InsertI->getParent();
1481  BasicBlock *End =
1482  SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue");
1483  IRBuilder<> Builder(MatMul);
1484  BasicBlock *InnerBody = TI.CreateTiledLoops(Start, End, Builder, DTU, *LI);
1485 
1486  Type *TileVecTy =
1488  MatrixTy TileResult;
1489  // Insert in the inner loop header.
1490  Builder.SetInsertPoint(TI.KLoop.Header->getTerminator());
1491  // Create PHI nodes for the result columns to accumulate across iterations.
1492  SmallVector<PHINode *, 4> ColumnPhis;
1493  for (unsigned I = 0; I < TileSize; I++) {
1494  auto *Phi = Builder.CreatePHI(TileVecTy, 2, "result.vec." + Twine(I));
1495  Phi->addIncoming(ConstantAggregateZero::get(TileVecTy),
1496  TI.RowLoop.Header->getSingleSuccessor());
1497  TileResult.addVector(Phi);
1498  ColumnPhis.push_back(Phi);
1499  }
1500 
1501  // Insert in the inner loop body, which computes
1502  // Res += Load(CurrentRow, K) * Load(K, CurrentColumn)
1503  Builder.SetInsertPoint(InnerBody->getTerminator());
1504  // Load tiles of the operands.
1505  MatrixTy A =
1506  loadMatrix(LPtr, {}, false, LShape, TI.RowLoop.Index, TI.KLoop.Index,
1507  {TileSize, TileSize}, EltType, Builder);
1508  MatrixTy B =
1509  loadMatrix(RPtr, {}, false, RShape, TI.KLoop.Index, TI.ColumnLoop.Index,
1510  {TileSize, TileSize}, EltType, Builder);
1511  emitMatrixMultiply(TileResult, A, B, Builder, true, false,
1512  getFastMathFlags(MatMul));
1513  // Store result after the inner loop is done.
1514  Builder.SetInsertPoint(TI.RowLoop.Latch->getTerminator());
1515  storeMatrix(TileResult, Store->getPointerOperand(), Store->getAlign(),
1516  Store->isVolatile(), {LShape.NumRows, RShape.NumColumns},
1517  TI.RowLoop.Index, TI.ColumnLoop.Index, EltType, Builder);
1518 
1519  for (unsigned I = 0; I < TileResult.getNumVectors(); I++)
1520  ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.KLoop.Latch);
1521 
1522  // Force unrolling of a few iterations of the inner loop, to make sure there
1523  // is enough work per iteration.
1524  // FIXME: The unroller should make this decision directly instead, but
1525  // currently the cost-model is not up to the task.
1526  unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns / TileSize);
1527  addStringMetadataToLoop(LI->getLoopFor(TI.KLoop.Header),
1528  "llvm.loop.unroll.count", InnerLoopUnrollCount);
1529  }
1530 
1531  void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1,
1532  StoreInst *Store,
1533  SmallPtrSetImpl<Instruction *> &FusedInsts) {
1535  "Tiling only supported for column-major matrixes at the moment!");
1536  if (!isFusionProfitable(MatMul))
1537  return;
1538 
1539  ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1540  ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1541 
1542  const unsigned R = LShape.NumRows;
1543  const unsigned C = RShape.NumColumns;
1544  const unsigned M = LShape.NumColumns;
1545  auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1546 
1547  Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul);
1548  Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul);
1549  Value *CPtr = Store->getPointerOperand();
1550 
1551  if (TileUseLoops && (R % TileSize == 0 && C % TileSize == 0))
1552  createTiledLoops(MatMul, APtr, LShape, BPtr, RShape, Store);
1553  else {
1555  for (unsigned J = 0; J < C; J += TileSize)
1556  for (unsigned I = 0; I < R; I += TileSize) {
1557  const unsigned TileR = std::min(R - I, unsigned(TileSize));
1558  const unsigned TileC = std::min(C - J, unsigned(TileSize));
1559  MatrixTy Res = getZeroMatrix(EltType, TileR, TileC);
1560 
1561  for (unsigned K = 0; K < M; K += TileSize) {
1562  const unsigned TileM = std::min(M - K, unsigned(TileSize));
1563  MatrixTy A =
1564  loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(),
1565  LShape, Builder.getInt64(I), Builder.getInt64(K),
1566  {TileR, TileM}, EltType, Builder);
1567  MatrixTy B =
1568  loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(),
1569  RShape, Builder.getInt64(K), Builder.getInt64(J),
1570  {TileM, TileC}, EltType, Builder);
1571  emitMatrixMultiply(Res, A, B, Builder, true, false,
1572  getFastMathFlags(MatMul));
1573  }
1574  storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M},
1575  Builder.getInt64(I), Builder.getInt64(J), EltType,
1576  Builder);
1577  }
1578  }
1579 
1580  // Mark eliminated instructions as fused and remove them.
1581  FusedInsts.insert(Store);
1582  FusedInsts.insert(MatMul);
1583  Store->eraseFromParent();
1584  MatMul->eraseFromParent();
1585  if (LoadOp0->hasNUses(0)) {
1586  FusedInsts.insert(LoadOp0);
1587  LoadOp0->eraseFromParent();
1588  }
1589  if (LoadOp1 != LoadOp0 && LoadOp1->hasNUses(0)) {
1590  FusedInsts.insert(LoadOp1);
1591  LoadOp1->eraseFromParent();
1592  }
1593  }
1594 
1595  /// Try to lower matrix multiply chains by fusing operations.
1596  ///
1597  /// Call finalizeLowering on lowered instructions. Instructions that are
1598  /// completely eliminated by fusion are added to \p FusedInsts.
1599  void LowerMatrixMultiplyFused(CallInst *MatMul,
1600  SmallPtrSetImpl<Instruction *> &FusedInsts) {
1601  if (!FuseMatrix || !DT)
1602  return;
1603 
1604  assert(AA && LI && "Analyses should be available");
1605 
1606  Value *A = MatMul->getArgOperand(0);
1607  Value *B = MatMul->getArgOperand(1);
1608 
1609  // We can fold the transpose into the operand that is used to fetch scalars.
1610  Value *T;
1612  ? match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))
1613  : match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))) {
1614  IRBuilder<> Builder(MatMul);
1615  auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1616  ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1617  ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1618  const unsigned R = LShape.NumRows;
1619  const unsigned M = LShape.NumColumns;
1620  const unsigned C = RShape.NumColumns;
1621 
1622  MatrixTy MA;
1623  MatrixTy MB;
1624 
1625  Value *Transpose;
1627  MA = getMatrix(A, ShapeInfo(R, M), Builder);
1628  MB = getMatrix(T, ShapeInfo(C, M), Builder);
1629  Transpose = B;
1630  } else {
1631  MA = getMatrix(T, ShapeInfo(R, M), Builder);
1632  MB = getMatrix(B, ShapeInfo(C, M), Builder);
1633  Transpose = A;
1634  }
1635 
1636  // Initialize the output
1637  MatrixTy Result(R, C, EltType);
1638 
1639  emitMatrixMultiply(Result, MA, MB, Builder, false, true,
1640  getFastMathFlags(MatMul));
1641 
1642  FusedInsts.insert(MatMul);
1643  if (Transpose->hasOneUse()) {
1644  FusedInsts.insert(cast<Instruction>(Transpose));
1645  ToRemove.push_back(cast<Instruction>(Transpose));
1646  // TODO: add a fake entry for the folded instruction so that this is
1647  // included in the expression in the remark.
1648  Inst2ColumnMatrix[Transpose] = MatrixTy(M, C, EltType);
1649  }
1650  finalizeLowering(MatMul, Result, Builder);
1651  return;
1652  }
1653 
1654  if (!MatMul->hasOneUse() || MatrixLayout != MatrixLayoutTy::ColumnMajor)
1655  return;
1656 
1657  // Lower {ld, ld} -> matmul -> st chains. No need to call finalizeLowering
1658  // since the single store user will be lowered as part of this.
1659  auto *LoadOp0 = dyn_cast<LoadInst>(A);
1660  auto *LoadOp1 = dyn_cast<LoadInst>(B);
1661  auto *Store = dyn_cast<StoreInst>(*MatMul->user_begin());
1662  if (LoadOp0 && LoadOp1 && Store) {
1663  // The store address must dominate the MatMul instruction, otherwise
1664  // we create invalid IR.
1665  SetVector<Value *> WorkList;
1666  WorkList.insert(Store->getOperand(1));
1668  for (unsigned I = 0; I != WorkList.size(); ++I) {
1669  Value *Current = WorkList[I];
1670  auto *CurrI = dyn_cast<Instruction>(Current);
1671  if (!CurrI)
1672  continue;
1673  if (isa<PHINode>(CurrI))
1674  return;
1675  if (DT->dominates(CurrI, MatMul))
1676  continue;
1677  if (CurrI->mayHaveSideEffects() || CurrI->mayReadFromMemory())
1678  return;
1679  ToHoist.push_back(CurrI);
1680  WorkList.insert(CurrI->op_begin(), CurrI->op_end());
1681  }
1682 
1683  sort(ToHoist, [this](Instruction *A, Instruction *B) {
1684  return DT->dominates(A, B);
1685  });
1686  for (Instruction *I : ToHoist)
1687  I->moveBefore(MatMul);
1688 
1689  emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts);
1690  return;
1691  }
1692  }
1693 
1694  /// Lowers llvm.matrix.multiply.
1695  void LowerMultiply(CallInst *MatMul) {
1696  IRBuilder<> Builder(MatMul);
1697  auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1698  ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1699  ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1700 
1701  const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder);
1702  const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder);
1703  assert(Lhs.getElementType() == Rhs.getElementType() &&
1704  "Matrix multiply argument element types do not match.");
1705 
1706  const unsigned R = LShape.NumRows;
1707  const unsigned C = RShape.NumColumns;
1708  assert(LShape.NumColumns == RShape.NumRows);
1709 
1710  // Initialize the output
1711  MatrixTy Result(R, C, EltType);
1712  assert(Lhs.getElementType() == Result.getElementType() &&
1713  "Matrix multiply result element type does not match arguments.");
1714 
1715  emitMatrixMultiply(Result, Lhs, Rhs, Builder, false, false,
1716  getFastMathFlags(MatMul));
1717  finalizeLowering(MatMul, Result, Builder);
1718  }
1719 
1720  /// Lowers llvm.matrix.transpose.
1721  void LowerTranspose(CallInst *Inst) {
1722  MatrixTy Result;
1723  IRBuilder<> Builder(Inst);
1724  Value *InputVal = Inst->getArgOperand(0);
1725  VectorType *VectorTy = cast<VectorType>(InputVal->getType());
1726  ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2));
1727  MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);
1728 
1729  const unsigned NewNumVecs =
1730  InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns;
1731  const unsigned NewNumElts =
1732  InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows;
1733 
1734  for (unsigned I = 0; I < NewNumVecs; ++I) {
1735  // Build a single result vector. First initialize it.
1736  Value *ResultVector = PoisonValue::get(
1737  FixedVectorType::get(VectorTy->getElementType(), NewNumElts));
1738  // Go through the old elements and insert it into the resulting vector.
1739  for (auto J : enumerate(InputMatrix.vectors())) {
1740  Value *Elt = Builder.CreateExtractElement(J.value(), I);
1741  // Row and column indices are transposed.
1742  ResultVector =
1743  Builder.CreateInsertElement(ResultVector, Elt, J.index());
1744  }
1745  Result.addVector(ResultVector);
1746  }
1747 
1748  // TODO: Improve estimate of operations needed for transposes. Currently we
1749  // just count the insertelement/extractelement instructions, but do not
1750  // account for later simplifications/combines.
1751  finalizeLowering(
1752  Inst,
1753  Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns)
1754  .addNumExposedTransposes(1),
1755  Builder);
1756  }
1757 
1758  /// Lower load instructions, if shape information is available.
1759  bool VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) {
1760  auto I = ShapeMap.find(Inst);
1761  if (I == ShapeMap.end())
1762  return false;
1763 
1764  LowerLoad(Inst, Ptr, Inst->getAlign(),
1765  Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
1766  I->second);
1767  return true;
1768  }
1769 
1770  bool VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr,
1771  IRBuilder<> &Builder) {
1772  auto I = ShapeMap.find(StoredVal);
1773  if (I == ShapeMap.end())
1774  return false;
1775 
1776  LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
1777  Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
1778  I->second);
1779  return true;
1780  }
1781 
1782  /// Lower binary operators, if shape information is available.
1783  bool VisitBinaryOperator(BinaryOperator *Inst) {
1784  auto I = ShapeMap.find(Inst);
1785  if (I == ShapeMap.end())
1786  return false;
1787 
1788  Value *Lhs = Inst->getOperand(0);
1789  Value *Rhs = Inst->getOperand(1);
1790 
1791  IRBuilder<> Builder(Inst);
1792  ShapeInfo &Shape = I->second;
1793 
1794  MatrixTy Result;
1795  MatrixTy A = getMatrix(Lhs, Shape, Builder);
1796  MatrixTy B = getMatrix(Rhs, Shape, Builder);
1797  assert(A.isColumnMajor() == B.isColumnMajor() &&
1798  Result.isColumnMajor() == A.isColumnMajor() &&
1799  "operands must agree on matrix layout");
1800 
1801  Builder.setFastMathFlags(getFastMathFlags(Inst));
1802 
1803  // Helper to perform binary op on vectors.
1804  auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) {
1805  switch (Inst->getOpcode()) {
1806  case Instruction::Add:
1807  return Builder.CreateAdd(LHS, RHS);
1808  case Instruction::Mul:
1809  return Builder.CreateMul(LHS, RHS);
1810  case Instruction::Sub:
1811  return Builder.CreateSub(LHS, RHS);
1812  case Instruction::FAdd:
1813  return Builder.CreateFAdd(LHS, RHS);
1814  case Instruction::FMul:
1815  return Builder.CreateFMul(LHS, RHS);
1816  case Instruction::FSub:
1817  return Builder.CreateFSub(LHS, RHS);
1818  default:
1819  llvm_unreachable("Unsupported binary operator for matrix");
1820  }
1821  };
1822 
1823  for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
1824  Result.addVector(BuildVectorOp(A.getVector(I), B.getVector(I)));
1825 
1826  finalizeLowering(Inst,
1827  Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
1828  Result.getNumVectors()),
1829  Builder);
1830  return true;
1831  }
1832 
1833  /// Lower unary operators, if shape information is available.
1834  bool VisitUnaryOperator(UnaryOperator *Inst) {
1835  auto I = ShapeMap.find(Inst);
1836  if (I == ShapeMap.end())
1837  return false;
1838 
1839  Value *Op = Inst->getOperand(0);
1840 
1841  IRBuilder<> Builder(Inst);
1842  ShapeInfo &Shape = I->second;
1843 
1844  MatrixTy Result;
1845  MatrixTy M = getMatrix(Op, Shape, Builder);
1846 
1847  Builder.setFastMathFlags(getFastMathFlags(Inst));
1848 
1849  // Helper to perform unary op on vectors.
1850  auto BuildVectorOp = [&Builder, Inst](Value *Op) {
1851  switch (Inst->getOpcode()) {
1852  case Instruction::FNeg:
1853  return Builder.CreateFNeg(Op);
1854  default:
1855  llvm_unreachable("Unsupported unary operator for matrix");
1856  }
1857  };
1858 
1859  for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
1860  Result.addVector(BuildVectorOp(M.getVector(I)));
1861 
1862  finalizeLowering(Inst,
1863  Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
1864  Result.getNumVectors()),
1865  Builder);
1866  return true;
1867  }
1868 
1869  /// Helper to linearize a matrix expression tree into a string. Currently
1870  /// matrix expressions are linarized by starting at an expression leaf and
1871  /// linearizing bottom up.
1872  struct ExprLinearizer {
1873  unsigned LengthToBreak = 100;
1874  std::string Str;
1875  raw_string_ostream Stream;
1876  unsigned LineLength = 0;
1877  const DataLayout &DL;
1878 
1879  /// Mapping from instructions to matrixes. It is used to identify
1880  /// matrix instructions.
1881  const MapVector<Value *, MatrixTy> &Inst2Matrix;
1882 
1883  /// Mapping from values to the leaves of all expressions that the value is
1884  /// part of.
1886 
1887  /// Set of matrix expressions in the scope of a given DISubprogram.
1888  const SmallSetVector<Value *, 32> &ExprsInSubprogram;
1889 
1890  /// Leaf node of the expression to linearize.
1891  Value *Leaf;
1892 
1893  /// Used to keep track of sub-expressions that get reused while linearizing
1894  /// the expression. Re-used sub-expressions are marked as (reused).
1895  SmallPtrSet<Value *, 8> ReusedExprs;
1896 
1897  ExprLinearizer(const DataLayout &DL,
1898  const MapVector<Value *, MatrixTy> &Inst2Matrix,
1899  const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
1900  const SmallSetVector<Value *, 32> &ExprsInSubprogram,
1901  Value *Leaf)
1902  : Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared),
1903  ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {}
1904 
1905  void indent(unsigned N) {
1906  LineLength += N;
1907  for (unsigned i = 0; i < N; i++)
1908  Stream << " ";
1909  }
1910 
1911  void lineBreak() {
1912  Stream << "\n";
1913  LineLength = 0;
1914  }
1915 
1916  void maybeIndent(unsigned Indent) {
1917  if (LineLength >= LengthToBreak)
1918  lineBreak();
1919 
1920  if (LineLength == 0)
1921  indent(Indent);
1922  }
1923 
1924  void write(StringRef S) {
1925  LineLength += S.size();
1926  Stream << S;
1927  }
1928 
1929  Value *getUnderlyingObjectThroughLoads(Value *V) {
1930  if (Value *Ptr = getPointerOperand(V))
1931  return getUnderlyingObjectThroughLoads(Ptr);
1932  else if (V->getType()->isPointerTy())
1933  return getUnderlyingObject(V);
1934  return V;
1935  }
1936 
1937  /// Returns true if \p V is a matrix value in the given subprogram.
1938  bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); }
1939 
1940  /// If \p V is a matrix value, print its shape as as NumRows x NumColumns to
1941  /// \p SS.
1942  void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) {
1943  auto M = Inst2Matrix.find(V);
1944  if (M == Inst2Matrix.end())
1945  SS << "unknown";
1946  else {
1947  SS << M->second.getNumRows();
1948  SS << "x";
1949  SS << M->second.getNumColumns();
1950  }
1951  }
1952 
1953  /// Write the called function name. Handles calls to llvm.matrix.*
1954  /// specially: we write the name, followed by the dimensions of the input
1955  /// matrixes, followed by the scalar type name.
1956  void writeFnName(CallInst *CI) {
1957  if (!CI->getCalledFunction())
1958  write("<no called fn>");
1959  else {
1961  if (!Name.startswith("llvm.matrix")) {
1962  write(Name);
1963  return;
1964  }
1965  auto *II = cast<IntrinsicInst>(CI);
1967  .drop_front(StringRef("llvm.matrix.").size()));
1968  write(".");
1969  std::string Tmp;
1970  raw_string_ostream SS(Tmp);
1971 
1972  switch (II->getIntrinsicID()) {
1973  case Intrinsic::matrix_multiply:
1974  prettyPrintMatrixType(II->getOperand(0), SS);
1975  SS << ".";
1976  prettyPrintMatrixType(II->getOperand(1), SS);
1977  SS << "." << *II->getType()->getScalarType();
1978  break;
1979  case Intrinsic::matrix_transpose:
1980  prettyPrintMatrixType(II->getOperand(0), SS);
1981  SS << "." << *II->getType()->getScalarType();
1982  break;
1983  case Intrinsic::matrix_column_major_load:
1984  prettyPrintMatrixType(II, SS);
1985  SS << "." << *II->getType()->getScalarType();
1986  break;
1987  case Intrinsic::matrix_column_major_store:
1988  prettyPrintMatrixType(II->getOperand(0), SS);
1989  SS << "." << *II->getOperand(0)->getType()->getScalarType();
1990  break;
1991  default:
1992  llvm_unreachable("Unhandled case");
1993  }
1994  SS.flush();
1995  write(Tmp);
1996  }
1997  }
1998 
1999  unsigned getNumShapeArgs(CallInst *CI) const {
2000  if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) {
2001  switch (II->getIntrinsicID()) {
2002  case Intrinsic::matrix_multiply:
2003  return 3;
2004  case Intrinsic::matrix_transpose:
2005  return 2;
2006  case Intrinsic::matrix_column_major_load:
2007  case Intrinsic::matrix_column_major_store:
2008  return 3;
2009  default:
2010  return 0;
2011  }
2012  }
2013  return 0;
2014  }
2015 
2016  /// Special printing for values: for pointers, we print if they refer to an
2017  /// (function) external address or a stack address, for other values we
2018  /// either print the constant or "scalar"/"matrix" for other values.
2019  void write(Value *V) {
2020  V = getUnderlyingObjectThroughLoads(V);
2021  if (V->getType()->isPointerTy()) {
2022  if (isa<AllocaInst>(V)) {
2023  Stream << "stack addr";
2024  LineLength += StringRef("stack addr").size();
2025  } else {
2026  Stream << "addr";
2027  LineLength += StringRef("addr").size();
2028  }
2029  if (!V->getName().empty()) {
2030  Stream << " %" << V->getName() << "";
2031  LineLength += V->getName().size() + 2;
2032  }
2033  return;
2034  }
2035 
2036  std::string Tmp;
2037  raw_string_ostream TmpStream(Tmp);
2038 
2039  if (auto *CI = dyn_cast<ConstantInt>(V))
2040  TmpStream << CI->getValue();
2041  else if (isa<Constant>(V))
2042  TmpStream << "constant";
2043  else {
2044  if (isMatrix(V))
2045  TmpStream << "matrix";
2046  else
2047  TmpStream << "scalar";
2048  }
2049  TmpStream.flush();
2050  Tmp = std::string(StringRef(Tmp).trim());
2051  LineLength += Tmp.size();
2052  Stream << Tmp;
2053  }
2054 
2055  /// Linearize expression \p Expr starting at an indentation of \p Indent.
2056  /// Expressions that are re-used multiple times are prefixed with (reused)
2057  /// at the re-used root instruction.
2058  void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused,
2059  bool ParentShared) {
2060  auto *I = cast<Instruction>(Expr);
2061  maybeIndent(Indent);
2063 
2064  // Is Expr shared with other expression leaves?
2065  bool ExprShared = false;
2066 
2067  // Deal with shared subtrees. Mark them as shared, if required.
2068  if (!ParentShared) {
2069  auto SI = Shared.find(Expr);
2070  assert(SI != Shared.end() && SI->second.count(Leaf));
2071 
2072  for (Value *S : SI->second) {
2073  if (S == Leaf)
2074  continue;
2075  DebugLoc DL = cast<Instruction>(S)->getDebugLoc();
2076  write("shared with remark at line " + std::to_string(DL.getLine()) +
2077  " column " + std::to_string(DL.getCol()) + " (");
2078  }
2079  ExprShared = SI->second.size() > 1;
2080  }
2081 
2082  bool Reused = !ReusedExprs.insert(Expr).second;
2083  if (Reused && !ParentReused)
2084  write("(reused) ");
2085 
2086  if (auto *CI = dyn_cast<CallInst>(I)) {
2087  writeFnName(CI);
2088 
2089  Ops.append(CI->arg_begin(), CI->arg_end() - getNumShapeArgs(CI));
2090  } else if (isa<BitCastInst>(Expr)) {
2091  // Special case bitcasts, which are used to materialize matrixes from
2092  // non-matrix ops.
2093  write("matrix");
2094  return;
2095  } else {
2096  Ops.append(I->value_op_begin(), I->value_op_end());
2097  write(std::string(I->getOpcodeName()));
2098  }
2099 
2100  write(std::string("("));
2101 
2102  unsigned NumOpsToBreak = 1;
2103  if (match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>()))
2104  NumOpsToBreak = 2;
2105 
2106  for (Value *Op : Ops) {
2107  if (Ops.size() > NumOpsToBreak)
2108  lineBreak();
2109 
2110  maybeIndent(Indent + 1);
2111  if (isMatrix(Op))
2112  linearizeExpr(Op, Indent + 1, Reused, ExprShared);
2113  else
2114  write(Op);
2115  if (Op != Ops.back())
2116  write(", ");
2117  }
2118 
2119  write(")");
2120  }
2121 
2122  const std::string &getResult() {
2123  Stream.flush();
2124  return Str;
2125  }
2126  };
2127 
2128  /// Generate remarks for matrix operations in a function. To generate remarks
2129  /// for matrix expressions, the following approach is used:
2130  /// 1. Use the inlined-at debug information to group matrix operations to the
2131  /// DISubprograms they are contained in.
2132  /// 2. Collect leaves of matrix expressions (done in
2133  /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression
2134  // mapping. Leaves are lowered matrix instructions without other matrix
2135  // users (like stores) in the current subprogram.
2136  /// 3. For each leaf, create a remark containing a linearizied version of the
2137  /// matrix expression. The expression is linearized by a recursive
2138  /// bottom-up traversal of the matrix operands, starting at a leaf. Note
2139  /// that multiple leaves can share sub-expressions. Shared subexpressions
2140  /// are explicitly marked as shared().
2141  struct RemarkGenerator {
2142  const MapVector<Value *, MatrixTy> &Inst2Matrix;
2144  Function &Func;
2145  const DataLayout &DL;
2146 
2147  RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix,
2148  OptimizationRemarkEmitter &ORE, Function &Func)
2149  : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func),
2150  DL(Func.getParent()->getDataLayout()) {}
2151 
2152  /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are
2153  /// instructions in Inst2Matrix returning void or without any users in
2154  /// \p ExprsInSubprogram. Currently that should only include stores.
2156  getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) {
2157  SmallVector<Value *, 4> Leaves;
2158  for (auto *Expr : ExprsInSubprogram)
2159  if (Expr->getType()->isVoidTy() ||
2160  !any_of(Expr->users(), [&ExprsInSubprogram](User *U) {
2161  return ExprsInSubprogram.count(U);
2162  }))
2163  Leaves.push_back(Expr);
2164  return Leaves;
2165  }
2166 
2167  /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf
2168  /// to all visited expressions in \p Shared. Limit the matrix operations to
2169  /// the ones in \p ExprsInSubprogram.
2170  void collectSharedInfo(Value *Leaf, Value *V,
2171  const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2172  DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) {
2173 
2174  if (!ExprsInSubprogram.count(V))
2175  return;
2176 
2177  auto I = Shared.insert({V, {}});
2178  I.first->second.insert(Leaf);
2179 
2180  for (Value *Op : cast<Instruction>(V)->operand_values())
2181  collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared);
2182  }
2183 
2184  /// Calculate the number of exclusive and shared op counts for expression
2185  /// starting at \p V. Expressions used multiple times are counted once.
2186  /// Limit the matrix operations to the ones in \p ExprsInSubprogram.
2187  std::pair<OpInfoTy, OpInfoTy>
2188  sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs,
2189  const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2190  DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const {
2191  if (!ExprsInSubprogram.count(Root))
2192  return {};
2193 
2194  // Already counted this expression. Stop.
2195  if (!ReusedExprs.insert(Root).second)
2196  return {};
2197 
2198  OpInfoTy SharedCount;
2199  OpInfoTy Count;
2200 
2201  auto I = Shared.find(Root);
2202  auto CM = Inst2Matrix.find(Root);
2203  if (I->second.size() == 1)
2204  Count = CM->second.getOpInfo();
2205  else
2206  SharedCount = CM->second.getOpInfo();
2207 
2208  for (Value *Op : cast<Instruction>(Root)->operand_values()) {
2209  auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared);
2210  Count += C.first;
2211  SharedCount += C.second;
2212  }
2213  return {Count, SharedCount};
2214  }
2215 
2216  void emitRemarks() {
2217  if (!ORE.allowExtraAnalysis(DEBUG_TYPE))
2218  return;
2219 
2220  // Map matrix operations to their containting subprograms, by traversing
2221  // the inlinedAt chain. If the function does not have a DISubprogram, we
2222  // only map them to the containing function.
2224  for (const auto &KV : Inst2Matrix) {
2225  if (Func.getSubprogram()) {
2226  auto *I = cast<Instruction>(KV.first);
2227  DILocation *Context = I->getDebugLoc();
2228  while (Context) {
2229  auto I =
2230  Subprog2Exprs.insert({getSubprogram(Context->getScope()), {}});
2231  I.first->second.push_back(KV.first);
2233  }
2234  } else {
2235  auto I = Subprog2Exprs.insert({nullptr, {}});
2236  I.first->second.push_back(KV.first);
2237  }
2238  }
2239  for (auto &KV : Subprog2Exprs) {
2240  SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(),
2241  KV.second.end());
2242  auto Leaves = getExpressionLeaves(ExprsInSubprogram);
2243 
2245  for (Value *Leaf : Leaves)
2246  collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared);
2247 
2248  // Generate remarks for each leaf.
2249  for (auto *L : Leaves) {
2250 
2251  DebugLoc Loc = cast<Instruction>(L)->getDebugLoc();
2252  DILocation *Context = cast<Instruction>(L)->getDebugLoc();
2253  while (Context) {
2254  if (getSubprogram(Context->getScope()) == KV.first) {
2255  Loc = Context;
2256  break;
2257  }
2259  }
2260 
2261  SmallPtrSet<Value *, 8> ReusedExprs;
2262  OpInfoTy Counts, SharedCounts;
2263  std::tie(Counts, SharedCounts) =
2264  sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared);
2265 
2266  OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc,
2267  cast<Instruction>(L)->getParent());
2268 
2269  Rem << "Lowered with ";
2270  Rem << ore::NV("NumStores", Counts.NumStores) << " stores, "
2271  << ore::NV("NumLoads", Counts.NumLoads) << " loads, "
2272  << ore::NV("NumComputeOps", Counts.NumComputeOps)
2273  << " compute ops, "
2274  << ore::NV("NumExposedTransposes", Counts.NumExposedTransposes)
2275  << " exposed transposes";
2276 
2277  if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 ||
2278  SharedCounts.NumComputeOps > 0) {
2279  Rem << ",\nadditionally "
2280  << ore::NV("NumStores", SharedCounts.NumStores) << " stores, "
2281  << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, "
2282  << ore::NV("NumFPOps", SharedCounts.NumComputeOps)
2283  << " compute ops"
2284  << " are shared with other expressions";
2285  }
2286 
2287  Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL));
2288  ORE.emit(Rem);
2289  }
2290  }
2291  }
2292 
2293  std::string
2294  linearize(Value *L,
2295  const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
2296  const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2297  const DataLayout &DL) {
2298  ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L);
2299  Lin.linearizeExpr(L, 0, false, false);
2300  return Lin.getResult();
2301  }
2302  };
2303 };
2304 } // namespace
2305 
2308  auto &TTI = AM.getResult<TargetIRAnalysis>(F);
2309  OptimizationRemarkEmitter *ORE = nullptr;
2310  AAResults *AA = nullptr;
2311  DominatorTree *DT = nullptr;
2312  LoopInfo *LI = nullptr;
2313 
2314  if (!Minimal) {
2316  AA = &AM.getResult<AAManager>(F);
2317  DT = &AM.getResult<DominatorTreeAnalysis>(F);
2318  LI = &AM.getResult<LoopAnalysis>(F);
2319  }
2320 
2321  LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE);
2322  if (LMT.Visit()) {
2323  PreservedAnalyses PA;
2324  if (!Minimal) {
2325  PA.preserve<LoopAnalysis>();
2327  }
2328  return PA;
2329  }
2330  return PreservedAnalyses::all();
2331 }
2332 
2334  raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
2336  OS, MapClassName2PassName);
2337  OS << "<";
2338  if (Minimal)
2339  OS << "minimal";
2340  OS << ">";
2341 }
2342 
2343 namespace {
2344 
2345 class LowerMatrixIntrinsicsLegacyPass : public FunctionPass {
2346 public:
2347  static char ID;
2348 
2349  LowerMatrixIntrinsicsLegacyPass() : FunctionPass(ID) {
2352  }
2353 
2354  bool runOnFunction(Function &F) override {
2355  auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
2356  auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
2357  auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
2358  auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
2359  auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
2360  LowerMatrixIntrinsics LMT(F, TTI, &AA, &DT, &LI, &ORE);
2361  bool C = LMT.Visit();
2362  return C;
2363  }
2364 
2365  void getAnalysisUsage(AnalysisUsage &AU) const override {
2373  }
2374 };
2375 } // namespace
2376 
2377 static const char pass_name[] = "Lower the matrix intrinsics";
2379 INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,
2380  false, false)
2385 INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,
2387 
2389  return new LowerMatrixIntrinsicsLegacyPass();
2390 }
2391 
2392 namespace {
2393 
2394 /// A lightweight version of the matrix lowering pass that only requires TTI.
2395 /// Advanced features that require DT, AA or ORE like tiling are disabled. This
2396 /// is used to lower matrix intrinsics if the main lowering pass is not run, for
2397 /// example with -O0.
2398 class LowerMatrixIntrinsicsMinimalLegacyPass : public FunctionPass {
2399 public:
2400  static char ID;
2401 
2402  LowerMatrixIntrinsicsMinimalLegacyPass() : FunctionPass(ID) {
2405  }
2406 
2407  bool runOnFunction(Function &F) override {
2408  auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
2409  LowerMatrixIntrinsics LMT(F, TTI, nullptr, nullptr, nullptr, nullptr);
2410  bool C = LMT.Visit();
2411  return C;
2412  }
2413 
2414  void getAnalysisUsage(AnalysisUsage &AU) const override {
2416  AU.setPreservesCFG();
2417  }
2418 };
2419 } // namespace
2420 
2421 static const char pass_name_minimal[] = "Lower the matrix intrinsics (minimal)";
2423 INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsMinimalLegacyPass,
2424  "lower-matrix-intrinsics-minimal", pass_name_minimal,
2425  false, false)
2426 INITIALIZE_PASS_END(LowerMatrixIntrinsicsMinimalLegacyPass,
2428  false)
2429 
2431  return new LowerMatrixIntrinsicsMinimalLegacyPass();
2432 }
llvm::Intrinsic::getBaseName
StringRef getBaseName(ID id)
Return the LLVM name for an intrinsic, without encoded types for overloading, such as "llvm....
Definition: Function.cpp:937
i
i
Definition: README.txt:29
llvm::PreservedAnalyses
A set of analyses that are preserved following a run of a transformation pass.
Definition: PassManager.h:152
pass_name
static const char pass_name[]
Definition: LowerMatrixIntrinsics.cpp:2377
llvm::AAManager
A manager for alias analyses.
Definition: AliasAnalysis.h:876
llvm::TargetIRAnalysis
Analysis pass providing the TargetTransformInfo.
Definition: TargetTransformInfo.h:2585
llvm::Function::isIntrinsic
bool isIntrinsic() const
isIntrinsic - Returns true if the function's name starts with "llvm.".
Definition: Function.h:210
llvm::MemoryLocation::get
static MemoryLocation get(const LoadInst *LI)
Return a location with information about the memory reference by the given instruction.
Definition: MemoryLocation.cpp:35
llvm
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
BlockSize
static const int BlockSize
Definition: TarWriter.cpp:33
llvm::DILocalScope::getSubprogram
DISubprogram * getSubprogram() const
Get the subprogram for this scope.
Definition: DebugInfoMetadata.cpp:961
M
We currently emits eax Perhaps this is what we really should generate is Is imull three or four cycles eax eax The current instruction priority is based on pattern complexity The former is more complex because it folds a load so the latter will not be emitted Perhaps we should use AddedComplexity to give LEA32r a higher priority We should always try to match LEA first since the LEA matching code does some estimate to determine whether the match is profitable if we care more about code then imull is better It s two bytes shorter than movl leal On a Pentium M
Definition: README.txt:252
llvm::make_range
iterator_range< T > make_range(T x, T y)
Convenience function for iterating over sub-ranges.
Definition: iterator_range.h:53
llvm::DataLayout
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:113
llvm::AArch64PACKey::ID
ID
Definition: AArch64BaseInfo.h:818
llvm::Value::hasOneUse
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition: Value.h:434
llvm::Intrinsic::getDeclaration
Function * getDeclaration(Module *M, ID id, ArrayRef< Type * > Tys=None)
Create or insert an LLVM Function declaration for an intrinsic, and return it.
Definition: Function.cpp:1481
llvm::User::operands
op_range operands()
Definition: User.h:242
PHI
Rewrite undef for PHI
Definition: AMDGPURewriteUndefForPHI.cpp:101
llvm::AllocaInst::getAlign
Align getAlign() const
Return the alignment of the memory that is being allocated by the instruction.
Definition: Instructions.h:121
IntrinsicInst.h
llvm::Type::isPointerTy
bool isPointerTy() const
True if this is an instance of PointerType.
Definition: Type.h:237
ceil
We have fiadd patterns now but the followings have the same cost and complexity We need a way to specify the later is more profitable def def The FP stackifier should handle simple permutates to reduce number of shuffle e g ceil
Definition: README-FPStack.txt:54
llvm::AnalysisManager::getResult
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition: PassManager.h:774
DebugInfoMetadata.h
llvm::MemoryLocation::Ptr
const Value * Ptr
The address of the start of the location.
Definition: MemoryLocation.h:218
llvm::ValueMap::end
iterator end()
Definition: ValueMap.h:136
Scalar.h
llvm::PassInfoMixin< LowerMatrixIntrinsicsPass >
LowerLoad
static SDValue LowerLoad(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)
Definition: X86ISelLowering.cpp:26077
llvm::TypeSize::getFixedSize
ScalarTy getFixedSize() const
Definition: TypeSize.h:444
llvm::Function
Definition: Function.h:60
Pass.h
llvm::TargetTransformInfo::getRegisterBitWidth
TypeSize getRegisterBitWidth(RegisterKind K) const
Definition: TargetTransformInfo.cpp:644
llvm::IntrinsicInst::getIntrinsicID
Intrinsic::ID getIntrinsicID() const
Return the intrinsic ID of this intrinsic.
Definition: IntrinsicInst.h:53
LowerStore
static SDValue LowerStore(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)
Definition: X86ISelLowering.cpp:25988
llvm::raw_string_ostream
A raw_ostream that writes to an std::string.
Definition: raw_ostream.h:628
llvm::write
Error write(MCStreamer &Out, ArrayRef< std::string > Inputs)
Definition: DWP.cpp:549
llvm::PointerType::get
static PointerType * get(Type *ElementType, unsigned AddressSpace)
This constructs a pointer to an object of the specified type in a numbered address space.
Definition: Type.cpp:727
llvm::SetVector::size
size_type size() const
Determine the number of elements in the SetVector.
Definition: SetVector.h:77
llvm::Type::getScalarType
Type * getScalarType() const
If this is a vector type, return the element type, otherwise return 'this'.
Definition: Type.h:328
llvm::SmallVector
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1199
llvm::enumerate
detail::enumerator< R > enumerate(R &&TheRange)
Given an input range, returns a new range whose values are are pair (A,B) such that A is the 0-based ...
Definition: STLExtras.h:2238
llvm::PatternMatch::m_Load
OneOps_match< OpTy, Instruction::Load > m_Load(const OpTy &Op)
Matches LoadInst.
Definition: PatternMatch.h:1563
llvm::PatternMatch::m_CombineOr
match_combine_or< LTy, RTy > m_CombineOr(const LTy &L, const RTy &R)
Combine two pattern matchers matching L || R.
Definition: PatternMatch.h:218
llvm::TargetTransformInfo
This pass provides access to the codegen interfaces that are needed for IR-level transformations.
Definition: TargetTransformInfo.h:172
insertVector
static Value * insertVector(IRBuilderTy &IRB, Value *Old, Value *V, unsigned BeginIndex, const Twine &Name)
Definition: SROA.cpp:2227
ToRemove
ReachingDefAnalysis InstSet & ToRemove
Definition: ARMLowOverheadLoops.cpp:547
llvm::IRBuilder<>
llvm::AAResults::isNoAlias
bool isNoAlias(const MemoryLocation &LocA, const MemoryLocation &LocB)
A trivial helper function to check to see if the specified pointers are no-alias.
Definition: AliasAnalysis.h:348
DomTreeUpdater.h
ValueTracking.h
OptimizationRemarkEmitter.h
llvm::DominatorTree
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:166
llvm::OptimizationRemarkEmitter::allowExtraAnalysis
bool allowExtraAnalysis(StringRef PassName) const
Whether we allow for extra compile-time budget to perform more analysis to produce fewer false positi...
Definition: OptimizationRemarkEmitter.h:98
llvm::cl::Hidden
@ Hidden
Definition: CommandLine.h:140
llvm::DILocation
Debug location.
Definition: DebugInfoMetadata.h:1595
ForceFusion
static cl::opt< bool > ForceFusion("force-fuse-matrix", cl::init(false), cl::Hidden, cl::desc("Force matrix instruction fusion even if not profitable."))
llvm::MemoryLocation::Size
LocationSize Size
The maximum size of the location, in address-units, or UnknownSize if the size is not known.
Definition: MemoryLocation.h:227
llvm::Type
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
llvm::DebugLoc::getInlinedAt
DILocation * getInlinedAt() const
Definition: DebugLoc.cpp:39
llvm::sys::path::end
const_iterator end(StringRef path)
Get end iterator over path.
Definition: Path.cpp:235
llvm::sys::path::begin
const_iterator begin(StringRef path, Style style=Style::native)
Get begin iterator over path.
Definition: Path.cpp:226
llvm::LoopInfoWrapperPass
The legacy pass manager's analysis pass to compute loop information.
Definition: LoopInfo.h:1293
llvm::DominatorTreeBase::Insert
static constexpr UpdateKind Insert
Definition: GenericDomTree.h:242
T1
#define T1
Definition: Mips16ISelLowering.cpp:340
llvm::SmallSet
SmallSet - This maintains a set of unique values, optimizing for the case when the set is small (less...
Definition: SmallSet.h:136
llvm::operator!=
bool operator!=(uint64_t V1, const APInt &V2)
Definition: APInt.h:2012
Vector
So we should use XX3Form_Rcr to implement intrinsic Convert DP outs ins xscvdpsp No builtin are required Round &Convert QP DP(dword[1] is set to zero) No builtin are required Round to Quad Precision because you need to assign rounding mode in instruction Provide builtin(set f128:$vT,(int_ppc_vsx_xsrqpi f128:$vB))(set f128 yields< n x< ty > >< result > yields< ty >< result > No builtin are required Load Store Vector
Definition: README_P9.txt:497
vectors
hexagon Hexagon specific predictive commoning for HVX vectors
Definition: HexagonVectorLoopCarriedReuse.cpp:221
T
#define T
Definition: Mips16ISelLowering.cpp:341
llvm::MapVector
This class implements a map that also provides access to all stored values in a deterministic order.
Definition: MapVector.h:37
llvm::SmallPtrSet
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
Definition: SmallPtrSet.h:450
llvm::ore::NV
DiagnosticInfoOptimizationBase::Argument NV
Definition: OptimizationRemarkEmitter.h:136
llvm::SPII::Store
@ Store
Definition: SparcInstrInfo.h:33
llvm::VectorType::getElementType
Type * getElementType() const
Definition: DerivedTypes.h:422
llvm::Value::user_begin
user_iterator user_begin()
Definition: Value.h:397
llvm::successors
auto successors(MachineBasicBlock *BB)
Definition: MachineSSAContext.h:29
llvm::CallBase::arg_begin
User::op_iterator arg_begin()
Return the iterator pointing to the beginning of the argument list.
Definition: InstrTypes.h:1316
RHS
Value * RHS
Definition: X86PartialReduction.cpp:76
llvm::UnaryOperator
Definition: InstrTypes.h:101
llvm::FastMathFlags
Convenience struct for specifying and reasoning about fast-math flags.
Definition: FMF.h:21
llvm::initializeLowerMatrixIntrinsicsMinimalLegacyPassPass
void initializeLowerMatrixIntrinsicsMinimalLegacyPassPass(PassRegistry &)
llvm::LoadInst::getAlign
Align getAlign() const
Return the alignment of the access that is being performed.
Definition: Instructions.h:216
llvm::DominatorTreeBase::applyUpdates
void applyUpdates(ArrayRef< UpdateType > Updates)
Inform the dominator tree about a sequence of CFG edge insertions and deletions and perform a batch u...
Definition: GenericDomTree.h:544
LLVM_DEBUG
#define LLVM_DEBUG(X)
Definition: Debug.h:101
llvm::NVPTX::PTXLdStInstCode::VecType
VecType
Definition: NVPTX.h:121
llvm::commonAlignment
Align commonAlignment(Align A, uint64_t Offset)
Returns the alignment that satisfies both alignments.
Definition: Alignment.h:213
F
#define F(x, y, z)
Definition: MD5.cpp:55
llvm::RISCVFenceField::R
@ R
Definition: RISCVBaseInfo.h:265
llvm::DomTreeUpdater::UpdateStrategy::Lazy
@ Lazy
llvm::TileInfo
A helper struct to create IR loop nests for tiling in IR of the following form: for ColumnLoop....
Definition: MatrixUtils.h:31
llvm::BasicBlock
LLVM Basic Block Representation.
Definition: BasicBlock.h:55
getSubprogram
static DISubprogram * getSubprogram(DIScope *Scope)
Helper function to either return Scope, if it is a subprogram or the attached subprogram for a local ...
Definition: LowerMatrixIntrinsics.cpp:90
AliasAnalysis.h
MatrixBuilder.h
Context
LLVMContext & Context
Definition: NVVMIntrRange.cpp:66
llvm::dbgs
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
llvm::DominatorTree::dominates
bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
Definition: Dominators.cpp:122
llvm::X86AS::SS
@ SS
Definition: X86.h:201
llvm::BitmaskEnumDetail::Mask
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
Definition: BitmaskEnum.h:80
CommandLine.h
llvm::cast
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition: Casting.h:565
llvm::UnaryOperator::getOpcode
UnaryOps getOpcode() const
Definition: InstrTypes.h:171
extractVector
static Value * extractVector(IRBuilderTy &IRB, Value *V, unsigned BeginIndex, unsigned EndIndex, const Twine &Name)
Definition: SROA.cpp:2205
LHS
Value * LHS
Definition: X86PartialReduction.cpp:75
llvm::ConstantInt
This is the shared class of boolean and integer constants.
Definition: Constants.h:79
isSplat
static bool isSplat(Value *V)
Return true if V is a splat of a value (which is used when multiplying a matrix with a scalar).
Definition: LowerMatrixIntrinsics.cpp:98
llvm::PassRegistry::getPassRegistry
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
Definition: PassRegistry.cpp:24
MatrixUtils.h
llvm::PatternMatch::match
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
isZero
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Definition: Lint.cpp:524
llvm::AAResults
Definition: AliasAnalysis.h:294
E
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
Operation
PowerPC Reduce CR logical Operation
Definition: PPCReduceCRLogicals.cpp:735
llvm::createLowerMatrixIntrinsicsPass
Pass * createLowerMatrixIntrinsicsPass()
Definition: LowerMatrixIntrinsics.cpp:2388
llvm::User
Definition: User.h:44
C
(vector float) vec_cmpeq(*A, *B) C
Definition: README_ALTIVEC.txt:86
llvm::ARM_PROC::A
@ A
Definition: ARMBaseInfo.h:34
TileSize
static cl::opt< unsigned > TileSize("fuse-matrix-tile-size", cl::init(4), cl::Hidden, cl::desc("Tile size for matrix instruction fusion using square-shaped tiles."))
llvm::CallBase::getCalledFunction
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
Definition: InstrTypes.h:1396
llvm::BasicBlock::begin
iterator begin()
Instruction iterator methods.
Definition: BasicBlock.h:306
llvm::operator+=
std::string & operator+=(std::string &buffer, StringRef string)
Definition: StringRef.h:895
SI
@ SI
Definition: SIInstrInfo.cpp:7882
llvm::AnalysisUsage
Represent the analysis usage information of a pass.
Definition: PassAnalysisSupport.h:47
llvm::LocationSize::getValue
uint64_t getValue() const
Definition: MemoryLocation.h:159
t
bitcast float %x to i32 %s=and i32 %t, 2147483647 %d=bitcast i32 %s to float ret float %d } declare float @fabsf(float %n) define float @bar(float %x) nounwind { %d=call float @fabsf(float %x) ret float %d } This IR(from PR6194):target datalayout="e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v64:64:64-v128:128:128-a0:0:64-s0:64:64-f80:128:128-n8:16:32:64-S128" target triple="x86_64-apple-darwin10.0.0" %0=type { double, double } %struct.float3=type { float, float, float } define void @test(%0, %struct.float3 *nocapture %res) nounwind noinline ssp { entry:%tmp18=extractvalue %0 %0, 0 t
Definition: README-SSE.txt:788
llvm::ms_demangle::QualifierMangleMode::Result
@ Result
llvm::BitTracker
Definition: BitTracker.h:35
llvm::Value::uses
iterator_range< use_iterator > uses()
Definition: Value.h:376
false
Definition: StackSlotColoring.cpp:141
llvm::MaybeAlign
This struct is a compact representation of a valid (power of two) or undefined (0) alignment.
Definition: Alignment.h:117
B
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
llvm::BinaryOperator::getOpcode
BinaryOps getOpcode() const
Definition: InstrTypes.h:392
llvm::PatternMatch::m_ConstantInt
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
Definition: PatternMatch.h:147
llvm::Instruction
Definition: Instruction.h:42
llvm::Type::getScalarSizeInBits
unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
Definition: Type.cpp:189
llvm::DominatorTreeWrapperPass
Legacy analysis pass which computes a DominatorTree.
Definition: Dominators.h:302
llvm::raw_ostream
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition: raw_ostream.h:52
llvm::LowerMatrixIntrinsicsPass::run
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
Definition: LowerMatrixIntrinsics.cpp:2306
llvm::codeview::EncodedFramePtrReg::BasePtr
@ BasePtr
llvm::raw_ostream::flush
void flush()
Definition: raw_ostream.h:185
llvm::UndefValue::get
static UndefValue * get(Type *T)
Static factory methods - Return an 'undef' object of the specified type.
Definition: Constants.cpp:1713
llvm::DomTreeUpdater
Definition: DomTreeUpdater.h:28
llvm::ConstantInt::get
static Constant * get(Type *Ty, uint64_t V, bool IsSigned=false)
If Ty is a vector type, return a Constant with a splat of the given value.
Definition: Constants.cpp:879
LoopUtils.h
llvm::getUnderlyingObject
const Value * getUnderlyingObject(const Value *V, unsigned MaxLookup=6)
This method strips off any GEP address adjustments and pointer casts from the specified value,...
Definition: ValueTracking.cpp:4498
llvm::SmallVectorImpl::append
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
Definition: SmallVector.h:687
PatternMatch.h
llvm::TargetTransformInfo::RGK_FixedWidthVector
@ RGK_FixedWidthVector
Definition: TargetTransformInfo.h:964
llvm::FixedVectorType::get
static FixedVectorType * get(Type *ElementType, unsigned NumElts)
Definition: Type.cpp:684
llvm::StoreInst::getAlign
Align getAlign() const
Definition: Instructions.h:341
llvm::Align
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition: Alignment.h:39
llvm::FastMathFlags::setAllowContract
void setAllowContract(bool B=true)
Definition: FMF.h:92
llvm::ValueMap::count
size_type count(const KeyT &Val) const
Return 1 if the specified key is in the map, 0 otherwise.
Definition: ValueMap.h:152
llvm::Value::use_empty
bool use_empty() const
Definition: Value.h:344
llvm::CallingConv::ID
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
MatrixLayoutTy
MatrixLayoutTy
Definition: LowerMatrixIntrinsics.cpp:75
INITIALIZE_PASS_END
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:58
CFG.h
LoopInfo.h
llvm::sort
void sort(IteratorTy Start, IteratorTy End)
Definition: STLExtras.h:1657
llvm::function_ref
An efficient, type-erasing, non-owning reference to a callable.
Definition: STLFunctionalExtras.h:36
llvm::StringRef::empty
constexpr bool empty() const
empty - Check if the string is empty.
Definition: StringRef.h:134
llvm::SmallPtrSetImplBase::empty
bool empty() const
Definition: SmallPtrSet.h:92
llvm::VectorType
Base class of all SIMD vector types.
Definition: DerivedTypes.h:389
VectorUtils.h
llvm::cl::opt< bool >
llvm::StoreInst::isVolatile
bool isVolatile() const
Return true if this is a store to a volatile memory location.
Definition: Instructions.h:333
llvm::StoreInst
An instruction for storing to memory.
Definition: Instructions.h:297
llvm::cl::values
ValuesClass values(OptsTy... Options)
Helper to build a ValuesClass by forwarding a variable number of arguments as an initializer list to ...
Definition: CommandLine.h:705
llvm::Instruction::eraseFromParent
SymbolTableList< Instruction >::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
Definition: Instruction.cpp:81
llvm::AMDGPU::Hwreg::Offset
Offset
Definition: SIDefines.h:416
llvm::LowerMatrixIntrinsicsPass::printPipeline
void printPipeline(raw_ostream &OS, function_ref< StringRef(StringRef)> MapClassName2PassName)
Definition: LowerMatrixIntrinsics.cpp:2333
llvm::MapVector::find
iterator find(const KeyT &Key)
Definition: MapVector.h:148
llvm::getPointerOperand
const Value * getPointerOperand(const Value *V)
A helper function that returns the pointer operand of a load, store or GEP instruction.
Definition: Instructions.h:5374
uint64_t
llvm::TargetTransformInfoWrapperPass
Wrapper pass for TargetTransformInfo.
Definition: TargetTransformInfo.h:2641
D
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
intrinsics
expand Expand reduction intrinsics
Definition: ExpandReductions.cpp:198
llvm::ARM_MB::ST
@ ST
Definition: ARMBaseInfo.h:73
INITIALIZE_PASS_DEPENDENCY
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
llvm::PreservedAnalyses::preserve
void preserve()
Mark an analysis as preserved.
Definition: PassManager.h:173
llvm::addStringMetadataToLoop
void addStringMetadataToLoop(Loop *TheLoop, const char *MDString, unsigned V=0)
Set input string into loop metadata by keeping other values intact.
Definition: LoopUtils.cpp:215
llvm::TargetTransformInfo::getRegisterClassForType
unsigned getRegisterClassForType(bool Vector, Type *Ty=nullptr) const
Definition: TargetTransformInfo.cpp:635
llvm::omp::AddressSpace::Shared
@ Shared
llvm::DenseMap
Definition: DenseMap.h:714
llvm::PatternMatch::m_Store
TwoOps_match< ValueOpTy, PointerOpTy, Instruction::Store > m_Store(const ValueOpTy &ValueOp, const PointerOpTy &PointerOp)
Matches StoreInst.
Definition: PatternMatch.h:1570
llvm::codeview::FrameCookieKind::Copy
@ Copy
llvm::ValueMap::erase
bool erase(const KeyT &Val)
Definition: ValueMap.h:191
llvm::LoopInfoBase::getLoopFor
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
Definition: LoopInfo.h:992
I
#define I(x, y, z)
Definition: MD5.cpp:58
llvm::AArch64PACKey::IB
@ IB
Definition: AArch64BaseInfo.h:820
llvm::cl::init
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:447
llvm::make_early_inc_range
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition: STLExtras.h:716
llvm::concatenateVectors
Value * concatenateVectors(IRBuilderBase &Builder, ArrayRef< Value * > Vecs)
Concatenate a list of vectors.
Definition: VectorUtils.cpp:1026
IRBuilder.h
assert
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
FuseMatrix
static cl::opt< bool > FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden, cl::desc("Enable/disable fusing matrix instructions."))
llvm::OptimizationRemarkEmitter::emit
void emit(DiagnosticInfoOptimizationBase &OptDiag)
Output the remark via the diagnostic handler and to the optimization record file.
Definition: OptimizationRemarkEmitter.cpp:77
llvm::ValueMap::insert
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: ValueMap.h:173
llvm::operator==
bool operator==(uint64_t V1, const APInt &V2)
Definition: APInt.h:2010
llvm::Instruction::getFastMathFlags
FastMathFlags getFastMathFlags() const LLVM_READONLY
Convenience function for getting all the fast-math flags, which must be an operator which supports th...
Definition: Instruction.cpp:319
Ptr
@ Ptr
Definition: TargetLibraryInfo.cpp:60
llvm::TTI
TargetTransformInfo TTI
Definition: TargetTransformInfo.h:167
llvm::Type::isVoidTy
bool isVoidTy() const
Return true if this is 'void'.
Definition: Type.h:139
llvm::ArrayType::get
static ArrayType * get(Type *ElementType, uint64_t NumElements)
This static method is the primary way to construct an ArrayType.
Definition: Type.cpp:638
MatrixLayoutTy::ColumnMajor
@ ColumnMajor
PrintAfterTransposeOpt
static cl::opt< bool > PrintAfterTransposeOpt("matrix-print-after-transpose-opt", cl::init(false))
LowerMatrixIntrinsics.h
llvm::CallBase::arg_end
User::op_iterator arg_end()
Return the iterator pointing to the end of the argument list.
Definition: InstrTypes.h:1322
llvm::SmallPtrSetImpl::count
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
Definition: SmallPtrSet.h:383
llvm::SmallSet::erase
bool erase(const T &V)
Definition: SmallSet.h:206
Builder
assume Assume Builder
Definition: AssumeBundleBuilder.cpp:651
llvm::PatternMatch::m_Value
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
Definition: PatternMatch.h:76
llvm::SetVector::insert
bool insert(const value_type &X)
Insert a new element into the SetVector.
Definition: SetVector.h:141
getType
static M68kRelType getType(unsigned Kind, MCSymbolRefExpr::VariantKind &Modifier, bool &IsPCRel)
Definition: M68kELFObjectWriter.cpp:48
llvm::size
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
Definition: STLExtras.h:1690
llvm::Function::getIntrinsicID
Intrinsic::ID getIntrinsicID() const LLVM_READONLY
getIntrinsicID - This method returns the ID number of the specified function, or Intrinsic::not_intri...
Definition: Function.h:205
llvm::ArrayRef
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: APInt.h:32
llvm::LoopInfo
Definition: LoopInfo.h:1108
llvm::BinaryOperator
Definition: InstrTypes.h:188
llvm::OptimizationRemarkEmitter
The optimization diagnostic interface.
Definition: OptimizationRemarkEmitter.h:33
Matrix
Live Register Matrix
Definition: LiveRegMatrix.cpp:44
llvm::min
Expected< ExpressionValue > min(const ExpressionValue &Lhs, const ExpressionValue &Rhs)
Definition: FileCheck.cpp:357
llvm::MapVector::insert
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: MapVector.h:118
Mul
BinaryOperator * Mul
Definition: X86PartialReduction.cpp:70
llvm::any_of
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1716
llvm::CallBase::getParamAlign
MaybeAlign getParamAlign(unsigned ArgNo) const
Extract the alignment for a call or parameter (0=unknown).
Definition: InstrTypes.h:1742
DataLayout.h
llvm::AnalysisUsage::setPreservesCFG
void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition: Pass.cpp:265
llvm::StringRef
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:50
llvm_unreachable
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
Definition: ErrorHandling.h:143
llvm::Value::getType
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
llvm::AnalysisUsage::addPreserved
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
Definition: PassAnalysisSupport.h:98
llvm::Value::replaceAllUsesWith
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:532
getParent
static const Function * getParent(const Value *V)
Definition: BasicAliasAnalysis.cpp:804
llvm::ms_demangle::IntrinsicFunctionKind::New
@ New
DL
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Definition: AArch64SLSHardening.cpp:76
clEnumValN
#define clEnumValN(ENUMVAL, FLAGNAME, DESC)
Definition: CommandLine.h:680
llvm::MatrixBuilder
Definition: MatrixBuilder.h:33
S
add sub stmia L5 ldr r0 bl L_printf $stub Instead of a and a wouldn t it be better to do three moves *Return an aggregate type is even return S
Definition: README.txt:210
pass_name_minimal
static const char pass_name_minimal[]
Definition: LowerMatrixIntrinsics.cpp:2421
llvm::TargetTransformInfo::getNumberOfRegisters
unsigned getNumberOfRegisters(unsigned ClassID) const
Definition: TargetTransformInfo.cpp:631
llvm::ifs::IFSSymbolType::Func
@ Func
llvm::Value::getName
StringRef getName() const
Return a constant reference to the value's name.
Definition: Value.cpp:308
llvm::ValueMap
See the file comment.
Definition: ValueMap.h:85
llvm::LoadInst
An instruction for reading from memory.
Definition: Instructions.h:173
users
iv users
Definition: IVUsers.cpp:48
llvm::MapVector::end
iterator end()
Definition: MapVector.h:72
llvm::ConstantInt::getZExtValue
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
Definition: Constants.h:142
llvm::StringRef::size
constexpr size_t size() const
size - Get the string size.
Definition: StringRef.h:137
llvm::RecurKind::FMulAdd
@ FMulAdd
Fused multiply-add of floats (a * b + c).
runOnFunction
static bool runOnFunction(Function &F, bool PostInlining)
Definition: EntryExitInstrumenter.cpp:85
llvm::Twine
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
Definition: Twine.h:81
Alignment.h
llvm::OptimizationRemarkEmitterWrapperPass
OptimizationRemarkEmitter legacy analysis pass.
Definition: OptimizationRemarkEmitter.h:146
MatrixLayout
static cl::opt< MatrixLayoutTy > MatrixLayout("matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor), cl::desc("Sets the default matrix layout"), cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major", "Use column-major layout"), clEnumValN(MatrixLayoutTy::RowMajor, "row-major", "Use row-major layout")))
llvm::GraphProgram::Name
Name
Definition: GraphWriter.h:50
AllowContractEnabled
static cl::opt< bool > AllowContractEnabled("matrix-allow-contract", cl::init(false), cl::Hidden, cl::desc("Allow the use of FMAs if available and profitable. This may " "result in different results, due to less rounding error."))
llvm::DIScope
Base class for scope-like contexts.
Definition: DebugInfoMetadata.h:509
llvm::AMDGPU::SendMsg::Op
Op
Definition: SIDefines.h:348
llvm::PreservedAnalyses::all
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: PassManager.h:158
MatrixLayoutTy::RowMajor
@ RowMajor
llvm::SPII::Load
@ Load
Definition: SparcInstrInfo.h:32
llvm::TypeSize
Definition: TypeSize.h:435
Function.h
llvm::Value::hasNUses
bool hasNUses(unsigned N) const
Return true if this Value has exactly N uses.
Definition: Value.cpp:148
llvm::SetVector< T, SmallVector< T, N >, SmallDenseSet< T, N > >::count
size_type count(const key_type &key) const
Count the number of elements of a given key in the SetVector.
Definition: SetVector.h:215
llvm::Type::getPointerTo
PointerType * getPointerTo(unsigned AddrSpace=0) const
Return a pointer to the current type.
Definition: Type.cpp:774
llvm::ValueMap::find
iterator find(const KeyT &Val)
Definition: ValueMap.h:156
llvm::MCID::Add
@ Add
Definition: MCInstrDesc.h:185
llvm::SmallSet::insert
std::pair< const_iterator, bool > insert(const T &V)
insert - Insert an element into the set if it isn't already there.
Definition: SmallSet.h:178
llvm::ReversePostOrderTraversal
Definition: PostOrderIterator.h:291
llvm::IntrinsicInst
A wrapper class for inspecting calls to intrinsic functions.
Definition: IntrinsicInst.h:46
llvm::DominatorTreeAnalysis
Analysis pass which computes a DominatorTree.
Definition: Dominators.h:267
llvm::BasicBlock::reverse_iterator
InstListType::reverse_iterator reverse_iterator
Definition: BasicBlock.h:89
llvm::X86II::TA
@ TA
Definition: X86BaseInfo.h:813
llvm::OptimizationRemark
Diagnostic information for applied optimization remarks.
Definition: DiagnosticInfo.h:689
llvm::Pass
Pass interface - Implemented by all 'passes'.
Definition: Pass.h:91
Instructions.h
PostOrderIterator.h
llvm::LoadInst::isVolatile
bool isVolatile() const
Return true if this is a load from a volatile memory location.
Definition: Instructions.h:210
llvm::initializeLowerMatrixIntrinsicsLegacyPassPass
void initializeLowerMatrixIntrinsicsLegacyPassPass(PassRegistry &)
SmallVector.h
m_AnyMul
auto m_AnyMul(const LTy &L, const RTy &R)
Match any mul operation (fp or integer).
Definition: LowerMatrixIntrinsics.cpp:106
BT
BitTracker BT
Definition: BitTracker.cpp:73
llvm::CallBase::getArgOperand
Value * getArgOperand(unsigned i) const
Definition: InstrTypes.h:1341
N
#define N
llvm::AAResultsWrapperPass
A wrapper pass to provide the legacy pass manager access to a suitably prepared AAResults object.
Definition: AliasAnalysis.h:924
llvm::Instruction::getParent
const BasicBlock * getParent() const
Definition: Instruction.h:91
llvm::SmallVectorImpl::pop_back_val
T pop_back_val()
Definition: SmallVector.h:677
llvm::to_string
std::string to_string(const T &Value)
Definition: ScopedPrinter.h:85
TargetTransformInfo.h
llvm::iterator_range
A range adaptor for a pair of iterators.
Definition: iterator_range.h:30
llvm::PHINode
Definition: Instructions.h:2698
llvm::BasicBlock::getTerminator
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.h:119
DEBUG_TYPE
#define DEBUG_TYPE
Definition: LowerMatrixIntrinsics.cpp:54
minimal
lower matrix intrinsics minimal
Definition: LowerMatrixIntrinsics.cpp:2427
llvm::DISubprogram
Subprogram description.
Definition: DebugInfoMetadata.h:1841
llvm::SmallVectorImpl< Instruction * >
INITIALIZE_PASS_BEGIN
INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, false, false) INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass
llvm::reverse
auto reverse(ContainerTy &&C)
Definition: STLExtras.h:485
llvm::SmallPtrSetImpl< Instruction * >
llvm::SmallSetVector
A SetVector that performs no allocations if smaller than a certain size.
Definition: SetVector.h:307
llvm::AnalysisManager
A container for analyses that lazily runs them and caches their results.
Definition: InstructionSimplify.h:42
llvm::FunctionPass
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:308
llvm::CallInst
This class represents a function call, abstracting a target machine's calling convention.
Definition: Instructions.h:1473
BB
Common register allocation spilling lr str ldr sxth r3 ldr mla r4 can lr mov lr str ldr sxth r3 mla r4 and then merge mul and lr str ldr sxth r3 mla r4 It also increase the likelihood the store may become dead bb27 Successors according to LLVM BB
Definition: README.txt:39
GEP
Hexagon Common GEP
Definition: HexagonCommonGEP.cpp:171
llvm::AnalysisUsage::addRequired
AnalysisUsage & addRequired()
Definition: PassAnalysisSupport.h:75
TileUseLoops
static cl::opt< bool > TileUseLoops("fuse-matrix-use-loops", cl::init(false), cl::Hidden, cl::desc("Generate loop nest for tiling."))
llvm::DebugLoc
A debug info location.
Definition: DebugLoc.h:33
llvm::AllocaInst
an instruction to allocate memory on the stack
Definition: Instructions.h:58
llvm::FastMathFlags::allowContract
bool allowContract() const
Definition: FMF.h:71
llvm::User::getOperand
Value * getOperand(unsigned i) const
Definition: User.h:169
llvm::cl::desc
Definition: CommandLine.h:413
llvm::createLowerMatrixIntrinsicsMinimalPass
Pass * createLowerMatrixIntrinsicsMinimalPass()
Definition: LowerMatrixIntrinsics.cpp:2430
llvm::createSequentialMask
llvm::SmallVector< int, 16 > createSequentialMask(unsigned Start, unsigned NumInts, unsigned NumUndefs)
Create a sequential shuffle mask.
Definition: VectorUtils.cpp:971
llvm::SetVector< Value * >
llvm::ConstantAggregateZero::get
static ConstantAggregateZero * get(Type *Ty)
Definition: Constants.cpp:1587
BasicBlockUtils.h
llvm::PatternMatch::m_FMul
BinaryOp_match< LHS, RHS, Instruction::FMul > m_FMul(const LHS &L, const RHS &R)
Definition: PatternMatch.h:1051
llvm::SplitBlock
BasicBlock * SplitBlock(BasicBlock *Old, Instruction *SplitPt, DominatorTree *DT, LoopInfo *LI=nullptr, MemorySSAUpdater *MSSAU=nullptr, const Twine &BBName="", bool Before=false)
Split the specified block at the specified instruction.
Definition: BasicBlockUtils.cpp:918
llvm::pdb::PDB_SymType::Block
@ Block
InitializePasses.h
llvm::OptimizationRemarkEmitterAnalysis
Definition: OptimizationRemarkEmitter.h:164
llvm::SmallSet::empty
bool empty() const
Definition: SmallSet.h:158
llvm::Value
LLVM Value Representation.
Definition: Value.h:74
Debug.h
llvm::Value::users
iterator_range< user_iterator > users()
Definition: Value.h:421
llvm::StringRef::drop_front
StringRef drop_front(size_t N=1) const
Return a StringRef equal to 'this' but with the first N elements dropped.
Definition: StringRef.h:601
llvm::MemoryLocation
Representation for a specific memory location.
Definition: MemoryLocation.h:210
llvm::LoopAnalysis
Analysis pass that exposes the LoopInfo for a function.
Definition: LoopInfo.h:1268
llvm::PatternMatch::m_Mul
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
Definition: PatternMatch.h:1045
llvm::DominatorTreeBase::Delete
static constexpr UpdateKind Delete
Definition: GenericDomTree.h:243
llvm::Use
A Use represents the edge between a Value definition and its users.
Definition: Use.h:43
llvm::Type::getPrimitiveSizeInBits
TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
Definition: Type.cpp:164
llvm::SmallPtrSetImpl::insert
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:365
llvm::PoisonValue::get
static PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
Definition: Constants.cpp:1732