LLVM  15.0.0git
LowerMatrixIntrinsics.cpp
Go to the documentation of this file.
1 //===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Lower matrix intrinsics to vector operations.
10 //
11 // TODO:
12 // * Improve fusion:
13 // * Support more cases, e.g. multiply-add, multiply-sub, operands/results
14 // transposed.
15 // * Improve cost-modeling, e.g. choose different number of rows/columns
16 // columns for tiles, consider cost of copies on alias.
17 //
18 //===----------------------------------------------------------------------===//
19 
22 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/Analysis/LoopInfo.h"
30 #include "llvm/IR/CFG.h"
31 #include "llvm/IR/DataLayout.h"
33 #include "llvm/IR/Function.h"
34 #include "llvm/IR/IRBuilder.h"
35 #include "llvm/IR/Instructions.h"
36 #include "llvm/IR/IntrinsicInst.h"
37 #include "llvm/IR/MatrixBuilder.h"
38 #include "llvm/IR/PatternMatch.h"
39 #include "llvm/InitializePasses.h"
40 #include "llvm/Pass.h"
41 #include "llvm/Support/Alignment.h"
43 #include "llvm/Support/Debug.h"
44 #include "llvm/Transforms/Scalar.h"
48 
49 using namespace llvm;
50 using namespace PatternMatch;
51 
52 #define DEBUG_TYPE "lower-matrix-intrinsics"
53 
54 static cl::opt<bool>
55  FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden,
56  cl::desc("Enable/disable fusing matrix instructions."));
57 // TODO: Allow and use non-square tiles.
59  "fuse-matrix-tile-size", cl::init(4), cl::Hidden,
60  cl::desc(
61  "Tile size for matrix instruction fusion using square-shaped tiles."));
62 static cl::opt<bool> TileUseLoops("fuse-matrix-use-loops", cl::init(false),
63  cl::Hidden,
64  cl::desc("Generate loop nest for tiling."));
66  "force-fuse-matrix", cl::init(false), cl::Hidden,
67  cl::desc("Force matrix instruction fusion even if not profitable."));
69  "matrix-allow-contract", cl::init(false), cl::Hidden,
70  cl::desc("Allow the use of FMAs if available and profitable. This may "
71  "result in different results, due to less rounding error."));
72 
73 enum class MatrixLayoutTy { ColumnMajor, RowMajor };
74 
76  "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor),
77  cl::desc("Sets the default matrix layout"),
79  "Use column-major layout"),
81  "Use row-major layout")));
82 
83 /// Helper function to either return Scope, if it is a subprogram or the
84 /// attached subprogram for a local scope.
86  if (auto *Subprogram = dyn_cast<DISubprogram>(Scope))
87  return Subprogram;
88  return cast<DILocalScope>(Scope)->getSubprogram();
89 }
90 
91 namespace {
92 
93 // Given an element pointer \p BasePtr to the start of a (sub) matrix, compute
94 // the start address of vector \p VecIdx with type (\p EltType x \p NumElements)
95 // assuming \p Stride elements between start two consecutive vectors.
96 // \p Stride must be >= \p NumElements.
97 // For column-major matrixes, the function computes the address of a column
98 // vectors and \p NumElements must be set to the number of elements in a column
99 // (= number of rows of the matrix). For row-major matrixes, the function
100 // computes the address of a row vector and \p NumElements must be set to the
101 // number of elements in a column (= number of columns of the matrix).
102 //
103 // Consider a 4x4 matrix in column-mjaor layout like below
104 //
105 // 0 1 2 3
106 // 0 v_0_0 v_0_1 v_0_2 v_0_3
107 // 1 v_1_0 v_1_1 v_1_2 v_1_3
108 // 2 v_2_0 v_2_1 v_2_2 v_2_3
109 // 3 v_3_0 v_3_1 v_3_2 v_3_3
110 
111 // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1,
112 // we need a pointer to the first element of the submatrix as base pointer.
113 // Then we can use computeVectorAddr to compute the addresses for the columns
114 // of the sub-matrix.
115 //
116 // Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..)
117 // -> just returns Base
118 // Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..)
119 // -> returns Base + (1 * 4)
120 // Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..)
121 // -> returns Base + (2 * 4)
122 //
123 // The graphic below illustrates the number of elements in a column (marked
124 // with |) and the number of skipped elements (marked with }).
125 //
126 // v_0_0 v_0_1 {v_0_2 {v_0_3
127 // Base Col 1 Col 2
128 // | | |
129 // v_1_0 |v_1_1 |v_1_2 |v_1_3
130 // v_2_0 |v_2_1 |v_2_2 |v_2_3
131 // v_3_0 {v_3_1 {v_3_2 v_3_3
132 //
133 Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride,
134  unsigned NumElements, Type *EltType,
135  IRBuilder<> &Builder) {
136 
137  assert((!isa<ConstantInt>(Stride) ||
138  cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) &&
139  "Stride must be >= the number of elements in the result vector.");
140  unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
141 
142  // Compute the start of the vector with index VecIdx as VecIdx * Stride.
143  Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start");
144 
145  // Get pointer to the start of the selected vector. Skip GEP creation,
146  // if we select vector 0.
147  if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->isZero())
148  VecStart = BasePtr;
149  else
150  VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep");
151 
152  // Cast elementwise vector start pointer to a pointer to a vector
153  // (EltType x NumElements)*.
154  auto *VecType = FixedVectorType::get(EltType, NumElements);
155  Type *VecPtrType = PointerType::get(VecType, AS);
156  return Builder.CreatePointerCast(VecStart, VecPtrType, "vec.cast");
157 }
158 
159 /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
160 ///
161 /// Currently, the lowering for each matrix intrinsic is done as follows:
162 /// 1. Propagate the shape information from intrinsics to connected
163 /// instructions.
164 /// 2. Lower instructions with shape information (assuming column-major layout).
165 /// The lowering works similarly using row-major layout.
166 /// 2.1. Get column vectors for each argument. If we already lowered the
167 /// definition of an argument, use the produced column vectors directly.
168 /// If not, split the operand vector containing an embedded matrix into
169 /// a set of column vectors,
170 /// 2.2. Lower the instruction in terms of column major operations, which
171 /// yields a set of column vectors containing result matrix. Note that we
172 /// lower all instructions that have shape information. Besides the
173 /// intrinsics, this includes stores for example.
174 /// 2.3. Update uses of the lowered instruction. If we have shape information
175 /// for a user, there is nothing to do, as we will look up the result
176 /// column matrix when lowering the user. For other uses, we embed the
177 /// result matrix in a flat vector and update the use.
178 /// 2.4. Cache the result column matrix for the instruction we lowered
179 /// 3. After we lowered all instructions in a function, remove the now
180 /// obsolete instructions.
181 ///
182 class LowerMatrixIntrinsics {
183  Function &Func;
184  const DataLayout &DL;
185  const TargetTransformInfo &TTI;
186  AliasAnalysis *AA;
187  DominatorTree *DT;
188  LoopInfo *LI;
190 
191  /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation.
192  struct OpInfoTy {
193  /// Number of stores emitted to generate this matrix.
194  unsigned NumStores = 0;
195  /// Number of loads emitted to generate this matrix.
196  unsigned NumLoads = 0;
197  /// Number of compute operations emitted to generate this matrix.
198  unsigned NumComputeOps = 0;
199  /// Most of the time transposes can be fused with matrix multiplies or can
200  /// be folded away via algebraic simplifications. This is the number of
201  /// transposes that we failed to make "free" via such optimizations.
202  unsigned NumExposedTransposes = 0;
203 
204  OpInfoTy &operator+=(const OpInfoTy &RHS) {
205  NumStores += RHS.NumStores;
206  NumLoads += RHS.NumLoads;
207  NumComputeOps += RHS.NumComputeOps;
208  NumExposedTransposes += RHS.NumExposedTransposes;
209  return *this;
210  }
211  };
212 
213  /// Wrapper class representing a matrix as a set of vectors, either in row or
214  /// column major layout. All vectors must have the same vector type.
215  class MatrixTy {
216  SmallVector<Value *, 16> Vectors;
217 
218  OpInfoTy OpInfo;
219 
220  bool IsColumnMajor = true;
221 
222  public:
223  MatrixTy() : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
224  MatrixTy(ArrayRef<Value *> Vectors)
225  : Vectors(Vectors.begin(), Vectors.end()),
226  IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
227  MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy)
228  : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {
229 
230  unsigned D = isColumnMajor() ? NumColumns : NumRows;
231  for (unsigned J = 0; J < D; ++J)
233  EltTy, isColumnMajor() ? NumRows : NumColumns)));
234  }
235 
236  Value *getVector(unsigned i) const { return Vectors[i]; }
237  Value *getColumn(unsigned i) const {
238  assert(isColumnMajor() && "only supported for column-major matrixes");
239  return Vectors[i];
240  }
241  Value *getRow(unsigned i) const {
242  assert(!isColumnMajor() && "only supported for row-major matrixes");
243  return Vectors[i];
244  }
245 
246  void setVector(unsigned i, Value *V) { Vectors[i] = V; }
247 
248  Type *getElementType() const { return getVectorTy()->getElementType(); }
249 
250  unsigned getNumVectors() const {
251  if (isColumnMajor())
252  return getNumColumns();
253  return getNumRows();
254  }
255 
256  unsigned getNumColumns() const {
257  if (isColumnMajor())
258  return Vectors.size();
259  else {
260  assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
261  return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
262  }
263  }
264  unsigned getNumRows() const {
265  if (isColumnMajor()) {
266  assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
267  return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
268  } else
269  return Vectors.size();
270  }
271 
272  void addVector(Value *V) { Vectors.push_back(V); }
273  VectorType *getColumnTy() {
274  assert(isColumnMajor() && "only supported for column-major matrixes");
275  return getVectorTy();
276  }
277 
278  VectorType *getVectorTy() const {
279  return cast<VectorType>(Vectors[0]->getType());
280  }
281 
283  assert(isColumnMajor() &&
284  "columns() only supported for column-major matrixes");
285  return make_range(Vectors.begin(), Vectors.end());
286  }
287 
289  return make_range(Vectors.begin(), Vectors.end());
290  }
291 
292  /// Embed the vectors of the matrix into a flat vector by concatenating
293  /// them.
294  Value *embedInVector(IRBuilder<> &Builder) const {
295  return Vectors.size() == 1 ? Vectors[0]
296  : concatenateVectors(Builder, Vectors);
297  }
298 
299  MatrixTy &addNumLoads(unsigned N) {
300  OpInfo.NumLoads += N;
301  return *this;
302  }
303 
304  void setNumLoads(unsigned N) { OpInfo.NumLoads = N; }
305 
306  MatrixTy &addNumStores(unsigned N) {
307  OpInfo.NumStores += N;
308  return *this;
309  }
310 
311  MatrixTy &addNumExposedTransposes(unsigned N) {
312  OpInfo.NumExposedTransposes += N;
313  return *this;
314  }
315 
316  MatrixTy &addNumComputeOps(unsigned N) {
317  OpInfo.NumComputeOps += N;
318  return *this;
319  }
320 
321  unsigned getNumStores() const { return OpInfo.NumStores; }
322  unsigned getNumLoads() const { return OpInfo.NumLoads; }
323  unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; }
324 
325  const OpInfoTy &getOpInfo() const { return OpInfo; }
326 
327  bool isColumnMajor() const { return IsColumnMajor; }
328 
329  unsigned getStride() const {
330  if (isColumnMajor())
331  return getNumRows();
332  return getNumColumns();
333  }
334 
335  /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the
336  /// matrix is column-major, the result vector is extracted from a column
337  /// vector, otherwise from a row vector.
338  Value *extractVector(unsigned I, unsigned J, unsigned NumElts,
339  IRBuilder<> &Builder) const {
340  Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I);
341  return Builder.CreateShuffleVector(
342  Vec, createSequentialMask(isColumnMajor() ? I : J, NumElts, 0),
343  "block");
344  }
345  };
346 
347  struct ShapeInfo {
348  unsigned NumRows;
349  unsigned NumColumns;
350 
351  bool IsColumnMajor;
352 
353  ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0)
354  : NumRows(NumRows), NumColumns(NumColumns),
355  IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
356 
357  ShapeInfo(Value *NumRows, Value *NumColumns)
358  : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(),
359  cast<ConstantInt>(NumColumns)->getZExtValue()) {}
360 
361  bool operator==(const ShapeInfo &other) {
362  return NumRows == other.NumRows && NumColumns == other.NumColumns;
363  }
364  bool operator!=(const ShapeInfo &other) { return !(*this == other); }
365 
366  /// Returns true if shape-information is defined, meaning both dimensions
367  /// are != 0.
368  operator bool() const {
369  assert(NumRows == 0 || NumColumns != 0);
370  return NumRows != 0;
371  }
372 
373  unsigned getStride() const {
374  if (IsColumnMajor)
375  return NumRows;
376  return NumColumns;
377  }
378 
379  unsigned getNumVectors() const {
380  if (IsColumnMajor)
381  return NumColumns;
382  return NumRows;
383  }
384  };
385 
386  /// Maps instructions to their shape information. The shape information
387  /// describes the shape to be used while lowering. This matches the shape of
388  /// the result value of the instruction, with the only exceptions being store
389  /// instructions and the matrix_column_major_store intrinsics. For those, the
390  /// shape information indicates that those instructions should be lowered
391  /// using shape information as well. A ValueMap is used so that when
392  /// sub-passes like optimizeTransposes performs RAUW the map stays
393  /// up-to-date.
395 
396  /// List of instructions to remove. While lowering, we are not replacing all
397  /// users of a lowered instruction, if shape information is available and
398  /// those need to be removed after we finished lowering.
400 
401  /// Map from instructions to their produced column matrix.
402  MapVector<Value *, MatrixTy> Inst2ColumnMatrix;
403 
404 private:
405  static FastMathFlags getFastMathFlags(Instruction *Inst) {
406  FastMathFlags FMF;
407 
408  if (isa<FPMathOperator>(*Inst))
409  FMF = Inst->getFastMathFlags();
410 
412 
413  return FMF;
414  }
415 
416 public:
417  LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI,
420  : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), AA(AA), DT(DT),
421  LI(LI), ORE(ORE) {}
422 
423  unsigned getNumOps(Type *VT) {
424  assert(isa<VectorType>(VT) && "Expected vector type");
425  return getNumOps(VT->getScalarType(),
426  cast<FixedVectorType>(VT)->getNumElements());
427  }
428 
429  /// Is this the minimal version executed in the backend pipelines.
430  bool isMinimal() const {
431  return !DT;
432  }
433 
434  /// Return the estimated number of vector ops required for an operation on
435  /// \p VT * N.
436  unsigned getNumOps(Type *ST, unsigned N) {
437  return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedSize() /
438  double(TTI.getRegisterBitWidth(
440  .getFixedSize()));
441  }
442 
443  /// Return the set of vectors that a matrix value is lowered to.
444  ///
445  /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise
446  /// split the flat vector \p MatrixVal containing a matrix with shape \p SI
447  /// into vectors.
448  MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI,
449  IRBuilder<> &Builder) {
450  VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType());
451  assert(VType && "MatrixVal must be a vector type");
452  assert(cast<FixedVectorType>(VType)->getNumElements() ==
453  SI.NumRows * SI.NumColumns &&
454  "The vector size must match the number of matrix elements");
455 
456  // Check if we lowered MatrixVal using shape information. In that case,
457  // return the existing matrix, if it matches the requested shape
458  // information. If there is a mis-match, embed the result in a flat
459  // vector and split it later.
460  auto Found = Inst2ColumnMatrix.find(MatrixVal);
461  if (Found != Inst2ColumnMatrix.end()) {
462  MatrixTy &M = Found->second;
463  // Return the found matrix, if its shape matches the requested shape
464  // information
465  if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns())
466  return M;
467 
468  MatrixVal = M.embedInVector(Builder);
469  }
470 
471  // Otherwise split MatrixVal.
472  SmallVector<Value *, 16> SplitVecs;
473  for (unsigned MaskStart = 0;
474  MaskStart < cast<FixedVectorType>(VType)->getNumElements();
475  MaskStart += SI.getStride()) {
476  Value *V = Builder.CreateShuffleVector(
477  MatrixVal, createSequentialMask(MaskStart, SI.getStride(), 0),
478  "split");
479  SplitVecs.push_back(V);
480  }
481 
482  return {SplitVecs};
483  }
484 
485  /// If \p V already has a known shape return false. Otherwise set the shape
486  /// for instructions that support it.
487  bool setShapeInfo(Value *V, ShapeInfo Shape) {
488  assert(Shape && "Shape not set");
489  if (isa<UndefValue>(V) || !supportsShapeInfo(V))
490  return false;
491 
492  auto SIter = ShapeMap.find(V);
493  if (SIter != ShapeMap.end()) {
494  LLVM_DEBUG(dbgs() << " not overriding existing shape: "
495  << SIter->second.NumRows << " "
496  << SIter->second.NumColumns << " for " << *V << "\n");
497  return false;
498  }
499 
500  ShapeMap.insert({V, Shape});
501  LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns
502  << " for " << *V << "\n");
503  return true;
504  }
505 
506  bool isUniformShape(Value *V) {
507  Instruction *I = dyn_cast<Instruction>(V);
508  if (!I)
509  return true;
510 
511  switch (I->getOpcode()) {
512  case Instruction::FAdd:
513  case Instruction::FSub:
514  case Instruction::FMul: // Scalar multiply.
515  case Instruction::FNeg:
516  case Instruction::Add:
517  case Instruction::Mul:
518  case Instruction::Sub:
519  return true;
520  default:
521  return false;
522  }
523  }
524 
525  /// Returns true if shape information can be used for \p V. The supported
526  /// instructions must match the instructions that can be lowered by this pass.
527  bool supportsShapeInfo(Value *V) {
528  Instruction *Inst = dyn_cast<Instruction>(V);
529  if (!Inst)
530  return false;
531 
532  IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst);
533  if (II)
534  switch (II->getIntrinsicID()) {
535  case Intrinsic::matrix_multiply:
536  case Intrinsic::matrix_transpose:
537  case Intrinsic::matrix_column_major_load:
538  case Intrinsic::matrix_column_major_store:
539  return true;
540  default:
541  return false;
542  }
543  return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
544  }
545 
546  /// Propagate the shape information of instructions to their users.
547  /// The work list contains instructions for which we can compute the shape,
548  /// either based on the information provided by matrix intrinsics or known
549  /// shapes of operands.
551  propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) {
552  SmallVector<Instruction *, 32> NewWorkList;
553  // Pop an element for which we guaranteed to have at least one of the
554  // operand shapes. Add the shape for this and then add users to the work
555  // list.
556  LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n");
557  while (!WorkList.empty()) {
558  Instruction *Inst = WorkList.pop_back_val();
559 
560  // New entry, set the value and insert operands
561  bool Propagate = false;
562 
563  Value *MatrixA;
564  Value *MatrixB;
565  Value *M;
566  Value *N;
567  Value *K;
568  if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>(
569  m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
570  m_Value(N), m_Value(K)))) {
571  Propagate = setShapeInfo(Inst, {M, K});
572  } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>(
573  m_Value(MatrixA), m_Value(M), m_Value(N)))) {
574  // Flip dimensions.
575  Propagate = setShapeInfo(Inst, {N, M});
576  } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>(
577  m_Value(MatrixA), m_Value(), m_Value(),
578  m_Value(), m_Value(M), m_Value(N)))) {
579  Propagate = setShapeInfo(Inst, {N, M});
580  } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>(
581  m_Value(), m_Value(), m_Value(), m_Value(M),
582  m_Value(N)))) {
583  Propagate = setShapeInfo(Inst, {M, N});
584  } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) {
585  auto OpShape = ShapeMap.find(MatrixA);
586  if (OpShape != ShapeMap.end())
587  setShapeInfo(Inst, OpShape->second);
588  continue;
589  } else if (isUniformShape(Inst)) {
590  // Find the first operand that has a known shape and use that.
591  for (auto &Op : Inst->operands()) {
592  auto OpShape = ShapeMap.find(Op.get());
593  if (OpShape != ShapeMap.end()) {
594  Propagate |= setShapeInfo(Inst, OpShape->second);
595  break;
596  }
597  }
598  }
599 
600  if (Propagate) {
601  NewWorkList.push_back(Inst);
602  for (auto *User : Inst->users())
603  if (ShapeMap.count(User) == 0)
604  WorkList.push_back(cast<Instruction>(User));
605  }
606  }
607 
608  return NewWorkList;
609  }
610 
611  /// Propagate the shape to operands of instructions with shape information.
612  /// \p Worklist contains the instruction for which we already know the shape.
614  propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) {
615  SmallVector<Instruction *, 32> NewWorkList;
616 
617  auto pushInstruction = [](Value *V,
618  SmallVectorImpl<Instruction *> &WorkList) {
619  Instruction *I = dyn_cast<Instruction>(V);
620  if (I)
621  WorkList.push_back(I);
622  };
623  // Pop an element with known shape. Traverse the operands, if their shape
624  // derives from the result shape and is unknown, add it and add them to the
625  // worklist.
626  LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n");
627  while (!WorkList.empty()) {
628  Value *V = WorkList.pop_back_val();
629 
630  size_t BeforeProcessingV = WorkList.size();
631  if (!isa<Instruction>(V))
632  continue;
633 
634  Value *MatrixA;
635  Value *MatrixB;
636  Value *M;
637  Value *N;
638  Value *K;
639  if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>(
640  m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
641  m_Value(N), m_Value(K)))) {
642  if (setShapeInfo(MatrixA, {M, N}))
643  pushInstruction(MatrixA, WorkList);
644 
645  if (setShapeInfo(MatrixB, {N, K}))
646  pushInstruction(MatrixB, WorkList);
647 
648  } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>(
649  m_Value(MatrixA), m_Value(M), m_Value(N)))) {
650  // Flip dimensions.
651  if (setShapeInfo(MatrixA, {M, N}))
652  pushInstruction(MatrixA, WorkList);
653  } else if (match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>(
654  m_Value(MatrixA), m_Value(), m_Value(), m_Value(),
655  m_Value(M), m_Value(N)))) {
656  if (setShapeInfo(MatrixA, {M, N})) {
657  pushInstruction(MatrixA, WorkList);
658  }
659  } else if (isa<LoadInst>(V) ||
660  match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) {
661  // Nothing to do, no matrix input.
662  } else if (isa<StoreInst>(V)) {
663  // Nothing to do. We forward-propagated to this so we would just
664  // backward propagate to an instruction with an already known shape.
665  } else if (isUniformShape(V)) {
666  // Propagate to all operands.
667  ShapeInfo Shape = ShapeMap[V];
668  for (Use &U : cast<Instruction>(V)->operands()) {
669  if (setShapeInfo(U.get(), Shape))
670  pushInstruction(U.get(), WorkList);
671  }
672  }
673  // After we discovered new shape info for new instructions in the
674  // worklist, we use their users as seeds for the next round of forward
675  // propagation.
676  for (size_t I = BeforeProcessingV; I != WorkList.size(); I++)
677  for (User *U : WorkList[I]->users())
678  if (isa<Instruction>(U) && V != U)
679  NewWorkList.push_back(cast<Instruction>(U));
680  }
681  return NewWorkList;
682  }
683 
684  /// Try moving transposes in order to fold them away or into multiplies.
685  void optimizeTransposes() {
686  auto ReplaceAllUsesWith = [this](Instruction &Old, Value *New) {
687  // We need to remove Old from the ShapeMap otherwise RAUW will replace it
688  // with New. We should only add New it it supportsShapeInfo so we insert
689  // it conditionally instead.
690  auto S = ShapeMap.find(&Old);
691  if (S != ShapeMap.end()) {
692  ShapeMap.erase(S);
693  if (supportsShapeInfo(New))
694  ShapeMap.insert({New, S->second});
695  }
696  Old.replaceAllUsesWith(New);
697  };
698 
699  // First sink all transposes inside matmuls, hoping that we end up with NN,
700  // NT or TN variants.
701  for (BasicBlock &BB : reverse(Func)) {
702  for (auto II = BB.rbegin(); II != BB.rend();) {
703  Instruction &I = *II;
704  // We may remove II. By default continue on the next/prev instruction.
705  ++II;
706  // If we were to erase II, move again.
707  auto EraseFromParent = [&II](Value *V) {
708  auto *Inst = cast<Instruction>(V);
709  if (Inst->use_empty()) {
710  if (Inst == &*II) {
711  ++II;
712  }
713  Inst->eraseFromParent();
714  }
715  };
716 
717  // If we're creating a new instruction, continue from there.
718  Instruction *NewInst = nullptr;
719 
720  IRBuilder<> IB(&I);
722 
723  Value *TA, *TAMA, *TAMB;
724  ConstantInt *R, *K, *C;
725  if (match(&I, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TA)))) {
726 
727  // Transpose of a transpose is a nop
728  Value *TATA;
729  if (match(TA,
730  m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TATA)))) {
731  ReplaceAllUsesWith(I, TATA);
732  EraseFromParent(&I);
733  EraseFromParent(TA);
734  }
735 
736  // (A * B)^t -> B^t * A^t
737  // RxK KxC CxK KxR
738  else if (match(TA, m_Intrinsic<Intrinsic::matrix_multiply>(
739  m_Value(TAMA), m_Value(TAMB), m_ConstantInt(R),
740  m_ConstantInt(K), m_ConstantInt(C)))) {
741  Value *T0 = Builder.CreateMatrixTranspose(TAMB, K->getZExtValue(),
742  C->getZExtValue(),
743  TAMB->getName() + "_t");
744  // We are being run after shape prop, add shape for newly created
745  // instructions so that we lower them later.
746  setShapeInfo(T0, {C, K});
747  Value *T1 = Builder.CreateMatrixTranspose(TAMA, R->getZExtValue(),
748  K->getZExtValue(),
749  TAMA->getName() + "_t");
750  setShapeInfo(T1, {K, R});
751  NewInst = Builder.CreateMatrixMultiply(T0, T1, C->getZExtValue(),
752  K->getZExtValue(),
753  R->getZExtValue(), "mmul");
754  ReplaceAllUsesWith(I, NewInst);
755  EraseFromParent(&I);
756  EraseFromParent(TA);
757  }
758  }
759 
760  // If we replaced I with a new instruction, continue from there.
761  if (NewInst)
762  II = std::next(BasicBlock::reverse_iterator(NewInst));
763  }
764  }
765 
766  // If we have a TT matmul, lift the transpose. We may be able to fold into
767  // consuming multiply.
768  for (BasicBlock &BB : Func) {
769  for (BasicBlock::iterator II = BB.begin(); II != BB.end();) {
770  Instruction *I = &*II;
771  // We may remove I.
772  ++II;
773  Value *A, *B, *AT, *BT;
774  ConstantInt *R, *K, *C;
775  // A^t * B ^t -> (B * A)^t
776  if (match(&*I, m_Intrinsic<Intrinsic::matrix_multiply>(
777  m_Value(A), m_Value(B), m_ConstantInt(R),
778  m_ConstantInt(K), m_ConstantInt(C))) &&
779  match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(AT))) &&
780  match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value((BT))))) {
781  IRBuilder<> IB(&*I);
783  Value *M = Builder.CreateMatrixMultiply(
784  BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue());
785  setShapeInfo(M, {C, R});
786  Instruction *NewInst = Builder.CreateMatrixTranspose(
787  M, C->getZExtValue(), R->getZExtValue());
788  ReplaceAllUsesWith(*I, NewInst);
789  if (I->use_empty())
790  I->eraseFromParent();
791  if (A->use_empty())
792  cast<Instruction>(A)->eraseFromParent();
793  if (A != B && B->use_empty())
794  cast<Instruction>(B)->eraseFromParent();
795  }
796  }
797  }
798  }
799 
800  bool Visit() {
802 
803  // Initially only the shape of matrix intrinsics is known.
804  // Initialize the work list with ops carrying shape information.
805  for (BasicBlock &BB : Func)
806  for (Instruction &Inst : BB) {
807  IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst);
808  if (!II)
809  continue;
810 
811  switch (II->getIntrinsicID()) {
812  case Intrinsic::matrix_multiply:
813  case Intrinsic::matrix_transpose:
814  case Intrinsic::matrix_column_major_load:
815  case Intrinsic::matrix_column_major_store:
816  WorkList.push_back(&Inst);
817  break;
818  default:
819  break;
820  }
821  }
822 
823  // Avoid unnecessary work if there are no matrix intrinsics in the function.
824  if (WorkList.empty())
825  return false;
826 
827  // Propagate shapes until nothing changes any longer.
828  while (!WorkList.empty()) {
829  WorkList = propagateShapeForward(WorkList);
830  WorkList = propagateShapeBackward(WorkList);
831  }
832 
833  if (!isMinimal()) {
834  optimizeTransposes();
835  LLVM_DEBUG({
836  dbgs() << "Dump after matrix transpose optimization:\n";
837  Func.dump();
838  });
839  }
840 
841  bool Changed = false;
842  SmallVector<CallInst *, 16> MaybeFusableInsts;
843  SmallVector<Instruction *, 16> MatrixInsts;
844 
845  // First, collect all instructions with shape information and candidates for
846  // fusion (currently only matrix multiplies).
848  for (auto *BB : RPOT)
849  for (Instruction &I : *BB) {
850  if (ShapeMap.find(&I) == ShapeMap.end())
851  continue;
852  if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>()))
853  MaybeFusableInsts.push_back(cast<CallInst>(&I));
854  MatrixInsts.push_back(&I);
855  }
856 
857  // Second, try to fuse candidates.
859  for (CallInst *CI : MaybeFusableInsts)
860  LowerMatrixMultiplyFused(CI, FusedInsts);
861  Changed = !FusedInsts.empty();
862 
863  // Third, lower remaining instructions with shape information.
864  for (Instruction *Inst : MatrixInsts) {
865  if (FusedInsts.count(Inst))
866  continue;
867 
868  IRBuilder<> Builder(Inst);
869 
870  if (CallInst *CInst = dyn_cast<CallInst>(Inst))
871  Changed |= VisitCallInst(CInst);
872 
873  Value *Op1;
874  Value *Op2;
875  if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
876  Changed |= VisitBinaryOperator(BinOp);
877  if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
878  Changed |= VisitUnaryOperator(UnOp);
879  if (match(Inst, m_Load(m_Value(Op1))))
880  Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder);
881  else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
882  Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);
883  }
884 
885  if (ORE) {
886  RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func);
887  RemarkGen.emitRemarks();
888  }
889 
890  // Delete the instructions backwards, as it has a reduced likelihood of
891  // having to update as many def-use and use-def chains.
892  //
893  // Because we add to ToRemove during fusion we can't guarantee that defs
894  // are before uses. Change uses to undef temporarily as these should get
895  // removed as well.
896  //
897  // For verification, we keep track of where we changed uses to undefs in
898  // UndefedInsts and then check that we in fact remove them.
899  SmallSet<Instruction *, 16> UndefedInsts;
900  for (auto *Inst : reverse(ToRemove)) {
901  for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
902  if (auto *Undefed = dyn_cast<Instruction>(U.getUser()))
903  UndefedInsts.insert(Undefed);
904  U.set(UndefValue::get(Inst->getType()));
905  }
906  Inst->eraseFromParent();
907  UndefedInsts.erase(Inst);
908  }
909  if (!UndefedInsts.empty()) {
910  // If we didn't remove all undefed instructions, it's a hard error.
911  dbgs() << "Undefed but present instructions:\n";
912  for (auto *I : UndefedInsts)
913  dbgs() << *I << "\n";
914  llvm_unreachable("Undefed but instruction not removed");
915  }
916 
917  return Changed;
918  }
919 
920  /// Turns \p BasePtr into an elementwise pointer to \p EltType.
921  Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) {
922  unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
923  Type *EltPtrType = PointerType::get(EltType, AS);
924  return Builder.CreatePointerCast(BasePtr, EltPtrType);
925  }
926 
927  /// Replace intrinsic calls
928  bool VisitCallInst(CallInst *Inst) {
929  if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic())
930  return false;
931 
932  switch (Inst->getCalledFunction()->getIntrinsicID()) {
933  case Intrinsic::matrix_multiply:
934  LowerMultiply(Inst);
935  break;
936  case Intrinsic::matrix_transpose:
937  LowerTranspose(Inst);
938  break;
939  case Intrinsic::matrix_column_major_load:
940  LowerColumnMajorLoad(Inst);
941  break;
942  case Intrinsic::matrix_column_major_store:
943  LowerColumnMajorStore(Inst);
944  break;
945  default:
946  return false;
947  }
948  return true;
949  }
950 
951  /// Compute the alignment for a column/row \p Idx with \p Stride between them.
952  /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a
953  /// ConstantInt, reduce the initial alignment based on the byte offset. For
954  /// non-ConstantInt strides, return the common alignment of the initial
955  /// alignment and the element size in bytes.
956  Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy,
957  MaybeAlign A) const {
958  Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy);
959  if (Idx == 0)
960  return InitialAlign;
961 
962  TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy);
963  if (auto *ConstStride = dyn_cast<ConstantInt>(Stride)) {
964  uint64_t StrideInBytes =
965  ConstStride->getZExtValue() * ElementSizeInBits / 8;
966  return commonAlignment(InitialAlign, Idx * StrideInBytes);
967  }
968  return commonAlignment(InitialAlign, ElementSizeInBits / 8);
969  }
970 
971  /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between
972  /// vectors.
973  MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride,
974  bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) {
975  auto *VType = cast<VectorType>(Ty);
976  Type *EltTy = VType->getElementType();
977  Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride());
978  Value *EltPtr = createElementPtr(Ptr, EltTy, Builder);
979  MatrixTy Result;
980  for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) {
981  Value *GEP = computeVectorAddr(
982  EltPtr, Builder.getIntN(Stride->getType()->getScalarSizeInBits(), I),
983  Stride, Shape.getStride(), EltTy, Builder);
984  Value *Vector = Builder.CreateAlignedLoad(
985  VecTy, GEP, getAlignForIndex(I, Stride, EltTy, MAlign),
986  IsVolatile, "col.load");
987 
988  Result.addVector(Vector);
989  }
990  return Result.addNumLoads(getNumOps(Result.getVectorTy()) *
991  Result.getNumVectors());
992  }
993 
994  /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix,
995  /// starting at \p MatrixPtr[I][J].
996  MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile,
997  ShapeInfo MatrixShape, Value *I, Value *J,
998  ShapeInfo ResultShape, Type *EltTy,
999  IRBuilder<> &Builder) {
1000 
1001  Value *Offset = Builder.CreateAdd(
1002  Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
1003 
1004  unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace();
1005  Value *EltPtr =
1006  Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS));
1007  Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset);
1008  auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows *
1009  ResultShape.NumColumns);
1010  Type *TilePtrTy = PointerType::get(TileTy, AS);
1011  Value *TilePtr =
1012  Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
1013 
1014  return loadMatrix(TileTy, TilePtr, Align,
1015  Builder.getInt64(MatrixShape.getStride()), IsVolatile,
1016  ResultShape, Builder);
1017  }
1018 
1019  /// Lower a load instruction with shape information.
1020  void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride,
1021  bool IsVolatile, ShapeInfo Shape) {
1022  IRBuilder<> Builder(Inst);
1023  finalizeLowering(Inst,
1024  loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile,
1025  Shape, Builder),
1026  Builder);
1027  }
1028 
1029  /// Lowers llvm.matrix.column.major.load.
1030  ///
1031  /// The intrinsic loads a matrix from memory using a stride between columns.
1032  void LowerColumnMajorLoad(CallInst *Inst) {
1034  "Intrinsic only supports column-major layout!");
1035  Value *Ptr = Inst->getArgOperand(0);
1036  Value *Stride = Inst->getArgOperand(1);
1037  LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride,
1038  cast<ConstantInt>(Inst->getArgOperand(2))->isOne(),
1039  {Inst->getArgOperand(3), Inst->getArgOperand(4)});
1040  }
1041 
1042  /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p
1043  /// MatrixPtr[I][J].
1044  void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr,
1045  MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape,
1046  Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) {
1047  Value *Offset = Builder.CreateAdd(
1048  Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
1049 
1050  unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace();
1051  Value *EltPtr =
1052  Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS));
1053  Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset);
1054  auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() *
1055  StoreVal.getNumColumns());
1056  Type *TilePtrTy = PointerType::get(TileTy, AS);
1057  Value *TilePtr =
1058  Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
1059 
1060  storeMatrix(TileTy, StoreVal, TilePtr, MAlign,
1061  Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder);
1062  }
1063 
1064  /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between
1065  /// vectors.
1066  MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr,
1067  MaybeAlign MAlign, Value *Stride, bool IsVolatile,
1068  IRBuilder<> &Builder) {
1069  auto VType = cast<VectorType>(Ty);
1070  Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
1071  for (auto Vec : enumerate(StoreVal.vectors())) {
1072  Value *GEP = computeVectorAddr(
1073  EltPtr,
1074  Builder.getIntN(Stride->getType()->getScalarSizeInBits(),
1075  Vec.index()),
1076  Stride, StoreVal.getStride(), VType->getElementType(), Builder);
1077  Builder.CreateAlignedStore(Vec.value(), GEP,
1078  getAlignForIndex(Vec.index(), Stride,
1079  VType->getElementType(),
1080  MAlign),
1081  IsVolatile);
1082  }
1083  return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) *
1084  StoreVal.getNumVectors());
1085  }
1086 
1087  /// Lower a store instruction with shape information.
1088  void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A,
1089  Value *Stride, bool IsVolatile, ShapeInfo Shape) {
1090  IRBuilder<> Builder(Inst);
1091  auto StoreVal = getMatrix(Matrix, Shape, Builder);
1092  finalizeLowering(Inst,
1093  storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride,
1094  IsVolatile, Builder),
1095  Builder);
1096  }
1097 
1098  /// Lowers llvm.matrix.column.major.store.
1099  ///
1100  /// The intrinsic store a matrix back memory using a stride between columns.
1101  void LowerColumnMajorStore(CallInst *Inst) {
1103  "Intrinsic only supports column-major layout!");
1104  Value *Matrix = Inst->getArgOperand(0);
1105  Value *Ptr = Inst->getArgOperand(1);
1106  Value *Stride = Inst->getArgOperand(2);
1107  LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride,
1108  cast<ConstantInt>(Inst->getArgOperand(3))->isOne(),
1109  {Inst->getArgOperand(4), Inst->getArgOperand(5)});
1110  }
1111 
1112  // Set elements I..I+NumElts-1 to Block
1113  Value *insertVector(Value *Col, unsigned I, Value *Block,
1114  IRBuilder<> &Builder) {
1115 
1116  // First, bring Block to the same size as Col
1117  unsigned BlockNumElts =
1118  cast<FixedVectorType>(Block->getType())->getNumElements();
1119  unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements();
1120  assert(NumElts >= BlockNumElts && "Too few elements for current block");
1121 
1122  Block = Builder.CreateShuffleVector(
1123  Block, createSequentialMask(0, BlockNumElts, NumElts - BlockNumElts));
1124 
1125  // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7,
1126  // 8, 4, 5, 6
1128  unsigned i;
1129  for (i = 0; i < I; i++)
1130  Mask.push_back(i);
1131 
1132  unsigned VecNumElts =
1133  cast<FixedVectorType>(Col->getType())->getNumElements();
1134  for (; i < I + BlockNumElts; i++)
1135  Mask.push_back(i - I + VecNumElts);
1136 
1137  for (; i < VecNumElts; i++)
1138  Mask.push_back(i);
1139 
1140  return Builder.CreateShuffleVector(Col, Block, Mask);
1141  }
1142 
1143  Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp,
1144  IRBuilder<> &Builder, bool AllowContraction,
1145  unsigned &NumComputeOps) {
1146  NumComputeOps += getNumOps(A->getType());
1147  if (!Sum)
1148  return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B);
1149 
1150  if (UseFPOp) {
1151  if (AllowContraction) {
1152  // Use fmuladd for floating point operations and let the backend decide
1153  // if that's profitable.
1155  Func.getParent(), Intrinsic::fmuladd, A->getType());
1156  return Builder.CreateCall(FMulAdd, {A, B, Sum});
1157  }
1158  NumComputeOps += getNumOps(A->getType());
1159  Value *Mul = Builder.CreateFMul(A, B);
1160  return Builder.CreateFAdd(Sum, Mul);
1161  }
1162 
1163  NumComputeOps += getNumOps(A->getType());
1164  Value *Mul = Builder.CreateMul(A, B);
1165  return Builder.CreateAdd(Sum, Mul);
1166  }
1167 
1168  /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
1169  /// users with shape information, there's nothing to do: they will use the
1170  /// cached value when they are lowered. For other users, \p Matrix is
1171  /// flattened and the uses are updated to use it. Also marks \p Inst for
1172  /// deletion.
1173  void finalizeLowering(Instruction *Inst, MatrixTy Matrix,
1174  IRBuilder<> &Builder) {
1175  auto inserted = Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix));
1176  (void)inserted;
1177  assert(inserted.second && "multiple matrix lowering mapping");
1178 
1179  ToRemove.push_back(Inst);
1180  Value *Flattened = nullptr;
1181  for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
1182  if (ShapeMap.find(U.getUser()) == ShapeMap.end()) {
1183  if (!Flattened)
1184  Flattened = Matrix.embedInVector(Builder);
1185  U.set(Flattened);
1186  }
1187  }
1188  }
1189 
1190  /// Compute \p Result += \p A * \p B for input matrices with left-associating
1191  /// addition.
1192  ///
1193  /// We can fold a transpose into the operand that is used to extract scalars.
1194  /// This is the first operands with row-major and the second with
1195  /// column-major. If \p IsScalarMatrixTransposed we assume the appropriate
1196  /// operand is transposed.
1197  void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A,
1198  const MatrixTy &B, IRBuilder<> &Builder, bool IsTiled,
1199  bool IsScalarMatrixTransposed, FastMathFlags FMF) {
1200  const unsigned VF = std::max<unsigned>(
1202  .getFixedSize() /
1203  Result.getElementType()->getPrimitiveSizeInBits().getFixedSize(),
1204  1U);
1205  unsigned R = Result.getNumRows();
1206  unsigned C = Result.getNumColumns();
1207  unsigned M = A.getNumColumns();
1208 
1209  bool IsFP = Result.getElementType()->isFloatingPointTy();
1210  assert(A.isColumnMajor() == B.isColumnMajor() &&
1211  Result.isColumnMajor() == A.isColumnMajor() &&
1212  "operands must agree on matrix layout");
1213  unsigned NumComputeOps = 0;
1214 
1215  Builder.setFastMathFlags(FMF);
1216 
1217  if (A.isColumnMajor()) {
1218  // Multiply columns from the first operand with scalars from the second
1219  // operand. Then move along the K axes and accumulate the columns. With
1220  // this the adds can be vectorized without reassociation.
1221  for (unsigned J = 0; J < C; ++J) {
1222  unsigned BlockSize = VF;
1223  // If Result is zero, we don't need to accumulate in the K==0 iteration.
1224  bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J));
1225 
1226  for (unsigned I = 0; I < R; I += BlockSize) {
1227  // Gradually lower the vectorization factor to cover the remainder.
1228  while (I + BlockSize > R)
1229  BlockSize /= 2;
1230 
1231  Value *Sum = IsTiled ? Result.extractVector(I, J, BlockSize, Builder)
1232  : nullptr;
1233  for (unsigned K = 0; K < M; ++K) {
1234  Value *L = A.extractVector(I, K, BlockSize, Builder);
1235  Value *RH = Builder.CreateExtractElement(
1236  B.getColumn(IsScalarMatrixTransposed ? K : J),
1237  IsScalarMatrixTransposed ? J : K);
1238  Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat");
1239  Sum =
1240  createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat,
1241  IsFP, Builder, FMF.allowContract(), NumComputeOps);
1242  }
1243  Result.setVector(J,
1244  insertVector(Result.getVector(J), I, Sum, Builder));
1245  }
1246  }
1247  } else {
1248  // Multiply rows from the second operand with scalars from the first
1249  // operand. Then move along the K axes and accumulate the rows. With this
1250  // the adds can be vectorized without reassociation.
1251  for (unsigned I = 0; I < R; ++I) {
1252  unsigned BlockSize = VF;
1253  bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I));
1254  for (unsigned J = 0; J < C; J += BlockSize) {
1255  // Gradually lower the vectorization factor to cover the remainder.
1256  while (J + BlockSize > C)
1257  BlockSize /= 2;
1258 
1259  Value *Sum = nullptr;
1260  for (unsigned K = 0; K < M; ++K) {
1261  Value *R = B.extractVector(K, J, BlockSize, Builder);
1262  Value *LH = Builder.CreateExtractElement(
1263  A.getVector(IsScalarMatrixTransposed ? K : I),
1264  IsScalarMatrixTransposed ? I : K);
1265  Value *Splat = Builder.CreateVectorSplat(BlockSize, LH, "splat");
1266  Sum =
1267  createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R,
1268  IsFP, Builder, FMF.allowContract(), NumComputeOps);
1269  }
1270  Result.setVector(I,
1271  insertVector(Result.getVector(I), J, Sum, Builder));
1272  }
1273  }
1274  }
1275  Result.addNumComputeOps(NumComputeOps);
1276  }
1277 
1278  /// Ensure that the memory in \p Load does not alias \p Store by potentially
1279  /// copying it to a new location. This new or otherwise the original location
1280  /// is returned.
1281  Value *getNonAliasingPointer(LoadInst *Load, StoreInst *Store,
1282  CallInst *MatMul) {
1285 
1286  // If we can statically determine noalias we're good.
1287  if (AA->isNoAlias(LoadLoc, StoreLoc))
1288  return Load->getPointerOperand();
1289 
1290  // Create code to check if the memory locations of the Load and Store
1291  // overlap and if they do, copy Load's operand to a new buffer.
1292 
1293  // First, create new blocks for 2n part of the check and the copy.
1294  BasicBlock *Check0 = MatMul->getParent();
1295  // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a
1296  // DT. Manually collect dominator tree updates, to avoid unnecessary work,
1297  // as we adjust Check0 and Check1's branches.
1299  for (BasicBlock *Succ : successors(Check0))
1300  DTUpdates.push_back({DT->Delete, Check0, Succ});
1301 
1302  BasicBlock *Check1 =
1303  SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1304  nullptr, "alias_cont");
1305  BasicBlock *Copy =
1306  SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1307  nullptr, "copy");
1308  BasicBlock *Fusion =
1309  SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1310  nullptr, "no_alias");
1311 
1312  // Check if the loaded memory location begins before the end of the store
1313  // location. If the condition holds, they might overlap, otherwise they are
1314  // guaranteed to not overlap.
1315  IRBuilder<> Builder(MatMul);
1316  Check0->getTerminator()->eraseFromParent();
1317  Builder.SetInsertPoint(Check0);
1318  Type *IntPtrTy = Builder.getIntPtrTy(Load->getModule()->getDataLayout());
1319  Value *StoreBegin = Builder.CreatePtrToInt(
1320  const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin");
1321  Value *StoreEnd = Builder.CreateAdd(
1322  StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()),
1323  "store.end", true, true);
1324  Value *LoadBegin = Builder.CreatePtrToInt(const_cast<Value *>(LoadLoc.Ptr),
1325  IntPtrTy, "load.begin");
1326  Builder.CreateCondBr(Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1,
1327  Fusion);
1328 
1329  // Check if the store begins before the end of the load location. If the
1330  // condition holds, they alias, otherwise they are guaranteed to not
1331  // overlap.
1332  Check1->getTerminator()->eraseFromParent();
1333  Builder.SetInsertPoint(Check1, Check1->begin());
1334  Value *LoadEnd = Builder.CreateAdd(
1335  LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()),
1336  "load.end", true, true);
1337  Builder.CreateCondBr(Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy,
1338  Fusion);
1339 
1340  // Copy load operand to new alloca.
1341  Builder.SetInsertPoint(Copy, Copy->begin());
1342  auto *VT = cast<FixedVectorType>(Load->getType());
1343  // Use an array type for the alloca, to avoid potentially huge alignment
1344  // requirements for large vector types.
1345  auto *ArrayTy = ArrayType::get(VT->getElementType(), VT->getNumElements());
1346  AllocaInst *Alloca =
1347  Builder.CreateAlloca(ArrayTy, Load->getPointerAddressSpace());
1348  Value *BC = Builder.CreateBitCast(Alloca, VT->getPointerTo());
1349 
1350  Builder.CreateMemCpy(BC, Alloca->getAlign(), Load->getPointerOperand(),
1351  Load->getAlign(), LoadLoc.Size.getValue());
1352  Builder.SetInsertPoint(Fusion, Fusion->begin());
1353  PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3);
1354  PHI->addIncoming(Load->getPointerOperand(), Check0);
1355  PHI->addIncoming(Load->getPointerOperand(), Check1);
1356  PHI->addIncoming(BC, Copy);
1357 
1358  // Adjust DT.
1359  DTUpdates.push_back({DT->Insert, Check0, Check1});
1360  DTUpdates.push_back({DT->Insert, Check0, Fusion});
1361  DTUpdates.push_back({DT->Insert, Check1, Copy});
1362  DTUpdates.push_back({DT->Insert, Check1, Fusion});
1363  DT->applyUpdates(DTUpdates);
1364  return PHI;
1365  }
1366 
1367  bool isFusionProfitable(CallInst *MatMul) {
1368  if (ForceFusion)
1369  return true;
1370 
1371  ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1372  ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1373 
1374  const unsigned R = LShape.NumRows;
1375  const unsigned C = RShape.NumColumns;
1376  const unsigned M = LShape.NumColumns;
1377  auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1378 
1379  const unsigned VF = std::max<unsigned>(
1381  .getFixedSize() /
1382  EltType->getPrimitiveSizeInBits().getFixedSize(),
1383  1U);
1384 
1385  // Cost model for tiling
1386  //
1387  // For tiling to be beneficial, we need reuse either along the R or
1388  // the C axis. We vectorize along the R axis so that means at least
1389  // 3 elements.
1390  // TODO: Also consider cost of copying if operands alias.
1391  if (R <= VF && C == 1)
1392  return false;
1393  // Then we need enough elements to exceed the number of vector
1394  // registers we have. Note that this is an oversimplification since
1395  // fusing also takes some extra loads which may exceed the number of
1396  // reloads necessary.
1397  unsigned Op0Regs = (R + VF - 1) / VF * M;
1398  unsigned Op1Regs = (M + VF - 1) / VF * C;
1399  return Op0Regs + Op1Regs >
1401  }
1402 
1403  MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) {
1404  MatrixTy Res;
1405  auto *ColumType = FixedVectorType::get(EltType, R);
1406  for (unsigned I = 0; I < C; ++I)
1407  Res.addVector(ConstantAggregateZero::get(ColumType));
1408  return Res;
1409  }
1410 
1411  void createTiledLoops(CallInst *MatMul, Value *LPtr, ShapeInfo LShape,
1412  Value *RPtr, ShapeInfo RShape, StoreInst *Store) {
1413  auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1414 
1415  // Create the main tiling loop nest.
1416  TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns, TileSize);
1418  Instruction *InsertI = cast<Instruction>(MatMul);
1419  BasicBlock *Start = InsertI->getParent();
1420  BasicBlock *End =
1421  SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue");
1422  IRBuilder<> Builder(MatMul);
1423  BasicBlock *InnerBody = TI.CreateTiledLoops(Start, End, Builder, DTU, *LI);
1424 
1425  Type *TileVecTy =
1427  MatrixTy TileResult;
1428  // Insert in the inner loop header.
1429  Builder.SetInsertPoint(TI.InnerLoopHeader->getTerminator());
1430  // Create PHI nodes for the result columns to accumulate across iterations.
1431  SmallVector<PHINode *, 4> ColumnPhis;
1432  for (unsigned I = 0; I < TileSize; I++) {
1433  auto *Phi = Builder.CreatePHI(TileVecTy, 2, "result.vec." + Twine(I));
1434  Phi->addIncoming(ConstantAggregateZero::get(TileVecTy),
1435  TI.RowLoopHeader->getSingleSuccessor());
1436  TileResult.addVector(Phi);
1437  ColumnPhis.push_back(Phi);
1438  }
1439 
1440  // Insert in the inner loop body, which computes
1441  // Res += Load(CurrentRow, K) * Load(K, CurrentColumn)
1442  Builder.SetInsertPoint(InnerBody->getTerminator());
1443  // Load tiles of the operands.
1444  MatrixTy A = loadMatrix(LPtr, {}, false, LShape, TI.CurrentRow, TI.CurrentK,
1445  {TileSize, TileSize}, EltType, Builder);
1446  MatrixTy B = loadMatrix(RPtr, {}, false, RShape, TI.CurrentK, TI.CurrentCol,
1447  {TileSize, TileSize}, EltType, Builder);
1448  emitMatrixMultiply(TileResult, A, B, Builder, true, false,
1449  getFastMathFlags(MatMul));
1450  // Store result after the inner loop is done.
1451  Builder.SetInsertPoint(TI.RowLoopLatch->getTerminator());
1452  storeMatrix(TileResult, Store->getPointerOperand(), Store->getAlign(),
1453  Store->isVolatile(), {LShape.NumRows, RShape.NumColumns},
1454  TI.CurrentRow, TI.CurrentCol, EltType, Builder);
1455 
1456  for (unsigned I = 0; I < TileResult.getNumVectors(); I++)
1457  ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.InnerLoopLatch);
1458 
1459  // Force unrolling of a few iterations of the inner loop, to make sure there
1460  // is enough work per iteration.
1461  // FIXME: The unroller should make this decision directly instead, but
1462  // currently the cost-model is not up to the task.
1463  unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns / TileSize);
1464  addStringMetadataToLoop(LI->getLoopFor(TI.InnerLoopHeader),
1465  "llvm.loop.unroll.count", InnerLoopUnrollCount);
1466  }
1467 
1468  void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1,
1469  StoreInst *Store,
1470  SmallPtrSetImpl<Instruction *> &FusedInsts) {
1472  "Tiling only supported for column-major matrixes at the moment!");
1473  if (!isFusionProfitable(MatMul))
1474  return;
1475 
1476  ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1477  ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1478 
1479  const unsigned R = LShape.NumRows;
1480  const unsigned C = RShape.NumColumns;
1481  const unsigned M = LShape.NumColumns;
1482  auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1483 
1484  Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul);
1485  Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul);
1486  Value *CPtr = Store->getPointerOperand();
1487 
1488  if (TileUseLoops && (R % TileSize == 0 && C % TileSize == 0))
1489  createTiledLoops(MatMul, APtr, LShape, BPtr, RShape, Store);
1490  else {
1492  for (unsigned J = 0; J < C; J += TileSize)
1493  for (unsigned I = 0; I < R; I += TileSize) {
1494  const unsigned TileR = std::min(R - I, unsigned(TileSize));
1495  const unsigned TileC = std::min(C - J, unsigned(TileSize));
1496  MatrixTy Res = getZeroMatrix(EltType, TileR, TileC);
1497 
1498  for (unsigned K = 0; K < M; K += TileSize) {
1499  const unsigned TileM = std::min(M - K, unsigned(TileSize));
1500  MatrixTy A =
1501  loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(),
1502  LShape, Builder.getInt64(I), Builder.getInt64(K),
1503  {TileR, TileM}, EltType, Builder);
1504  MatrixTy B =
1505  loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(),
1506  RShape, Builder.getInt64(K), Builder.getInt64(J),
1507  {TileM, TileC}, EltType, Builder);
1508  emitMatrixMultiply(Res, A, B, Builder, true, false,
1509  getFastMathFlags(MatMul));
1510  }
1511  storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M},
1512  Builder.getInt64(I), Builder.getInt64(J), EltType,
1513  Builder);
1514  }
1515  }
1516 
1517  // Mark eliminated instructions as fused and remove them.
1518  FusedInsts.insert(Store);
1519  FusedInsts.insert(MatMul);
1520  Store->eraseFromParent();
1521  MatMul->eraseFromParent();
1522  if (LoadOp0->hasNUses(0)) {
1523  FusedInsts.insert(LoadOp0);
1524  LoadOp0->eraseFromParent();
1525  }
1526  if (LoadOp1 != LoadOp0 && LoadOp1->hasNUses(0)) {
1527  FusedInsts.insert(LoadOp1);
1528  LoadOp1->eraseFromParent();
1529  }
1530  }
1531 
1532  /// Try to lower matrix multiply chains by fusing operations.
1533  ///
1534  /// Call finalizeLowering on lowered instructions. Instructions that are
1535  /// completely eliminated by fusion are added to \p FusedInsts.
1536  void LowerMatrixMultiplyFused(CallInst *MatMul,
1537  SmallPtrSetImpl<Instruction *> &FusedInsts) {
1538  if (!FuseMatrix || !DT)
1539  return;
1540 
1541  assert(AA && LI && "Analyses should be available");
1542 
1543  Value *A = MatMul->getArgOperand(0);
1544  Value *B = MatMul->getArgOperand(1);
1545 
1546  // We can fold the transpose into the operand that is used to fetch scalars.
1547  Value *T;
1549  ? match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))
1550  : match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))) {
1551  IRBuilder<> Builder(MatMul);
1552  auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1553  ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1554  ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1555  const unsigned R = LShape.NumRows;
1556  const unsigned M = LShape.NumColumns;
1557  const unsigned C = RShape.NumColumns;
1558 
1559  MatrixTy MA;
1560  MatrixTy MB;
1561 
1562  Value *Transpose;
1564  MA = getMatrix(A, ShapeInfo(R, M), Builder);
1565  MB = getMatrix(T, ShapeInfo(C, M), Builder);
1566  Transpose = B;
1567  } else {
1568  MA = getMatrix(T, ShapeInfo(R, M), Builder);
1569  MB = getMatrix(B, ShapeInfo(C, M), Builder);
1570  Transpose = A;
1571  }
1572 
1573  // Initialize the output
1574  MatrixTy Result(R, C, EltType);
1575 
1576  emitMatrixMultiply(Result, MA, MB, Builder, false, true,
1577  getFastMathFlags(MatMul));
1578 
1579  FusedInsts.insert(MatMul);
1580  if (Transpose->hasOneUse()) {
1581  FusedInsts.insert(cast<Instruction>(Transpose));
1582  ToRemove.push_back(cast<Instruction>(Transpose));
1583  // TODO: add a fake entry for the folded instruction so that this is
1584  // included in the expression in the remark.
1585  Inst2ColumnMatrix[Transpose] = MatrixTy(M, C, EltType);
1586  }
1587  finalizeLowering(MatMul, Result, Builder);
1588  return;
1589  }
1590 
1591  if (!MatMul->hasOneUse() || MatrixLayout != MatrixLayoutTy::ColumnMajor)
1592  return;
1593 
1594  // Lower {ld, ld} -> matmul -> st chains. No need to call finalizeLowering
1595  // since the single store user will be lowered as part of this.
1596  auto *LoadOp0 = dyn_cast<LoadInst>(A);
1597  auto *LoadOp1 = dyn_cast<LoadInst>(B);
1598  auto *Store = dyn_cast<StoreInst>(*MatMul->user_begin());
1599  if (LoadOp0 && LoadOp1 && Store) {
1600  // The store address must dominate the MatMul instruction, otherwise
1601  // we create invalid IR.
1602  SetVector<Value *> WorkList;
1603  WorkList.insert(Store->getOperand(1));
1605  for (unsigned I = 0; I != WorkList.size(); ++I) {
1606  Value *Current = WorkList[I];
1607  auto *CurrI = dyn_cast<Instruction>(Current);
1608  if (!CurrI)
1609  continue;
1610  if (isa<PHINode>(CurrI))
1611  return;
1612  if (DT->dominates(CurrI, MatMul))
1613  continue;
1614  if (CurrI->mayHaveSideEffects() || CurrI->mayReadFromMemory())
1615  return;
1616  ToHoist.push_back(CurrI);
1617  WorkList.insert(CurrI->op_begin(), CurrI->op_end());
1618  }
1619 
1620  sort(ToHoist, [this](Instruction *A, Instruction *B) {
1621  return DT->dominates(A, B);
1622  });
1623  for (Instruction *I : ToHoist)
1624  I->moveBefore(MatMul);
1625 
1626  emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts);
1627  return;
1628  }
1629  }
1630 
1631  /// Lowers llvm.matrix.multiply.
1632  void LowerMultiply(CallInst *MatMul) {
1633  IRBuilder<> Builder(MatMul);
1634  auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1635  ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1636  ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1637 
1638  const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder);
1639  const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder);
1640  assert(Lhs.getElementType() == Rhs.getElementType() &&
1641  "Matrix multiply argument element types do not match.");
1642 
1643  const unsigned R = LShape.NumRows;
1644  const unsigned C = RShape.NumColumns;
1645  assert(LShape.NumColumns == RShape.NumRows);
1646 
1647  // Initialize the output
1648  MatrixTy Result(R, C, EltType);
1649  assert(Lhs.getElementType() == Result.getElementType() &&
1650  "Matrix multiply result element type does not match arguments.");
1651 
1652  emitMatrixMultiply(Result, Lhs, Rhs, Builder, false, false,
1653  getFastMathFlags(MatMul));
1654  finalizeLowering(MatMul, Result, Builder);
1655  }
1656 
1657  /// Lowers llvm.matrix.transpose.
1658  void LowerTranspose(CallInst *Inst) {
1659  MatrixTy Result;
1660  IRBuilder<> Builder(Inst);
1661  Value *InputVal = Inst->getArgOperand(0);
1662  VectorType *VectorTy = cast<VectorType>(InputVal->getType());
1663  ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2));
1664  MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);
1665 
1666  const unsigned NewNumVecs =
1667  InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns;
1668  const unsigned NewNumElts =
1669  InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows;
1670 
1671  for (unsigned I = 0; I < NewNumVecs; ++I) {
1672  // Build a single result vector. First initialize it.
1673  Value *ResultVector = UndefValue::get(
1674  FixedVectorType::get(VectorTy->getElementType(), NewNumElts));
1675  // Go through the old elements and insert it into the resulting vector.
1676  for (auto J : enumerate(InputMatrix.vectors())) {
1677  Value *Elt = Builder.CreateExtractElement(J.value(), I);
1678  // Row and column indices are transposed.
1679  ResultVector =
1680  Builder.CreateInsertElement(ResultVector, Elt, J.index());
1681  }
1682  Result.addVector(ResultVector);
1683  }
1684 
1685  // TODO: Improve estimate of operations needed for transposes. Currently we
1686  // just count the insertelement/extractelement instructions, but do not
1687  // account for later simplifications/combines.
1688  finalizeLowering(
1689  Inst,
1690  Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns)
1691  .addNumExposedTransposes(1),
1692  Builder);
1693  }
1694 
1695  /// Lower load instructions, if shape information is available.
1696  bool VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) {
1697  auto I = ShapeMap.find(Inst);
1698  if (I == ShapeMap.end())
1699  return false;
1700 
1701  LowerLoad(Inst, Ptr, Inst->getAlign(),
1702  Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
1703  I->second);
1704  return true;
1705  }
1706 
1707  bool VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr,
1708  IRBuilder<> &Builder) {
1709  auto I = ShapeMap.find(StoredVal);
1710  if (I == ShapeMap.end())
1711  return false;
1712 
1713  LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
1714  Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
1715  I->second);
1716  return true;
1717  }
1718 
1719  /// Lower binary operators, if shape information is available.
1720  bool VisitBinaryOperator(BinaryOperator *Inst) {
1721  auto I = ShapeMap.find(Inst);
1722  if (I == ShapeMap.end())
1723  return false;
1724 
1725  Value *Lhs = Inst->getOperand(0);
1726  Value *Rhs = Inst->getOperand(1);
1727 
1728  IRBuilder<> Builder(Inst);
1729  ShapeInfo &Shape = I->second;
1730 
1731  MatrixTy Result;
1732  MatrixTy A = getMatrix(Lhs, Shape, Builder);
1733  MatrixTy B = getMatrix(Rhs, Shape, Builder);
1734  assert(A.isColumnMajor() == B.isColumnMajor() &&
1735  Result.isColumnMajor() == A.isColumnMajor() &&
1736  "operands must agree on matrix layout");
1737 
1738  Builder.setFastMathFlags(getFastMathFlags(Inst));
1739 
1740  // Helper to perform binary op on vectors.
1741  auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) {
1742  switch (Inst->getOpcode()) {
1743  case Instruction::Add:
1744  return Builder.CreateAdd(LHS, RHS);
1745  case Instruction::Mul:
1746  return Builder.CreateMul(LHS, RHS);
1747  case Instruction::Sub:
1748  return Builder.CreateSub(LHS, RHS);
1749  case Instruction::FAdd:
1750  return Builder.CreateFAdd(LHS, RHS);
1751  case Instruction::FMul:
1752  return Builder.CreateFMul(LHS, RHS);
1753  case Instruction::FSub:
1754  return Builder.CreateFSub(LHS, RHS);
1755  default:
1756  llvm_unreachable("Unsupported binary operator for matrix");
1757  }
1758  };
1759 
1760  for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
1761  Result.addVector(BuildVectorOp(A.getVector(I), B.getVector(I)));
1762 
1763  finalizeLowering(Inst,
1764  Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
1765  Result.getNumVectors()),
1766  Builder);
1767  return true;
1768  }
1769 
1770  /// Lower unary operators, if shape information is available.
1771  bool VisitUnaryOperator(UnaryOperator *Inst) {
1772  auto I = ShapeMap.find(Inst);
1773  if (I == ShapeMap.end())
1774  return false;
1775 
1776  Value *Op = Inst->getOperand(0);
1777 
1778  IRBuilder<> Builder(Inst);
1779  ShapeInfo &Shape = I->second;
1780 
1781  MatrixTy Result;
1782  MatrixTy M = getMatrix(Op, Shape, Builder);
1783 
1784  Builder.setFastMathFlags(getFastMathFlags(Inst));
1785 
1786  // Helper to perform unary op on vectors.
1787  auto BuildVectorOp = [&Builder, Inst](Value *Op) {
1788  switch (Inst->getOpcode()) {
1789  case Instruction::FNeg:
1790  return Builder.CreateFNeg(Op);
1791  default:
1792  llvm_unreachable("Unsupported unary operator for matrix");
1793  }
1794  };
1795 
1796  for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
1797  Result.addVector(BuildVectorOp(M.getVector(I)));
1798 
1799  finalizeLowering(Inst,
1800  Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
1801  Result.getNumVectors()),
1802  Builder);
1803  return true;
1804  }
1805 
1806  /// Helper to linearize a matrix expression tree into a string. Currently
1807  /// matrix expressions are linarized by starting at an expression leaf and
1808  /// linearizing bottom up.
1809  struct ExprLinearizer {
1810  unsigned LengthToBreak = 100;
1811  std::string Str;
1813  unsigned LineLength = 0;
1814  const DataLayout &DL;
1815 
1816  /// Mapping from instructions to matrixes. It is used to identify
1817  /// matrix instructions.
1818  const MapVector<Value *, MatrixTy> &Inst2Matrix;
1819 
1820  /// Mapping from values to the leaves of all expressions that the value is
1821  /// part of.
1823 
1824  /// Set of matrix expressions in the scope of a given DISubprogram.
1825  const SmallSetVector<Value *, 32> &ExprsInSubprogram;
1826 
1827  /// Leaf node of the expression to linearize.
1828  Value *Leaf;
1829 
1830  /// Used to keep track of sub-expressions that get reused while linearizing
1831  /// the expression. Re-used sub-expressions are marked as (reused).
1832  SmallPtrSet<Value *, 8> ReusedExprs;
1833 
1834  ExprLinearizer(const DataLayout &DL,
1835  const MapVector<Value *, MatrixTy> &Inst2Matrix,
1836  const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
1837  const SmallSetVector<Value *, 32> &ExprsInSubprogram,
1838  Value *Leaf)
1839  : Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared),
1840  ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {}
1841 
1842  void indent(unsigned N) {
1843  LineLength += N;
1844  for (unsigned i = 0; i < N; i++)
1845  Stream << " ";
1846  }
1847 
1848  void lineBreak() {
1849  Stream << "\n";
1850  LineLength = 0;
1851  }
1852 
1853  void maybeIndent(unsigned Indent) {
1854  if (LineLength >= LengthToBreak)
1855  lineBreak();
1856 
1857  if (LineLength == 0)
1858  indent(Indent);
1859  }
1860 
1861  void write(StringRef S) {
1862  LineLength += S.size();
1863  Stream << S;
1864  }
1865 
1866  Value *getUnderlyingObjectThroughLoads(Value *V) {
1867  if (Value *Ptr = getPointerOperand(V))
1868  return getUnderlyingObjectThroughLoads(Ptr);
1869  else if (V->getType()->isPointerTy())
1870  return getUnderlyingObject(V);
1871  return V;
1872  }
1873 
1874  /// Returns true if \p V is a matrix value in the given subprogram.
1875  bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); }
1876 
1877  /// If \p V is a matrix value, print its shape as as NumRows x NumColumns to
1878  /// \p SS.
1879  void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) {
1880  auto M = Inst2Matrix.find(V);
1881  if (M == Inst2Matrix.end())
1882  SS << "unknown";
1883  else {
1884  SS << M->second.getNumRows();
1885  SS << "x";
1886  SS << M->second.getNumColumns();
1887  }
1888  }
1889 
1890  /// Write the called function name. Handles calls to llvm.matrix.*
1891  /// specially: we write the name, followed by the dimensions of the input
1892  /// matrixes, followed by the scalar type name.
1893  void writeFnName(CallInst *CI) {
1894  if (!CI->getCalledFunction())
1895  write("<no called fn>");
1896  else {
1898  if (!Name.startswith("llvm.matrix")) {
1899  write(Name);
1900  return;
1901  }
1902  auto *II = cast<IntrinsicInst>(CI);
1904  .drop_front(StringRef("llvm.matrix.").size()));
1905  write(".");
1906  std::string Tmp;
1907  raw_string_ostream SS(Tmp);
1908 
1909  switch (II->getIntrinsicID()) {
1910  case Intrinsic::matrix_multiply:
1911  prettyPrintMatrixType(II->getOperand(0), SS);
1912  SS << ".";
1913  prettyPrintMatrixType(II->getOperand(1), SS);
1914  SS << "." << *II->getType()->getScalarType();
1915  break;
1916  case Intrinsic::matrix_transpose:
1917  prettyPrintMatrixType(II->getOperand(0), SS);
1918  SS << "." << *II->getType()->getScalarType();
1919  break;
1920  case Intrinsic::matrix_column_major_load:
1921  prettyPrintMatrixType(II, SS);
1922  SS << "." << *II->getType()->getScalarType();
1923  break;
1924  case Intrinsic::matrix_column_major_store:
1925  prettyPrintMatrixType(II->getOperand(0), SS);
1926  SS << "." << *II->getOperand(0)->getType()->getScalarType();
1927  break;
1928  default:
1929  llvm_unreachable("Unhandled case");
1930  }
1931  SS.flush();
1932  write(Tmp);
1933  }
1934  }
1935 
1936  unsigned getNumShapeArgs(CallInst *CI) const {
1937  if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) {
1938  switch (II->getIntrinsicID()) {
1939  case Intrinsic::matrix_multiply:
1940  return 3;
1941  case Intrinsic::matrix_transpose:
1942  return 2;
1943  case Intrinsic::matrix_column_major_load:
1944  case Intrinsic::matrix_column_major_store:
1945  return 3;
1946  default:
1947  return 0;
1948  }
1949  }
1950  return 0;
1951  }
1952 
1953  /// Special printing for values: for pointers, we print if they refer to an
1954  /// (function) external address or a stack address, for other values we
1955  /// either print the constant or "scalar"/"matrix" for other values.
1956  void write(Value *V) {
1957  V = getUnderlyingObjectThroughLoads(V);
1958  if (V->getType()->isPointerTy()) {
1959  if (isa<AllocaInst>(V)) {
1960  Stream << "stack addr";
1961  LineLength += StringRef("stack addr").size();
1962  } else {
1963  Stream << "addr";
1964  LineLength += StringRef("addr").size();
1965  }
1966  if (!V->getName().empty()) {
1967  Stream << " %" << V->getName() << "";
1968  LineLength += V->getName().size() + 2;
1969  }
1970  return;
1971  }
1972 
1973  std::string Tmp;
1974  raw_string_ostream TmpStream(Tmp);
1975 
1976  if (auto *CI = dyn_cast<ConstantInt>(V))
1977  TmpStream << CI->getValue();
1978  else if (isa<Constant>(V))
1979  TmpStream << "constant";
1980  else {
1981  if (isMatrix(V))
1982  TmpStream << "matrix";
1983  else
1984  TmpStream << "scalar";
1985  }
1986  TmpStream.flush();
1987  Tmp = std::string(StringRef(Tmp).trim());
1988  LineLength += Tmp.size();
1989  Stream << Tmp;
1990  }
1991 
1992  /// Linearize expression \p Expr starting at an indentation of \p Indent.
1993  /// Expressions that are re-used multiple times are prefixed with (reused)
1994  /// at the re-used root instruction.
1995  void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused,
1996  bool ParentShared) {
1997  auto *I = cast<Instruction>(Expr);
1998  maybeIndent(Indent);
2000 
2001  // Is Expr shared with other expression leaves?
2002  bool ExprShared = false;
2003 
2004  // Deal with shared subtrees. Mark them as shared, if required.
2005  if (!ParentShared) {
2006  auto SI = Shared.find(Expr);
2007  assert(SI != Shared.end() && SI->second.count(Leaf));
2008 
2009  for (Value *S : SI->second) {
2010  if (S == Leaf)
2011  continue;
2012  DebugLoc DL = cast<Instruction>(S)->getDebugLoc();
2013  write("shared with remark at line " + std::to_string(DL.getLine()) +
2014  " column " + std::to_string(DL.getCol()) + " (");
2015  }
2016  ExprShared = SI->second.size() > 1;
2017  }
2018 
2019  bool Reused = !ReusedExprs.insert(Expr).second;
2020  if (Reused && !ParentReused)
2021  write("(reused) ");
2022 
2023  if (auto *CI = dyn_cast<CallInst>(I)) {
2024  writeFnName(CI);
2025 
2026  Ops.append(CI->arg_begin(), CI->arg_end() - getNumShapeArgs(CI));
2027  } else if (isa<BitCastInst>(Expr)) {
2028  // Special case bitcasts, which are used to materialize matrixes from
2029  // non-matrix ops.
2030  write("matrix");
2031  return;
2032  } else {
2033  Ops.append(I->value_op_begin(), I->value_op_end());
2034  write(std::string(I->getOpcodeName()));
2035  }
2036 
2037  write(std::string("("));
2038 
2039  unsigned NumOpsToBreak = 1;
2040  if (match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>()))
2041  NumOpsToBreak = 2;
2042 
2043  for (Value *Op : Ops) {
2044  if (Ops.size() > NumOpsToBreak)
2045  lineBreak();
2046 
2047  maybeIndent(Indent + 1);
2048  if (isMatrix(Op))
2049  linearizeExpr(Op, Indent + 1, Reused, ExprShared);
2050  else
2051  write(Op);
2052  if (Op != Ops.back())
2053  write(", ");
2054  }
2055 
2056  write(")");
2057  }
2058 
2059  const std::string &getResult() {
2060  Stream.flush();
2061  return Str;
2062  }
2063  };
2064 
2065  /// Generate remarks for matrix operations in a function. To generate remarks
2066  /// for matrix expressions, the following approach is used:
2067  /// 1. Use the inlined-at debug information to group matrix operations to the
2068  /// DISubprograms they are contained in.
2069  /// 2. Collect leaves of matrix expressions (done in
2070  /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression
2071  // mapping. Leaves are lowered matrix instructions without other matrix
2072  // users (like stores) in the current subprogram.
2073  /// 3. For each leaf, create a remark containing a linearizied version of the
2074  /// matrix expression. The expression is linearized by a recursive
2075  /// bottom-up traversal of the matrix operands, starting at a leaf. Note
2076  /// that multiple leaves can share sub-expressions. Shared subexpressions
2077  /// are explicitly marked as shared().
2078  struct RemarkGenerator {
2079  const MapVector<Value *, MatrixTy> &Inst2Matrix;
2081  Function &Func;
2082  const DataLayout &DL;
2083 
2084  RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix,
2085  OptimizationRemarkEmitter &ORE, Function &Func)
2086  : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func),
2087  DL(Func.getParent()->getDataLayout()) {}
2088 
2089  /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are
2090  /// instructions in Inst2Matrix returning void or without any users in
2091  /// \p ExprsInSubprogram. Currently that should only include stores.
2093  getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) {
2094  SmallVector<Value *, 4> Leaves;
2095  for (auto *Expr : ExprsInSubprogram)
2096  if (Expr->getType()->isVoidTy() ||
2097  !any_of(Expr->users(), [&ExprsInSubprogram](User *U) {
2098  return ExprsInSubprogram.count(U);
2099  }))
2100  Leaves.push_back(Expr);
2101  return Leaves;
2102  }
2103 
2104  /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf
2105  /// to all visited expressions in \p Shared. Limit the matrix operations to
2106  /// the ones in \p ExprsInSubprogram.
2107  void collectSharedInfo(Value *Leaf, Value *V,
2108  const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2109  DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) {
2110 
2111  if (!ExprsInSubprogram.count(V))
2112  return;
2113 
2114  auto I = Shared.insert({V, {}});
2115  I.first->second.insert(Leaf);
2116 
2117  for (Value *Op : cast<Instruction>(V)->operand_values())
2118  collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared);
2119  }
2120 
2121  /// Calculate the number of exclusive and shared op counts for expression
2122  /// starting at \p V. Expressions used multiple times are counted once.
2123  /// Limit the matrix operations to the ones in \p ExprsInSubprogram.
2124  std::pair<OpInfoTy, OpInfoTy>
2125  sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs,
2126  const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2127  DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const {
2128  if (!ExprsInSubprogram.count(Root))
2129  return {};
2130 
2131  // Already counted this expression. Stop.
2132  if (!ReusedExprs.insert(Root).second)
2133  return {};
2134 
2135  OpInfoTy SharedCount;
2136  OpInfoTy Count;
2137 
2138  auto I = Shared.find(Root);
2139  auto CM = Inst2Matrix.find(Root);
2140  if (I->second.size() == 1)
2141  Count = CM->second.getOpInfo();
2142  else
2143  SharedCount = CM->second.getOpInfo();
2144 
2145  for (Value *Op : cast<Instruction>(Root)->operand_values()) {
2146  auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared);
2147  Count += C.first;
2148  SharedCount += C.second;
2149  }
2150  return {Count, SharedCount};
2151  }
2152 
2153  void emitRemarks() {
2154  if (!ORE.allowExtraAnalysis(DEBUG_TYPE))
2155  return;
2156 
2157  // Map matrix operations to their containting subprograms, by traversing
2158  // the inlinedAt chain. If the function does not have a DISubprogram, we
2159  // only map them to the containing function.
2161  for (auto &KV : Inst2Matrix) {
2162  if (Func.getSubprogram()) {
2163  auto *I = cast<Instruction>(KV.first);
2164  DILocation *Context = I->getDebugLoc();
2165  while (Context) {
2166  auto I =
2167  Subprog2Exprs.insert({getSubprogram(Context->getScope()), {}});
2168  I.first->second.push_back(KV.first);
2170  }
2171  } else {
2172  auto I = Subprog2Exprs.insert({nullptr, {}});
2173  I.first->second.push_back(KV.first);
2174  }
2175  }
2176  for (auto &KV : Subprog2Exprs) {
2177  SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(),
2178  KV.second.end());
2179  auto Leaves = getExpressionLeaves(ExprsInSubprogram);
2180 
2182  for (Value *Leaf : Leaves)
2183  collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared);
2184 
2185  // Generate remarks for each leaf.
2186  for (auto *L : Leaves) {
2187 
2188  DebugLoc Loc = cast<Instruction>(L)->getDebugLoc();
2189  DILocation *Context = cast<Instruction>(L)->getDebugLoc();
2190  while (Context) {
2191  if (getSubprogram(Context->getScope()) == KV.first) {
2192  Loc = Context;
2193  break;
2194  }
2196  }
2197 
2198  SmallPtrSet<Value *, 8> ReusedExprs;
2199  OpInfoTy Counts, SharedCounts;
2200  std::tie(Counts, SharedCounts) =
2201  sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared);
2202 
2203  OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc,
2204  cast<Instruction>(L)->getParent());
2205 
2206  Rem << "Lowered with ";
2207  Rem << ore::NV("NumStores", Counts.NumStores) << " stores, "
2208  << ore::NV("NumLoads", Counts.NumLoads) << " loads, "
2209  << ore::NV("NumComputeOps", Counts.NumComputeOps)
2210  << " compute ops, "
2211  << ore::NV("NumExposedTransposes", Counts.NumExposedTransposes)
2212  << " exposed transposes";
2213 
2214  if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 ||
2215  SharedCounts.NumComputeOps > 0) {
2216  Rem << ",\nadditionally "
2217  << ore::NV("NumStores", SharedCounts.NumStores) << " stores, "
2218  << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, "
2219  << ore::NV("NumFPOps", SharedCounts.NumComputeOps)
2220  << " compute ops"
2221  << " are shared with other expressions";
2222  }
2223 
2224  Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL));
2225  ORE.emit(Rem);
2226  }
2227  }
2228  }
2229 
2230  std::string
2231  linearize(Value *L,
2232  const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
2233  const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2234  const DataLayout &DL) {
2235  ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L);
2236  Lin.linearizeExpr(L, 0, false, false);
2237  return Lin.getResult();
2238  }
2239  };
2240 };
2241 } // namespace
2242 
2245  auto &TTI = AM.getResult<TargetIRAnalysis>(F);
2246  OptimizationRemarkEmitter *ORE = nullptr;
2247  AAResults *AA = nullptr;
2248  DominatorTree *DT = nullptr;
2249  LoopInfo *LI = nullptr;
2250 
2251  if (!Minimal) {
2253  AA = &AM.getResult<AAManager>(F);
2254  DT = &AM.getResult<DominatorTreeAnalysis>(F);
2255  LI = &AM.getResult<LoopAnalysis>(F);
2256  }
2257 
2258  LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE);
2259  if (LMT.Visit()) {
2260  PreservedAnalyses PA;
2261  if (!Minimal) {
2262  PA.preserve<LoopAnalysis>();
2264  }
2265  return PA;
2266  }
2267  return PreservedAnalyses::all();
2268 }
2269 
2271  raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
2273  OS, MapClassName2PassName);
2274  OS << "<";
2275  if (Minimal)
2276  OS << "minimal";
2277  OS << ">";
2278 }
2279 
2280 namespace {
2281 
2282 class LowerMatrixIntrinsicsLegacyPass : public FunctionPass {
2283 public:
2284  static char ID;
2285 
2286  LowerMatrixIntrinsicsLegacyPass() : FunctionPass(ID) {
2289  }
2290 
2291  bool runOnFunction(Function &F) override {
2292  auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
2293  auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
2294  auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
2295  auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
2296  auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
2297  LowerMatrixIntrinsics LMT(F, TTI, &AA, &DT, &LI, &ORE);
2298  bool C = LMT.Visit();
2299  return C;
2300  }
2301 
2302  void getAnalysisUsage(AnalysisUsage &AU) const override {
2310  }
2311 };
2312 } // namespace
2313 
2314 static const char pass_name[] = "Lower the matrix intrinsics";
2316 INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,
2317  false, false)
2322 INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,
2324 
2326  return new LowerMatrixIntrinsicsLegacyPass();
2327 }
2328 
2329 namespace {
2330 
2331 /// A lightweight version of the matrix lowering pass that only requires TTI.
2332 /// Advanced features that require DT, AA or ORE like tiling are disabled. This
2333 /// is used to lower matrix intrinsics if the main lowering pass is not run, for
2334 /// example with -O0.
2335 class LowerMatrixIntrinsicsMinimalLegacyPass : public FunctionPass {
2336 public:
2337  static char ID;
2338 
2339  LowerMatrixIntrinsicsMinimalLegacyPass() : FunctionPass(ID) {
2342  }
2343 
2344  bool runOnFunction(Function &F) override {
2345  auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
2346  LowerMatrixIntrinsics LMT(F, TTI, nullptr, nullptr, nullptr, nullptr);
2347  bool C = LMT.Visit();
2348  return C;
2349  }
2350 
2351  void getAnalysisUsage(AnalysisUsage &AU) const override {
2353  AU.setPreservesCFG();
2354  }
2355 };
2356 } // namespace
2357 
2358 static const char pass_name_minimal[] = "Lower the matrix intrinsics (minimal)";
2360 INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsMinimalLegacyPass,
2361  "lower-matrix-intrinsics-minimal", pass_name_minimal,
2362  false, false)
2363 INITIALIZE_PASS_END(LowerMatrixIntrinsicsMinimalLegacyPass,
2365  false)
2366 
2368  return new LowerMatrixIntrinsicsMinimalLegacyPass();
2369 }
llvm::Intrinsic::getBaseName
StringRef getBaseName(ID id)
Return the LLVM name for an intrinsic, without encoded types for overloading, such as "llvm....
Definition: Function.cpp:874
i
i
Definition: README.txt:29
llvm::PreservedAnalyses
A set of analyses that are preserved following a run of a transformation pass.
Definition: PassManager.h:152
pass_name
static const char pass_name[]
Definition: LowerMatrixIntrinsics.cpp:2314
llvm::AAManager
A manager for alias analyses.
Definition: AliasAnalysis.h:1299
llvm::TargetIRAnalysis
Analysis pass providing the TargetTransformInfo.
Definition: TargetTransformInfo.h:2458
llvm::Function::isIntrinsic
bool isIntrinsic() const
isIntrinsic - Returns true if the function's name starts with "llvm.".
Definition: Function.h:210
llvm::MemoryLocation::get
static MemoryLocation get(const LoadInst *LI)
Return a location with information about the memory reference by the given instruction.
Definition: MemoryLocation.cpp:35
llvm
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:17
BlockSize
static const int BlockSize
Definition: TarWriter.cpp:33
llvm::DILocalScope::getSubprogram
DISubprogram * getSubprogram() const
Get the subprogram for this scope.
Definition: DebugInfoMetadata.cpp:930
M
We currently emits eax Perhaps this is what we really should generate is Is imull three or four cycles eax eax The current instruction priority is based on pattern complexity The former is more complex because it folds a load so the latter will not be emitted Perhaps we should use AddedComplexity to give LEA32r a higher priority We should always try to match LEA first since the LEA matching code does some estimate to determine whether the match is profitable if we care more about code then imull is better It s two bytes shorter than movl leal On a Pentium M
Definition: README.txt:252
llvm::make_range
iterator_range< T > make_range(T x, T y)
Convenience function for iterating over sub-ranges.
Definition: iterator_range.h:53
llvm::DataLayout
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:113
llvm::Value::hasOneUse
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition: Value.h:434
llvm::Intrinsic::getDeclaration
Function * getDeclaration(Module *M, ID id, ArrayRef< Type * > Tys=None)
Create or insert an LLVM Function declaration for an intrinsic, and return it.
Definition: Function.cpp:1410
llvm::User::operands
op_range operands()
Definition: User.h:242
llvm::BasicBlock::iterator
InstListType::iterator iterator
Instruction iterators...
Definition: BasicBlock.h:87
llvm::AllocaInst::getAlign
Align getAlign() const
Return the alignment of the memory that is being allocated by the instruction.
Definition: Instructions.h:121
IntrinsicInst.h
llvm::Type::isPointerTy
bool isPointerTy() const
True if this is an instance of PointerType.
Definition: Type.h:218
ceil
We have fiadd patterns now but the followings have the same cost and complexity We need a way to specify the later is more profitable def def The FP stackifier should handle simple permutates to reduce number of shuffle e g ceil
Definition: README-FPStack.txt:54
llvm::AnalysisManager::getResult
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition: PassManager.h:780
DebugInfoMetadata.h
llvm::MemoryLocation::Ptr
const Value * Ptr
The address of the start of the location.
Definition: MemoryLocation.h:218
llvm::ValueMap::end
iterator end()
Definition: ValueMap.h:136
Scalar.h
llvm::PassInfoMixin< LowerMatrixIntrinsicsPass >
LowerLoad
static SDValue LowerLoad(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)
Definition: X86ISelLowering.cpp:25369
llvm::TypeSize::getFixedSize
ScalarTy getFixedSize() const
Definition: TypeSize.h:430
T
llvm::Function
Definition: Function.h:60
Pass.h
llvm::TargetTransformInfo::getRegisterBitWidth
TypeSize getRegisterBitWidth(RegisterKind K) const
Definition: TargetTransformInfo.cpp:612
llvm::IntrinsicInst::getIntrinsicID
Intrinsic::ID getIntrinsicID() const
Return the intrinsic ID of this intrinsic.
Definition: IntrinsicInst.h:53
LowerStore
static SDValue LowerStore(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)
Definition: X86ISelLowering.cpp:25283
llvm::X86AS::SS
@ SS
Definition: X86.h:189
llvm::raw_string_ostream
A raw_ostream that writes to an std::string.
Definition: raw_ostream.h:632
llvm::write
Error write(MCStreamer &Out, ArrayRef< std::string > Inputs)
Definition: DWP.cpp:536
llvm::PointerType::get
static PointerType * get(Type *ElementType, unsigned AddressSpace)
This constructs a pointer to an object of the specified type in a numbered address space.
Definition: Type.cpp:727
llvm::SetVector::size
size_type size() const
Determine the number of elements in the SetVector.
Definition: SetVector.h:77
llvm::Type::getScalarType
Type * getScalarType() const
If this is a vector type, return the element type, otherwise return 'this'.
Definition: Type.h:309
llvm::SmallVector
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1185
llvm::cast
decltype(auto) LLVM_NODISCARD cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition: Casting.h:572
llvm::enumerate
detail::enumerator< R > enumerate(R &&TheRange)
Given an input range, returns a new range whose values are are pair (A,B) such that A is the 0-based ...
Definition: STLExtras.h:2045
llvm::PatternMatch::m_Load
OneOps_match< OpTy, Instruction::Load > m_Load(const OpTy &Op)
Matches LoadInst.
Definition: PatternMatch.h:1573
llvm::TargetTransformInfo
This pass provides access to the codegen interfaces that are needed for IR-level transformations.
Definition: TargetTransformInfo.h:167
insertVector
static Value * insertVector(IRBuilderTy &IRB, Value *Old, Value *V, unsigned BeginIndex, const Twine &Name)
Definition: SROA.cpp:2190
ToRemove
ReachingDefAnalysis InstSet & ToRemove
Definition: ARMLowOverheadLoops.cpp:542
llvm::IRBuilder<>
DomTreeUpdater.h
ValueTracking.h
OptimizationRemarkEmitter.h
llvm::DominatorTree
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:166
llvm::OptimizationRemarkEmitter::allowExtraAnalysis
bool allowExtraAnalysis(StringRef PassName) const
Whether we allow for extra compile-time budget to perform more analysis to produce fewer false positi...
Definition: OptimizationRemarkEmitter.h:98
llvm::cl::Hidden
@ Hidden
Definition: CommandLine.h:139
llvm::DILocation
Debug location.
Definition: DebugInfoMetadata.h:1551
ForceFusion
static cl::opt< bool > ForceFusion("force-fuse-matrix", cl::init(false), cl::Hidden, cl::desc("Force matrix instruction fusion even if not profitable."))
llvm::MemoryLocation::Size
LocationSize Size
The maximum size of the location, in address-units, or UnknownSize if the size is not known.
Definition: MemoryLocation.h:227
llvm::Type
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
llvm::DebugLoc::getInlinedAt
DILocation * getInlinedAt() const
Definition: DebugLoc.cpp:39
llvm::reverse
auto reverse(ContainerTy &&C, std::enable_if_t< has_rbegin< ContainerTy >::value > *=nullptr)
Definition: STLExtras.h:380
llvm::sys::path::end
const_iterator end(StringRef path)
Get end iterator over path.
Definition: Path.cpp:235
llvm::sys::path::begin
const_iterator begin(StringRef path, Style style=Style::native)
Get begin iterator over path.
Definition: Path.cpp:226
llvm::LoopInfoWrapperPass
The legacy pass manager's analysis pass to compute loop information.
Definition: LoopInfo.h:1271
llvm::X86II::TA
@ TA
Definition: X86BaseInfo.h:808
llvm::DominatorTreeBase::Insert
static constexpr UpdateKind Insert
Definition: GenericDomTree.h:242
T1
#define T1
Definition: Mips16ISelLowering.cpp:340
llvm::SmallSet
SmallSet - This maintains a set of unique values, optimizing for the case when the set is small (less...
Definition: SmallSet.h:136
llvm::SPII::Load
@ Load
Definition: SparcInstrInfo.h:32
llvm::operator!=
bool operator!=(uint64_t V1, const APInt &V2)
Definition: APInt.h:2004
Vector
So we should use XX3Form_Rcr to implement intrinsic Convert DP outs ins xscvdpsp No builtin are required Round &Convert QP DP(dword[1] is set to zero) No builtin are required Round to Quad Precision because you need to assign rounding mode in instruction Provide builtin(set f128:$vT,(int_ppc_vsx_xsrqpi f128:$vB))(set f128 yields< n x< ty > >< result > yields< ty >< result > No builtin are required Load Store Vector
Definition: README_P9.txt:497
vectors
hexagon Hexagon specific predictive commoning for HVX vectors
Definition: HexagonVectorLoopCarriedReuse.cpp:221
llvm::MapVector
This class implements a map that also provides access to all stored values in a deterministic order.
Definition: MapVector.h:37
llvm::SmallPtrSet
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
Definition: SmallPtrSet.h:450
llvm::ore::NV
DiagnosticInfoOptimizationBase::Argument NV
Definition: OptimizationRemarkEmitter.h:136
llvm::VectorType::getElementType
Type * getElementType() const
Definition: DerivedTypes.h:422
llvm::Value::user_begin
user_iterator user_begin()
Definition: Value.h:397
llvm::successors
auto successors(MachineBasicBlock *BB)
Definition: MachineSSAContext.h:29
llvm::CallBase::arg_begin
User::op_iterator arg_begin()
Return the iterator pointing to the beginning of the argument list.
Definition: InstrTypes.h:1316
llvm::SmallVectorImpl::pop_back_val
LLVM_NODISCARD T pop_back_val()
Definition: SmallVector.h:654
RHS
Value * RHS
Definition: X86PartialReduction.cpp:76
llvm::UnaryOperator
Definition: InstrTypes.h:101
llvm::FastMathFlags
Convenience struct for specifying and reasoning about fast-math flags.
Definition: FMF.h:21
llvm::initializeLowerMatrixIntrinsicsMinimalLegacyPassPass
void initializeLowerMatrixIntrinsicsMinimalLegacyPassPass(PassRegistry &)
llvm::LoadInst::getAlign
Align getAlign() const
Return the alignment of the access that is being performed.
Definition: Instructions.h:224
llvm::DominatorTreeBase::applyUpdates
void applyUpdates(ArrayRef< UpdateType > Updates)
Inform the dominator tree about a sequence of CFG edge insertions and deletions and perform a batch u...
Definition: GenericDomTree.h:544
LLVM_DEBUG
#define LLVM_DEBUG(X)
Definition: Debug.h:101
llvm::NVPTX::PTXLdStInstCode::VecType
VecType
Definition: NVPTX.h:121
F
#define F(x, y, z)
Definition: MD5.cpp:55
llvm::RISCVFenceField::R
@ R
Definition: RISCVBaseInfo.h:240
llvm::DomTreeUpdater::UpdateStrategy::Lazy
@ Lazy
llvm::TileInfo
A helper struct to create IR loop nests for tiling in IR of the following form: for CurrentColumn = 0...
Definition: MatrixUtils.h:31
llvm::BasicBlock
LLVM Basic Block Representation.
Definition: BasicBlock.h:55
getSubprogram
static DISubprogram * getSubprogram(DIScope *Scope)
Helper function to either return Scope, if it is a subprogram or the attached subprogram for a local ...
Definition: LowerMatrixIntrinsics.cpp:85
AliasAnalysis.h
MatrixBuilder.h
Context
LLVMContext & Context
Definition: NVVMIntrRange.cpp:66
llvm::dbgs
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
llvm::DominatorTree::dominates
bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
Definition: Dominators.cpp:122
llvm::BitmaskEnumDetail::Mask
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
Definition: BitmaskEnum.h:80
CommandLine.h
llvm::UnaryOperator::getOpcode
UnaryOps getOpcode() const
Definition: InstrTypes.h:171
extractVector
static Value * extractVector(IRBuilderTy &IRB, Value *V, unsigned BeginIndex, unsigned EndIndex, const Twine &Name)
Definition: SROA.cpp:2168
LHS
Value * LHS
Definition: X86PartialReduction.cpp:75
llvm::ConstantInt
This is the shared class of boolean and integer constants.
Definition: Constants.h:79
llvm::Intrinsic::getType
FunctionType * getType(LLVMContext &Context, ID id, ArrayRef< Type * > Tys=None)
Return the function type for an intrinsic.
Definition: Function.cpp:1366
llvm::PassRegistry::getPassRegistry
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
Definition: PassRegistry.cpp:31
MatrixUtils.h
llvm::PatternMatch::match
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
isZero
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Definition: Lint.cpp:518
llvm::AAResults
Definition: AliasAnalysis.h:511
E
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
llvm::SmallVectorImpl::append
void append(in_iter in_start, in_iter in_end)
Add the specified range to the end of the SmallVector.
Definition: SmallVector.h:667
llvm::createLowerMatrixIntrinsicsPass
Pass * createLowerMatrixIntrinsicsPass()
Definition: LowerMatrixIntrinsics.cpp:2325
llvm::User
Definition: User.h:44
C
(vector float) vec_cmpeq(*A, *B) C
Definition: README_ALTIVEC.txt:86
llvm::ARM_PROC::A
@ A
Definition: ARMBaseInfo.h:34
TileSize
static cl::opt< unsigned > TileSize("fuse-matrix-tile-size", cl::init(4), cl::Hidden, cl::desc("Tile size for matrix instruction fusion using square-shaped tiles."))
llvm::CallBase::getCalledFunction
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
Definition: InstrTypes.h:1396
llvm::BasicBlock::begin
iterator begin()
Instruction iterator methods.
Definition: BasicBlock.h:297
llvm::operator+=
std::string & operator+=(std::string &buffer, StringRef string)
Definition: StringRef.h:960
llvm::AnalysisUsage
Represent the analysis usage information of a pass.
Definition: PassAnalysisSupport.h:47
llvm::LocationSize::getValue
uint64_t getValue() const
Definition: MemoryLocation.h:159
llvm::ms_demangle::QualifierMangleMode::Result
@ Result
llvm::BitTracker
Definition: BitTracker.h:35
llvm::Value::uses
iterator_range< use_iterator > uses()
Definition: Value.h:376
false
Definition: StackSlotColoring.cpp:141
llvm::MaybeAlign
This struct is a compact representation of a valid (power of two) or undefined (0) alignment.
Definition: Alignment.h:109
B
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
llvm::BinaryOperator::getOpcode
BinaryOps getOpcode() const
Definition: InstrTypes.h:392
llvm::PatternMatch::m_ConstantInt
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
Definition: PatternMatch.h:145
llvm::Instruction
Definition: Instruction.h:42
llvm::Type::getScalarSizeInBits
unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
Definition: Type.cpp:189
llvm::SPII::Store
@ Store
Definition: SparcInstrInfo.h:33
llvm::DominatorTreeWrapperPass
Legacy analysis pass which computes a DominatorTree.
Definition: Dominators.h:302
llvm::raw_ostream
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition: raw_ostream.h:54
llvm::LowerMatrixIntrinsicsPass::run
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
Definition: LowerMatrixIntrinsics.cpp:2243
llvm::codeview::EncodedFramePtrReg::BasePtr
@ BasePtr
llvm::UndefValue::get
static UndefValue * get(Type *T)
Static factory methods - Return an 'undef' object of the specified type.
Definition: Constants.cpp:1769
llvm::DomTreeUpdater
Definition: DomTreeUpdater.h:28
llvm::ConstantInt::get
static Constant * get(Type *Ty, uint64_t V, bool IsSigned=false)
If Ty is a vector type, return a Constant with a splat of the given value.
Definition: Constants.cpp:919
LoopUtils.h
llvm::getUnderlyingObject
const Value * getUnderlyingObject(const Value *V, unsigned MaxLookup=6)
This method strips off any GEP address adjustments and pointer casts from the specified value,...
Definition: ValueTracking.cpp:4343
PatternMatch.h
llvm::TargetTransformInfo::RGK_FixedWidthVector
@ RGK_FixedWidthVector
Definition: TargetTransformInfo.h:919
llvm::FixedVectorType::get
static FixedVectorType * get(Type *ElementType, unsigned NumElts)
Definition: Type.cpp:684
llvm::StoreInst::getAlign
Align getAlign() const
Definition: Instructions.h:354
llvm::Align
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition: Alignment.h:39
llvm::FastMathFlags::setAllowContract
void setAllowContract(bool B=true)
Definition: FMF.h:92
llvm::ValueMap::count
size_type count(const KeyT &Val) const
Return 1 if the specified key is in the map, 0 otherwise.
Definition: ValueMap.h:152
llvm::Value::use_empty
bool use_empty() const
Definition: Value.h:344
llvm::CallingConv::ID
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
MatrixLayoutTy
MatrixLayoutTy
Definition: LowerMatrixIntrinsics.cpp:73
INITIALIZE_PASS_END
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:58
CFG.h
LoopInfo.h
llvm::function_ref
An efficient, type-erasing, non-owning reference to a callable.
Definition: STLFunctionalExtras.h:36
llvm::VectorType
Base class of all SIMD vector types.
Definition: DerivedTypes.h:389
VectorUtils.h
llvm::cl::opt< bool >
llvm::StoreInst::isVolatile
bool isVolatile() const
Return true if this is a store to a volatile memory location.
Definition: Instructions.h:341
llvm::StoreInst
An instruction for storing to memory.
Definition: Instructions.h:305
llvm::cl::values
ValuesClass values(OptsTy... Options)
Helper to build a ValuesClass by forwarding a variable number of arguments as an initializer list to ...
Definition: CommandLine.h:685
llvm::Instruction::eraseFromParent
SymbolTableList< Instruction >::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
Definition: Instruction.cpp:77
llvm::AMDGPU::Hwreg::Offset
Offset
Definition: SIDefines.h:406
llvm::LowerMatrixIntrinsicsPass::printPipeline
void printPipeline(raw_ostream &OS, function_ref< StringRef(StringRef)> MapClassName2PassName)
Definition: LowerMatrixIntrinsics.cpp:2270
llvm::StringRef::empty
constexpr LLVM_NODISCARD bool empty() const
empty - Check if the string is empty.
Definition: StringRef.h:153
llvm::MapVector::find
iterator find(const KeyT &Key)
Definition: MapVector.h:148
llvm::getPointerOperand
const Value * getPointerOperand(const Value *V)
A helper function that returns the pointer operand of a load, store or GEP instruction.
Definition: Instructions.h:5344
uint64_t
llvm::TargetTransformInfoWrapperPass
Wrapper pass for TargetTransformInfo.
Definition: TargetTransformInfo.h:2514
D
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
intrinsics
expand Expand reduction intrinsics
Definition: ExpandReductions.cpp:198
llvm::ARM_MB::ST
@ ST
Definition: ARMBaseInfo.h:73
INITIALIZE_PASS_DEPENDENCY
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
llvm::PreservedAnalyses::preserve
void preserve()
Mark an analysis as preserved.
Definition: PassManager.h:173
llvm::PHINode::addIncoming
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
Definition: Instructions.h:2814
llvm::addStringMetadataToLoop
void addStringMetadataToLoop(Loop *TheLoop, const char *MDString, unsigned V=0)
Set input string into loop metadata by keeping other values intact.
Definition: LoopUtils.cpp:217
llvm::TargetTransformInfo::getRegisterClassForType
unsigned getRegisterClassForType(bool Vector, Type *Ty=nullptr) const
Definition: TargetTransformInfo.cpp:603
llvm::omp::AddressSpace::Shared
@ Shared
llvm::DenseMap
Definition: DenseMap.h:716
llvm::PatternMatch::m_Store
TwoOps_match< ValueOpTy, PointerOpTy, Instruction::Store > m_Store(const ValueOpTy &ValueOp, const PointerOpTy &PointerOp)
Matches StoreInst.
Definition: PatternMatch.h:1580
llvm::codeview::FrameCookieKind::Copy
@ Copy
llvm::ValueMap::erase
bool erase(const KeyT &Val)
Definition: ValueMap.h:191
llvm::LoopInfoBase::getLoopFor
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
Definition: LoopInfo.h:970
I
#define I(x, y, z)
Definition: MD5.cpp:58
llvm::cl::init
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:432
llvm::make_early_inc_range
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition: STLExtras.h:608
llvm::concatenateVectors
Value * concatenateVectors(IRBuilderBase &Builder, ArrayRef< Value * > Vecs)
Concatenate a list of vectors.
Definition: VectorUtils.cpp:990
IRBuilder.h
assert
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
llvm::SPIRV::Decoration::Stream
@ Stream
SI
StandardInstrumentations SI(Debug, VerifyEach)
FuseMatrix
static cl::opt< bool > FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden, cl::desc("Enable/disable fusing matrix instructions."))
llvm::OptimizationRemarkEmitter::emit
void emit(DiagnosticInfoOptimizationBase &OptDiag)
Output the remark via the diagnostic handler and to the optimization record file.
Definition: OptimizationRemarkEmitter.cpp:77
llvm::ValueMap::insert
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: ValueMap.h:173
llvm::operator==
bool operator==(uint64_t V1, const APInt &V2)
Definition: APInt.h:2002
llvm::TTI
TargetTransformInfo TTI
Definition: TargetTransformInfo.h:162
llvm::Type::isVoidTy
bool isVoidTy() const
Return true if this is 'void'.
Definition: Type.h:139
llvm::ArrayType::get
static ArrayType * get(Type *ElementType, uint64_t NumElements)
This static method is the primary way to construct an ArrayType.
Definition: Type.cpp:638
MatrixLayoutTy::ColumnMajor
@ ColumnMajor
LowerMatrixIntrinsics.h
llvm::CallBase::arg_end
User::op_iterator arg_end()
Return the iterator pointing to the end of the argument list.
Definition: InstrTypes.h:1322
llvm::SmallPtrSetImpl::count
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
Definition: SmallPtrSet.h:383
llvm::SmallSet::erase
bool erase(const T &V)
Definition: SmallSet.h:209
Builder
assume Assume Builder
Definition: AssumeBundleBuilder.cpp:651
llvm::PatternMatch::m_Value
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
Definition: PatternMatch.h:76
llvm::SetVector::insert
bool insert(const value_type &X)
Insert a new element into the SetVector.
Definition: SetVector.h:141
llvm::size
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
Definition: STLExtras.h:1586
llvm::Function::getIntrinsicID
Intrinsic::ID getIntrinsicID() const LLVM_READONLY
getIntrinsicID - This method returns the ID number of the specified function, or Intrinsic::not_intri...
Definition: Function.h:205
llvm::ArrayRef
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: APInt.h:32
llvm::LoopInfo
Definition: LoopInfo.h:1086
llvm::Instruction::getFastMathFlags
FastMathFlags getFastMathFlags() const
Convenience function for getting all the fast-math flags, which must be an operator which supports th...
Definition: Instruction.cpp:289
llvm::BinaryOperator
Definition: InstrTypes.h:188
llvm::OptimizationRemarkEmitter
The optimization diagnostic interface.
Definition: OptimizationRemarkEmitter.h:33
Matrix
Live Register Matrix
Definition: LiveRegMatrix.cpp:44
llvm::min
Expected< ExpressionValue > min(const ExpressionValue &Lhs, const ExpressionValue &Rhs)
Definition: FileCheck.cpp:357
llvm::MapVector::insert
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: MapVector.h:118
Mul
BinaryOperator * Mul
Definition: X86PartialReduction.cpp:70
llvm::any_of
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1612
llvm::CallBase::getParamAlign
MaybeAlign getParamAlign(unsigned ArgNo) const
Extract the alignment for a call or parameter (0=unknown).
Definition: InstrTypes.h:1742
DataLayout.h
llvm::AnalysisUsage::setPreservesCFG
void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition: Pass.cpp:263
llvm::StringRef
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:58
llvm_unreachable
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
Definition: ErrorHandling.h:143
llvm::Value::getType
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
llvm::AnalysisUsage::addPreserved
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
Definition: PassAnalysisSupport.h:98
llvm::Value::replaceAllUsesWith
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:529
getParent
static const Function * getParent(const Value *V)
Definition: BasicAliasAnalysis.cpp:868
llvm::ms_demangle::IntrinsicFunctionKind::New
@ New
DL
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Definition: AArch64SLSHardening.cpp:76
clEnumValN
#define clEnumValN(ENUMVAL, FLAGNAME, DESC)
Definition: CommandLine.h:660
llvm::MatrixBuilder
Definition: MatrixBuilder.h:33
S
add sub stmia L5 ldr r0 bl L_printf $stub Instead of a and a wouldn t it be better to do three moves *Return an aggregate type is even return S
Definition: README.txt:210
pass_name_minimal
static const char pass_name_minimal[]
Definition: LowerMatrixIntrinsics.cpp:2358
llvm::TargetTransformInfo::getNumberOfRegisters
unsigned getNumberOfRegisters(unsigned ClassID) const
Definition: TargetTransformInfo.cpp:599
llvm::AMDGPU::HSAMD::Kernel::Arg::Key::IsVolatile
constexpr char IsVolatile[]
Key for Kernel::Arg::Metadata::mIsVolatile.
Definition: AMDGPUMetadata.h:199
llvm::SmallSet::insert
std::pair< NoneType, bool > insert(const T &V)
insert - Insert an element into the set if it isn't already there.
Definition: SmallSet.h:182
llvm::ifs::IFSSymbolType::Func
@ Func
llvm::Value::getName
StringRef getName() const
Return a constant reference to the value's name.
Definition: Value.cpp:305
llvm::ValueMap
See the file comment.
Definition: ValueMap.h:85
llvm::LoadInst
An instruction for reading from memory.
Definition: Instructions.h:176
users
iv users
Definition: IVUsers.cpp:48
llvm::MapVector::end
iterator end()
Definition: MapVector.h:72
llvm::ConstantInt::getZExtValue
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
Definition: Constants.h:142
llvm::RecurKind::FMulAdd
@ FMulAdd
Fused multiply-add of floats (a * b + c).
llvm::StringRef::size
constexpr LLVM_NODISCARD size_t size() const
size - Get the string size.
Definition: StringRef.h:157
runOnFunction
static bool runOnFunction(Function &F, bool PostInlining)
Definition: EntryExitInstrumenter.cpp:69
llvm::Twine
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
Definition: Twine.h:83
Alignment.h
llvm::commonAlignment
Align commonAlignment(Align A, Align B)
Returns the alignment that satisfies both alignments.
Definition: Alignment.h:211
llvm::OptimizationRemarkEmitterWrapperPass
OptimizationRemarkEmitter legacy analysis pass.
Definition: OptimizationRemarkEmitter.h:146
MatrixLayout
static cl::opt< MatrixLayoutTy > MatrixLayout("matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor), cl::desc("Sets the default matrix layout"), cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major", "Use column-major layout"), clEnumValN(MatrixLayoutTy::RowMajor, "row-major", "Use row-major layout")))
llvm::GraphProgram::Name
Name
Definition: GraphWriter.h:50
AllowContractEnabled
static cl::opt< bool > AllowContractEnabled("matrix-allow-contract", cl::init(false), cl::Hidden, cl::desc("Allow the use of FMAs if available and profitable. This may " "result in different results, due to less rounding error."))
llvm::StringRef::drop_front
LLVM_NODISCARD StringRef drop_front(size_t N=1) const
Return a StringRef equal to 'this' but with the first N elements dropped.
Definition: StringRef.h:653
llvm::DIScope
Base class for scope-like contexts.
Definition: DebugInfoMetadata.h:471
llvm::AMDGPU::SendMsg::Op
Op
Definition: SIDefines.h:338
llvm::PreservedAnalyses::all
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: PassManager.h:158
MatrixLayoutTy::RowMajor
@ RowMajor
llvm::TypeSize
Definition: TypeSize.h:421
Function.h
llvm::Value::hasNUses
bool hasNUses(unsigned N) const
Return true if this Value has exactly N uses.
Definition: Value.cpp:145
llvm::sort
void sort(IteratorTy Start, IteratorTy End)
Definition: STLExtras.h:1550
llvm::SetVector< T, SmallVector< T, N >, SmallDenseSet< T, N > >::count
size_type count(const key_type &key) const
Count the number of elements of a given key in the SetVector.
Definition: SetVector.h:215
llvm::Type::getPointerTo
PointerType * getPointerTo(unsigned AddrSpace=0) const
Return a pointer to the current type.
Definition: Type.cpp:774
llvm::ValueMap::find
iterator find(const KeyT &Val)
Definition: ValueMap.h:156
llvm::MCID::Add
@ Add
Definition: MCInstrDesc.h:185
llvm::ReversePostOrderTraversal
Definition: PostOrderIterator.h:291
llvm::IntrinsicInst
A wrapper class for inspecting calls to intrinsic functions.
Definition: IntrinsicInst.h:46
AA
llvm::DominatorTreeAnalysis
Analysis pass which computes a DominatorTree.
Definition: Dominators.h:267
llvm::BasicBlock::reverse_iterator
InstListType::reverse_iterator reverse_iterator
Definition: BasicBlock.h:89
llvm::OptimizationRemark
Diagnostic information for applied optimization remarks.
Definition: DiagnosticInfo.h:690
llvm::Pass
Pass interface - Implemented by all 'passes'.
Definition: Pass.h:91
Instructions.h
PostOrderIterator.h
llvm::LoadInst::isVolatile
bool isVolatile() const
Return true if this is a load from a volatile memory location.
Definition: Instructions.h:213
llvm::initializeLowerMatrixIntrinsicsLegacyPassPass
void initializeLowerMatrixIntrinsicsLegacyPassPass(PassRegistry &)
SmallVector.h
BT
BitTracker BT
Definition: BitTracker.cpp:73
llvm::CallBase::getArgOperand
Value * getArgOperand(unsigned i) const
Definition: InstrTypes.h:1341
llvm::SmallPtrSetImplBase::empty
LLVM_NODISCARD bool empty() const
Definition: SmallPtrSet.h:92
N
#define N
llvm::AAResultsWrapperPass
A wrapper pass to provide the legacy pass manager access to a suitably prepared AAResults object.
Definition: AliasAnalysis.h:1347
llvm::Instruction::getParent
const BasicBlock * getParent() const
Definition: Instruction.h:91
llvm::to_string
std::string to_string(const T &Value)
Definition: ScopedPrinter.h:86
TargetTransformInfo.h
llvm::iterator_range
A range adaptor for a pair of iterators.
Definition: iterator_range.h:30
llvm::PHINode
Definition: Instructions.h:2664
llvm::BasicBlock::getTerminator
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.h:119
DEBUG_TYPE
#define DEBUG_TYPE
Definition: LowerMatrixIntrinsics.cpp:52
minimal
lower matrix intrinsics minimal
Definition: LowerMatrixIntrinsics.cpp:2364
llvm::DISubprogram
Subprogram description.
Definition: DebugInfoMetadata.h:1797
llvm::SmallSet::empty
LLVM_NODISCARD bool empty() const
Definition: SmallSet.h:157
llvm::SmallVectorImpl< Instruction * >
INITIALIZE_PASS_BEGIN
INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, false, false) INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass
llvm::SmallPtrSetImpl< Instruction * >
llvm::SmallSetVector
A SetVector that performs no allocations if smaller than a certain size.
Definition: SetVector.h:307
llvm::AnalysisManager
A container for analyses that lazily runs them and caches their results.
Definition: InstructionSimplify.h:42
llvm::FunctionPass
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:308
llvm::CallInst
This class represents a function call, abstracting a target machine's calling convention.
Definition: Instructions.h:1474
BB
Common register allocation spilling lr str ldr sxth r3 ldr mla r4 can lr mov lr str ldr sxth r3 mla r4 and then merge mul and lr str ldr sxth r3 mla r4 It also increase the likelihood the store may become dead bb27 Successors according to LLVM BB
Definition: README.txt:39
GEP
Hexagon Common GEP
Definition: HexagonCommonGEP.cpp:172
llvm::AnalysisUsage::addRequired
AnalysisUsage & addRequired()
Definition: PassAnalysisSupport.h:75
TileUseLoops
static cl::opt< bool > TileUseLoops("fuse-matrix-use-loops", cl::init(false), cl::Hidden, cl::desc("Generate loop nest for tiling."))
llvm::DebugLoc
A debug info location.
Definition: DebugLoc.h:33
llvm::AllocaInst
an instruction to allocate memory on the stack
Definition: Instructions.h:58
llvm::FastMathFlags::allowContract
bool allowContract() const
Definition: FMF.h:71
llvm::User::getOperand
Value * getOperand(unsigned i) const
Definition: User.h:169
llvm::cl::desc
Definition: CommandLine.h:405
llvm::createLowerMatrixIntrinsicsMinimalPass
Pass * createLowerMatrixIntrinsicsMinimalPass()
Definition: LowerMatrixIntrinsics.cpp:2367
llvm::createSequentialMask
llvm::SmallVector< int, 16 > createSequentialMask(unsigned Start, unsigned NumInts, unsigned NumUndefs)
Create a sequential shuffle mask.
Definition: VectorUtils.cpp:935
llvm::SetVector< Value * >
llvm::ConstantAggregateZero::get
static ConstantAggregateZero * get(Type *Ty)
Definition: Constants.cpp:1648
BasicBlockUtils.h
llvm::SplitBlock
BasicBlock * SplitBlock(BasicBlock *Old, Instruction *SplitPt, DominatorTree *DT, LoopInfo *LI=nullptr, MemorySSAUpdater *MSSAU=nullptr, const Twine &BBName="", bool Before=false)
Split the specified block at the specified instruction.
Definition: BasicBlockUtils.cpp:837
llvm::pdb::PDB_SymType::Block
@ Block
InitializePasses.h
llvm::OptimizationRemarkEmitterAnalysis
Definition: OptimizationRemarkEmitter.h:164
llvm::Value
LLVM Value Representation.
Definition: Value.h:74
Debug.h
llvm::Value::users
iterator_range< user_iterator > users()
Definition: Value.h:421
llvm::MemoryLocation
Representation for a specific memory location.
Definition: MemoryLocation.h:210
llvm::LoopAnalysis
Analysis pass that exposes the LoopInfo for a function.
Definition: LoopInfo.h:1246
llvm::DominatorTreeBase::Delete
static constexpr UpdateKind Delete
Definition: GenericDomTree.h:243
llvm::Use
A Use represents the edge between a Value definition and its users.
Definition: Use.h:43
llvm::Type::getPrimitiveSizeInBits
TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
Definition: Type.cpp:164
llvm::SmallPtrSetImpl::insert
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:365
llvm::Intrinsic::ID
unsigned ID
Definition: TargetTransformInfo.h:37