LLVM  13.0.0git
MatrixBuilder.h
Go to the documentation of this file.
1 //===- llvm/MatrixBuilder.h - Builder to lower matrix ops -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the MatrixBuilder class, which is used as a convenient way
10 // to lower matrix operations to LLVM IR.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef LLVM_IR_MATRIXBUILDER_H
15 #define LLVM_IR_MATRIXBUILDER_H
16 
17 #include "llvm/IR/Constant.h"
18 #include "llvm/IR/Constants.h"
19 #include "llvm/IR/IRBuilder.h"
20 #include "llvm/IR/InstrTypes.h"
21 #include "llvm/IR/Instruction.h"
22 #include "llvm/IR/IntrinsicInst.h"
23 #include "llvm/IR/Type.h"
24 #include "llvm/IR/Value.h"
25 #include "llvm/Support/Alignment.h"
26 
27 namespace llvm {
28 
29 class Function;
30 class Twine;
31 class Module;
32 
33 template <class IRBuilderTy> class MatrixBuilder {
34  IRBuilderTy &B;
35  Module *getModule() { return B.GetInsertBlock()->getParent()->getParent(); }
36 
37  std::pair<Value *, Value *> splatScalarOperandIfNeeded(Value *LHS,
38  Value *RHS) {
39  assert((LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()) &&
40  "One of the operands must be a matrix (embedded in a vector)");
41  if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
42  assert(!isa<ScalableVectorType>(LHS->getType()) &&
43  "LHS Assumed to be fixed width");
44  RHS = B.CreateVectorSplat(
45  cast<VectorType>(LHS->getType())->getElementCount(), RHS,
46  "scalar.splat");
47  } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
48  assert(!isa<ScalableVectorType>(RHS->getType()) &&
49  "RHS Assumed to be fixed width");
50  LHS = B.CreateVectorSplat(
51  cast<VectorType>(RHS->getType())->getElementCount(), LHS,
52  "scalar.splat");
53  }
54  return {LHS, RHS};
55  }
56 
57 public:
58  MatrixBuilder(IRBuilderTy &Builder) : B(Builder) {}
59 
60  /// Create a column major, strided matrix load.
61  /// \p DataPtr - Start address of the matrix read
62  /// \p Rows - Number of rows in matrix (must be a constant)
63  /// \p Columns - Number of columns in matrix (must be a constant)
64  /// \p Stride - Space between columns
65  CallInst *CreateColumnMajorLoad(Value *DataPtr, Align Alignment,
66  Value *Stride, bool IsVolatile, unsigned Rows,
67  unsigned Columns, const Twine &Name = "") {
68 
69  // Deal with the pointer
70  PointerType *PtrTy = cast<PointerType>(DataPtr->getType());
71  Type *EltTy = PtrTy->getElementType();
72 
73  auto *RetType = FixedVectorType::get(EltTy, Rows * Columns);
74 
75  Value *Ops[] = {DataPtr, Stride, B.getInt1(IsVolatile), B.getInt32(Rows),
76  B.getInt32(Columns)};
77  Type *OverloadedTypes[] = {RetType};
78 
80  getModule(), Intrinsic::matrix_column_major_load, OverloadedTypes);
81 
82  CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
83  Attribute AlignAttr =
84  Attribute::getWithAlignment(Call->getContext(), Alignment);
85  Call->addAttribute(1, AlignAttr);
86  return Call;
87  }
88 
89  /// Create a column major, strided matrix store.
90  /// \p Matrix - Matrix to store
91  /// \p Ptr - Pointer to write back to
92  /// \p Stride - Space between columns
94  Value *Stride, bool IsVolatile,
95  unsigned Rows, unsigned Columns,
96  const Twine &Name = "") {
97  Value *Ops[] = {Matrix, Ptr,
98  Stride, B.getInt1(IsVolatile),
99  B.getInt32(Rows), B.getInt32(Columns)};
100  Type *OverloadedTypes[] = {Matrix->getType()};
101 
103  getModule(), Intrinsic::matrix_column_major_store, OverloadedTypes);
104 
105  CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
106  Attribute AlignAttr =
107  Attribute::getWithAlignment(Call->getContext(), Alignment);
108  Call->addAttribute(2, AlignAttr);
109  return Call;
110  }
111 
112  /// Create a llvm.matrix.transpose call, transposing \p Matrix with \p Rows
113  /// rows and \p Columns columns.
115  unsigned Columns, const Twine &Name = "") {
116  auto *OpType = cast<VectorType>(Matrix->getType());
117  auto *ReturnType =
118  FixedVectorType::get(OpType->getElementType(), Rows * Columns);
119 
120  Type *OverloadedTypes[] = {ReturnType};
121  Value *Ops[] = {Matrix, B.getInt32(Rows), B.getInt32(Columns)};
123  getModule(), Intrinsic::matrix_transpose, OverloadedTypes);
124 
125  return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
126  }
127 
128  /// Create a llvm.matrix.multiply call, multiplying matrixes \p LHS and \p
129  /// RHS.
130  CallInst *CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows,
131  unsigned LHSColumns, unsigned RHSColumns,
132  const Twine &Name = "") {
133  auto *LHSType = cast<VectorType>(LHS->getType());
134  auto *RHSType = cast<VectorType>(RHS->getType());
135 
136  auto *ReturnType =
137  FixedVectorType::get(LHSType->getElementType(), LHSRows * RHSColumns);
138 
139  Value *Ops[] = {LHS, RHS, B.getInt32(LHSRows), B.getInt32(LHSColumns),
140  B.getInt32(RHSColumns)};
141  Type *OverloadedTypes[] = {ReturnType, LHSType, RHSType};
142 
144  getModule(), Intrinsic::matrix_multiply, OverloadedTypes);
145  return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
146  }
147 
148  /// Insert a single element \p NewVal into \p Matrix at indices (\p RowIdx, \p
149  /// ColumnIdx).
151  Value *ColumnIdx, unsigned NumRows) {
152  return B.CreateInsertElement(
153  Matrix, NewVal,
154  B.CreateAdd(B.CreateMul(ColumnIdx, ConstantInt::get(
155  ColumnIdx->getType(), NumRows)),
156  RowIdx));
157  }
158 
159  /// Add matrixes \p LHS and \p RHS. Support both integer and floating point
160  /// matrixes.
161  Value *CreateAdd(Value *LHS, Value *RHS) {
162  assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy());
163  if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
164  assert(!isa<ScalableVectorType>(LHS->getType()) &&
165  "LHS Assumed to be fixed width");
166  RHS = B.CreateVectorSplat(
167  cast<VectorType>(LHS->getType())->getElementCount(), RHS,
168  "scalar.splat");
169  } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
170  assert(!isa<ScalableVectorType>(RHS->getType()) &&
171  "RHS Assumed to be fixed width");
172  LHS = B.CreateVectorSplat(
173  cast<VectorType>(RHS->getType())->getElementCount(), LHS,
174  "scalar.splat");
175  }
176 
177  return cast<VectorType>(LHS->getType())
178  ->getElementType()
179  ->isFloatingPointTy()
180  ? B.CreateFAdd(LHS, RHS)
181  : B.CreateAdd(LHS, RHS);
182  }
183 
184  /// Subtract matrixes \p LHS and \p RHS. Support both integer and floating
185  /// point matrixes.
186  Value *CreateSub(Value *LHS, Value *RHS) {
187  assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy());
188  if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
189  assert(!isa<ScalableVectorType>(LHS->getType()) &&
190  "LHS Assumed to be fixed width");
191  RHS = B.CreateVectorSplat(
192  cast<VectorType>(LHS->getType())->getElementCount(), RHS,
193  "scalar.splat");
194  } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
195  assert(!isa<ScalableVectorType>(RHS->getType()) &&
196  "RHS Assumed to be fixed width");
197  LHS = B.CreateVectorSplat(
198  cast<VectorType>(RHS->getType())->getElementCount(), LHS,
199  "scalar.splat");
200  }
201 
202  return cast<VectorType>(LHS->getType())
203  ->getElementType()
204  ->isFloatingPointTy()
205  ? B.CreateFSub(LHS, RHS)
206  : B.CreateSub(LHS, RHS);
207  }
208 
209  /// Multiply matrix \p LHS with scalar \p RHS or scalar \p LHS with matrix \p
210  /// RHS.
212  std::tie(LHS, RHS) = splatScalarOperandIfNeeded(LHS, RHS);
213  if (LHS->getType()->getScalarType()->isFloatingPointTy())
214  return B.CreateFMul(LHS, RHS);
215  return B.CreateMul(LHS, RHS);
216  }
217 
218  /// Divide matrix \p LHS by scalar \p RHS. If the operands are integers, \p
219  /// IsUnsigned indicates whether UDiv or SDiv should be used.
220  Value *CreateScalarDiv(Value *LHS, Value *RHS, bool IsUnsigned) {
221  assert(LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy());
222  assert(!isa<ScalableVectorType>(LHS->getType()) &&
223  "LHS Assumed to be fixed width");
224  RHS =
225  B.CreateVectorSplat(cast<VectorType>(LHS->getType())->getElementCount(),
226  RHS, "scalar.splat");
227  return cast<VectorType>(LHS->getType())
228  ->getElementType()
229  ->isFloatingPointTy()
230  ? B.CreateFDiv(LHS, RHS)
231  : (IsUnsigned ? B.CreateUDiv(LHS, RHS) : B.CreateSDiv(LHS, RHS));
232  }
233 
234  /// Extracts the element at (\p RowIdx, \p ColumnIdx) from \p Matrix.
236  unsigned NumRows, Twine const &Name = "") {
237 
238  unsigned MaxWidth = std::max(RowIdx->getType()->getScalarSizeInBits(),
239  ColumnIdx->getType()->getScalarSizeInBits());
240  Type *IntTy = IntegerType::get(RowIdx->getType()->getContext(), MaxWidth);
241  RowIdx = B.CreateZExt(RowIdx, IntTy);
242  ColumnIdx = B.CreateZExt(ColumnIdx, IntTy);
243  Value *NumRowsV = B.getIntN(MaxWidth, NumRows);
244  return B.CreateExtractElement(
245  Matrix, B.CreateAdd(B.CreateMul(ColumnIdx, NumRowsV), RowIdx),
246  "matext");
247  }
248 };
249 
250 } // end namespace llvm
251 
252 #endif // LLVM_IR_MATRIXBUILDER_H
llvm::MatrixBuilder::CreateMatrixInsert
Value * CreateMatrixInsert(Value *Matrix, Value *NewVal, Value *RowIdx, Value *ColumnIdx, unsigned NumRows)
Insert a single element NewVal into Matrix at indices (RowIdx, ColumnIdx).
Definition: MatrixBuilder.h:150
llvm
Definition: AllocatorList.h:23
llvm::Intrinsic::getDeclaration
Function * getDeclaration(Module *M, ID id, ArrayRef< Type * > Tys=None)
Create or insert an LLVM Function declaration for an intrinsic, and return it.
Definition: Function.cpp:1295
IntrinsicInst.h
llvm::MatrixBuilder::CreateScalarDiv
Value * CreateScalarDiv(Value *LHS, Value *RHS, bool IsUnsigned)
Divide matrix LHS by scalar RHS.
Definition: MatrixBuilder.h:220
llvm::PointerType::getElementType
Type * getElementType() const
Definition: DerivedTypes.h:653
llvm::Function
Definition: Function.h:61
llvm::Attribute
Definition: Attributes.h:52
llvm::Type::getScalarType
Type * getScalarType() const
If this is a vector type, return the element type, otherwise return 'this'.
Definition: Type.h:317
llvm::Type
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:46
llvm::MatrixBuilder::CreateScalarMultiply
Value * CreateScalarMultiply(Value *LHS, Value *RHS)
Multiply matrix LHS with scalar RHS or scalar LHS with matrix RHS.
Definition: MatrixBuilder.h:211
llvm::Type::isFloatingPointTy
bool isFloatingPointTy() const
Return true if this is one of the six floating-point types.
Definition: Type.h:163
llvm::MatrixBuilder::CreateColumnMajorStore
CallInst * CreateColumnMajorStore(Value *Matrix, Value *Ptr, Align Alignment, Value *Stride, bool IsVolatile, unsigned Rows, unsigned Columns, const Twine &Name="")
Create a column major, strided matrix store.
Definition: MatrixBuilder.h:93
llvm::MatrixBuilder::CreateSub
Value * CreateSub(Value *LHS, Value *RHS)
Subtract matrixes LHS and RHS.
Definition: MatrixBuilder.h:186
Instruction.h
Constants.h
InstrTypes.h
llvm::Type::isVectorTy
bool isVectorTy() const
True if this is an instance of VectorType.
Definition: Type.h:235
llvm::Type::getScalarSizeInBits
unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
Definition: Type.cpp:154
llvm::ConstantInt::get
static Constant * get(Type *Ty, uint64_t V, bool IsSigned=false)
If Ty is a vector type, return a Constant with a splat of the given value.
Definition: Constants.cpp:885
llvm::FixedVectorType::get
static FixedVectorType * get(Type *ElementType, unsigned NumElts)
Definition: Type.cpp:650
llvm::Align
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition: Alignment.h:39
Type.h
llvm::MatrixBuilder::MatrixBuilder
MatrixBuilder(IRBuilderTy &Builder)
Definition: MatrixBuilder.h:58
llvm::PointerType
Class to represent pointers.
Definition: DerivedTypes.h:634
IRBuilder.h
assert
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
B
compiles ldr LCPI1_0 ldr ldr mov lsr tst moveq r1 ldr LCPI1_1 and r0 bx lr It would be better to do something like to fold the shift into the conditional ldr LCPI1_0 ldr ldr tst movne lsr ldr LCPI1_1 and r0 bx lr it saves an instruction and a register It might be profitable to cse MOVi16 if there are lots of bit immediates with the same bottom half Robert Muth started working on an alternate jump table implementation that does not put the tables in line in the text This is more like the llvm default jump table implementation This might be useful sometime Several revisions of patches are on the mailing beginning while CMP sets them like a subtract Therefore to be able to use CMN for comparisons other than the Z we ll need additional logic to reverse the conditionals associated with the comparison Perhaps a pseudo instruction for the with a post codegen pass to clean up and handle the condition codes See PR5694 for testcase Given the following on int B
Definition: README.txt:592
llvm::Module
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:67
llvm::ARM::WinEH::ReturnType
ReturnType
Definition: ARMWinEH.h:25
Builder
assume Assume Builder
Definition: AssumeBundleBuilder.cpp:649
llvm::MatrixBuilder::CreateMatrixTranspose
CallInst * CreateMatrixTranspose(Value *Matrix, unsigned Rows, unsigned Columns, const Twine &Name="")
Create a llvm.matrix.transpose call, transposing Matrix with Rows rows and Columns columns.
Definition: MatrixBuilder.h:114
Matrix
Live Register Matrix
Definition: LiveRegMatrix.cpp:44
llvm::Value::getType
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:256
llvm::MatrixBuilder
Definition: MatrixBuilder.h:33
Module
Machine Check Debug Module
Definition: MachineCheckDebugify.cpp:122
llvm::AMDGPU::HSAMD::Kernel::Arg::Key::IsVolatile
constexpr char IsVolatile[]
Key for Kernel::Arg::Metadata::mIsVolatile.
Definition: AMDGPUMetadata.h:194
llvm::Type::getContext
LLVMContext & getContext() const
Return the LLVMContext in which this type was uniqued.
Definition: Type.h:128
llvm::Attribute::getWithAlignment
static Attribute getWithAlignment(LLVMContext &Context, Align Alignment)
Return a uniquified Attribute object that has the specific alignment set.
Definition: Attributes.cpp:160
Constant.h
llvm::Twine
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
Definition: Twine.h:80
Alignment.h
llvm::GraphProgram::Name
Name
Definition: GraphWriter.h:52
llvm::MatrixBuilder::CreateAdd
Value * CreateAdd(Value *LHS, Value *RHS)
Add matrixes LHS and RHS.
Definition: MatrixBuilder.h:161
llvm::Function::getFunctionType
FunctionType * getFunctionType() const
Returns the FunctionType for me.
Definition: Function.h:165
llvm::MatrixBuilder::CreateColumnMajorLoad
CallInst * CreateColumnMajorLoad(Value *DataPtr, Align Alignment, Value *Stride, bool IsVolatile, unsigned Rows, unsigned Columns, const Twine &Name="")
Create a column major, strided matrix load.
Definition: MatrixBuilder.h:65
llvm::MatrixBuilder::CreateExtractElement
Value * CreateExtractElement(Value *Matrix, Value *RowIdx, Value *ColumnIdx, unsigned NumRows, Twine const &Name="")
Extracts the element at (RowIdx, ColumnIdx) from Matrix.
Definition: MatrixBuilder.h:235
llvm::max
Align max(MaybeAlign Lhs, Align Rhs)
Definition: Alignment.h:350
llvm::MatrixBuilder::CreateMatrixMultiply
CallInst * CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows, unsigned LHSColumns, unsigned RHSColumns, const Twine &Name="")
Create a llvm.matrix.multiply call, multiplying matrixes LHS and RHS.
Definition: MatrixBuilder.h:130
llvm::IntegerType::get
static IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition: Type.cpp:276
llvm::CallInst
This class represents a function call, abstracting a target machine's calling convention.
Definition: Instructions.h:1450
Value.h
llvm::Value
LLVM Value Representation.
Definition: Value.h:75
llvm::codeview::PublicSymFlags::Function
@ Function