LLVM  14.0.0git
LowerMatrixIntrinsics.cpp
Go to the documentation of this file.
1 //===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Lower matrix intrinsics to vector operations.
10 //
11 // TODO:
12 // * Improve fusion:
13 // * Support more cases, e.g. multiply-add, multiply-sub, operands/results
14 // transposed.
15 // * Improve cost-modeling, e.g. choose different number of rows/columns
16 // columns for tiles, consider cost of copies on alias.
17 //
18 //===----------------------------------------------------------------------===//
19 
21 #include "llvm/ADT/GraphTraits.h"
23 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/IR/CFG.h"
31 #include "llvm/IR/DataLayout.h"
33 #include "llvm/IR/Function.h"
34 #include "llvm/IR/IRBuilder.h"
35 #include "llvm/IR/Instructions.h"
36 #include "llvm/IR/IntrinsicInst.h"
37 #include "llvm/IR/MatrixBuilder.h"
38 #include "llvm/IR/PatternMatch.h"
39 #include "llvm/InitializePasses.h"
40 #include "llvm/Pass.h"
41 #include "llvm/Support/Alignment.h"
43 #include "llvm/Support/Debug.h"
44 #include "llvm/Transforms/Scalar.h"
48 
49 using namespace llvm;
50 using namespace PatternMatch;
51 
52 #define DEBUG_TYPE "lower-matrix-intrinsics"
53 
54 static cl::opt<bool>
55  FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden,
56  cl::desc("Enable/disable fusing matrix instructions."));
57 // TODO: Allow and use non-square tiles.
59  "fuse-matrix-tile-size", cl::init(4), cl::Hidden,
60  cl::desc(
61  "Tile size for matrix instruction fusion using square-shaped tiles."));
62 static cl::opt<bool> TileUseLoops("fuse-matrix-use-loops", cl::init(false),
63  cl::Hidden,
64  cl::desc("Generate loop nest for tiling."));
66  "force-fuse-matrix", cl::init(false), cl::Hidden,
67  cl::desc("Force matrix instruction fusion even if not profitable."));
69  "matrix-allow-contract", cl::init(false), cl::Hidden,
70  cl::desc("Allow the use of FMAs if available and profitable. This may "
71  "result in different results, due to less rounding error."));
72 
74 
76  "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor),
77  cl::desc("Sets the default matrix layout"),
79  "Use column-major layout"),
81  "Use row-major layout")));
82 
83 /// Helper function to either return Scope, if it is a subprogram or the
84 /// attached subprogram for a local scope.
86  if (auto *Subprogram = dyn_cast<DISubprogram>(Scope))
87  return Subprogram;
88  return cast<DILocalScope>(Scope)->getSubprogram();
89 }
90 
91 namespace {
92 
93 // Given an element pointer \p BasePtr to the start of a (sub) matrix, compute
94 // the start address of vector \p VecIdx with type (\p EltType x \p NumElements)
95 // assuming \p Stride elements between start two consecutive vectors.
96 // \p Stride must be >= \p NumElements.
97 // For column-major matrixes, the function computes the address of a column
98 // vectors and \p NumElements must be set to the number of elements in a column
99 // (= number of rows of the matrix). For row-major matrixes, the function
100 // computes the address of a row vector and \p NumElements must be set to the
101 // number of elements in a column (= number of columns of the matrix).
102 //
103 // Consider a 4x4 matrix in column-mjaor layout like below
104 //
105 // 0 1 2 3
106 // 0 v_0_0 v_0_1 v_0_2 v_0_3
107 // 1 v_1_0 v_1_1 v_1_2 v_1_3
108 // 2 v_2_0 v_2_1 v_2_2 v_2_3
109 // 3 v_3_0 v_3_1 v_3_2 v_3_3
110 
111 // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1,
112 // we need a pointer to the first element of the submatrix as base pointer.
113 // Then we can use computeVectorAddr to compute the addresses for the columns
114 // of the sub-matrix.
115 //
116 // Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..)
117 // -> just returns Base
118 // Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..)
119 // -> returns Base + (1 * 4)
120 // Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..)
121 // -> returns Base + (2 * 4)
122 //
123 // The graphic below illustrates the number of elements in a column (marked
124 // with |) and the number of skipped elements (marked with }).
125 //
126 // v_0_0 v_0_1 {v_0_2 {v_0_3
127 // Base Col 1 Col 2
128 // | | |
129 // v_1_0 |v_1_1 |v_1_2 |v_1_3
130 // v_2_0 |v_2_1 |v_2_2 |v_2_3
131 // v_3_0 {v_3_1 {v_3_2 v_3_3
132 //
133 Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride,
134  unsigned NumElements, Type *EltType,
135  IRBuilder<> &Builder) {
136 
137  assert((!isa<ConstantInt>(Stride) ||
138  cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) &&
139  "Stride must be >= the number of elements in the result vector.");
140  unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
141 
142  // Compute the start of the vector with index VecIdx as VecIdx * Stride.
143  Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start");
144 
145  // Get pointer to the start of the selected vector. Skip GEP creation,
146  // if we select vector 0.
147  if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->isZero())
148  VecStart = BasePtr;
149  else
150  VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep");
151 
152  // Cast elementwise vector start pointer to a pointer to a vector
153  // (EltType x NumElements)*.
154  auto *VecType = FixedVectorType::get(EltType, NumElements);
155  Type *VecPtrType = PointerType::get(VecType, AS);
156  return Builder.CreatePointerCast(VecStart, VecPtrType, "vec.cast");
157 }
158 
159 /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
160 ///
161 /// Currently, the lowering for each matrix intrinsic is done as follows:
162 /// 1. Propagate the shape information from intrinsics to connected
163 /// instructions.
164 /// 2. Lower instructions with shape information (assuming column-major layout).
165 /// The lowering works similarly using row-major layout.
166 /// 2.1. Get column vectors for each argument. If we already lowered the
167 /// definition of an argument, use the produced column vectors directly.
168 /// If not, split the operand vector containing an embedded matrix into
169 /// a set of column vectors,
170 /// 2.2. Lower the instruction in terms of column major operations, which
171 /// yields a set of column vectors containing result matrix. Note that we
172 /// lower all instructions that have shape information. Besides the
173 /// intrinsics, this includes stores for example.
174 /// 2.3. Update uses of the lowered instruction. If we have shape information
175 /// for a user, there is nothing to do, as we will look up the result
176 /// column matrix when lowering the user. For other uses, we embed the
177 /// result matrix in a flat vector and update the use.
178 /// 2.4. Cache the result column matrix for the instruction we lowered
179 /// 3. After we lowered all instructions in a function, remove the now
180 /// obsolete instructions.
181 ///
182 class LowerMatrixIntrinsics {
183  Function &Func;
184  const DataLayout &DL;
185  const TargetTransformInfo &TTI;
186  AliasAnalysis *AA;
187  DominatorTree *DT;
188  LoopInfo *LI;
190 
191  /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation.
192  struct OpInfoTy {
193  /// Number of stores emitted to generate this matrix.
194  unsigned NumStores = 0;
195  /// Number of loads emitted to generate this matrix.
196  unsigned NumLoads = 0;
197  /// Number of compute operations emitted to generate this matrix.
198  unsigned NumComputeOps = 0;
199  /// Most of the time transposes can be fused with matrix multiplies or can
200  /// be folded away via algebraic simplifications. This is the number of
201  /// transposes that we failed to make "free" via such optimizations.
202  unsigned NumExposedTransposes = 0;
203 
204  OpInfoTy &operator+=(const OpInfoTy &RHS) {
205  NumStores += RHS.NumStores;
206  NumLoads += RHS.NumLoads;
207  NumComputeOps += RHS.NumComputeOps;
208  NumExposedTransposes += RHS.NumExposedTransposes;
209  return *this;
210  }
211  };
212 
213  /// Wrapper class representing a matrix as a set of vectors, either in row or
214  /// column major layout. All vectors must have the same vector type.
215  class MatrixTy {
216  SmallVector<Value *, 16> Vectors;
217 
218  OpInfoTy OpInfo;
219 
220  bool IsColumnMajor = true;
221 
222  public:
223  MatrixTy()
224  : Vectors(),
225  IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
226  MatrixTy(ArrayRef<Value *> Vectors)
227  : Vectors(Vectors.begin(), Vectors.end()),
228  IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
229  MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy)
230  : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {
231 
232  unsigned D = isColumnMajor() ? NumColumns : NumRows;
233  for (unsigned J = 0; J < D; ++J)
235  EltTy, isColumnMajor() ? NumRows : NumColumns)));
236  }
237 
238  Value *getVector(unsigned i) const { return Vectors[i]; }
239  Value *getColumn(unsigned i) const {
240  assert(isColumnMajor() && "only supported for column-major matrixes");
241  return Vectors[i];
242  }
243  Value *getRow(unsigned i) const {
244  assert(!isColumnMajor() && "only supported for row-major matrixes");
245  return Vectors[i];
246  }
247 
248  void setVector(unsigned i, Value *V) { Vectors[i] = V; }
249 
250  Type *getElementType() const { return getVectorTy()->getElementType(); }
251 
252  unsigned getNumVectors() const {
253  if (isColumnMajor())
254  return getNumColumns();
255  return getNumRows();
256  }
257 
258  unsigned getNumColumns() const {
259  if (isColumnMajor())
260  return Vectors.size();
261  else {
262  assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
263  return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
264  }
265  }
266  unsigned getNumRows() const {
267  if (isColumnMajor()) {
268  assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
269  return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
270  } else
271  return Vectors.size();
272  }
273 
274  void addVector(Value *V) { Vectors.push_back(V); }
275  VectorType *getColumnTy() {
276  assert(isColumnMajor() && "only supported for column-major matrixes");
277  return getVectorTy();
278  }
279 
280  VectorType *getVectorTy() const {
281  return cast<VectorType>(Vectors[0]->getType());
282  }
283 
285  assert(isColumnMajor() &&
286  "columns() only supported for column-major matrixes");
287  return make_range(Vectors.begin(), Vectors.end());
288  }
289 
291  return make_range(Vectors.begin(), Vectors.end());
292  }
293 
294  /// Embed the vectors of the matrix into a flat vector by concatenating
295  /// them.
296  Value *embedInVector(IRBuilder<> &Builder) const {
297  return Vectors.size() == 1 ? Vectors[0]
298  : concatenateVectors(Builder, Vectors);
299  }
300 
301  MatrixTy &addNumLoads(unsigned N) {
302  OpInfo.NumLoads += N;
303  return *this;
304  }
305 
306  void setNumLoads(unsigned N) { OpInfo.NumLoads = N; }
307 
308  MatrixTy &addNumStores(unsigned N) {
309  OpInfo.NumStores += N;
310  return *this;
311  }
312 
313  MatrixTy &addNumExposedTransposes(unsigned N) {
314  OpInfo.NumExposedTransposes += N;
315  return *this;
316  }
317 
318  MatrixTy &addNumComputeOps(unsigned N) {
319  OpInfo.NumComputeOps += N;
320  return *this;
321  }
322 
323  unsigned getNumStores() const { return OpInfo.NumStores; }
324  unsigned getNumLoads() const { return OpInfo.NumLoads; }
325  unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; }
326 
327  const OpInfoTy &getOpInfo() const { return OpInfo; }
328 
329  bool isColumnMajor() const { return IsColumnMajor; }
330 
331  unsigned getStride() const {
332  if (isColumnMajor())
333  return getNumRows();
334  return getNumColumns();
335  }
336 
337  /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the
338  /// matrix is column-major, the result vector is extracted from a column
339  /// vector, otherwise from a row vector.
340  Value *extractVector(unsigned I, unsigned J, unsigned NumElts,
341  IRBuilder<> &Builder) const {
342  Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I);
343  return Builder.CreateShuffleVector(
344  Vec, createSequentialMask(isColumnMajor() ? I : J, NumElts, 0),
345  "block");
346  }
347  };
348 
349  struct ShapeInfo {
350  unsigned NumRows;
351  unsigned NumColumns;
352 
353  bool IsColumnMajor;
354 
355  ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0)
356  : NumRows(NumRows), NumColumns(NumColumns),
357  IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
358 
359  ShapeInfo(Value *NumRows, Value *NumColumns)
360  : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(),
361  cast<ConstantInt>(NumColumns)->getZExtValue()) {}
362 
363  bool operator==(const ShapeInfo &other) {
364  return NumRows == other.NumRows && NumColumns == other.NumColumns;
365  }
366  bool operator!=(const ShapeInfo &other) { return !(*this == other); }
367 
368  /// Returns true if shape-information is defined, meaning both dimensions
369  /// are != 0.
370  operator bool() const {
371  assert(NumRows == 0 || NumColumns != 0);
372  return NumRows != 0;
373  }
374 
375  unsigned getStride() const {
376  if (IsColumnMajor)
377  return NumRows;
378  return NumColumns;
379  }
380 
381  unsigned getNumVectors() const {
382  if (IsColumnMajor)
383  return NumColumns;
384  return NumRows;
385  }
386  };
387 
388  /// Maps instructions to their shape information. The shape information
389  /// describes the shape to be used while lowering. This matches the shape of
390  /// the result value of the instruction, with the only exceptions being store
391  /// instructions and the matrix_column_major_store intrinsics. For those, the
392  /// shape information indicates that those instructions should be lowered
393  /// using shape information as well. A ValueMap is used so that when
394  /// sub-passes like optimizeTransposes performs RAUW the map stays
395  /// up-to-date.
397 
398  /// List of instructions to remove. While lowering, we are not replacing all
399  /// users of a lowered instruction, if shape information is available and
400  /// those need to be removed after we finished lowering.
402 
403  /// Map from instructions to their produced column matrix.
404  MapVector<Value *, MatrixTy> Inst2ColumnMatrix;
405 
406 private:
407  static FastMathFlags getFastMathFlags(Instruction *Inst) {
408  FastMathFlags FMF;
409 
410  if (isa<FPMathOperator>(*Inst))
411  FMF = Inst->getFastMathFlags();
412 
414 
415  return FMF;
416  }
417 
418 public:
419  LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI,
420  AliasAnalysis *AA, DominatorTree *DT, LoopInfo *LI,
422  : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), AA(AA), DT(DT),
423  LI(LI), ORE(ORE) {}
424 
425  unsigned getNumOps(Type *VT) {
426  assert(isa<VectorType>(VT) && "Expected vector type");
427  return getNumOps(VT->getScalarType(),
428  cast<FixedVectorType>(VT)->getNumElements());
429  }
430 
431  /// Is this the minimal version executed in the backend pipelines.
432  bool isMinimal() const {
433  return !DT;
434  }
435 
436  /// Return the estimated number of vector ops required for an operation on
437  /// \p VT * N.
438  unsigned getNumOps(Type *ST, unsigned N) {
439  return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedSize() /
440  double(TTI.getRegisterBitWidth(
442  .getFixedSize()));
443  }
444 
445  /// Return the set of vectors that a matrix value is lowered to.
446  ///
447  /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise
448  /// split the flat vector \p MatrixVal containing a matrix with shape \p SI
449  /// into vectors.
450  MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI,
451  IRBuilder<> &Builder) {
452  VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType());
453  assert(VType && "MatrixVal must be a vector type");
454  assert(cast<FixedVectorType>(VType)->getNumElements() ==
455  SI.NumRows * SI.NumColumns &&
456  "The vector size must match the number of matrix elements");
457 
458  // Check if we lowered MatrixVal using shape information. In that case,
459  // return the existing matrix, if it matches the requested shape
460  // information. If there is a mis-match, embed the result in a flat
461  // vector and split it later.
462  auto Found = Inst2ColumnMatrix.find(MatrixVal);
463  if (Found != Inst2ColumnMatrix.end()) {
464  MatrixTy &M = Found->second;
465  // Return the found matrix, if its shape matches the requested shape
466  // information
467  if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns())
468  return M;
469 
470  MatrixVal = M.embedInVector(Builder);
471  }
472 
473  // Otherwise split MatrixVal.
474  SmallVector<Value *, 16> SplitVecs;
475  for (unsigned MaskStart = 0;
476  MaskStart < cast<FixedVectorType>(VType)->getNumElements();
477  MaskStart += SI.getStride()) {
478  Value *V = Builder.CreateShuffleVector(
479  MatrixVal, createSequentialMask(MaskStart, SI.getStride(), 0),
480  "split");
481  SplitVecs.push_back(V);
482  }
483 
484  return {SplitVecs};
485  }
486 
487  /// If \p V already has a known shape return false. Otherwise set the shape
488  /// for instructions that support it.
489  bool setShapeInfo(Value *V, ShapeInfo Shape) {
490  assert(Shape && "Shape not set");
491  if (isa<UndefValue>(V) || !supportsShapeInfo(V))
492  return false;
493 
494  auto SIter = ShapeMap.find(V);
495  if (SIter != ShapeMap.end()) {
496  LLVM_DEBUG(dbgs() << " not overriding existing shape: "
497  << SIter->second.NumRows << " "
498  << SIter->second.NumColumns << " for " << *V << "\n");
499  return false;
500  }
501 
502  ShapeMap.insert({V, Shape});
503  LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns
504  << " for " << *V << "\n");
505  return true;
506  }
507 
508  bool isUniformShape(Value *V) {
509  Instruction *I = dyn_cast<Instruction>(V);
510  if (!I)
511  return true;
512 
513  switch (I->getOpcode()) {
514  case Instruction::FAdd:
515  case Instruction::FSub:
516  case Instruction::FMul: // Scalar multiply.
517  case Instruction::FNeg:
518  case Instruction::Add:
519  case Instruction::Mul:
520  case Instruction::Sub:
521  return true;
522  default:
523  return false;
524  }
525  }
526 
527  /// Returns true if shape information can be used for \p V. The supported
528  /// instructions must match the instructions that can be lowered by this pass.
529  bool supportsShapeInfo(Value *V) {
530  Instruction *Inst = dyn_cast<Instruction>(V);
531  if (!Inst)
532  return false;
533 
534  IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst);
535  if (II)
536  switch (II->getIntrinsicID()) {
537  case Intrinsic::matrix_multiply:
538  case Intrinsic::matrix_transpose:
539  case Intrinsic::matrix_column_major_load:
540  case Intrinsic::matrix_column_major_store:
541  return true;
542  default:
543  return false;
544  }
545  return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
546  }
547 
548  /// Propagate the shape information of instructions to their users.
549  /// The work list contains instructions for which we can compute the shape,
550  /// either based on the information provided by matrix intrinsics or known
551  /// shapes of operands.
553  propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) {
554  SmallVector<Instruction *, 32> NewWorkList;
555  // Pop an element for which we guaranteed to have at least one of the
556  // operand shapes. Add the shape for this and then add users to the work
557  // list.
558  LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n");
559  while (!WorkList.empty()) {
560  Instruction *Inst = WorkList.pop_back_val();
561 
562  // New entry, set the value and insert operands
563  bool Propagate = false;
564 
565  Value *MatrixA;
566  Value *MatrixB;
567  Value *M;
568  Value *N;
569  Value *K;
570  if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>(
571  m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
572  m_Value(N), m_Value(K)))) {
573  Propagate = setShapeInfo(Inst, {M, K});
574  } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>(
575  m_Value(MatrixA), m_Value(M), m_Value(N)))) {
576  // Flip dimensions.
577  Propagate = setShapeInfo(Inst, {N, M});
578  } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>(
579  m_Value(MatrixA), m_Value(), m_Value(),
580  m_Value(), m_Value(M), m_Value(N)))) {
581  Propagate = setShapeInfo(Inst, {N, M});
582  } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>(
583  m_Value(), m_Value(), m_Value(), m_Value(M),
584  m_Value(N)))) {
585  Propagate = setShapeInfo(Inst, {M, N});
586  } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) {
587  auto OpShape = ShapeMap.find(MatrixA);
588  if (OpShape != ShapeMap.end())
589  setShapeInfo(Inst, OpShape->second);
590  continue;
591  } else if (isUniformShape(Inst)) {
592  // Find the first operand that has a known shape and use that.
593  for (auto &Op : Inst->operands()) {
594  auto OpShape = ShapeMap.find(Op.get());
595  if (OpShape != ShapeMap.end()) {
596  Propagate |= setShapeInfo(Inst, OpShape->second);
597  break;
598  }
599  }
600  }
601 
602  if (Propagate) {
603  NewWorkList.push_back(Inst);
604  for (auto *User : Inst->users())
605  if (ShapeMap.count(User) == 0)
606  WorkList.push_back(cast<Instruction>(User));
607  }
608  }
609 
610  return NewWorkList;
611  }
612 
613  /// Propagate the shape to operands of instructions with shape information.
614  /// \p Worklist contains the instruction for which we already know the shape.
616  propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) {
617  SmallVector<Instruction *, 32> NewWorkList;
618 
619  auto pushInstruction = [](Value *V,
620  SmallVectorImpl<Instruction *> &WorkList) {
621  Instruction *I = dyn_cast<Instruction>(V);
622  if (I)
623  WorkList.push_back(I);
624  };
625  // Pop an element with known shape. Traverse the operands, if their shape
626  // derives from the result shape and is unknown, add it and add them to the
627  // worklist.
628  LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n");
629  while (!WorkList.empty()) {
630  Value *V = WorkList.pop_back_val();
631 
632  size_t BeforeProcessingV = WorkList.size();
633  if (!isa<Instruction>(V))
634  continue;
635 
636  Value *MatrixA;
637  Value *MatrixB;
638  Value *M;
639  Value *N;
640  Value *K;
641  if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>(
642  m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
643  m_Value(N), m_Value(K)))) {
644  if (setShapeInfo(MatrixA, {M, N}))
645  pushInstruction(MatrixA, WorkList);
646 
647  if (setShapeInfo(MatrixB, {N, K}))
648  pushInstruction(MatrixB, WorkList);
649 
650  } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>(
651  m_Value(MatrixA), m_Value(M), m_Value(N)))) {
652  // Flip dimensions.
653  if (setShapeInfo(MatrixA, {M, N}))
654  pushInstruction(MatrixA, WorkList);
655  } else if (match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>(
656  m_Value(MatrixA), m_Value(), m_Value(), m_Value(),
657  m_Value(M), m_Value(N)))) {
658  if (setShapeInfo(MatrixA, {M, N})) {
659  pushInstruction(MatrixA, WorkList);
660  }
661  } else if (isa<LoadInst>(V) ||
662  match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) {
663  // Nothing to do, no matrix input.
664  } else if (isa<StoreInst>(V)) {
665  // Nothing to do. We forward-propagated to this so we would just
666  // backward propagate to an instruction with an already known shape.
667  } else if (isUniformShape(V)) {
668  // Propagate to all operands.
669  ShapeInfo Shape = ShapeMap[V];
670  for (Use &U : cast<Instruction>(V)->operands()) {
671  if (setShapeInfo(U.get(), Shape))
672  pushInstruction(U.get(), WorkList);
673  }
674  }
675  // After we discovered new shape info for new instructions in the
676  // worklist, we use their users as seeds for the next round of forward
677  // propagation.
678  for (size_t I = BeforeProcessingV; I != WorkList.size(); I++)
679  for (User *U : WorkList[I]->users())
680  if (isa<Instruction>(U) && V != U)
681  NewWorkList.push_back(cast<Instruction>(U));
682  }
683  return NewWorkList;
684  }
685 
686  /// Try moving transposes in order to fold them away or into multiplies.
687  void optimizeTransposes() {
688  auto ReplaceAllUsesWith = [this](Instruction &Old, Value *New) {
689  // We need to remove Old from the ShapeMap otherwise RAUW will replace it
690  // with New. We should only add New it it supportsShapeInfo so we insert
691  // it conditionally instead.
692  auto S = ShapeMap.find(&Old);
693  if (S != ShapeMap.end()) {
694  ShapeMap.erase(S);
695  if (supportsShapeInfo(New))
696  ShapeMap.insert({New, S->second});
697  }
698  Old.replaceAllUsesWith(New);
699  };
700 
701  // First sink all transposes inside matmuls, hoping that we end up with NN,
702  // NT or TN variants.
703  for (BasicBlock &BB : reverse(Func)) {
704  for (auto II = BB.rbegin(); II != BB.rend();) {
705  Instruction &I = *II;
706  // We may remove II. By default continue on the next/prev instruction.
707  ++II;
708  // If we were to erase II, move again.
709  auto EraseFromParent = [&II](Value *V) {
710  auto *Inst = cast<Instruction>(V);
711  if (Inst->use_empty()) {
712  if (Inst == &*II) {
713  ++II;
714  }
715  Inst->eraseFromParent();
716  }
717  };
718 
719  // If we're creating a new instruction, continue from there.
720  Instruction *NewInst = nullptr;
721 
722  IRBuilder<> IB(&I);
724 
725  Value *TA, *TAMA, *TAMB;
726  ConstantInt *R, *K, *C;
727  if (match(&I, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TA)))) {
728 
729  // Transpose of a transpose is a nop
730  Value *TATA;
731  if (match(TA,
732  m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TATA)))) {
733  ReplaceAllUsesWith(I, TATA);
734  EraseFromParent(&I);
735  EraseFromParent(TA);
736  }
737 
738  // (A * B)^t -> B^t * A^t
739  // RxK KxC CxK KxR
740  else if (match(TA, m_Intrinsic<Intrinsic::matrix_multiply>(
741  m_Value(TAMA), m_Value(TAMB), m_ConstantInt(R),
742  m_ConstantInt(K), m_ConstantInt(C)))) {
743  Value *T0 = Builder.CreateMatrixTranspose(TAMB, K->getZExtValue(),
744  C->getZExtValue(),
745  TAMB->getName() + "_t");
746  // We are being run after shape prop, add shape for newly created
747  // instructions so that we lower them later.
748  setShapeInfo(T0, {C, K});
749  Value *T1 = Builder.CreateMatrixTranspose(TAMA, R->getZExtValue(),
750  K->getZExtValue(),
751  TAMA->getName() + "_t");
752  setShapeInfo(T1, {K, R});
753  NewInst = Builder.CreateMatrixMultiply(T0, T1, C->getZExtValue(),
754  K->getZExtValue(),
755  R->getZExtValue(), "mmul");
756  ReplaceAllUsesWith(I, NewInst);
757  EraseFromParent(&I);
758  EraseFromParent(TA);
759  }
760  }
761 
762  // If we replaced I with a new instruction, continue from there.
763  if (NewInst)
764  II = std::next(BasicBlock::reverse_iterator(NewInst));
765  }
766  }
767 
768  // If we have a TT matmul, lift the transpose. We may be able to fold into
769  // consuming multiply.
770  for (BasicBlock &BB : Func) {
771  for (BasicBlock::iterator II = BB.begin(); II != BB.end();) {
772  Instruction *I = &*II;
773  // We may remove I.
774  ++II;
775  Value *A, *B, *AT, *BT;
776  ConstantInt *R, *K, *C;
777  // A^t * B ^t -> (B * A)^t
778  if (match(&*I, m_Intrinsic<Intrinsic::matrix_multiply>(
779  m_Value(A), m_Value(B), m_ConstantInt(R),
780  m_ConstantInt(K), m_ConstantInt(C))) &&
781  match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(AT))) &&
782  match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value((BT))))) {
783  IRBuilder<> IB(&*I);
785  Value *M = Builder.CreateMatrixMultiply(
786  BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue());
787  setShapeInfo(M, {C, R});
788  Instruction *NewInst = Builder.CreateMatrixTranspose(
789  M, C->getZExtValue(), R->getZExtValue());
790  ReplaceAllUsesWith(*I, NewInst);
791  if (I->use_empty())
792  I->eraseFromParent();
793  if (A->use_empty())
794  cast<Instruction>(A)->eraseFromParent();
795  if (A != B && B->use_empty())
796  cast<Instruction>(B)->eraseFromParent();
797  }
798  }
799  }
800  }
801 
802  bool Visit() {
804 
805  // Initially only the shape of matrix intrinsics is known.
806  // Initialize the work list with ops carrying shape information.
807  for (BasicBlock &BB : Func)
808  for (Instruction &Inst : BB) {
809  IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst);
810  if (!II)
811  continue;
812 
813  switch (II->getIntrinsicID()) {
814  case Intrinsic::matrix_multiply:
815  case Intrinsic::matrix_transpose:
816  case Intrinsic::matrix_column_major_load:
817  case Intrinsic::matrix_column_major_store:
818  WorkList.push_back(&Inst);
819  break;
820  default:
821  break;
822  }
823  }
824 
825  // Avoid unnecessary work if there are no matrix intrinsics in the function.
826  if (WorkList.empty())
827  return false;
828 
829  // Propagate shapes until nothing changes any longer.
830  while (!WorkList.empty()) {
831  WorkList = propagateShapeForward(WorkList);
832  WorkList = propagateShapeBackward(WorkList);
833  }
834 
835  if (!isMinimal()) {
836  optimizeTransposes();
837  LLVM_DEBUG({
838  dbgs() << "Dump after matrix transpose optimization:\n";
839  Func.dump();
840  });
841  }
842 
843  bool Changed = false;
844  SmallVector<CallInst *, 16> MaybeFusableInsts;
845  SmallVector<Instruction *, 16> MatrixInsts;
846 
847  // First, collect all instructions with shape information and candidates for
848  // fusion (currently only matrix multiplies).
850  for (auto *BB : RPOT)
851  for (Instruction &I : *BB) {
852  if (ShapeMap.find(&I) == ShapeMap.end())
853  continue;
854  if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>()))
855  MaybeFusableInsts.push_back(cast<CallInst>(&I));
856  MatrixInsts.push_back(&I);
857  }
858 
859  // Second, try to fuse candidates.
861  for (CallInst *CI : MaybeFusableInsts)
862  LowerMatrixMultiplyFused(CI, FusedInsts);
863  Changed = !FusedInsts.empty();
864 
865  // Third, lower remaining instructions with shape information.
866  for (Instruction *Inst : MatrixInsts) {
867  if (FusedInsts.count(Inst))
868  continue;
869 
870  IRBuilder<> Builder(Inst);
871 
872  if (CallInst *CInst = dyn_cast<CallInst>(Inst))
873  Changed |= VisitCallInst(CInst);
874 
875  Value *Op1;
876  Value *Op2;
877  if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
878  Changed |= VisitBinaryOperator(BinOp);
879  if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
880  Changed |= VisitUnaryOperator(UnOp);
881  if (match(Inst, m_Load(m_Value(Op1))))
882  Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder);
883  else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
884  Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);
885  }
886 
887  if (ORE) {
888  RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func);
889  RemarkGen.emitRemarks();
890  }
891 
892  // Delete the instructions backwards, as it has a reduced likelihood of
893  // having to update as many def-use and use-def chains.
894  //
895  // Because we add to ToRemove during fusion we can't guarantee that defs
896  // are before uses. Change uses to undef temporarily as these should get
897  // removed as well.
898  //
899  // For verification, we keep track of where we changed uses to undefs in
900  // UndefedInsts and then check that we in fact remove them.
901  SmallSet<Instruction *, 16> UndefedInsts;
902  for (auto *Inst : reverse(ToRemove)) {
903  for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
904  if (auto *Undefed = dyn_cast<Instruction>(U.getUser()))
905  UndefedInsts.insert(Undefed);
906  U.set(UndefValue::get(Inst->getType()));
907  }
908  Inst->eraseFromParent();
909  UndefedInsts.erase(Inst);
910  }
911  if (!UndefedInsts.empty()) {
912  // If we didn't remove all undefed instructions, it's a hard error.
913  dbgs() << "Undefed but present instructions:\n";
914  for (auto *I : UndefedInsts)
915  dbgs() << *I << "\n";
916  llvm_unreachable("Undefed but instruction not removed");
917  }
918 
919  return Changed;
920  }
921 
922  /// Turns \p BasePtr into an elementwise pointer to \p EltType.
923  Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) {
924  unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
925  Type *EltPtrType = PointerType::get(EltType, AS);
926  return Builder.CreatePointerCast(BasePtr, EltPtrType);
927  }
928 
929  /// Replace intrinsic calls
930  bool VisitCallInst(CallInst *Inst) {
931  if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic())
932  return false;
933 
934  switch (Inst->getCalledFunction()->getIntrinsicID()) {
935  case Intrinsic::matrix_multiply:
936  LowerMultiply(Inst);
937  break;
938  case Intrinsic::matrix_transpose:
939  LowerTranspose(Inst);
940  break;
941  case Intrinsic::matrix_column_major_load:
942  LowerColumnMajorLoad(Inst);
943  break;
944  case Intrinsic::matrix_column_major_store:
945  LowerColumnMajorStore(Inst);
946  break;
947  default:
948  return false;
949  }
950  return true;
951  }
952 
953  /// Compute the alignment for a column/row \p Idx with \p Stride between them.
954  /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a
955  /// ConstantInt, reduce the initial alignment based on the byte offset. For
956  /// non-ConstantInt strides, return the common alignment of the initial
957  /// alignment and the element size in bytes.
958  Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy,
959  MaybeAlign A) const {
960  Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy);
961  if (Idx == 0)
962  return InitialAlign;
963 
964  TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy);
965  if (auto *ConstStride = dyn_cast<ConstantInt>(Stride)) {
966  uint64_t StrideInBytes =
967  ConstStride->getZExtValue() * ElementSizeInBits / 8;
968  return commonAlignment(InitialAlign, Idx * StrideInBytes);
969  }
970  return commonAlignment(InitialAlign, ElementSizeInBits / 8);
971  }
972 
973  /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between
974  /// vectors.
975  MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride,
976  bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) {
977  auto *VType = cast<VectorType>(Ty);
978  Type *EltTy = VType->getElementType();
979  Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride());
980  Value *EltPtr = createElementPtr(Ptr, EltTy, Builder);
981  MatrixTy Result;
982  for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) {
983  Value *GEP = computeVectorAddr(
984  EltPtr, Builder.getIntN(Stride->getType()->getScalarSizeInBits(), I),
985  Stride, Shape.getStride(), EltTy, Builder);
986  Value *Vector = Builder.CreateAlignedLoad(
987  VecTy, GEP, getAlignForIndex(I, Stride, EltTy, MAlign),
988  IsVolatile, "col.load");
989 
990  Result.addVector(Vector);
991  }
992  return Result.addNumLoads(getNumOps(Result.getVectorTy()) *
993  Result.getNumVectors());
994  }
995 
996  /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix,
997  /// starting at \p MatrixPtr[I][J].
998  MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile,
999  ShapeInfo MatrixShape, Value *I, Value *J,
1000  ShapeInfo ResultShape, Type *EltTy,
1001  IRBuilder<> &Builder) {
1002 
1003  Value *Offset = Builder.CreateAdd(
1004  Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
1005 
1006  unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace();
1007  Value *EltPtr =
1008  Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS));
1009  Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset);
1010  auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows *
1011  ResultShape.NumColumns);
1012  Type *TilePtrTy = PointerType::get(TileTy, AS);
1013  Value *TilePtr =
1014  Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
1015 
1016  return loadMatrix(TileTy, TilePtr, Align,
1017  Builder.getInt64(MatrixShape.getStride()), IsVolatile,
1018  ResultShape, Builder);
1019  }
1020 
1021  /// Lower a load instruction with shape information.
1022  void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride,
1023  bool IsVolatile, ShapeInfo Shape) {
1024  IRBuilder<> Builder(Inst);
1025  finalizeLowering(Inst,
1026  loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile,
1027  Shape, Builder),
1028  Builder);
1029  }
1030 
1031  /// Lowers llvm.matrix.column.major.load.
1032  ///
1033  /// The intrinsic loads a matrix from memory using a stride between columns.
1034  void LowerColumnMajorLoad(CallInst *Inst) {
1036  "Intrinsic only supports column-major layout!");
1037  Value *Ptr = Inst->getArgOperand(0);
1038  Value *Stride = Inst->getArgOperand(1);
1039  LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride,
1040  cast<ConstantInt>(Inst->getArgOperand(2))->isOne(),
1041  {Inst->getArgOperand(3), Inst->getArgOperand(4)});
1042  }
1043 
1044  /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p
1045  /// MatrixPtr[I][J].
1046  void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr,
1047  MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape,
1048  Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) {
1049  Value *Offset = Builder.CreateAdd(
1050  Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
1051 
1052  unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace();
1053  Value *EltPtr =
1054  Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS));
1055  Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset);
1056  auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() *
1057  StoreVal.getNumColumns());
1058  Type *TilePtrTy = PointerType::get(TileTy, AS);
1059  Value *TilePtr =
1060  Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
1061 
1062  storeMatrix(TileTy, StoreVal, TilePtr, MAlign,
1063  Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder);
1064  }
1065 
1066  /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between
1067  /// vectors.
1068  MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr,
1069  MaybeAlign MAlign, Value *Stride, bool IsVolatile,
1070  IRBuilder<> &Builder) {
1071  auto VType = cast<VectorType>(Ty);
1072  Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
1073  for (auto Vec : enumerate(StoreVal.vectors())) {
1074  Value *GEP = computeVectorAddr(
1075  EltPtr,
1076  Builder.getIntN(Stride->getType()->getScalarSizeInBits(),
1077  Vec.index()),
1078  Stride, StoreVal.getStride(), VType->getElementType(), Builder);
1079  Builder.CreateAlignedStore(Vec.value(), GEP,
1080  getAlignForIndex(Vec.index(), Stride,
1081  VType->getElementType(),
1082  MAlign),
1083  IsVolatile);
1084  }
1085  return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) *
1086  StoreVal.getNumVectors());
1087  }
1088 
1089  /// Lower a store instruction with shape information.
1090  void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A,
1091  Value *Stride, bool IsVolatile, ShapeInfo Shape) {
1092  IRBuilder<> Builder(Inst);
1093  auto StoreVal = getMatrix(Matrix, Shape, Builder);
1094  finalizeLowering(Inst,
1095  storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride,
1096  IsVolatile, Builder),
1097  Builder);
1098  }
1099 
1100  /// Lowers llvm.matrix.column.major.store.
1101  ///
1102  /// The intrinsic store a matrix back memory using a stride between columns.
1103  void LowerColumnMajorStore(CallInst *Inst) {
1105  "Intrinsic only supports column-major layout!");
1106  Value *Matrix = Inst->getArgOperand(0);
1107  Value *Ptr = Inst->getArgOperand(1);
1108  Value *Stride = Inst->getArgOperand(2);
1109  LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride,
1110  cast<ConstantInt>(Inst->getArgOperand(3))->isOne(),
1111  {Inst->getArgOperand(4), Inst->getArgOperand(5)});
1112  }
1113 
1114  // Set elements I..I+NumElts-1 to Block
1115  Value *insertVector(Value *Col, unsigned I, Value *Block,
1116  IRBuilder<> &Builder) {
1117 
1118  // First, bring Block to the same size as Col
1119  unsigned BlockNumElts =
1120  cast<FixedVectorType>(Block->getType())->getNumElements();
1121  unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements();
1122  assert(NumElts >= BlockNumElts && "Too few elements for current block");
1123 
1124  Block = Builder.CreateShuffleVector(
1125  Block, createSequentialMask(0, BlockNumElts, NumElts - BlockNumElts));
1126 
1127  // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7,
1128  // 8, 4, 5, 6
1130  unsigned i;
1131  for (i = 0; i < I; i++)
1132  Mask.push_back(i);
1133 
1134  unsigned VecNumElts =
1135  cast<FixedVectorType>(Col->getType())->getNumElements();
1136  for (; i < I + BlockNumElts; i++)
1137  Mask.push_back(i - I + VecNumElts);
1138 
1139  for (; i < VecNumElts; i++)
1140  Mask.push_back(i);
1141 
1142  return Builder.CreateShuffleVector(Col, Block, Mask);
1143  }
1144 
1145  Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp,
1146  IRBuilder<> &Builder, bool AllowContraction,
1147  unsigned &NumComputeOps) {
1148  NumComputeOps += getNumOps(A->getType());
1149  if (!Sum)
1150  return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B);
1151 
1152  if (UseFPOp) {
1153  if (AllowContraction) {
1154  // Use fmuladd for floating point operations and let the backend decide
1155  // if that's profitable.
1157  Func.getParent(), Intrinsic::fmuladd, A->getType());
1158  return Builder.CreateCall(FMulAdd, {A, B, Sum});
1159  }
1160  NumComputeOps += getNumOps(A->getType());
1161  Value *Mul = Builder.CreateFMul(A, B);
1162  return Builder.CreateFAdd(Sum, Mul);
1163  }
1164 
1165  NumComputeOps += getNumOps(A->getType());
1166  Value *Mul = Builder.CreateMul(A, B);
1167  return Builder.CreateAdd(Sum, Mul);
1168  }
1169 
1170  /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
1171  /// users with shape information, there's nothing to do: they will use the
1172  /// cached value when they are lowered. For other users, \p Matrix is
1173  /// flattened and the uses are updated to use it. Also marks \p Inst for
1174  /// deletion.
1175  void finalizeLowering(Instruction *Inst, MatrixTy Matrix,
1176  IRBuilder<> &Builder) {
1177  auto inserted = Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix));
1178  (void)inserted;
1179  assert(inserted.second && "multiple matrix lowering mapping");
1180 
1181  ToRemove.push_back(Inst);
1182  Value *Flattened = nullptr;
1183  for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
1184  if (ShapeMap.find(U.getUser()) == ShapeMap.end()) {
1185  if (!Flattened)
1186  Flattened = Matrix.embedInVector(Builder);
1187  U.set(Flattened);
1188  }
1189  }
1190  }
1191 
1192  /// Compute \p Result += \p A * \p B for input matrices with left-associating
1193  /// addition.
1194  ///
1195  /// We can fold a transpose into the operand that is used to extract scalars.
1196  /// This is the first operands with row-major and the second with
1197  /// column-major. If \p IsScalarMatrixTransposed we assume the appropriate
1198  /// operand is transposed.
1199  void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A,
1200  const MatrixTy &B, IRBuilder<> &Builder, bool IsTiled,
1201  bool IsScalarMatrixTransposed, FastMathFlags FMF) {
1202  const unsigned VF = std::max<unsigned>(
1204  .getFixedSize() /
1205  Result.getElementType()->getPrimitiveSizeInBits().getFixedSize(),
1206  1U);
1207  unsigned R = Result.getNumRows();
1208  unsigned C = Result.getNumColumns();
1209  unsigned M = A.getNumColumns();
1210 
1211  bool IsFP = Result.getElementType()->isFloatingPointTy();
1212  assert(A.isColumnMajor() == B.isColumnMajor() &&
1213  Result.isColumnMajor() == A.isColumnMajor() &&
1214  "operands must agree on matrix layout");
1215  unsigned NumComputeOps = 0;
1216 
1217  Builder.setFastMathFlags(FMF);
1218 
1219  if (A.isColumnMajor()) {
1220  // Multiply columns from the first operand with scalars from the second
1221  // operand. Then move along the K axes and accumulate the columns. With
1222  // this the adds can be vectorized without reassociation.
1223  for (unsigned J = 0; J < C; ++J) {
1224  unsigned BlockSize = VF;
1225  // If Result is zero, we don't need to accumulate in the K==0 iteration.
1226  bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J));
1227 
1228  for (unsigned I = 0; I < R; I += BlockSize) {
1229  // Gradually lower the vectorization factor to cover the remainder.
1230  while (I + BlockSize > R)
1231  BlockSize /= 2;
1232 
1233  Value *Sum = IsTiled ? Result.extractVector(I, J, BlockSize, Builder)
1234  : nullptr;
1235  for (unsigned K = 0; K < M; ++K) {
1236  Value *L = A.extractVector(I, K, BlockSize, Builder);
1237  Value *RH = Builder.CreateExtractElement(
1238  B.getColumn(IsScalarMatrixTransposed ? K : J),
1239  IsScalarMatrixTransposed ? J : K);
1240  Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat");
1241  Sum =
1242  createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat,
1243  IsFP, Builder, FMF.allowContract(), NumComputeOps);
1244  }
1245  Result.setVector(J,
1246  insertVector(Result.getVector(J), I, Sum, Builder));
1247  }
1248  }
1249  } else {
1250  // Multiply rows from the second operand with scalars from the first
1251  // operand. Then move along the K axes and accumulate the rows. With this
1252  // the adds can be vectorized without reassociation.
1253  for (unsigned I = 0; I < R; ++I) {
1254  unsigned BlockSize = VF;
1255  bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I));
1256  for (unsigned J = 0; J < C; J += BlockSize) {
1257  // Gradually lower the vectorization factor to cover the remainder.
1258  while (J + BlockSize > C)
1259  BlockSize /= 2;
1260 
1261  Value *Sum = nullptr;
1262  for (unsigned K = 0; K < M; ++K) {
1263  Value *R = B.extractVector(K, J, BlockSize, Builder);
1264  Value *LH = Builder.CreateExtractElement(
1265  A.getVector(IsScalarMatrixTransposed ? K : I),
1266  IsScalarMatrixTransposed ? I : K);
1267  Value *Splat = Builder.CreateVectorSplat(BlockSize, LH, "splat");
1268  Sum =
1269  createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R,
1270  IsFP, Builder, FMF.allowContract(), NumComputeOps);
1271  }
1272  Result.setVector(I,
1273  insertVector(Result.getVector(I), J, Sum, Builder));
1274  }
1275  }
1276  }
1277  Result.addNumComputeOps(NumComputeOps);
1278  }
1279 
1280  /// Ensure that the memory in \p Load does not alias \p Store by potentially
1281  /// copying it to a new location. This new or otherwise the original location
1282  /// is returned.
1283  Value *getNonAliasingPointer(LoadInst *Load, StoreInst *Store,
1284  CallInst *MatMul) {
1287 
1288  // If we can statically determine noalias we're good.
1289  if (AA->isNoAlias(LoadLoc, StoreLoc))
1290  return Load->getPointerOperand();
1291 
1292  // Create code to check if the memory locations of the Load and Store
1293  // overlap and if they do, copy Load's operand to a new buffer.
1294 
1295  // First, create new blocks for 2n part of the check and the copy.
1296  BasicBlock *Check0 = MatMul->getParent();
1297  // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a
1298  // DT. Manually collect dominator tree updates, to avoid unnecessary work,
1299  // as we adjust Check0 and Check1's branches.
1301  for (BasicBlock *Succ : successors(Check0))
1302  DTUpdates.push_back({DT->Delete, Check0, Succ});
1303 
1304  BasicBlock *Check1 =
1305  SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1306  nullptr, "alias_cont");
1307  BasicBlock *Copy =
1308  SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1309  nullptr, "copy");
1310  BasicBlock *Fusion =
1311  SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1312  nullptr, "no_alias");
1313 
1314  // Check if the loaded memory location begins before the end of the store
1315  // location. If the condition holds, they might overlap, otherwise they are
1316  // guaranteed to not overlap.
1317  IRBuilder<> Builder(MatMul);
1318  Check0->getTerminator()->eraseFromParent();
1319  Builder.SetInsertPoint(Check0);
1320  Type *IntPtrTy = Builder.getIntPtrTy(Load->getModule()->getDataLayout());
1321  Value *StoreBegin = Builder.CreatePtrToInt(
1322  const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin");
1323  Value *StoreEnd = Builder.CreateAdd(
1324  StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()),
1325  "store.end", true, true);
1326  Value *LoadBegin = Builder.CreatePtrToInt(const_cast<Value *>(LoadLoc.Ptr),
1327  IntPtrTy, "load.begin");
1328  Builder.CreateCondBr(Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1,
1329  Fusion);
1330 
1331  // Check if the store begins before the end of the load location. If the
1332  // condition holds, they alias, otherwise they are guaranteed to not
1333  // overlap.
1334  Check1->getTerminator()->eraseFromParent();
1335  Builder.SetInsertPoint(Check1, Check1->begin());
1336  Value *LoadEnd = Builder.CreateAdd(
1337  LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()),
1338  "load.end", true, true);
1339  Builder.CreateCondBr(Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy,
1340  Fusion);
1341 
1342  // Copy load operand to new alloca.
1343  Builder.SetInsertPoint(Copy, Copy->begin());
1344  AllocaInst *NewLd =
1345  Builder.CreateAlloca(Load->getType(), Load->getPointerAddressSpace());
1346  Builder.CreateMemCpy(NewLd, NewLd->getAlign(),
1347  Load->getPointerOperand(), Load->getAlign(),
1348  LoadLoc.Size.getValue());
1349  Builder.SetInsertPoint(Fusion, Fusion->begin());
1350  PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3);
1351  PHI->addIncoming(Load->getPointerOperand(), Check0);
1352  PHI->addIncoming(Load->getPointerOperand(), Check1);
1353  PHI->addIncoming(NewLd, Copy);
1354 
1355  // Adjust DT.
1356  DTUpdates.push_back({DT->Insert, Check0, Check1});
1357  DTUpdates.push_back({DT->Insert, Check0, Fusion});
1358  DTUpdates.push_back({DT->Insert, Check1, Copy});
1359  DTUpdates.push_back({DT->Insert, Check1, Fusion});
1360  DT->applyUpdates(DTUpdates);
1361  return PHI;
1362  }
1363 
1364  bool isFusionProfitable(CallInst *MatMul) {
1365  if (ForceFusion)
1366  return true;
1367 
1368  ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1369  ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1370 
1371  const unsigned R = LShape.NumRows;
1372  const unsigned C = RShape.NumColumns;
1373  const unsigned M = LShape.NumColumns;
1374  auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1375 
1376  const unsigned VF = std::max<unsigned>(
1378  .getFixedSize() /
1379  EltType->getPrimitiveSizeInBits().getFixedSize(),
1380  1U);
1381 
1382  // Cost model for tiling
1383  //
1384  // For tiling to be beneficial, we need reuse either along the R or
1385  // the C axis. We vectorize along the R axis so that means at least
1386  // 3 elements.
1387  // TODO: Also consider cost of copying if operands alias.
1388  if (R <= VF && C == 1)
1389  return false;
1390  // Then we need enough elements to exceed the number of vector
1391  // registers we have. Note that this is an oversimplification since
1392  // fusing also takes some extra loads which may exceed the number of
1393  // reloads necessary.
1394  unsigned Op0Regs = (R + VF - 1) / VF * M;
1395  unsigned Op1Regs = (M + VF - 1) / VF * C;
1396  return Op0Regs + Op1Regs > TTI.getNumberOfRegisters(true);
1397  }
1398 
1399  MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) {
1400  MatrixTy Res;
1401  auto *ColumType = FixedVectorType::get(EltType, R);
1402  for (unsigned I = 0; I < C; ++I)
1403  Res.addVector(ConstantAggregateZero::get(ColumType));
1404  return Res;
1405  }
1406 
1407  void createTiledLoops(CallInst *MatMul, Value *LPtr, ShapeInfo LShape,
1408  Value *RPtr, ShapeInfo RShape, StoreInst *Store) {
1409  auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1410 
1411  // Create the main tiling loop nest.
1412  TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns, TileSize);
1414  Instruction *InsertI = cast<Instruction>(MatMul);
1415  BasicBlock *Start = InsertI->getParent();
1416  BasicBlock *End =
1417  SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue");
1418  IRBuilder<> Builder(MatMul);
1419  BasicBlock *InnerBody = TI.CreateTiledLoops(Start, End, Builder, DTU, *LI);
1420 
1421  Type *TileVecTy =
1423  MatrixTy TileResult;
1424  // Insert in the inner loop header.
1425  Builder.SetInsertPoint(TI.InnerLoopHeader->getTerminator());
1426  // Create PHI nodes for the result columns to accumulate across iterations.
1427  SmallVector<PHINode *, 4> ColumnPhis;
1428  for (unsigned I = 0; I < TileSize; I++) {
1429  auto *Phi = Builder.CreatePHI(TileVecTy, 2, "result.vec." + Twine(I));
1430  Phi->addIncoming(ConstantAggregateZero::get(TileVecTy),
1431  TI.RowLoopHeader->getSingleSuccessor());
1432  TileResult.addVector(Phi);
1433  ColumnPhis.push_back(Phi);
1434  }
1435 
1436  // Insert in the inner loop body, which computes
1437  // Res += Load(CurrentRow, K) * Load(K, CurrentColumn)
1438  Builder.SetInsertPoint(InnerBody->getTerminator());
1439  // Load tiles of the operands.
1440  MatrixTy A = loadMatrix(LPtr, {}, false, LShape, TI.CurrentRow, TI.CurrentK,
1441  {TileSize, TileSize}, EltType, Builder);
1442  MatrixTy B = loadMatrix(RPtr, {}, false, RShape, TI.CurrentK, TI.CurrentCol,
1443  {TileSize, TileSize}, EltType, Builder);
1444  emitMatrixMultiply(TileResult, A, B, Builder, true, false,
1445  getFastMathFlags(MatMul));
1446  // Store result after the inner loop is done.
1447  Builder.SetInsertPoint(TI.RowLoopLatch->getTerminator());
1448  storeMatrix(TileResult, Store->getPointerOperand(), Store->getAlign(),
1449  Store->isVolatile(), {LShape.NumRows, RShape.NumColumns},
1450  TI.CurrentRow, TI.CurrentCol, EltType, Builder);
1451 
1452  for (unsigned I = 0; I < TileResult.getNumVectors(); I++)
1453  ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.InnerLoopLatch);
1454 
1455  // Force unrolling of a few iterations of the inner loop, to make sure there
1456  // is enough work per iteration.
1457  // FIXME: The unroller should make this decision directly instead, but
1458  // currently the cost-model is not up to the task.
1459  unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns / TileSize);
1460  addStringMetadataToLoop(LI->getLoopFor(TI.InnerLoopHeader),
1461  "llvm.loop.unroll.count", InnerLoopUnrollCount);
1462  }
1463 
1464  void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1,
1465  StoreInst *Store,
1466  SmallPtrSetImpl<Instruction *> &FusedInsts) {
1468  "Tiling only supported for column-major matrixes at the moment!");
1469  if (!isFusionProfitable(MatMul))
1470  return;
1471 
1472  ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1473  ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1474 
1475  const unsigned R = LShape.NumRows;
1476  const unsigned C = RShape.NumColumns;
1477  const unsigned M = LShape.NumColumns;
1478  auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1479 
1480  Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul);
1481  Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul);
1482  Value *CPtr = Store->getPointerOperand();
1483 
1484  if (TileUseLoops && (R % TileSize == 0 && C % TileSize == 0))
1485  createTiledLoops(MatMul, APtr, LShape, BPtr, RShape, Store);
1486  else {
1488  for (unsigned J = 0; J < C; J += TileSize)
1489  for (unsigned I = 0; I < R; I += TileSize) {
1490  const unsigned TileR = std::min(R - I, unsigned(TileSize));
1491  const unsigned TileC = std::min(C - J, unsigned(TileSize));
1492  MatrixTy Res = getZeroMatrix(EltType, TileR, TileC);
1493 
1494  for (unsigned K = 0; K < M; K += TileSize) {
1495  const unsigned TileM = std::min(M - K, unsigned(TileSize));
1496  MatrixTy A =
1497  loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(),
1498  LShape, Builder.getInt64(I), Builder.getInt64(K),
1499  {TileR, TileM}, EltType, Builder);
1500  MatrixTy B =
1501  loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(),
1502  RShape, Builder.getInt64(K), Builder.getInt64(J),
1503  {TileM, TileC}, EltType, Builder);
1504  emitMatrixMultiply(Res, A, B, Builder, true, false,
1505  getFastMathFlags(MatMul));
1506  }
1507  storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M},
1508  Builder.getInt64(I), Builder.getInt64(J), EltType,
1509  Builder);
1510  }
1511  }
1512 
1513  // Mark eliminated instructions as fused and remove them.
1514  FusedInsts.insert(Store);
1515  FusedInsts.insert(MatMul);
1516  Store->eraseFromParent();
1517  MatMul->eraseFromParent();
1518  if (LoadOp0->hasNUses(0)) {
1519  FusedInsts.insert(LoadOp0);
1520  LoadOp0->eraseFromParent();
1521  }
1522  if (LoadOp1 != LoadOp0 && LoadOp1->hasNUses(0)) {
1523  FusedInsts.insert(LoadOp1);
1524  LoadOp1->eraseFromParent();
1525  }
1526  }
1527 
1528  /// Try to lower matrix multiply chains by fusing operations.
1529  ///
1530  /// Call finalizeLowering on lowered instructions. Instructions that are
1531  /// completely eliminated by fusion are added to \p FusedInsts.
1532  void LowerMatrixMultiplyFused(CallInst *MatMul,
1533  SmallPtrSetImpl<Instruction *> &FusedInsts) {
1534  if (!FuseMatrix || !DT)
1535  return;
1536 
1537  assert(AA && LI && "Analyses should be available");
1538 
1539  Value *A = MatMul->getArgOperand(0);
1540  Value *B = MatMul->getArgOperand(1);
1541 
1542  // We can fold the transpose into the operand that is used to fetch scalars.
1543  Value *T;
1545  ? match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))
1546  : match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))) {
1547  IRBuilder<> Builder(MatMul);
1548  auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1549  ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1550  ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1551  const unsigned R = LShape.NumRows;
1552  const unsigned M = LShape.NumColumns;
1553  const unsigned C = RShape.NumColumns;
1554 
1555  MatrixTy MA;
1556  MatrixTy MB;
1557 
1558  Value *Transpose;
1560  MA = getMatrix(A, ShapeInfo(R, M), Builder);
1561  MB = getMatrix(T, ShapeInfo(C, M), Builder);
1562  Transpose = B;
1563  } else {
1564  MA = getMatrix(T, ShapeInfo(R, M), Builder);
1565  MB = getMatrix(B, ShapeInfo(C, M), Builder);
1566  Transpose = A;
1567  }
1568 
1569  // Initialize the output
1570  MatrixTy Result(R, C, EltType);
1571 
1572  emitMatrixMultiply(Result, MA, MB, Builder, false, true,
1573  getFastMathFlags(MatMul));
1574 
1575  FusedInsts.insert(MatMul);
1576  if (Transpose->hasOneUse()) {
1577  FusedInsts.insert(cast<Instruction>(Transpose));
1578  ToRemove.push_back(cast<Instruction>(Transpose));
1579  // TODO: add a fake entry for the folded instruction so that this is
1580  // included in the expression in the remark.
1581  Inst2ColumnMatrix[Transpose] = MatrixTy(M, C, EltType);
1582  }
1583  finalizeLowering(MatMul, Result, Builder);
1584  return;
1585  }
1586 
1587  if (!MatMul->hasOneUse() || MatrixLayout != MatrixLayoutTy::ColumnMajor)
1588  return;
1589 
1590  // Lower {ld, ld} -> matmul -> st chains. No need to call finalizeLowering
1591  // since the single store user will be lowered as part of this.
1592  auto *LoadOp0 = dyn_cast<LoadInst>(A);
1593  auto *LoadOp1 = dyn_cast<LoadInst>(B);
1594  auto *Store = dyn_cast<StoreInst>(*MatMul->user_begin());
1595  if (LoadOp0 && LoadOp1 && Store) {
1596  // The store address must dominate the MatMul instruction, otherwise
1597  // we create invalid IR.
1598  SetVector<Value *> WorkList;
1599  WorkList.insert(Store->getOperand(1));
1601  for (unsigned I = 0; I != WorkList.size(); ++I) {
1602  Value *Current = WorkList[I];
1603  auto *CurrI = dyn_cast<Instruction>(Current);
1604  if (!CurrI)
1605  continue;
1606  if (isa<PHINode>(CurrI))
1607  return;
1608  if (DT->dominates(CurrI, MatMul))
1609  continue;
1610  if (CurrI->mayHaveSideEffects() || CurrI->mayReadFromMemory())
1611  return;
1612  ToHoist.push_back(CurrI);
1613  WorkList.insert(CurrI->op_begin(), CurrI->op_end());
1614  }
1615 
1616  sort(ToHoist, [this](Instruction *A, Instruction *B) {
1617  return DT->dominates(A, B);
1618  });
1619  for (Instruction *I : ToHoist)
1620  I->moveBefore(MatMul);
1621 
1622  emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts);
1623  return;
1624  }
1625  }
1626 
1627  /// Lowers llvm.matrix.multiply.
1628  void LowerMultiply(CallInst *MatMul) {
1629  IRBuilder<> Builder(MatMul);
1630  auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1631  ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1632  ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1633 
1634  const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder);
1635  const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder);
1636  assert(Lhs.getElementType() == Rhs.getElementType() &&
1637  "Matrix multiply argument element types do not match.");
1638 
1639  const unsigned R = LShape.NumRows;
1640  const unsigned C = RShape.NumColumns;
1641  assert(LShape.NumColumns == RShape.NumRows);
1642 
1643  // Initialize the output
1644  MatrixTy Result(R, C, EltType);
1645  assert(Lhs.getElementType() == Result.getElementType() &&
1646  "Matrix multiply result element type does not match arguments.");
1647 
1648  emitMatrixMultiply(Result, Lhs, Rhs, Builder, false, false,
1649  getFastMathFlags(MatMul));
1650  finalizeLowering(MatMul, Result, Builder);
1651  }
1652 
1653  /// Lowers llvm.matrix.transpose.
1654  void LowerTranspose(CallInst *Inst) {
1655  MatrixTy Result;
1656  IRBuilder<> Builder(Inst);
1657  Value *InputVal = Inst->getArgOperand(0);
1658  VectorType *VectorTy = cast<VectorType>(InputVal->getType());
1659  ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2));
1660  MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);
1661 
1662  const unsigned NewNumVecs =
1663  InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns;
1664  const unsigned NewNumElts =
1665  InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows;
1666 
1667  for (unsigned I = 0; I < NewNumVecs; ++I) {
1668  // Build a single result vector. First initialize it.
1669  Value *ResultVector = UndefValue::get(
1670  FixedVectorType::get(VectorTy->getElementType(), NewNumElts));
1671  // Go through the old elements and insert it into the resulting vector.
1672  for (auto J : enumerate(InputMatrix.vectors())) {
1673  Value *Elt = Builder.CreateExtractElement(J.value(), I);
1674  // Row and column indices are transposed.
1675  ResultVector =
1676  Builder.CreateInsertElement(ResultVector, Elt, J.index());
1677  }
1678  Result.addVector(ResultVector);
1679  }
1680 
1681  // TODO: Improve estimate of operations needed for transposes. Currently we
1682  // just count the insertelement/extractelement instructions, but do not
1683  // account for later simplifications/combines.
1684  finalizeLowering(
1685  Inst,
1686  Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns)
1687  .addNumExposedTransposes(1),
1688  Builder);
1689  }
1690 
1691  /// Lower load instructions, if shape information is available.
1692  bool VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) {
1693  auto I = ShapeMap.find(Inst);
1694  if (I == ShapeMap.end())
1695  return false;
1696 
1697  LowerLoad(Inst, Ptr, Inst->getAlign(),
1698  Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
1699  I->second);
1700  return true;
1701  }
1702 
1703  bool VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr,
1704  IRBuilder<> &Builder) {
1705  auto I = ShapeMap.find(StoredVal);
1706  if (I == ShapeMap.end())
1707  return false;
1708 
1709  LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
1710  Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
1711  I->second);
1712  return true;
1713  }
1714 
1715  /// Lower binary operators, if shape information is available.
1716  bool VisitBinaryOperator(BinaryOperator *Inst) {
1717  auto I = ShapeMap.find(Inst);
1718  if (I == ShapeMap.end())
1719  return false;
1720 
1721  Value *Lhs = Inst->getOperand(0);
1722  Value *Rhs = Inst->getOperand(1);
1723 
1724  IRBuilder<> Builder(Inst);
1725  ShapeInfo &Shape = I->second;
1726 
1727  MatrixTy Result;
1728  MatrixTy A = getMatrix(Lhs, Shape, Builder);
1729  MatrixTy B = getMatrix(Rhs, Shape, Builder);
1730  assert(A.isColumnMajor() == B.isColumnMajor() &&
1731  Result.isColumnMajor() == A.isColumnMajor() &&
1732  "operands must agree on matrix layout");
1733 
1734  Builder.setFastMathFlags(getFastMathFlags(Inst));
1735 
1736  // Helper to perform binary op on vectors.
1737  auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) {
1738  switch (Inst->getOpcode()) {
1739  case Instruction::Add:
1740  return Builder.CreateAdd(LHS, RHS);
1741  case Instruction::Mul:
1742  return Builder.CreateMul(LHS, RHS);
1743  case Instruction::Sub:
1744  return Builder.CreateSub(LHS, RHS);
1745  case Instruction::FAdd:
1746  return Builder.CreateFAdd(LHS, RHS);
1747  case Instruction::FMul:
1748  return Builder.CreateFMul(LHS, RHS);
1749  case Instruction::FSub:
1750  return Builder.CreateFSub(LHS, RHS);
1751  default:
1752  llvm_unreachable("Unsupported binary operator for matrix");
1753  }
1754  };
1755 
1756  for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
1757  Result.addVector(BuildVectorOp(A.getVector(I), B.getVector(I)));
1758 
1759  finalizeLowering(Inst,
1760  Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
1761  Result.getNumVectors()),
1762  Builder);
1763  return true;
1764  }
1765 
1766  /// Lower unary operators, if shape information is available.
1767  bool VisitUnaryOperator(UnaryOperator *Inst) {
1768  auto I = ShapeMap.find(Inst);
1769  if (I == ShapeMap.end())
1770  return false;
1771 
1772  Value *Op = Inst->getOperand(0);
1773 
1774  IRBuilder<> Builder(Inst);
1775  ShapeInfo &Shape = I->second;
1776 
1777  MatrixTy Result;
1778  MatrixTy M = getMatrix(Op, Shape, Builder);
1779 
1780  Builder.setFastMathFlags(getFastMathFlags(Inst));
1781 
1782  // Helper to perform unary op on vectors.
1783  auto BuildVectorOp = [&Builder, Inst](Value *Op) {
1784  switch (Inst->getOpcode()) {
1785  case Instruction::FNeg:
1786  return Builder.CreateFNeg(Op);
1787  default:
1788  llvm_unreachable("Unsupported unary operator for matrix");
1789  }
1790  };
1791 
1792  for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
1793  Result.addVector(BuildVectorOp(M.getVector(I)));
1794 
1795  finalizeLowering(Inst,
1796  Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
1797  Result.getNumVectors()),
1798  Builder);
1799  return true;
1800  }
1801 
1802  /// Helper to linearize a matrix expression tree into a string. Currently
1803  /// matrix expressions are linarized by starting at an expression leaf and
1804  /// linearizing bottom up.
1805  struct ExprLinearizer {
1806  unsigned LengthToBreak = 100;
1807  std::string Str;
1808  raw_string_ostream Stream;
1809  unsigned LineLength = 0;
1810  const DataLayout &DL;
1811 
1812  /// Mapping from instructions to matrixes. It is used to identify
1813  /// matrix instructions.
1814  const MapVector<Value *, MatrixTy> &Inst2Matrix;
1815 
1816  /// Mapping from values to the leaves of all expressions that the value is
1817  /// part of.
1819 
1820  /// Set of matrix expressions in the scope of a given DISubprogram.
1821  const SmallSetVector<Value *, 32> &ExprsInSubprogram;
1822 
1823  /// Leaf node of the expression to linearize.
1824  Value *Leaf;
1825 
1826  /// Used to keep track of sub-expressions that get reused while linearizing
1827  /// the expression. Re-used sub-expressions are marked as (reused).
1828  SmallPtrSet<Value *, 8> ReusedExprs;
1829 
1830  ExprLinearizer(const DataLayout &DL,
1831  const MapVector<Value *, MatrixTy> &Inst2Matrix,
1832  const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
1833  const SmallSetVector<Value *, 32> &ExprsInSubprogram,
1834  Value *Leaf)
1835  : Str(), Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared),
1836  ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {}
1837 
1838  void indent(unsigned N) {
1839  LineLength += N;
1840  for (unsigned i = 0; i < N; i++)
1841  Stream << " ";
1842  }
1843 
1844  void lineBreak() {
1845  Stream << "\n";
1846  LineLength = 0;
1847  }
1848 
1849  void maybeIndent(unsigned Indent) {
1850  if (LineLength >= LengthToBreak)
1851  lineBreak();
1852 
1853  if (LineLength == 0)
1854  indent(Indent);
1855  }
1856 
1857  void write(StringRef S) {
1858  LineLength += S.size();
1859  Stream << S;
1860  }
1861 
1862  Value *getUnderlyingObjectThroughLoads(Value *V) {
1863  if (Value *Ptr = getPointerOperand(V))
1864  return getUnderlyingObjectThroughLoads(Ptr);
1865  else if (V->getType()->isPointerTy())
1866  return getUnderlyingObject(V);
1867  return V;
1868  }
1869 
1870  /// Returns true if \p V is a matrix value in the given subprogram.
1871  bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); }
1872 
1873  /// If \p V is a matrix value, print its shape as as NumRows x NumColumns to
1874  /// \p SS.
1875  void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) {
1876  auto M = Inst2Matrix.find(V);
1877  if (M == Inst2Matrix.end())
1878  SS << "unknown";
1879  else {
1880  SS << M->second.getNumRows();
1881  SS << "x";
1882  SS << M->second.getNumColumns();
1883  }
1884  }
1885 
1886  /// Write the called function name. Handles calls to llvm.matrix.*
1887  /// specially: we write the name, followed by the dimensions of the input
1888  /// matrixes, followed by the scalar type name.
1889  void writeFnName(CallInst *CI) {
1890  if (!CI->getCalledFunction())
1891  write("<no called fn>");
1892  else {
1894  if (!Name.startswith("llvm.matrix")) {
1895  write(Name);
1896  return;
1897  }
1898  IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
1900  .drop_front(StringRef("llvm.matrix.").size()));
1901  write(".");
1902  std::string Tmp;
1903  raw_string_ostream SS(Tmp);
1904 
1905  switch (II->getIntrinsicID()) {
1906  case Intrinsic::matrix_multiply:
1907  prettyPrintMatrixType(II->getOperand(0), SS);
1908  SS << ".";
1909  prettyPrintMatrixType(II->getOperand(1), SS);
1910  SS << "." << *II->getType()->getScalarType();
1911  break;
1912  case Intrinsic::matrix_transpose:
1913  prettyPrintMatrixType(II->getOperand(0), SS);
1914  SS << "." << *II->getType()->getScalarType();
1915  break;
1916  case Intrinsic::matrix_column_major_load:
1917  prettyPrintMatrixType(II, SS);
1918  SS << "." << *II->getType()->getScalarType();
1919  break;
1920  case Intrinsic::matrix_column_major_store:
1921  prettyPrintMatrixType(II->getOperand(0), SS);
1922  SS << "." << *II->getOperand(0)->getType()->getScalarType();
1923  break;
1924  default:
1925  llvm_unreachable("Unhandled case");
1926  }
1927  SS.flush();
1928  write(Tmp);
1929  }
1930  }
1931 
1932  unsigned getNumShapeArgs(CallInst *CI) const {
1933  if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) {
1934  switch (II->getIntrinsicID()) {
1935  case Intrinsic::matrix_multiply:
1936  return 3;
1937  case Intrinsic::matrix_transpose:
1938  return 2;
1939  case Intrinsic::matrix_column_major_load:
1940  case Intrinsic::matrix_column_major_store:
1941  return 3;
1942  default:
1943  return 0;
1944  }
1945  }
1946  return 0;
1947  }
1948 
1949  /// Special printing for values: for pointers, we print if they refer to an
1950  /// (function) external address or a stack address, for other values we
1951  /// either print the constant or "scalar"/"matrix" for other values.
1952  void write(Value *V) {
1953  V = getUnderlyingObjectThroughLoads(V);
1954  if (V->getType()->isPointerTy()) {
1955  if (isa<AllocaInst>(V)) {
1956  Stream << "stack addr";
1957  LineLength += StringRef("stack addr").size();
1958  } else {
1959  Stream << "addr";
1960  LineLength += StringRef("addr").size();
1961  }
1962  if (!V->getName().empty()) {
1963  Stream << " %" << V->getName() << "";
1964  LineLength += V->getName().size() + 2;
1965  }
1966  return;
1967  }
1968 
1969  std::string Tmp;
1970  raw_string_ostream TmpStream(Tmp);
1971 
1972  if (auto *CI = dyn_cast<ConstantInt>(V))
1973  TmpStream << CI->getValue();
1974  else if (isa<Constant>(V))
1975  TmpStream << "constant";
1976  else {
1977  if (isMatrix(V))
1978  TmpStream << "matrix";
1979  else
1980  TmpStream << "scalar";
1981  }
1982  TmpStream.flush();
1983  Tmp = std::string(StringRef(Tmp).trim());
1984  LineLength += Tmp.size();
1985  Stream << Tmp;
1986  }
1987 
1988  /// Linearize expression \p Expr starting at an indentation of \p Indent.
1989  /// Expressions that are re-used multiple times are prefixed with (reused)
1990  /// at the re-used root instruction.
1991  void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused,
1992  bool ParentShared) {
1993  auto *I = cast<Instruction>(Expr);
1994  maybeIndent(Indent);
1996 
1997  // Is Expr shared with other expression leaves?
1998  bool ExprShared = false;
1999 
2000  // Deal with shared subtrees. Mark them as shared, if required.
2001  if (!ParentShared) {
2002  auto SI = Shared.find(Expr);
2003  assert(SI != Shared.end() && SI->second.count(Leaf));
2004 
2005  for (Value *S : SI->second) {
2006  if (S == Leaf)
2007  continue;
2008  DebugLoc DL = cast<Instruction>(S)->getDebugLoc();
2009  write("shared with remark at line " + std::to_string(DL.getLine()) +
2010  " column " + std::to_string(DL.getCol()) + " (");
2011  }
2012  ExprShared = SI->second.size() > 1;
2013  }
2014 
2015  bool Reused = !ReusedExprs.insert(Expr).second;
2016  if (Reused && !ParentReused)
2017  write("(reused) ");
2018 
2019  if (auto *CI = dyn_cast<CallInst>(I)) {
2020  writeFnName(CI);
2021 
2022  Ops.append(CI->arg_begin(), CI->arg_end() - getNumShapeArgs(CI));
2023  } else if (isa<BitCastInst>(Expr)) {
2024  // Special case bitcasts, which are used to materialize matrixes from
2025  // non-matrix ops.
2026  write("matrix");
2027  return;
2028  } else {
2029  Ops.append(I->value_op_begin(), I->value_op_end());
2030  write(std::string(I->getOpcodeName()));
2031  }
2032 
2033  write(std::string("("));
2034 
2035  unsigned NumOpsToBreak = 1;
2036  if (match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>()))
2037  NumOpsToBreak = 2;
2038 
2039  for (Value *Op : Ops) {
2040  if (Ops.size() > NumOpsToBreak)
2041  lineBreak();
2042 
2043  maybeIndent(Indent + 1);
2044  if (isMatrix(Op))
2045  linearizeExpr(Op, Indent + 1, Reused, ExprShared);
2046  else
2047  write(Op);
2048  if (Op != Ops.back())
2049  write(", ");
2050  }
2051 
2052  write(")");
2053  }
2054 
2055  const std::string &getResult() {
2056  Stream.flush();
2057  return Str;
2058  }
2059  };
2060 
2061  /// Generate remarks for matrix operations in a function. To generate remarks
2062  /// for matrix expressions, the following approach is used:
2063  /// 1. Use the inlined-at debug information to group matrix operations to the
2064  /// DISubprograms they are contained in.
2065  /// 2. Collect leaves of matrix expressions (done in
2066  /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression
2067  // mapping. Leaves are lowered matrix instructions without other matrix
2068  // users (like stores) in the current subprogram.
2069  /// 3. For each leaf, create a remark containing a linearizied version of the
2070  /// matrix expression. The expression is linearized by a recursive
2071  /// bottom-up traversal of the matrix operands, starting at a leaf. Note
2072  /// that multiple leaves can share sub-expressions. Shared subexpressions
2073  /// are explicitly marked as shared().
2074  struct RemarkGenerator {
2075  const MapVector<Value *, MatrixTy> &Inst2Matrix;
2077  Function &Func;
2078  const DataLayout &DL;
2079 
2080  RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix,
2081  OptimizationRemarkEmitter &ORE, Function &Func)
2082  : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func),
2083  DL(Func.getParent()->getDataLayout()) {}
2084 
2085  /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are
2086  /// instructions in Inst2Matrix returning void or without any users in
2087  /// \p ExprsInSubprogram. Currently that should only include stores.
2089  getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) {
2090  SmallVector<Value *, 4> Leaves;
2091  for (auto *Expr : ExprsInSubprogram)
2092  if (Expr->getType()->isVoidTy() ||
2093  !any_of(Expr->users(), [&ExprsInSubprogram](User *U) {
2094  return ExprsInSubprogram.count(U);
2095  }))
2096  Leaves.push_back(Expr);
2097  return Leaves;
2098  }
2099 
2100  /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf
2101  /// to all visited expressions in \p Shared. Limit the matrix operations to
2102  /// the ones in \p ExprsInSubprogram.
2103  void collectSharedInfo(Value *Leaf, Value *V,
2104  const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2105  DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) {
2106 
2107  if (!ExprsInSubprogram.count(V))
2108  return;
2109 
2110  auto I = Shared.insert({V, {}});
2111  I.first->second.insert(Leaf);
2112 
2113  for (Value *Op : cast<Instruction>(V)->operand_values())
2114  collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared);
2115  }
2116 
2117  /// Calculate the number of exclusive and shared op counts for expression
2118  /// starting at \p V. Expressions used multiple times are counted once.
2119  /// Limit the matrix operations to the ones in \p ExprsInSubprogram.
2120  std::pair<OpInfoTy, OpInfoTy>
2121  sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs,
2122  const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2123  DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const {
2124  if (!ExprsInSubprogram.count(Root))
2125  return {};
2126 
2127  // Already counted this expression. Stop.
2128  if (!ReusedExprs.insert(Root).second)
2129  return {};
2130 
2131  OpInfoTy SharedCount;
2132  OpInfoTy Count;
2133 
2134  auto I = Shared.find(Root);
2135  auto CM = Inst2Matrix.find(Root);
2136  if (I->second.size() == 1)
2137  Count = CM->second.getOpInfo();
2138  else
2139  SharedCount = CM->second.getOpInfo();
2140 
2141  for (Value *Op : cast<Instruction>(Root)->operand_values()) {
2142  auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared);
2143  Count += C.first;
2144  SharedCount += C.second;
2145  }
2146  return {Count, SharedCount};
2147  }
2148 
2149  void emitRemarks() {
2150  if (!ORE.allowExtraAnalysis(DEBUG_TYPE))
2151  return;
2152 
2153  // Map matrix operations to their containting subprograms, by traversing
2154  // the inlinedAt chain. If the function does not have a DISubprogram, we
2155  // only map them to the containing function.
2157  for (auto &KV : Inst2Matrix) {
2158  if (Func.getSubprogram()) {
2159  auto *I = cast<Instruction>(KV.first);
2160  DILocation *Context = I->getDebugLoc();
2161  while (Context) {
2162  auto I =
2163  Subprog2Exprs.insert({getSubprogram(Context->getScope()), {}});
2164  I.first->second.push_back(KV.first);
2166  }
2167  } else {
2168  auto I = Subprog2Exprs.insert({nullptr, {}});
2169  I.first->second.push_back(KV.first);
2170  }
2171  }
2172  for (auto &KV : Subprog2Exprs) {
2173  SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(),
2174  KV.second.end());
2175  auto Leaves = getExpressionLeaves(ExprsInSubprogram);
2176 
2178  for (Value *Leaf : Leaves)
2179  collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared);
2180 
2181  // Generate remarks for each leaf.
2182  for (auto *L : Leaves) {
2183 
2184  DebugLoc Loc = cast<Instruction>(L)->getDebugLoc();
2185  DILocation *Context = cast<Instruction>(L)->getDebugLoc();
2186  while (Context) {
2187  if (getSubprogram(Context->getScope()) == KV.first) {
2188  Loc = Context;
2189  break;
2190  }
2192  }
2193 
2194  SmallPtrSet<Value *, 8> ReusedExprs;
2195  OpInfoTy Counts, SharedCounts;
2196  std::tie(Counts, SharedCounts) =
2197  sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared);
2198 
2199  OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc,
2200  cast<Instruction>(L)->getParent());
2201 
2202  Rem << "Lowered with ";
2203  Rem << ore::NV("NumStores", Counts.NumStores) << " stores, "
2204  << ore::NV("NumLoads", Counts.NumLoads) << " loads, "
2205  << ore::NV("NumComputeOps", Counts.NumComputeOps)
2206  << " compute ops, "
2207  << ore::NV("NumExposedTransposes", Counts.NumExposedTransposes)
2208  << " exposed transposes";
2209 
2210  if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 ||
2211  SharedCounts.NumComputeOps > 0) {
2212  Rem << ",\nadditionally "
2213  << ore::NV("NumStores", SharedCounts.NumStores) << " stores, "
2214  << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, "
2215  << ore::NV("NumFPOps", SharedCounts.NumComputeOps)
2216  << " compute ops"
2217  << " are shared with other expressions";
2218  }
2219 
2220  Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL));
2221  ORE.emit(Rem);
2222  }
2223  }
2224  }
2225 
2226  std::string
2227  linearize(Value *L,
2228  const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
2229  const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2230  const DataLayout &DL) {
2231  ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L);
2232  Lin.linearizeExpr(L, 0, false, false);
2233  return Lin.getResult();
2234  }
2235  };
2236 };
2237 } // namespace
2238 
2241  auto &TTI = AM.getResult<TargetIRAnalysis>(F);
2242  OptimizationRemarkEmitter *ORE = nullptr;
2243  AAResults *AA = nullptr;
2244  DominatorTree *DT = nullptr;
2245  LoopInfo *LI = nullptr;
2246 
2247  if (!Minimal) {
2249  AA = &AM.getResult<AAManager>(F);
2250  DT = &AM.getResult<DominatorTreeAnalysis>(F);
2251  LI = &AM.getResult<LoopAnalysis>(F);
2252  }
2253 
2254  LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE);
2255  if (LMT.Visit()) {
2256  PreservedAnalyses PA;
2257  if (!Minimal) {
2258  PA.preserve<LoopAnalysis>();
2260  }
2261  return PA;
2262  }
2263  return PreservedAnalyses::all();
2264 }
2265 
2267  raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
2269  OS, MapClassName2PassName);
2270  OS << "<";
2271  if (Minimal)
2272  OS << "minimal";
2273  OS << ">";
2274 }
2275 
2276 namespace {
2277 
2278 class LowerMatrixIntrinsicsLegacyPass : public FunctionPass {
2279 public:
2280  static char ID;
2281 
2282  LowerMatrixIntrinsicsLegacyPass() : FunctionPass(ID) {
2285  }
2286 
2287  bool runOnFunction(Function &F) override {
2288  auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
2289  auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
2290  auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
2291  auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
2292  auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
2293  LowerMatrixIntrinsics LMT(F, TTI, &AA, &DT, &LI, &ORE);
2294  bool C = LMT.Visit();
2295  return C;
2296  }
2297 
2298  void getAnalysisUsage(AnalysisUsage &AU) const override {
2306  }
2307 };
2308 } // namespace
2309 
2310 static const char pass_name[] = "Lower the matrix intrinsics";
2312 INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,
2313  false, false)
2318 INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,
2320 
2322  return new LowerMatrixIntrinsicsLegacyPass();
2323 }
2324 
2325 namespace {
2326 
2327 /// A lightweight version of the matrix lowering pass that only requires TTI.
2328 /// Advanced features that require DT, AA or ORE like tiling are disabled. This
2329 /// is used to lower matrix intrinsics if the main lowering pass is not run, for
2330 /// example with -O0.
2331 class LowerMatrixIntrinsicsMinimalLegacyPass : public FunctionPass {
2332 public:
2333  static char ID;
2334 
2335  LowerMatrixIntrinsicsMinimalLegacyPass() : FunctionPass(ID) {
2338  }
2339 
2340  bool runOnFunction(Function &F) override {
2341  auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
2342  LowerMatrixIntrinsics LMT(F, TTI, nullptr, nullptr, nullptr, nullptr);
2343  bool C = LMT.Visit();
2344  return C;
2345  }
2346 
2347  void getAnalysisUsage(AnalysisUsage &AU) const override {
2349  AU.setPreservesCFG();
2350  }
2351 };
2352 } // namespace
2353 
2354 static const char pass_name_minimal[] = "Lower the matrix intrinsics (minimal)";
2356 INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsMinimalLegacyPass,
2357  "lower-matrix-intrinsics-minimal", pass_name_minimal,
2358  false, false)
2359 INITIALIZE_PASS_END(LowerMatrixIntrinsicsMinimalLegacyPass,
2361  false)
2362 
2364  return new LowerMatrixIntrinsicsMinimalLegacyPass();
2365 }
llvm::Intrinsic::getBaseName
StringRef getBaseName(ID id)
Return the LLVM name for an intrinsic, without encoded types for overloading, such as "llvm....
Definition: Function.cpp:874
i
i
Definition: README.txt:29
llvm::PreservedAnalyses
A set of analyses that are preserved following a run of a transformation pass.
Definition: PassManager.h:155
pass_name
static const char pass_name[]
Definition: LowerMatrixIntrinsics.cpp:2310
llvm::AAManager
A manager for alias analyses.
Definition: AliasAnalysis.h:1288
llvm::TargetIRAnalysis
Analysis pass providing the TargetTransformInfo.
Definition: TargetTransformInfo.h:2331
llvm::cast
std::enable_if_t<!is_simple_type< Y >::value, typename cast_retty< X, const Y >::ret_type > cast(const Y &Val)
Definition: Casting.h:254
llvm::Function::isIntrinsic
bool isIntrinsic() const
isIntrinsic - Returns true if the function's name starts with "llvm.".
Definition: Function.h:212
llvm::MemoryLocation::get
static MemoryLocation get(const LoadInst *LI)
Return a location with information about the memory reference by the given instruction.
Definition: MemoryLocation.cpp:37
llvm
This file implements support for optimizing divisions by a constant.
Definition: AllocatorList.h:23
BlockSize
static const int BlockSize
Definition: TarWriter.cpp:33
llvm::DILocalScope::getSubprogram
DISubprogram * getSubprogram() const
Get the subprogram for this scope.
Definition: DebugInfoMetadata.cpp:814
M
We currently emits eax Perhaps this is what we really should generate is Is imull three or four cycles eax eax The current instruction priority is based on pattern complexity The former is more complex because it folds a load so the latter will not be emitted Perhaps we should use AddedComplexity to give LEA32r a higher priority We should always try to match LEA first since the LEA matching code does some estimate to determine whether the match is profitable if we care more about code then imull is better It s two bytes shorter than movl leal On a Pentium M
Definition: README.txt:252
llvm::StringRef::empty
LLVM_NODISCARD bool empty() const
empty - Check if the string is empty.
Definition: StringRef.h:153
llvm::make_range
iterator_range< T > make_range(T x, T y)
Convenience function for iterating over sub-ranges.
Definition: iterator_range.h:53
llvm::DataLayout
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:113
llvm::Value::hasOneUse
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition: Value.h:434
llvm::Intrinsic::getDeclaration
Function * getDeclaration(Module *M, ID id, ArrayRef< Type * > Tys=None)
Create or insert an LLVM Function declaration for an intrinsic, and return it.
Definition: Function.cpp:1379
llvm::User::operands
op_range operands()
Definition: User.h:242
llvm::BasicBlock::iterator
InstListType::iterator iterator
Instruction iterators...
Definition: BasicBlock.h:90
llvm::AllocaInst::getAlign
Align getAlign() const
Return the alignment of the memory that is being allocated by the instruction.
Definition: Instructions.h:120
IntrinsicInst.h
llvm::Type::isPointerTy
bool isPointerTy() const
True if this is an instance of PointerType.
Definition: Type.h:217
ceil
We have fiadd patterns now but the followings have the same cost and complexity We need a way to specify the later is more profitable def def The FP stackifier should handle simple permutates to reduce number of shuffle e g ceil
Definition: README-FPStack.txt:54
llvm::AnalysisManager::getResult
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition: PassManager.h:783
DebugInfoMetadata.h
llvm::MemoryLocation::Ptr
const Value * Ptr
The address of the start of the location.
Definition: MemoryLocation.h:217
llvm::ValueMap::end
iterator end()
Definition: ValueMap.h:136
Scalar.h
llvm::PassInfoMixin< LowerMatrixIntrinsicsPass >
LowerLoad
static SDValue LowerLoad(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)
Definition: X86ISelLowering.cpp:25035
llvm::TypeSize::getFixedSize
ScalarTy getFixedSize() const
Definition: TypeSize.h:426
T
llvm::Function
Definition: Function.h:62
Pass.h
llvm::TargetTransformInfo::getRegisterBitWidth
TypeSize getRegisterBitWidth(RegisterKind K) const
Definition: TargetTransformInfo.cpp:594
llvm::IntrinsicInst::getIntrinsicID
Intrinsic::ID getIntrinsicID() const
Return the intrinsic ID of this intrinsic.
Definition: IntrinsicInst.h:52
LowerStore
static SDValue LowerStore(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)
Definition: X86ISelLowering.cpp:24950
llvm::raw_string_ostream
A raw_ostream that writes to an std::string.
Definition: raw_ostream.h:625
llvm::write
Error write(MCStreamer &Out, ArrayRef< std::string > Inputs)
Definition: DWP.cpp:535
llvm::PointerType::get
static PointerType * get(Type *ElementType, unsigned AddressSpace)
This constructs a pointer to an object of the specified type in a numbered address space.
Definition: Type.cpp:729
llvm::SetVector::size
size_type size() const
Determine the number of elements in the SetVector.
Definition: SetVector.h:77
llvm::Type::getScalarType
Type * getScalarType() const
If this is a vector type, return the element type, otherwise return 'this'.
Definition: Type.h:308
llvm::SmallVector
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1168
llvm::enumerate
detail::enumerator< R > enumerate(R &&TheRange)
Given an input range, returns a new range whose values are are pair (A,B) such that A is the 0-based ...
Definition: STLExtras.h:1981
llvm::PatternMatch::m_Load
OneOps_match< OpTy, Instruction::Load > m_Load(const OpTy &Op)
Matches LoadInst.
Definition: PatternMatch.h:1573
llvm::TargetTransformInfo
This pass provides access to the codegen interfaces that are needed for IR-level transformations.
Definition: TargetTransformInfo.h:168
insertVector
static Value * insertVector(IRBuilderTy &IRB, Value *Old, Value *V, unsigned BeginIndex, const Twine &Name)
Definition: SROA.cpp:2181
ToRemove
ReachingDefAnalysis InstSet & ToRemove
Definition: ARMLowOverheadLoops.cpp:546
llvm::IRBuilder<>
llvm::SPII::Store
@ Store
Definition: SparcInstrInfo.h:33
llvm::AAResults::isNoAlias
bool isNoAlias(const MemoryLocation &LocA, const MemoryLocation &LocB)
A trivial helper function to check to see if the specified pointers are no-alias.
Definition: AliasAnalysis.h:562
DomTreeUpdater.h
ValueTracking.h
OptimizationRemarkEmitter.h
llvm::DominatorTree
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:151
llvm::OptimizationRemarkEmitter::allowExtraAnalysis
bool allowExtraAnalysis(StringRef PassName) const
Whether we allow for extra compile-time budget to perform more analysis to produce fewer false positi...
Definition: OptimizationRemarkEmitter.h:98
llvm::cl::Hidden
@ Hidden
Definition: CommandLine.h:143
llvm::DILocation
Debug location.
Definition: DebugInfoMetadata.h:1580
llvm::X86II::TA
@ TA
Definition: X86BaseInfo.h:803
ForceFusion
static cl::opt< bool > ForceFusion("force-fuse-matrix", cl::init(false), cl::Hidden, cl::desc("Force matrix instruction fusion even if not profitable."))
llvm::MemoryLocation::Size
LocationSize Size
The maximum size of the location, in address-units, or UnknownSize if the size is not known.
Definition: MemoryLocation.h:226
llvm::Type
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
llvm::DebugLoc::getInlinedAt
DILocation * getInlinedAt() const
Definition: DebugLoc.cpp:40
llvm::reverse
auto reverse(ContainerTy &&C, std::enable_if_t< has_rbegin< ContainerTy >::value > *=nullptr)
Definition: STLExtras.h:333
llvm::sys::path::end
const_iterator end(StringRef path)
Get end iterator over path.
Definition: Path.cpp:233
llvm::sys::path::begin
const_iterator begin(StringRef path, Style style=Style::native)
Get begin iterator over path.
Definition: Path.cpp:224
llvm::LoopInfoWrapperPass
The legacy pass manager's analysis pass to compute loop information.
Definition: LoopInfo.h:1268
llvm::DominatorTreeBase::Insert
static constexpr UpdateKind Insert
Definition: GenericDomTree.h:242
T1
#define T1
Definition: Mips16ISelLowering.cpp:340
llvm::SmallSet
SmallSet - This maintains a set of unique values, optimizing for the case when the set is small (less...
Definition: SmallSet.h:134
llvm::operator!=
bool operator!=(uint64_t V1, const APInt &V2)
Definition: APInt.h:1983
vectors
hexagon Hexagon specific predictive commoning for HVX vectors
Definition: HexagonVectorLoopCarriedReuse.cpp:221
Offset
uint64_t Offset
Definition: ELFObjHandler.cpp:81
llvm::MapVector
This class implements a map that also provides access to all stored values in a deterministic order.
Definition: MapVector.h:37
llvm::SmallPtrSet< Instruction *, 16 >
llvm::ore::NV
DiagnosticInfoOptimizationBase::Argument NV
Definition: OptimizationRemarkEmitter.h:136
llvm::VectorType::getElementType
Type * getElementType() const
Definition: DerivedTypes.h:422
llvm::Value::user_begin
user_iterator user_begin()
Definition: Value.h:397
llvm::CallBase::arg_begin
User::op_iterator arg_begin()
Return the iterator pointing to the beginning of the argument list.
Definition: InstrTypes.h:1303
llvm::SmallVectorImpl::pop_back_val
LLVM_NODISCARD T pop_back_val()
Definition: SmallVector.h:635
llvm::UnaryOperator
Definition: InstrTypes.h:102
llvm::successors
succ_range successors(Instruction *I)
Definition: CFG.h:262
llvm::FastMathFlags
Convenience struct for specifying and reasoning about fast-math flags.
Definition: Operator.h:161
llvm::initializeLowerMatrixIntrinsicsMinimalLegacyPassPass
void initializeLowerMatrixIntrinsicsMinimalLegacyPassPass(PassRegistry &)
llvm::LoadInst::getAlign
Align getAlign() const
Return the alignment of the access that is being performed.
Definition: Instructions.h:223
llvm::BitmaskEnumDetail::Mask
std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
Definition: BitmaskEnum.h:80
llvm::DominatorTreeBase::applyUpdates
void applyUpdates(ArrayRef< UpdateType > Updates)
Inform the dominator tree about a sequence of CFG edge insertions and deletions and perform a batch u...
Definition: GenericDomTree.h:544
LLVM_DEBUG
#define LLVM_DEBUG(X)
Definition: Debug.h:101
llvm::NVPTX::PTXLdStInstCode::VecType
VecType
Definition: NVPTX.h:121
F
#define F(x, y, z)
Definition: MD5.cpp:56
llvm::RISCVFenceField::R
@ R
Definition: RISCVBaseInfo.h:198
llvm::DomTreeUpdater::UpdateStrategy::Lazy
@ Lazy
llvm::TileInfo
A helper struct to create IR loop nests for tiling in IR of the following form: for CurrentColumn = 0...
Definition: MatrixUtils.h:31
llvm::BasicBlock
LLVM Basic Block Representation.
Definition: BasicBlock.h:58
getSubprogram
static DISubprogram * getSubprogram(DIScope *Scope)
Helper function to either return Scope, if it is a subprogram or the attached subprogram for a local ...
Definition: LowerMatrixIntrinsics.cpp:85
AliasAnalysis.h
MatrixBuilder.h
Context
LLVMContext & Context
Definition: NVVMIntrRange.cpp:66
llvm::dbgs
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
GraphTraits.h
llvm::DominatorTree::dominates
bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
Definition: Dominators.cpp:115
CommandLine.h
llvm::UnaryOperator::getOpcode
UnaryOps getOpcode() const
Definition: InstrTypes.h:172
extractVector
static Value * extractVector(IRBuilderTy &IRB, Value *V, unsigned BeginIndex, unsigned EndIndex, const Twine &Name)
Definition: SROA.cpp:2156
llvm::ConstantInt
This is the shared class of boolean and integer constants.
Definition: Constants.h:79
llvm::Intrinsic::getType
FunctionType * getType(LLVMContext &Context, ID id, ArrayRef< Type * > Tys=None)
Return the function type for an intrinsic.
Definition: Function.cpp:1335
llvm::PassRegistry::getPassRegistry
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
Definition: PassRegistry.cpp:31
MatrixUtils.h
llvm::PatternMatch::match
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
isZero
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Definition: Lint.cpp:519
llvm::AAResults
Definition: AliasAnalysis.h:508
E
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
llvm::SmallVectorImpl::append
void append(in_iter in_start, in_iter in_end)
Add the specified range to the end of the SmallVector.
Definition: SmallVector.h:648
llvm::createLowerMatrixIntrinsicsPass
Pass * createLowerMatrixIntrinsicsPass()
Definition: LowerMatrixIntrinsics.cpp:2321
llvm::User
Definition: User.h:44
C
(vector float) vec_cmpeq(*A, *B) C
Definition: README_ALTIVEC.txt:86
llvm::ARM_PROC::A
@ A
Definition: ARMBaseInfo.h:34
TileSize
static cl::opt< unsigned > TileSize("fuse-matrix-tile-size", cl::init(4), cl::Hidden, cl::desc("Tile size for matrix instruction fusion using square-shaped tiles."))
llvm::CallBase::getCalledFunction
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation.
Definition: InstrTypes.h:1383
llvm::BasicBlock::begin
iterator begin()
Instruction iterator methods.
Definition: BasicBlock.h:296
llvm::operator+=
std::string & operator+=(std::string &buffer, StringRef string)
Definition: StringRef.h:941
llvm::AnalysisUsage
Represent the analysis usage information of a pass.
Definition: PassAnalysisSupport.h:47
llvm::LocationSize::getValue
uint64_t getValue() const
Definition: MemoryLocation.h:158
llvm::ms_demangle::QualifierMangleMode::Result
@ Result
llvm::BitTracker
Definition: BitTracker.h:35
llvm::Value::uses
iterator_range< use_iterator > uses()
Definition: Value.h:376
false
Definition: StackSlotColoring.cpp:142
llvm::MaybeAlign
This struct is a compact representation of a valid (power of two) or undefined (0) alignment.
Definition: Alignment.h:109
B
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
llvm::BinaryOperator::getOpcode
BinaryOps getOpcode() const
Definition: InstrTypes.h:393
llvm::PatternMatch::m_ConstantInt
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
Definition: PatternMatch.h:145
llvm::Instruction
Definition: Instruction.h:45
llvm::Type::getScalarSizeInBits
unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
Definition: Type.cpp:191
llvm::DominatorTreeWrapperPass
Legacy analysis pass which computes a DominatorTree.
Definition: Dominators.h:287
llvm::raw_ostream
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition: raw_ostream.h:53
llvm::LowerMatrixIntrinsicsPass::run
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
Definition: LowerMatrixIntrinsics.cpp:2239
llvm::codeview::EncodedFramePtrReg::BasePtr
@ BasePtr
llvm::raw_ostream::flush
void flush()
Definition: raw_ostream.h:186
llvm::UndefValue::get
static UndefValue * get(Type *T)
Static factory methods - Return an 'undef' object of the specified type.
Definition: Constants.cpp:1796
llvm::DomTreeUpdater
Definition: DomTreeUpdater.h:28
llvm::ConstantInt::get
static Constant * get(Type *Ty, uint64_t V, bool IsSigned=false)
If Ty is a vector type, return a Constant with a splat of the given value.
Definition: Constants.cpp:925
LoopUtils.h
llvm::getUnderlyingObject
const Value * getUnderlyingObject(const Value *V, unsigned MaxLookup=6)
This method strips off any GEP address adjustments and pointer casts from the specified value,...
Definition: ValueTracking.cpp:4389
PatternMatch.h
llvm::TargetTransformInfo::RGK_FixedWidthVector
@ RGK_FixedWidthVector
Definition: TargetTransformInfo.h:907
llvm::FixedVectorType::get
static FixedVectorType * get(Type *ElementType, unsigned NumElts)
Definition: Type.cpp:686
llvm::StoreInst::getAlign
Align getAlign() const
Definition: Instructions.h:353
llvm::Align
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition: Alignment.h:39
llvm::FastMathFlags::setAllowContract
void setAllowContract(bool B=true)
Definition: Operator.h:232
llvm::ValueMap::count
size_type count(const KeyT &Val) const
Return 1 if the specified key is in the map, 0 otherwise.
Definition: ValueMap.h:152
llvm::Value::use_empty
bool use_empty() const
Definition: Value.h:344
MatrixLayoutTy
MatrixLayoutTy
Definition: LowerMatrixIntrinsics.cpp:73
INITIALIZE_PASS_END
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:58
CFG.h
llvm::function_ref
An efficient, type-erasing, non-owning reference to a callable.
Definition: STLExtras.h:168
llvm::VectorType
Base class of all SIMD vector types.
Definition: DerivedTypes.h:389
VectorUtils.h
llvm::cl::opt< bool >
llvm::StoreInst::isVolatile
bool isVolatile() const
Return true if this is a store to a volatile memory location.
Definition: Instructions.h:340
llvm::StoreInst
An instruction for storing to memory.
Definition: Instructions.h:304
llvm::cl::values
ValuesClass values(OptsTy... Options)
Helper to build a ValuesClass by forwarding a variable number of arguments as an initializer list to ...
Definition: CommandLine.h:697
llvm::Instruction::eraseFromParent
SymbolTableList< Instruction >::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
Definition: Instruction.cpp:78
llvm::LowerMatrixIntrinsicsPass::printPipeline
void printPipeline(raw_ostream &OS, function_ref< StringRef(StringRef)> MapClassName2PassName)
Definition: LowerMatrixIntrinsics.cpp:2266
llvm::MapVector::find
iterator find(const KeyT &Key)
Definition: MapVector.h:147
llvm::getPointerOperand
const Value * getPointerOperand(const Value *V)
A helper function that returns the pointer operand of a load, store or GEP instruction.
Definition: Instructions.h:5313
uint64_t
llvm::TargetTransformInfoWrapperPass
Wrapper pass for TargetTransformInfo.
Definition: TargetTransformInfo.h:2387
D
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
intrinsics
expand Expand reduction intrinsics
Definition: ExpandReductions.cpp:200
llvm::ARM_MB::ST
@ ST
Definition: ARMBaseInfo.h:73
INITIALIZE_PASS_DEPENDENCY
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
llvm::PreservedAnalyses::preserve
void preserve()
Mark an analysis as preserved.
Definition: PassManager.h:176
llvm::PHINode::addIncoming
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
Definition: Instructions.h:2783
llvm::addStringMetadataToLoop
void addStringMetadataToLoop(Loop *TheLoop, const char *MDString, unsigned V=0)
Set input string into loop metadata by keeping other values intact.
Definition: LoopUtils.cpp:222
llvm::DenseMap
Definition: DenseMap.h:714
llvm::PatternMatch::m_Store
TwoOps_match< ValueOpTy, PointerOpTy, Instruction::Store > m_Store(const ValueOpTy &ValueOp, const PointerOpTy &PointerOp)
Matches StoreInst.
Definition: PatternMatch.h:1580
llvm::codeview::FrameCookieKind::Copy
@ Copy
llvm::ValueMap::erase
bool erase(const KeyT &Val)
Definition: ValueMap.h:191
llvm::LoopInfoBase::getLoopFor
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
Definition: LoopInfo.h:967
I
#define I(x, y, z)
Definition: MD5.cpp:59
llvm::cl::init
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:441
llvm::make_early_inc_range
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition: STLExtras.h:576
llvm::concatenateVectors
Value * concatenateVectors(IRBuilderBase &Builder, ArrayRef< Value * > Vecs)
Concatenate a list of vectors.
Definition: VectorUtils.cpp:875
IRBuilder.h
assert
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
SI
StandardInstrumentations SI(Debug, VerifyEach)
FuseMatrix
static cl::opt< bool > FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden, cl::desc("Enable/disable fusing matrix instructions."))
llvm::OptimizationRemarkEmitter::emit
void emit(DiagnosticInfoOptimizationBase &OptDiag)
Output the remark via the diagnostic handler and to the optimization record file.
Definition: OptimizationRemarkEmitter.cpp:77
llvm::ValueMap::insert
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: ValueMap.h:173
llvm::operator==
bool operator==(uint64_t V1, const APInt &V2)
Definition: APInt.h:1981
llvm::TTI
TargetTransformInfo TTI
Definition: TargetTransformInfo.h:163
llvm::Type::isVoidTy
bool isVoidTy() const
Return true if this is 'void'.
Definition: Type.h:138
MatrixLayoutTy::ColumnMajor
@ ColumnMajor
LowerMatrixIntrinsics.h
llvm::CallBase::arg_end
User::op_iterator arg_end()
Return the iterator pointing to the end of the argument list.
Definition: InstrTypes.h:1309
llvm::SmallPtrSetImpl::count
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
Definition: SmallPtrSet.h:382
llvm::SmallSet::erase
bool erase(const T &V)
Definition: SmallSet.h:207
Builder
assume Assume Builder
Definition: AssumeBundleBuilder.cpp:650
llvm::PatternMatch::m_Value
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
Definition: PatternMatch.h:76
llvm::SetVector::insert
bool insert(const value_type &X)
Insert a new element into the SetVector.
Definition: SetVector.h:141
llvm::size
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
Definition: STLExtras.h:1532
llvm::Function::getIntrinsicID
Intrinsic::ID getIntrinsicID() const LLVM_READONLY
getIntrinsicID - This method returns the ID number of the specified function, or Intrinsic::not_intri...
Definition: Function.h:207
llvm::ArrayRef
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: APInt.h:32
llvm::LoopInfo
Definition: LoopInfo.h:1083
llvm::Instruction::getFastMathFlags
FastMathFlags getFastMathFlags() const
Convenience function for getting all the fast-math flags, which must be an operator which supports th...
Definition: Instruction.cpp:280
llvm::BinaryOperator
Definition: InstrTypes.h:189
llvm::OptimizationRemarkEmitter
The optimization diagnostic interface.
Definition: OptimizationRemarkEmitter.h:33
Matrix
Live Register Matrix
Definition: LiveRegMatrix.cpp:44
llvm::min
Expected< ExpressionValue > min(const ExpressionValue &Lhs, const ExpressionValue &Rhs)
Definition: FileCheck.cpp:357
llvm::MapVector::insert
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: MapVector.h:117
llvm::any_of
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1558
llvm::CallBase::getParamAlign
MaybeAlign getParamAlign(unsigned ArgNo) const
Extract the alignment for a call or parameter (0=unknown).
Definition: InstrTypes.h:1714
DataLayout.h
llvm::AnalysisUsage::setPreservesCFG
void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition: Pass.cpp:253
llvm::StringRef
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:58
llvm_unreachable
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
Definition: ErrorHandling.h:134
llvm::Value::getType
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
llvm::AnalysisUsage::addPreserved
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
Definition: PassAnalysisSupport.h:98
llvm::Value::replaceAllUsesWith
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:532
getParent
static const Function * getParent(const Value *V)
Definition: BasicAliasAnalysis.cpp:890
llvm::ms_demangle::IntrinsicFunctionKind::New
@ New
DL
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Definition: AArch64SLSHardening.cpp:76
clEnumValN
#define clEnumValN(ENUMVAL, FLAGNAME, DESC)
Definition: CommandLine.h:672
llvm::MatrixBuilder
Definition: MatrixBuilder.h:33
S
add sub stmia L5 ldr r0 bl L_printf $stub Instead of a and a wouldn t it be better to do three moves *Return an aggregate type is even return S
Definition: README.txt:210
llvm::RecurKind::Mul
@ Mul
Product of integers.
pass_name_minimal
static const char pass_name_minimal[]
Definition: LowerMatrixIntrinsics.cpp:2354
llvm::TargetTransformInfo::getNumberOfRegisters
unsigned getNumberOfRegisters(unsigned ClassID) const
Definition: TargetTransformInfo.cpp:581
llvm::AMDGPU::HSAMD::Kernel::Arg::Key::IsVolatile
constexpr char IsVolatile[]
Key for Kernel::Arg::Metadata::mIsVolatile.
Definition: AMDGPUMetadata.h:194
llvm::SmallSet::insert
std::pair< NoneType, bool > insert(const T &V)
insert - Insert an element into the set if it isn't already there.
Definition: SmallSet.h:180
llvm::ifs::IFSSymbolType::Func
@ Func
llvm::Value::getName
StringRef getName() const
Return a constant reference to the value's name.
Definition: Value.cpp:309
llvm::ValueMap
See the file comment.
Definition: ValueMap.h:85
llvm::LoadInst
An instruction for reading from memory.
Definition: Instructions.h:175
llvm::BasicBlock::getTerminator
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.cpp:152
llvm::SPII::Load
@ Load
Definition: SparcInstrInfo.h:32
users
iv users
Definition: IVUsers.cpp:52
llvm::MapVector::end
iterator end()
Definition: MapVector.h:71
llvm::ConstantInt::getZExtValue
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
Definition: Constants.h:142
runOnFunction
static bool runOnFunction(Function &F, bool PostInlining)
Definition: EntryExitInstrumenter.cpp:69
llvm::Twine
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
Definition: Twine.h:83
Alignment.h
llvm::commonAlignment
Align commonAlignment(Align A, Align B)
Returns the alignment that satisfies both alignments.
Definition: Alignment.h:211
llvm::OptimizationRemarkEmitterWrapperPass
OptimizationRemarkEmitter legacy analysis pass.
Definition: OptimizationRemarkEmitter.h:146
MatrixLayout
static cl::opt< MatrixLayoutTy > MatrixLayout("matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor), cl::desc("Sets the default matrix layout"), cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major", "Use column-major layout"), clEnumValN(MatrixLayoutTy::RowMajor, "row-major", "Use row-major layout")))
llvm::GraphProgram::Name
Name
Definition: GraphWriter.h:52
AllowContractEnabled
static cl::opt< bool > AllowContractEnabled("matrix-allow-contract", cl::init(false), cl::Hidden, cl::desc("Allow the use of FMAs if available and profitable. This may " "result in different results, due to less rounding error."))
llvm::StringRef::drop_front
LLVM_NODISCARD StringRef drop_front(size_t N=1) const
Return a StringRef equal to 'this' but with the first N elements dropped.
Definition: StringRef.h:653
llvm::DIScope
Base class for scope-like contexts.
Definition: DebugInfoMetadata.h:476
llvm::AMDGPU::SendMsg::Op
Op
Definition: SIDefines.h:324
llvm::PreservedAnalyses::all
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: PassManager.h:161
MatrixLayoutTy::RowMajor
@ RowMajor
llvm::TypeSize
Definition: TypeSize.h:417
Function.h
llvm::Value::hasNUses
bool hasNUses(unsigned N) const
Return true if this Value has exactly N uses.
Definition: Value.cpp:149
llvm::sort
void sort(IteratorTy Start, IteratorTy End)
Definition: STLExtras.h:1492
llvm::SetVector< T, SmallVector< T, N >, SmallDenseSet< T, N > >::count
size_type count(const key_type &key) const
Count the number of elements of a given key in the SetVector.
Definition: SetVector.h:215
llvm::ValueMap::find
iterator find(const KeyT &Val)
Definition: ValueMap.h:156
llvm::MCID::Add
@ Add
Definition: MCInstrDesc.h:183
llvm::ReversePostOrderTraversal
Definition: PostOrderIterator.h:290
llvm::IntrinsicInst
A wrapper class for inspecting calls to intrinsic functions.
Definition: IntrinsicInst.h:45
Vector
So we should use XX3Form_Rcr to implement instrinsic Convert DP outs ins xscvdpsp No builtin are required Round &Convert QP DP(dword[1] is set to zero) No builtin are required Round to Quad Precision because you need to assign rounding mode in instruction Provide builtin(set f128:$vT,(int_ppc_vsx_xsrqpi f128:$vB))(set f128 yields< n x< ty > >< result > yields< ty >< result > No builtin are required Load Store Vector
Definition: README_P9.txt:497
llvm::DominatorTreeAnalysis
Analysis pass which computes a DominatorTree.
Definition: Dominators.h:252
llvm::BasicBlock::reverse_iterator
InstListType::reverse_iterator reverse_iterator
Definition: BasicBlock.h:92
llvm::OptimizationRemark
Diagnostic information for applied optimization remarks.
Definition: DiagnosticInfo.h:685
llvm::Pass
Pass interface - Implemented by all 'passes'.
Definition: Pass.h:91
llvm::X86AS::SS
@ SS
Definition: X86.h:189
Instructions.h
PostOrderIterator.h
llvm::LoadInst::isVolatile
bool isVolatile() const
Return true if this is a load from a volatile memory location.
Definition: Instructions.h:212
llvm::initializeLowerMatrixIntrinsicsLegacyPassPass
void initializeLowerMatrixIntrinsicsLegacyPassPass(PassRegistry &)
SmallVector.h
BT
BitTracker BT
Definition: BitTracker.cpp:73
llvm::CallBase::getArgOperand
Value * getArgOperand(unsigned i) const
Definition: InstrTypes.h:1328
llvm::SmallPtrSetImplBase::empty
LLVM_NODISCARD bool empty() const
Definition: SmallPtrSet.h:91
N
#define N
llvm::AAResultsWrapperPass
A wrapper pass to provide the legacy pass manager access to a suitably prepared AAResults object.
Definition: AliasAnalysis.h:1336
llvm::Instruction::getParent
const BasicBlock * getParent() const
Definition: Instruction.h:94
llvm::to_string
std::string to_string(const T &Value)
Definition: ScopedPrinter.h:63
TargetTransformInfo.h
llvm::iterator_range
A range adaptor for a pair of iterators.
Definition: iterator_range.h:30
llvm::PHINode
Definition: Instructions.h:2633
DEBUG_TYPE
#define DEBUG_TYPE
Definition: LowerMatrixIntrinsics.cpp:52
minimal
lower matrix intrinsics minimal
Definition: LowerMatrixIntrinsics.cpp:2360
llvm::DISubprogram
Subprogram description.
Definition: DebugInfoMetadata.h:1820
llvm::SmallSet::empty
LLVM_NODISCARD bool empty() const
Definition: SmallSet.h:155
llvm::SmallVectorImpl< Instruction * >
INITIALIZE_PASS_BEGIN
INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, false, false) INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass
llvm::SmallPtrSetImpl< Instruction * >
llvm::SmallSetVector
A SetVector that performs no allocations if smaller than a certain size.
Definition: SetVector.h:307
llvm::AnalysisManager
A container for analyses that lazily runs them and caches their results.
Definition: InstructionSimplify.h:44
llvm::FunctionPass
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:298
llvm::CallInst
This class represents a function call, abstracting a target machine's calling convention.
Definition: Instructions.h:1475
BB
Common register allocation spilling lr str ldr sxth r3 ldr mla r4 can lr mov lr str ldr sxth r3 mla r4 and then merge mul and lr str ldr sxth r3 mla r4 It also increase the likelihood the store may become dead bb27 Successors according to LLVM BB
Definition: README.txt:39
GEP
Hexagon Common GEP
Definition: HexagonCommonGEP.cpp:172
llvm::AnalysisUsage::addRequired
AnalysisUsage & addRequired()
Definition: PassAnalysisSupport.h:75
TileUseLoops
static cl::opt< bool > TileUseLoops("fuse-matrix-use-loops", cl::init(false), cl::Hidden, cl::desc("Generate loop nest for tiling."))
llvm::DebugLoc
A debug info location.
Definition: DebugLoc.h:33
llvm::AllocaInst
an instruction to allocate memory on the stack
Definition: Instructions.h:62
llvm::FastMathFlags::allowContract
bool allowContract() const
Definition: Operator.h:211
llvm::User::getOperand
Value * getOperand(unsigned i) const
Definition: User.h:169
llvm::cl::desc
Definition: CommandLine.h:412
llvm::createLowerMatrixIntrinsicsMinimalPass
Pass * createLowerMatrixIntrinsicsMinimalPass()
Definition: LowerMatrixIntrinsics.cpp:2363
llvm::createSequentialMask
llvm::SmallVector< int, 16 > createSequentialMask(unsigned Start, unsigned NumInts, unsigned NumUndefs)
Create a sequential shuffle mask.
Definition: VectorUtils.cpp:820
llvm::StringRef::size
LLVM_NODISCARD size_t size() const
size - Get the string size.
Definition: StringRef.h:157
llvm::SetVector< Value * >
llvm::ConstantAggregateZero::get
static ConstantAggregateZero * get(Type *Ty)
Definition: Constants.cpp:1675
BasicBlockUtils.h
llvm::SplitBlock
BasicBlock * SplitBlock(BasicBlock *Old, Instruction *SplitPt, DominatorTree *DT, LoopInfo *LI=nullptr, MemorySSAUpdater *MSSAU=nullptr, const Twine &BBName="", bool Before=false)
Split the specified block at the specified instruction.
Definition: BasicBlockUtils.cpp:814
llvm::pdb::PDB_SymType::Block
@ Block
InitializePasses.h
llvm::OptimizationRemarkEmitterAnalysis
Definition: OptimizationRemarkEmitter.h:164
llvm::Value
LLVM Value Representation.
Definition: Value.h:74
Debug.h
llvm::Value::users
iterator_range< user_iterator > users()
Definition: Value.h:421
llvm::MemoryLocation
Representation for a specific memory location.
Definition: MemoryLocation.h:209
llvm::LoopAnalysis
Analysis pass that exposes the LoopInfo for a function.
Definition: LoopInfo.h:1243
llvm::DominatorTreeBase::Delete
static constexpr UpdateKind Delete
Definition: GenericDomTree.h:243
llvm::Use
A Use represents the edge between a Value definition and its users.
Definition: Use.h:44
llvm::Type::getPrimitiveSizeInBits
TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
Definition: Type.cpp:166
llvm::SmallPtrSetImpl::insert
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:364
llvm::Intrinsic::ID
unsigned ID
Definition: TargetTransformInfo.h:37