LLVM  16.0.0git
MatrixBuilder.h
Go to the documentation of this file.
1 //===- llvm/MatrixBuilder.h - Builder to lower matrix ops -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the MatrixBuilder class, which is used as a convenient way
10 // to lower matrix operations to LLVM IR.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef LLVM_IR_MATRIXBUILDER_H
15 #define LLVM_IR_MATRIXBUILDER_H
16 
17 #include "llvm/IR/Constant.h"
18 #include "llvm/IR/Constants.h"
19 #include "llvm/IR/IRBuilder.h"
20 #include "llvm/IR/InstrTypes.h"
21 #include "llvm/IR/Instruction.h"
22 #include "llvm/IR/IntrinsicInst.h"
23 #include "llvm/IR/Type.h"
24 #include "llvm/IR/Value.h"
25 #include "llvm/Support/Alignment.h"
26 
27 namespace llvm {
28 
29 class Function;
30 class Twine;
31 class Module;
32 
34  IRBuilderBase &B;
35  Module *getModule() { return B.GetInsertBlock()->getParent()->getParent(); }
36 
37  std::pair<Value *, Value *> splatScalarOperandIfNeeded(Value *LHS,
38  Value *RHS) {
39  assert((LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()) &&
40  "One of the operands must be a matrix (embedded in a vector)");
41  if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
42  assert(!isa<ScalableVectorType>(LHS->getType()) &&
43  "LHS Assumed to be fixed width");
44  RHS = B.CreateVectorSplat(
45  cast<VectorType>(LHS->getType())->getElementCount(), RHS,
46  "scalar.splat");
47  } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
48  assert(!isa<ScalableVectorType>(RHS->getType()) &&
49  "RHS Assumed to be fixed width");
50  LHS = B.CreateVectorSplat(
51  cast<VectorType>(RHS->getType())->getElementCount(), LHS,
52  "scalar.splat");
53  }
54  return {LHS, RHS};
55  }
56 
57 public:
59 
60  /// Create a column major, strided matrix load.
61  /// \p EltTy - Matrix element type
62  /// \p DataPtr - Start address of the matrix read
63  /// \p Rows - Number of rows in matrix (must be a constant)
64  /// \p Columns - Number of columns in matrix (must be a constant)
65  /// \p Stride - Space between columns
66  CallInst *CreateColumnMajorLoad(Type *EltTy, Value *DataPtr, Align Alignment,
67  Value *Stride, bool IsVolatile, unsigned Rows,
68  unsigned Columns, const Twine &Name = "") {
69  auto *RetType = FixedVectorType::get(EltTy, Rows * Columns);
70 
71  Value *Ops[] = {DataPtr, Stride, B.getInt1(IsVolatile), B.getInt32(Rows),
72  B.getInt32(Columns)};
73  Type *OverloadedTypes[] = {RetType, Stride->getType()};
74 
76  getModule(), Intrinsic::matrix_column_major_load, OverloadedTypes);
77 
78  CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
79  Attribute AlignAttr =
80  Attribute::getWithAlignment(Call->getContext(), Alignment);
81  Call->addParamAttr(0, AlignAttr);
82  return Call;
83  }
84 
85  /// Create a column major, strided matrix store.
86  /// \p Matrix - Matrix to store
87  /// \p Ptr - Pointer to write back to
88  /// \p Stride - Space between columns
90  Value *Stride, bool IsVolatile,
91  unsigned Rows, unsigned Columns,
92  const Twine &Name = "") {
93  Value *Ops[] = {Matrix, Ptr,
94  Stride, B.getInt1(IsVolatile),
95  B.getInt32(Rows), B.getInt32(Columns)};
96  Type *OverloadedTypes[] = {Matrix->getType(), Stride->getType()};
97 
99  getModule(), Intrinsic::matrix_column_major_store, OverloadedTypes);
100 
101  CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
102  Attribute AlignAttr =
103  Attribute::getWithAlignment(Call->getContext(), Alignment);
104  Call->addParamAttr(1, AlignAttr);
105  return Call;
106  }
107 
108  /// Create a llvm.matrix.transpose call, transposing \p Matrix with \p Rows
109  /// rows and \p Columns columns.
110  CallInst *CreateMatrixTranspose(Value *Matrix, unsigned Rows,
111  unsigned Columns, const Twine &Name = "") {
112  auto *OpType = cast<VectorType>(Matrix->getType());
113  auto *ReturnType =
114  FixedVectorType::get(OpType->getElementType(), Rows * Columns);
115 
116  Type *OverloadedTypes[] = {ReturnType};
117  Value *Ops[] = {Matrix, B.getInt32(Rows), B.getInt32(Columns)};
119  getModule(), Intrinsic::matrix_transpose, OverloadedTypes);
120 
121  return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
122  }
123 
124  /// Create a llvm.matrix.multiply call, multiplying matrixes \p LHS and \p
125  /// RHS.
127  unsigned LHSColumns, unsigned RHSColumns,
128  const Twine &Name = "") {
129  auto *LHSType = cast<VectorType>(LHS->getType());
130  auto *RHSType = cast<VectorType>(RHS->getType());
131 
132  auto *ReturnType =
133  FixedVectorType::get(LHSType->getElementType(), LHSRows * RHSColumns);
134 
135  Value *Ops[] = {LHS, RHS, B.getInt32(LHSRows), B.getInt32(LHSColumns),
136  B.getInt32(RHSColumns)};
137  Type *OverloadedTypes[] = {ReturnType, LHSType, RHSType};
138 
140  getModule(), Intrinsic::matrix_multiply, OverloadedTypes);
141  return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
142  }
143 
144  /// Insert a single element \p NewVal into \p Matrix at indices (\p RowIdx, \p
145  /// ColumnIdx).
146  Value *CreateMatrixInsert(Value *Matrix, Value *NewVal, Value *RowIdx,
147  Value *ColumnIdx, unsigned NumRows) {
148  return B.CreateInsertElement(
149  Matrix, NewVal,
150  B.CreateAdd(B.CreateMul(ColumnIdx, ConstantInt::get(
151  ColumnIdx->getType(), NumRows)),
152  RowIdx));
153  }
154 
155  /// Add matrixes \p LHS and \p RHS. Support both integer and floating point
156  /// matrixes.
158  assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy());
159  if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
160  assert(!isa<ScalableVectorType>(LHS->getType()) &&
161  "LHS Assumed to be fixed width");
162  RHS = B.CreateVectorSplat(
163  cast<VectorType>(LHS->getType())->getElementCount(), RHS,
164  "scalar.splat");
165  } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
166  assert(!isa<ScalableVectorType>(RHS->getType()) &&
167  "RHS Assumed to be fixed width");
168  LHS = B.CreateVectorSplat(
169  cast<VectorType>(RHS->getType())->getElementCount(), LHS,
170  "scalar.splat");
171  }
172 
173  return cast<VectorType>(LHS->getType())
174  ->getElementType()
175  ->isFloatingPointTy()
176  ? B.CreateFAdd(LHS, RHS)
177  : B.CreateAdd(LHS, RHS);
178  }
179 
180  /// Subtract matrixes \p LHS and \p RHS. Support both integer and floating
181  /// point matrixes.
183  assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy());
184  if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
185  assert(!isa<ScalableVectorType>(LHS->getType()) &&
186  "LHS Assumed to be fixed width");
187  RHS = B.CreateVectorSplat(
188  cast<VectorType>(LHS->getType())->getElementCount(), RHS,
189  "scalar.splat");
190  } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
191  assert(!isa<ScalableVectorType>(RHS->getType()) &&
192  "RHS Assumed to be fixed width");
193  LHS = B.CreateVectorSplat(
194  cast<VectorType>(RHS->getType())->getElementCount(), LHS,
195  "scalar.splat");
196  }
197 
198  return cast<VectorType>(LHS->getType())
199  ->getElementType()
200  ->isFloatingPointTy()
201  ? B.CreateFSub(LHS, RHS)
202  : B.CreateSub(LHS, RHS);
203  }
204 
205  /// Multiply matrix \p LHS with scalar \p RHS or scalar \p LHS with matrix \p
206  /// RHS.
208  std::tie(LHS, RHS) = splatScalarOperandIfNeeded(LHS, RHS);
209  if (LHS->getType()->getScalarType()->isFloatingPointTy())
210  return B.CreateFMul(LHS, RHS);
211  return B.CreateMul(LHS, RHS);
212  }
213 
214  /// Divide matrix \p LHS by scalar \p RHS. If the operands are integers, \p
215  /// IsUnsigned indicates whether UDiv or SDiv should be used.
216  Value *CreateScalarDiv(Value *LHS, Value *RHS, bool IsUnsigned) {
217  assert(LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy());
218  assert(!isa<ScalableVectorType>(LHS->getType()) &&
219  "LHS Assumed to be fixed width");
220  RHS =
221  B.CreateVectorSplat(cast<VectorType>(LHS->getType())->getElementCount(),
222  RHS, "scalar.splat");
223  return cast<VectorType>(LHS->getType())
224  ->getElementType()
225  ->isFloatingPointTy()
226  ? B.CreateFDiv(LHS, RHS)
227  : (IsUnsigned ? B.CreateUDiv(LHS, RHS) : B.CreateSDiv(LHS, RHS));
228  }
229 
230  /// Create an assumption that \p Idx is less than \p NumElements.
231  void CreateIndexAssumption(Value *Idx, unsigned NumElements,
232  Twine const &Name = "") {
233  Value *NumElts =
234  B.getIntN(Idx->getType()->getScalarSizeInBits(), NumElements);
235  auto *Cmp = B.CreateICmpULT(Idx, NumElts);
236  if (isa<ConstantInt>(Cmp))
237  assert(cast<ConstantInt>(Cmp)->isOne() && "Index must be valid!");
238  else
239  B.CreateAssumption(Cmp);
240  }
241 
242  /// Compute the index to access the element at (\p RowIdx, \p ColumnIdx) from
243  /// a matrix with \p NumRows embedded in a vector.
244  Value *CreateIndex(Value *RowIdx, Value *ColumnIdx, unsigned NumRows,
245  Twine const &Name = "") {
246  unsigned MaxWidth = std::max(RowIdx->getType()->getScalarSizeInBits(),
247  ColumnIdx->getType()->getScalarSizeInBits());
248  Type *IntTy = IntegerType::get(RowIdx->getType()->getContext(), MaxWidth);
249  RowIdx = B.CreateZExt(RowIdx, IntTy);
250  ColumnIdx = B.CreateZExt(ColumnIdx, IntTy);
251  Value *NumRowsV = B.getIntN(MaxWidth, NumRows);
252  return B.CreateAdd(B.CreateMul(ColumnIdx, NumRowsV), RowIdx);
253  }
254 };
255 
256 } // end namespace llvm
257 
258 #endif // LLVM_IR_MATRIXBUILDER_H
llvm
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
llvm::MatrixBuilder::CreateMatrixMultiply
CallInst * CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows, unsigned LHSColumns, unsigned RHSColumns, const Twine &Name="")
Create a llvm.matrix.multiply call, multiplying matrixes LHS and RHS.
Definition: MatrixBuilder.h:126
IntrinsicInst.h
llvm::Function
Definition: Function.h:60
llvm::Attribute
Definition: Attributes.h:67
llvm::MatrixBuilder::CreateIndex
Value * CreateIndex(Value *RowIdx, Value *ColumnIdx, unsigned NumRows, Twine const &Name="")
Compute the index to access the element at (RowIdx, ColumnIdx) from a matrix with NumRows embedded in...
Definition: MatrixBuilder.h:244
llvm::Type
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
llvm::max
Expected< ExpressionValue > max(const ExpressionValue &Lhs, const ExpressionValue &Rhs)
Definition: FileCheck.cpp:337
RHS
Value * RHS
Definition: X86PartialReduction.cpp:76
llvm::MatrixBuilder::CreateColumnMajorStore
CallInst * CreateColumnMajorStore(Value *Matrix, Value *Ptr, Align Alignment, Value *Stride, bool IsVolatile, unsigned Rows, unsigned Columns, const Twine &Name="")
Create a column major, strided matrix store.
Definition: MatrixBuilder.h:89
llvm::MatrixBuilder::CreateAdd
Value * CreateAdd(Value *LHS, Value *RHS)
Add matrixes LHS and RHS.
Definition: MatrixBuilder.h:157
llvm::MatrixBuilder::CreateColumnMajorLoad
CallInst * CreateColumnMajorLoad(Type *EltTy, Value *DataPtr, Align Alignment, Value *Stride, bool IsVolatile, unsigned Rows, unsigned Columns, const Twine &Name="")
Create a column major, strided matrix load.
Definition: MatrixBuilder.h:66
Instruction.h
LHS
Value * LHS
Definition: X86PartialReduction.cpp:75
llvm::MatrixBuilder::CreateScalarMultiply
Value * CreateScalarMultiply(Value *LHS, Value *RHS)
Multiply matrix LHS with scalar RHS or scalar LHS with matrix RHS.
Definition: MatrixBuilder.h:207
Constants.h
InstrTypes.h
llvm::Type::getScalarSizeInBits
unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
Definition: Type.cpp:189
llvm::ConstantInt::get
static Constant * get(Type *Ty, uint64_t V, bool IsSigned=false)
If Ty is a vector type, return a Constant with a splat of the given value.
Definition: Constants.cpp:879
llvm::FixedVectorType::get
static FixedVectorType * get(Type *ElementType, unsigned NumElts)
Definition: Type.cpp:684
llvm::Align
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition: Alignment.h:39
llvm::MatrixBuilder::CreateIndexAssumption
void CreateIndexAssumption(Value *Idx, unsigned NumElements, Twine const &Name="")
Create an assumption that Idx is less than NumElements.
Definition: MatrixBuilder.h:231
Type.h
llvm::MatrixBuilder::CreateSub
Value * CreateSub(Value *LHS, Value *RHS)
Subtract matrixes LHS and RHS.
Definition: MatrixBuilder.h:182
llvm::MatrixBuilder::MatrixBuilder
MatrixBuilder(IRBuilderBase &Builder)
Definition: MatrixBuilder.h:58
IRBuilder.h
llvm::MatrixBuilder::CreateMatrixInsert
Value * CreateMatrixInsert(Value *Matrix, Value *NewVal, Value *RowIdx, Value *ColumnIdx, unsigned NumRows)
Insert a single element NewVal into Matrix at indices (RowIdx, ColumnIdx).
Definition: MatrixBuilder.h:146
assert
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
Ptr
@ Ptr
Definition: TargetLibraryInfo.cpp:60
B
compiles ldr LCPI1_0 ldr ldr mov lsr tst moveq r1 ldr LCPI1_1 and r0 bx lr It would be better to do something like to fold the shift into the conditional ldr LCPI1_0 ldr ldr tst movne lsr ldr LCPI1_1 and r0 bx lr it saves an instruction and a register It might be profitable to cse MOVi16 if there are lots of bit immediates with the same bottom half Robert Muth started working on an alternate jump table implementation that does not put the tables in line in the text This is more like the llvm default jump table implementation This might be useful sometime Several revisions of patches are on the mailing beginning while CMP sets them like a subtract Therefore to be able to use CMN for comparisons other than the Z we ll need additional logic to reverse the conditionals associated with the comparison Perhaps a pseudo instruction for the with a post codegen pass to clean up and handle the condition codes See PR5694 for testcase Given the following on int B
Definition: README.txt:592
llvm::Module
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:65
llvm::ARM::WinEH::ReturnType
ReturnType
Definition: ARMWinEH.h:25
Builder
assume Assume Builder
Definition: AssumeBundleBuilder.cpp:651
Matrix
Live Register Matrix
Definition: LiveRegMatrix.cpp:44
llvm::Value::getType
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
llvm::IRBuilderBase
Common base class shared among various IRBuilders.
Definition: IRBuilder.h:94
llvm::MatrixBuilder
Definition: MatrixBuilder.h:33
Module
Machine Check Debug Module
Definition: MachineCheckDebugify.cpp:122
llvm::Type::getContext
LLVMContext & getContext() const
Return the LLVMContext in which this type was uniqued.
Definition: Type.h:128
llvm::Intrinsic::getDeclaration
Function * getDeclaration(Module *M, ID id, ArrayRef< Type * > Tys=std::nullopt)
Create or insert an LLVM Function declaration for an intrinsic, and return it.
Definition: Function.cpp:1481
llvm::Attribute::getWithAlignment
static Attribute getWithAlignment(LLVMContext &Context, Align Alignment)
Return a uniquified Attribute object that has the specific alignment set.
Definition: Attributes.cpp:167
Constant.h
llvm::Twine
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
Definition: Twine.h:81
Alignment.h
llvm::Function::getFunctionType
FunctionType * getFunctionType() const
Returns the FunctionType for me.
Definition: Function.h:175
llvm::MatrixBuilder::CreateMatrixTranspose
CallInst * CreateMatrixTranspose(Value *Matrix, unsigned Rows, unsigned Columns, const Twine &Name="")
Create a llvm.matrix.transpose call, transposing Matrix with Rows rows and Columns columns.
Definition: MatrixBuilder.h:110
llvm::IntegerType::get
static IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition: Type.cpp:311
llvm::CallInst
This class represents a function call, abstracting a target machine's calling convention.
Definition: Instructions.h:1474
llvm::MatrixBuilder::CreateScalarDiv
Value * CreateScalarDiv(Value *LHS, Value *RHS, bool IsUnsigned)
Divide matrix LHS by scalar RHS.
Definition: MatrixBuilder.h:216
Value.h
llvm::Value
LLVM Value Representation.
Definition: Value.h:74
llvm::codeview::PublicSymFlags::Function
@ Function