50 using namespace PatternMatch;
52 #define DEBUG_TYPE "lower-matrix-intrinsics"
56 cl::desc(
"Enable/disable fusing matrix instructions."));
61 "Tile size for matrix instruction fusion using square-shaped tiles."));
64 cl::desc(
"Generate loop nest for tiling."));
67 cl::desc(
"Force matrix instruction fusion even if not profitable."));
70 cl::desc(
"Allow the use of FMAs if available and profitable. This may "
71 "result in different results, due to less rounding error."));
77 cl::desc(
"Sets the default matrix layout"),
79 "Use column-major layout"),
81 "Use row-major layout")));
86 if (
auto *Subprogram = dyn_cast<DISubprogram>(
Scope))
134 unsigned NumElements,
Type *EltType,
137 assert((!isa<ConstantInt>(Stride) ||
138 cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) &&
139 "Stride must be >= the number of elements in the result vector.");
140 unsigned AS = cast<PointerType>(
BasePtr->getType())->getAddressSpace();
143 Value *VecStart =
Builder.CreateMul(VecIdx, Stride,
"vec.start");
147 if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->
isZero())
150 VecStart =
Builder.CreateGEP(EltType, BasePtr, VecStart,
"vec.gep");
156 return Builder.CreatePointerCast(VecStart, VecPtrType,
"vec.cast");
182 class LowerMatrixIntrinsics {
194 unsigned NumStores = 0;
196 unsigned NumLoads = 0;
198 unsigned NumComputeOps = 0;
202 unsigned NumExposedTransposes = 0;
205 NumStores +=
RHS.NumStores;
206 NumLoads +=
RHS.NumLoads;
207 NumComputeOps +=
RHS.NumComputeOps;
208 NumExposedTransposes +=
RHS.NumExposedTransposes;
220 bool IsColumnMajor =
true;
225 : Vectors(Vectors.
begin(), Vectors.
end()),
227 MatrixTy(
unsigned NumRows,
unsigned NumColumns,
Type *EltTy)
230 unsigned D = isColumnMajor() ? NumColumns : NumRows;
231 for (
unsigned J = 0; J <
D; ++J)
233 EltTy, isColumnMajor() ? NumRows : NumColumns)));
236 Value *getVector(
unsigned i)
const {
return Vectors[
i]; }
237 Value *getColumn(
unsigned i)
const {
238 assert(isColumnMajor() &&
"only supported for column-major matrixes");
241 Value *getRow(
unsigned i)
const {
242 assert(!isColumnMajor() &&
"only supported for row-major matrixes");
246 void setVector(
unsigned i,
Value *V) { Vectors[
i] = V; }
248 Type *getElementType()
const {
return getVectorTy()->getElementType(); }
250 unsigned getNumVectors()
const {
252 return getNumColumns();
256 unsigned getNumColumns()
const {
258 return Vectors.size();
260 assert(Vectors.size() > 0 &&
"Cannot call getNumRows without columns");
261 return cast<FixedVectorType>(Vectors[0]->
getType())->getNumElements();
264 unsigned getNumRows()
const {
265 if (isColumnMajor()) {
266 assert(Vectors.size() > 0 &&
"Cannot call getNumRows without columns");
267 return cast<FixedVectorType>(Vectors[0]->
getType())->getNumElements();
269 return Vectors.size();
272 void addVector(
Value *V) { Vectors.push_back(V); }
274 assert(isColumnMajor() &&
"only supported for column-major matrixes");
275 return getVectorTy();
279 return cast<VectorType>(Vectors[0]->
getType());
284 "columns() only supported for column-major matrixes");
285 return make_range(Vectors.begin(), Vectors.end());
289 return make_range(Vectors.begin(), Vectors.end());
295 return Vectors.size() == 1 ? Vectors[0]
299 MatrixTy &addNumLoads(
unsigned N) {
300 OpInfo.NumLoads +=
N;
304 void setNumLoads(
unsigned N) { OpInfo.NumLoads =
N; }
306 MatrixTy &addNumStores(
unsigned N) {
307 OpInfo.NumStores +=
N;
311 MatrixTy &addNumExposedTransposes(
unsigned N) {
312 OpInfo.NumExposedTransposes +=
N;
316 MatrixTy &addNumComputeOps(
unsigned N) {
317 OpInfo.NumComputeOps +=
N;
321 unsigned getNumStores()
const {
return OpInfo.NumStores; }
322 unsigned getNumLoads()
const {
return OpInfo.NumLoads; }
323 unsigned getNumComputeOps()
const {
return OpInfo.NumComputeOps; }
325 const OpInfoTy &getOpInfo()
const {
return OpInfo; }
327 bool isColumnMajor()
const {
return IsColumnMajor; }
329 unsigned getStride()
const {
332 return getNumColumns();
340 Value *Vec = isColumnMajor() ? getColumn(J) : getRow(
I);
341 return Builder.CreateShuffleVector(
353 ShapeInfo(
unsigned NumRows = 0,
unsigned NumColumns = 0)
354 : NumRows(NumRows), NumColumns(NumColumns),
362 return NumRows == other.NumRows && NumColumns == other.NumColumns;
364 bool operator!=(
const ShapeInfo &other) {
return !(*
this == other); }
368 operator bool()
const {
369 assert(NumRows == 0 || NumColumns != 0);
373 unsigned getStride()
const {
379 unsigned getNumVectors()
const {
408 if (isa<FPMathOperator>(*Inst))
423 unsigned getNumOps(
Type *VT) {
424 assert(isa<VectorType>(VT) &&
"Expected vector type");
426 cast<FixedVectorType>(VT)->getNumElements());
430 bool isMinimal()
const {
436 unsigned getNumOps(
Type *
ST,
unsigned N) {
437 return std::ceil((
ST->getPrimitiveSizeInBits() *
N).getFixedSize() /
448 MatrixTy getMatrix(
Value *MatrixVal,
const ShapeInfo &
SI,
451 assert(VType &&
"MatrixVal must be a vector type");
452 assert(cast<FixedVectorType>(VType)->getNumElements() ==
453 SI.NumRows *
SI.NumColumns &&
454 "The vector size must match the number of matrix elements");
460 auto Found = Inst2ColumnMatrix.
find(MatrixVal);
461 if (Found != Inst2ColumnMatrix.
end()) {
462 MatrixTy &
M = Found->second;
465 if (
SI.NumRows ==
M.getNumRows() &&
SI.NumColumns ==
M.getNumColumns())
468 MatrixVal =
M.embedInVector(
Builder);
473 for (
unsigned MaskStart = 0;
474 MaskStart < cast<FixedVectorType>(VType)->getNumElements();
475 MaskStart +=
SI.getStride()) {
479 SplitVecs.push_back(V);
487 bool setShapeInfo(
Value *V, ShapeInfo Shape) {
488 assert(Shape &&
"Shape not set");
489 if (isa<UndefValue>(V) || !supportsShapeInfo(V))
492 auto SIter = ShapeMap.
find(V);
493 if (SIter != ShapeMap.
end()) {
495 << SIter->second.NumRows <<
" "
496 << SIter->second.NumColumns <<
" for " << *V <<
"\n");
500 ShapeMap.
insert({V, Shape});
501 LLVM_DEBUG(
dbgs() <<
" " << Shape.NumRows <<
" x " << Shape.NumColumns
502 <<
" for " << *V <<
"\n");
506 bool isUniformShape(
Value *V) {
511 switch (
I->getOpcode()) {
512 case Instruction::FAdd:
513 case Instruction::FSub:
514 case Instruction::FMul:
515 case Instruction::FNeg:
518 case Instruction::Sub:
527 bool supportsShapeInfo(
Value *V) {
535 case Intrinsic::matrix_multiply:
536 case Intrinsic::matrix_transpose:
537 case Intrinsic::matrix_column_major_load:
538 case Intrinsic::matrix_column_major_store:
543 return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
557 while (!WorkList.empty()) {
561 bool Propagate =
false;
568 if (
match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>(
571 Propagate = setShapeInfo(Inst, {
M, K});
572 }
else if (
match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>(
575 Propagate = setShapeInfo(Inst, {
N,
M});
576 }
else if (
match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>(
579 Propagate = setShapeInfo(Inst, {
N,
M});
580 }
else if (
match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>(
583 Propagate = setShapeInfo(Inst, {
M,
N});
585 auto OpShape = ShapeMap.
find(MatrixA);
586 if (OpShape != ShapeMap.
end())
587 setShapeInfo(Inst, OpShape->second);
589 }
else if (isUniformShape(Inst)) {
592 auto OpShape = ShapeMap.
find(
Op.get());
593 if (OpShape != ShapeMap.
end()) {
594 Propagate |= setShapeInfo(Inst, OpShape->second);
601 NewWorkList.push_back(Inst);
604 WorkList.push_back(cast<Instruction>(
User));
617 auto pushInstruction = [](
Value *V,
621 WorkList.push_back(
I);
627 while (!WorkList.empty()) {
630 size_t BeforeProcessingV = WorkList.size();
631 if (!isa<Instruction>(V))
639 if (
match(V, m_Intrinsic<Intrinsic::matrix_multiply>(
642 if (setShapeInfo(MatrixA, {
M,
N}))
643 pushInstruction(MatrixA, WorkList);
645 if (setShapeInfo(MatrixB, {
N, K}))
646 pushInstruction(MatrixB, WorkList);
648 }
else if (
match(V, m_Intrinsic<Intrinsic::matrix_transpose>(
651 if (setShapeInfo(MatrixA, {
M,
N}))
652 pushInstruction(MatrixA, WorkList);
653 }
else if (
match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>(
656 if (setShapeInfo(MatrixA, {
M,
N})) {
657 pushInstruction(MatrixA, WorkList);
659 }
else if (isa<LoadInst>(V) ||
660 match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) {
662 }
else if (isa<StoreInst>(V)) {
665 }
else if (isUniformShape(V)) {
667 ShapeInfo Shape = ShapeMap[V];
668 for (
Use &U : cast<Instruction>(V)->operands()) {
669 if (setShapeInfo(U.get(), Shape))
670 pushInstruction(U.get(), WorkList);
676 for (
size_t I = BeforeProcessingV;
I != WorkList.size();
I++)
678 if (isa<Instruction>(U) && V != U)
679 NewWorkList.push_back(cast<Instruction>(U));
685 void optimizeTransposes() {
690 auto S = ShapeMap.
find(&Old);
691 if (
S != ShapeMap.
end()) {
693 if (supportsShapeInfo(New))
694 ShapeMap.
insert({New,
S->second});
702 for (
auto II =
BB.rbegin(); II !=
BB.rend();) {
707 auto EraseFromParent = [&II](
Value *V) {
708 auto *Inst = cast<Instruction>(V);
725 if (
match(&
I, m_Intrinsic<Intrinsic::matrix_transpose>(
m_Value(
TA)))) {
730 m_Intrinsic<Intrinsic::matrix_transpose>(
m_Value(TATA)))) {
731 ReplaceAllUsesWith(
I, TATA);
738 else if (
match(
TA, m_Intrinsic<Intrinsic::matrix_multiply>(
746 setShapeInfo(T0, {
C, K});
750 setShapeInfo(
T1, {K,
R});
751 NewInst =
Builder.CreateMatrixMultiply(T0,
T1,
C->getZExtValue(),
753 R->getZExtValue(),
"mmul");
754 ReplaceAllUsesWith(
I, NewInst);
776 if (
match(&*
I, m_Intrinsic<Intrinsic::matrix_multiply>(
779 match(A, m_Intrinsic<Intrinsic::matrix_transpose>(
m_Value(AT))) &&
785 setShapeInfo(M, {
C,
R});
787 M,
C->getZExtValue(),
R->getZExtValue());
788 ReplaceAllUsesWith(*
I, NewInst);
790 I->eraseFromParent();
792 cast<Instruction>(A)->eraseFromParent();
793 if (A !=
B &&
B->use_empty())
794 cast<Instruction>(
B)->eraseFromParent();
812 case Intrinsic::matrix_multiply:
813 case Intrinsic::matrix_transpose:
814 case Intrinsic::matrix_column_major_load:
815 case Intrinsic::matrix_column_major_store:
816 WorkList.push_back(&Inst);
824 if (WorkList.empty())
828 while (!WorkList.empty()) {
829 WorkList = propagateShapeForward(WorkList);
830 WorkList = propagateShapeBackward(WorkList);
834 optimizeTransposes();
836 dbgs() <<
"Dump after matrix transpose optimization:\n";
841 bool Changed =
false;
848 for (
auto *
BB : RPOT)
850 if (ShapeMap.
find(&
I) == ShapeMap.
end())
852 if (
match(&
I, m_Intrinsic<Intrinsic::matrix_multiply>()))
853 MaybeFusableInsts.push_back(cast<CallInst>(&
I));
854 MatrixInsts.push_back(&
I);
859 for (
CallInst *CI : MaybeFusableInsts)
860 LowerMatrixMultiplyFused(CI, FusedInsts);
861 Changed = !FusedInsts.
empty();
865 if (FusedInsts.
count(Inst))
870 if (
CallInst *CInst = dyn_cast<CallInst>(Inst))
871 Changed |= VisitCallInst(CInst);
875 if (
auto *BinOp = dyn_cast<BinaryOperator>(Inst))
876 Changed |= VisitBinaryOperator(BinOp);
877 if (
auto *UnOp = dyn_cast<UnaryOperator>(Inst))
878 Changed |= VisitUnaryOperator(UnOp);
880 Changed |= VisitLoad(cast<LoadInst>(Inst), Op1,
Builder);
882 Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2,
Builder);
886 RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func);
887 RemarkGen.emitRemarks();
902 if (
auto *Undefed = dyn_cast<Instruction>(U.getUser()))
903 UndefedInsts.
insert(Undefed);
907 UndefedInsts.
erase(Inst);
909 if (!UndefedInsts.
empty()) {
911 dbgs() <<
"Undefed but present instructions:\n";
912 for (
auto *
I : UndefedInsts)
913 dbgs() << *
I <<
"\n";
922 unsigned AS = cast<PointerType>(
BasePtr->getType())->getAddressSpace();
924 return Builder.CreatePointerCast(BasePtr, EltPtrType);
928 bool VisitCallInst(
CallInst *Inst) {
933 case Intrinsic::matrix_multiply:
936 case Intrinsic::matrix_transpose:
937 LowerTranspose(Inst);
939 case Intrinsic::matrix_column_major_load:
940 LowerColumnMajorLoad(Inst);
942 case Intrinsic::matrix_column_major_store:
943 LowerColumnMajorStore(Inst);
956 Align getAlignForIndex(
unsigned Idx,
Value *Stride,
Type *ElementTy,
958 Align InitialAlign =
DL.getValueOrABITypeAlignment(A, ElementTy);
962 TypeSize ElementSizeInBits =
DL.getTypeSizeInBits(ElementTy);
963 if (
auto *ConstStride = dyn_cast<ConstantInt>(Stride)) {
965 ConstStride->getZExtValue() * ElementSizeInBits / 8;
975 auto *VType = cast<VectorType>(Ty);
980 for (
unsigned I = 0,
E = Shape.getNumVectors();
I <
E; ++
I) {
983 Stride, Shape.getStride(), EltTy,
Builder);
985 VecTy,
GEP, getAlignForIndex(
I, Stride, EltTy, MAlign),
990 return Result.addNumLoads(getNumOps(
Result.getVectorTy()) *
998 ShapeInfo ResultShape,
Type *EltTy,
1004 unsigned AS = cast<PointerType>(MatrixPtr->
getType())->getAddressSpace();
1007 Value *TileStart =
Builder.CreateGEP(EltTy, EltPtr, Offset);
1009 ResultShape.NumColumns);
1012 Builder.CreatePointerCast(TileStart, TilePtrTy,
"col.cast");
1014 return loadMatrix(TileTy, TilePtr,
Align,
1023 finalizeLowering(Inst,
1032 void LowerColumnMajorLoad(
CallInst *Inst) {
1034 "Intrinsic only supports column-major layout!");
1039 {Inst->getArgOperand(3), Inst->getArgOperand(4)});
1044 void storeMatrix(
const MatrixTy &StoreVal,
Value *MatrixPtr,
1050 unsigned AS = cast<PointerType>(MatrixPtr->
getType())->getAddressSpace();
1053 Value *TileStart =
Builder.CreateGEP(EltTy, EltPtr, Offset);
1055 StoreVal.getNumColumns());
1058 Builder.CreatePointerCast(TileStart, TilePtrTy,
"col.cast");
1060 storeMatrix(TileTy, StoreVal, TilePtr, MAlign,
1066 MatrixTy storeMatrix(
Type *Ty, MatrixTy StoreVal,
Value *Ptr,
1069 auto VType = cast<VectorType>(Ty);
1071 for (
auto Vec :
enumerate(StoreVal.vectors())) {
1078 getAlignForIndex(Vec.index(), Stride,
1083 return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) *
1084 StoreVal.getNumVectors());
1091 auto StoreVal = getMatrix(Matrix, Shape,
Builder);
1092 finalizeLowering(Inst,
1093 storeMatrix(
Matrix->getType(), StoreVal, Ptr, A, Stride,
1101 void LowerColumnMajorStore(
CallInst *Inst) {
1103 "Intrinsic only supports column-major layout!");
1109 {Inst->getArgOperand(4), Inst->getArgOperand(5)});
1117 unsigned BlockNumElts =
1118 cast<FixedVectorType>(
Block->getType())->getNumElements();
1119 unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements();
1120 assert(NumElts >= BlockNumElts &&
"Too few elements for current block");
1129 for (
i = 0;
i <
I;
i++)
1132 unsigned VecNumElts =
1133 cast<FixedVectorType>(Col->getType())->getNumElements();
1134 for (;
i <
I + BlockNumElts;
i++)
1135 Mask.push_back(
i -
I + VecNumElts);
1137 for (;
i < VecNumElts;
i++)
1140 return Builder.CreateShuffleVector(Col, Block,
Mask);
1145 unsigned &NumComputeOps) {
1146 NumComputeOps += getNumOps(
A->getType());
1151 if (AllowContraction) {
1155 Func.getParent(), Intrinsic::fmuladd,
A->getType());
1156 return Builder.CreateCall(FMulAdd, {
A,
B, Sum});
1158 NumComputeOps += getNumOps(
A->getType());
1160 return Builder.CreateFAdd(Sum, Mul);
1163 NumComputeOps += getNumOps(
A->getType());
1165 return Builder.CreateAdd(Sum, Mul);
1173 void finalizeLowering(
Instruction *Inst, MatrixTy Matrix,
1175 auto inserted = Inst2ColumnMatrix.
insert(std::make_pair(Inst, Matrix));
1177 assert(inserted.second &&
"multiple matrix lowering mapping");
1180 Value *Flattened =
nullptr;
1182 if (ShapeMap.
find(U.getUser()) == ShapeMap.
end()) {
1197 void emitMatrixMultiply(MatrixTy &Result,
const MatrixTy &A,
1200 const unsigned VF = std::max<unsigned>(
1203 Result.getElementType()->getPrimitiveSizeInBits().getFixedSize(),
1205 unsigned R =
Result.getNumRows();
1206 unsigned C =
Result.getNumColumns();
1207 unsigned M =
A.getNumColumns();
1209 bool IsFP =
Result.getElementType()->isFloatingPointTy();
1210 assert(
A.isColumnMajor() ==
B.isColumnMajor() &&
1211 Result.isColumnMajor() ==
A.isColumnMajor() &&
1212 "operands must agree on matrix layout");
1213 unsigned NumComputeOps = 0;
1215 Builder.setFastMathFlags(FMF);
1217 if (
A.isColumnMajor()) {
1221 for (
unsigned J = 0; J <
C; ++J) {
1224 bool isSumZero = isa<ConstantAggregateZero>(
Result.getColumn(J));
1233 for (
unsigned K = 0; K <
M; ++K) {
1236 B.getColumn(IsScalarMatrixTransposed ? K : J),
1237 IsScalarMatrixTransposed ? J : K);
1240 createMulAdd(isSumZero && K == 0 ?
nullptr : Sum, L, Splat,
1251 for (
unsigned I = 0;
I <
R; ++
I) {
1253 bool isSumZero = isa<ConstantAggregateZero>(
Result.getRow(
I));
1254 for (
unsigned J = 0; J <
C; J +=
BlockSize) {
1259 Value *Sum =
nullptr;
1260 for (
unsigned K = 0; K <
M; ++K) {
1263 A.getVector(IsScalarMatrixTransposed ? K :
I),
1264 IsScalarMatrixTransposed ?
I : K);
1267 createMulAdd(isSumZero && K == 0 ?
nullptr : Sum, Splat, R,
1275 Result.addNumComputeOps(NumComputeOps);
1287 if (
AA->isNoAlias(LoadLoc, StoreLoc))
1288 return Load->getPointerOperand();
1300 DTUpdates.push_back({DT->
Delete, Check0, Succ});
1304 nullptr,
"alias_cont");
1310 nullptr,
"no_alias");
1317 Builder.SetInsertPoint(Check0);
1318 Type *IntPtrTy =
Builder.getIntPtrTy(
Load->getModule()->getDataLayout());
1320 const_cast<Value *
>(StoreLoc.
Ptr), IntPtrTy,
"store.begin");
1323 "store.end",
true,
true);
1325 IntPtrTy,
"load.begin");
1326 Builder.CreateCondBr(
Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1,
1336 "load.end",
true,
true);
1337 Builder.CreateCondBr(
Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy,
1342 auto *VT = cast<FixedVectorType>(
Load->getType());
1345 auto *ArrayTy =
ArrayType::get(VT->getElementType(), VT->getNumElements());
1347 Builder.CreateAlloca(ArrayTy,
Load->getPointerAddressSpace());
1359 DTUpdates.push_back({DT->
Insert, Check0, Check1});
1360 DTUpdates.push_back({DT->
Insert, Check0, Fusion});
1361 DTUpdates.push_back({DT->
Insert, Check1,
Copy});
1362 DTUpdates.push_back({DT->
Insert, Check1, Fusion});
1367 bool isFusionProfitable(
CallInst *MatMul) {
1374 const unsigned R = LShape.NumRows;
1375 const unsigned C = RShape.NumColumns;
1376 const unsigned M = LShape.NumColumns;
1377 auto *EltType = cast<VectorType>(MatMul->
getType())->getElementType();
1379 const unsigned VF = std::max<unsigned>(
1391 if (R <= VF &&
C == 1)
1397 unsigned Op0Regs = (
R + VF - 1) / VF * M;
1398 unsigned Op1Regs = (
M + VF - 1) / VF *
C;
1399 return Op0Regs + Op1Regs >
1403 MatrixTy getZeroMatrix(
Type *EltType,
unsigned R,
unsigned C) {
1406 for (
unsigned I = 0;
I <
C; ++
I)
1411 void createTiledLoops(
CallInst *MatMul,
Value *LPtr, ShapeInfo LShape,
1413 auto *EltType = cast<VectorType>(MatMul->
getType())->getElementType();
1416 TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns,
TileSize);
1427 MatrixTy TileResult;
1429 Builder.SetInsertPoint(TI.InnerLoopHeader->getTerminator());
1433 auto *Phi =
Builder.CreatePHI(TileVecTy, 2,
"result.vec." +
Twine(
I));
1435 TI.RowLoopHeader->getSingleSuccessor());
1436 TileResult.addVector(Phi);
1437 ColumnPhis.push_back(Phi);
1444 MatrixTy
A = loadMatrix(LPtr, {},
false, LShape, TI.CurrentRow, TI.CurrentK,
1446 MatrixTy
B = loadMatrix(RPtr, {},
false, RShape, TI.CurrentK, TI.CurrentCol,
1448 emitMatrixMultiply(TileResult, A,
B,
Builder,
true,
false,
1449 getFastMathFlags(MatMul));
1451 Builder.SetInsertPoint(TI.RowLoopLatch->getTerminator());
1452 storeMatrix(TileResult,
Store->getPointerOperand(),
Store->getAlign(),
1453 Store->isVolatile(), {LShape.NumRows, RShape.NumColumns},
1454 TI.CurrentRow, TI.CurrentCol, EltType,
Builder);
1456 for (
unsigned I = 0;
I < TileResult.getNumVectors();
I++)
1457 ColumnPhis[
I]->addIncoming(TileResult.getVector(
I), TI.InnerLoopLatch);
1463 unsigned InnerLoopUnrollCount =
std::min(10u, LShape.NumColumns /
TileSize);
1465 "llvm.loop.unroll.count", InnerLoopUnrollCount);
1472 "Tiling only supported for column-major matrixes at the moment!");
1473 if (!isFusionProfitable(MatMul))
1479 const unsigned R = LShape.NumRows;
1480 const unsigned C = RShape.NumColumns;
1481 const unsigned M = LShape.NumColumns;
1482 auto *EltType = cast<VectorType>(MatMul->
getType())->getElementType();
1484 Value *APtr = getNonAliasingPointer(LoadOp0,
Store, MatMul);
1485 Value *BPtr = getNonAliasingPointer(LoadOp1,
Store, MatMul);
1489 createTiledLoops(MatMul, APtr, LShape, BPtr, RShape,
Store);
1492 for (
unsigned J = 0; J <
C; J +=
TileSize)
1496 MatrixTy Res = getZeroMatrix(EltType, TileR, TileC);
1498 for (
unsigned K = 0; K <
M; K +=
TileSize) {
1503 {TileR, TileM}, EltType,
Builder);
1507 {TileM, TileC}, EltType,
Builder);
1508 emitMatrixMultiply(Res, A,
B,
Builder,
true,
false,
1509 getFastMathFlags(MatMul));
1511 storeMatrix(Res, CPtr,
Store->getAlign(),
Store->isVolatile(), {R, M},
1519 FusedInsts.
insert(MatMul);
1520 Store->eraseFromParent();
1523 FusedInsts.
insert(LoadOp0);
1526 if (LoadOp1 != LoadOp0 && LoadOp1->
hasNUses(0)) {
1527 FusedInsts.
insert(LoadOp1);
1536 void LowerMatrixMultiplyFused(
CallInst *MatMul,
1541 assert(
AA && LI &&
"Analyses should be available");
1550 :
match(A, m_Intrinsic<Intrinsic::matrix_transpose>(
m_Value(
T)))) {
1552 auto *EltType = cast<VectorType>(MatMul->
getType())->getElementType();
1555 const unsigned R = LShape.NumRows;
1556 const unsigned M = LShape.NumColumns;
1557 const unsigned C = RShape.NumColumns;
1564 MA = getMatrix(A, ShapeInfo(R, M),
Builder);
1565 MB = getMatrix(
T, ShapeInfo(
C, M),
Builder);
1568 MA = getMatrix(
T, ShapeInfo(R, M),
Builder);
1569 MB = getMatrix(
B, ShapeInfo(
C, M),
Builder);
1574 MatrixTy
Result(R,
C, EltType);
1576 emitMatrixMultiply(Result, MA, MB,
Builder,
false,
true,
1577 getFastMathFlags(MatMul));
1579 FusedInsts.
insert(MatMul);
1581 FusedInsts.
insert(cast<Instruction>(Transpose));
1582 ToRemove.push_back(cast<Instruction>(Transpose));
1585 Inst2ColumnMatrix[Transpose] = MatrixTy(M,
C, EltType);
1587 finalizeLowering(MatMul, Result,
Builder);
1596 auto *LoadOp0 = dyn_cast<LoadInst>(A);
1597 auto *LoadOp1 = dyn_cast<LoadInst>(
B);
1599 if (LoadOp0 && LoadOp1 &&
Store) {
1605 for (
unsigned I = 0;
I != WorkList.
size(); ++
I) {
1606 Value *Current = WorkList[
I];
1607 auto *CurrI = dyn_cast<Instruction>(Current);
1610 if (isa<PHINode>(CurrI))
1614 if (CurrI->mayHaveSideEffects() || CurrI->mayReadFromMemory())
1616 ToHoist.push_back(CurrI);
1617 WorkList.
insert(CurrI->op_begin(), CurrI->op_end());
1624 I->moveBefore(MatMul);
1626 emitSIMDTiling(MatMul, LoadOp0, LoadOp1,
Store, FusedInsts);
1632 void LowerMultiply(
CallInst *MatMul) {
1634 auto *EltType = cast<VectorType>(MatMul->
getType())->getElementType();
1640 assert(Lhs.getElementType() == Rhs.getElementType() &&
1641 "Matrix multiply argument element types do not match.");
1643 const unsigned R = LShape.NumRows;
1644 const unsigned C = RShape.NumColumns;
1645 assert(LShape.NumColumns == RShape.NumRows);
1648 MatrixTy
Result(R,
C, EltType);
1649 assert(Lhs.getElementType() ==
Result.getElementType() &&
1650 "Matrix multiply result element type does not match arguments.");
1652 emitMatrixMultiply(Result, Lhs, Rhs,
Builder,
false,
false,
1653 getFastMathFlags(MatMul));
1654 finalizeLowering(MatMul, Result,
Builder);
1658 void LowerTranspose(
CallInst *Inst) {
1664 MatrixTy InputMatrix = getMatrix(InputVal, ArgShape,
Builder);
1666 const unsigned NewNumVecs =
1667 InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns;
1668 const unsigned NewNumElts =
1669 InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows;
1671 for (
unsigned I = 0;
I < NewNumVecs; ++
I) {
1676 for (
auto J :
enumerate(InputMatrix.vectors())) {
1680 Builder.CreateInsertElement(ResultVector, Elt, J.index());
1682 Result.addVector(ResultVector);
1690 Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns)
1691 .addNumExposedTransposes(1),
1697 auto I = ShapeMap.
find(Inst);
1698 if (
I == ShapeMap.
end())
1709 auto I = ShapeMap.
find(StoredVal);
1710 if (
I == ShapeMap.
end())
1721 auto I = ShapeMap.
find(Inst);
1722 if (
I == ShapeMap.
end())
1729 ShapeInfo &Shape =
I->second;
1732 MatrixTy
A = getMatrix(Lhs, Shape,
Builder);
1733 MatrixTy
B = getMatrix(Rhs, Shape,
Builder);
1734 assert(
A.isColumnMajor() ==
B.isColumnMajor() &&
1735 Result.isColumnMajor() ==
A.isColumnMajor() &&
1736 "operands must agree on matrix layout");
1738 Builder.setFastMathFlags(getFastMathFlags(Inst));
1747 case Instruction::Sub:
1749 case Instruction::FAdd:
1751 case Instruction::FMul:
1753 case Instruction::FSub:
1760 for (
unsigned I = 0;
I < Shape.getNumVectors(); ++
I)
1761 Result.addVector(BuildVectorOp(
A.getVector(
I),
B.getVector(
I)));
1763 finalizeLowering(Inst,
1764 Result.addNumComputeOps(getNumOps(
Result.getVectorTy()) *
1772 auto I = ShapeMap.
find(Inst);
1773 if (
I == ShapeMap.
end())
1779 ShapeInfo &Shape =
I->second;
1784 Builder.setFastMathFlags(getFastMathFlags(Inst));
1789 case Instruction::FNeg:
1796 for (
unsigned I = 0;
I < Shape.getNumVectors(); ++
I)
1797 Result.addVector(BuildVectorOp(
M.getVector(
I)));
1799 finalizeLowering(Inst,
1800 Result.addNumComputeOps(getNumOps(
Result.getVectorTy()) *
1809 struct ExprLinearizer {
1810 unsigned LengthToBreak = 100;
1813 unsigned LineLength = 0;
1840 ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {}
1842 void indent(
unsigned N) {
1844 for (
unsigned i = 0;
i <
N;
i++)
1853 void maybeIndent(
unsigned Indent) {
1854 if (LineLength >= LengthToBreak)
1857 if (LineLength == 0)
1862 LineLength +=
S.size();
1866 Value *getUnderlyingObjectThroughLoads(
Value *V) {
1868 return getUnderlyingObjectThroughLoads(Ptr);
1875 bool isMatrix(
Value *V)
const {
return ExprsInSubprogram.
count(V); }
1880 auto M = Inst2Matrix.
find(V);
1881 if (M == Inst2Matrix.
end())
1884 SS <<
M->second.getNumRows();
1886 SS <<
M->second.getNumColumns();
1895 write(
"<no called fn>");
1898 if (!
Name.startswith(
"llvm.matrix")) {
1902 auto *II = cast<IntrinsicInst>(CI);
1910 case Intrinsic::matrix_multiply:
1916 case Intrinsic::matrix_transpose:
1920 case Intrinsic::matrix_column_major_load:
1921 prettyPrintMatrixType(II,
SS);
1924 case Intrinsic::matrix_column_major_store:
1936 unsigned getNumShapeArgs(
CallInst *CI)
const {
1939 case Intrinsic::matrix_multiply:
1941 case Intrinsic::matrix_transpose:
1943 case Intrinsic::matrix_column_major_load:
1944 case Intrinsic::matrix_column_major_store:
1957 V = getUnderlyingObjectThroughLoads(V);
1959 if (isa<AllocaInst>(V)) {
1976 if (
auto *CI = dyn_cast<ConstantInt>(V))
1977 TmpStream << CI->getValue();
1978 else if (isa<Constant>(V))
1979 TmpStream <<
"constant";
1982 TmpStream <<
"matrix";
1984 TmpStream <<
"scalar";
1987 Tmp = std::string(
StringRef(Tmp).trim());
1988 LineLength += Tmp.size();
1995 void linearizeExpr(
Value *Expr,
unsigned Indent,
bool ParentReused,
1996 bool ParentShared) {
1997 auto *
I = cast<Instruction>(Expr);
1998 maybeIndent(Indent);
2002 bool ExprShared =
false;
2005 if (!ParentShared) {
2012 DebugLoc DL = cast<Instruction>(
S)->getDebugLoc();
2016 ExprShared =
SI->second.size() > 1;
2019 bool Reused = !ReusedExprs.
insert(Expr).second;
2020 if (Reused && !ParentReused)
2023 if (
auto *CI = dyn_cast<CallInst>(
I)) {
2027 }
else if (isa<BitCastInst>(Expr)) {
2033 Ops.
append(
I->value_op_begin(),
I->value_op_end());
2034 write(std::string(
I->getOpcodeName()));
2037 write(std::string(
"("));
2039 unsigned NumOpsToBreak = 1;
2040 if (
match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>()))
2044 if (Ops.size() > NumOpsToBreak)
2047 maybeIndent(Indent + 1);
2049 linearizeExpr(
Op, Indent + 1, Reused, ExprShared);
2052 if (
Op != Ops.back())
2059 const std::string &getResult() {
2078 struct RemarkGenerator {
2086 : Inst2Matrix(Inst2Matrix), ORE(ORE),
Func(
Func),
2095 for (
auto *Expr : ExprsInSubprogram)
2098 return ExprsInSubprogram.count(U);
2100 Leaves.push_back(Expr);
2107 void collectSharedInfo(
Value *Leaf,
Value *V,
2111 if (!ExprsInSubprogram.
count(V))
2114 auto I =
Shared.insert({V, {}});
2115 I.first->second.insert(Leaf);
2117 for (
Value *
Op : cast<Instruction>(V)->operand_values())
2118 collectSharedInfo(Leaf,
Op, ExprsInSubprogram, Shared);
2124 std::pair<OpInfoTy, OpInfoTy>
2128 if (!ExprsInSubprogram.
count(Root))
2132 if (!ReusedExprs.
insert(Root).second)
2135 OpInfoTy SharedCount;
2139 auto CM = Inst2Matrix.
find(Root);
2140 if (
I->second.size() == 1)
2141 Count = CM->second.getOpInfo();
2143 SharedCount = CM->second.getOpInfo();
2145 for (
Value *
Op : cast<Instruction>(Root)->operand_values()) {
2146 auto C = sumOpInfos(
Op, ReusedExprs, ExprsInSubprogram, Shared);
2148 SharedCount +=
C.second;
2150 return {Count, SharedCount};
2153 void emitRemarks() {
2161 for (
auto &KV : Inst2Matrix) {
2162 if (
Func.getSubprogram()) {
2163 auto *
I = cast<Instruction>(KV.first);
2168 I.first->second.push_back(KV.first);
2172 auto I = Subprog2Exprs.
insert({
nullptr, {}});
2173 I.first->second.push_back(KV.first);
2176 for (
auto &KV : Subprog2Exprs) {
2179 auto Leaves = getExpressionLeaves(ExprsInSubprogram);
2182 for (
Value *Leaf : Leaves)
2183 collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared);
2186 for (
auto *L : Leaves) {
2188 DebugLoc Loc = cast<Instruction>(L)->getDebugLoc();
2199 OpInfoTy Counts, SharedCounts;
2200 std::tie(Counts, SharedCounts) =
2201 sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared);
2206 Rem <<
"Lowered with ";
2207 Rem <<
ore::NV(
"NumStores", Counts.NumStores) <<
" stores, "
2208 <<
ore::NV(
"NumLoads", Counts.NumLoads) <<
" loads, "
2209 <<
ore::NV(
"NumComputeOps", Counts.NumComputeOps)
2211 <<
ore::NV(
"NumExposedTransposes", Counts.NumExposedTransposes)
2212 <<
" exposed transposes";
2214 if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 ||
2215 SharedCounts.NumComputeOps > 0) {
2216 Rem <<
",\nadditionally "
2217 <<
ore::NV(
"NumStores", SharedCounts.NumStores) <<
" stores, "
2218 <<
ore::NV(
"NumLoads", SharedCounts.NumLoads) <<
" loads, "
2219 <<
ore::NV(
"NumFPOps", SharedCounts.NumComputeOps)
2221 <<
" are shared with other expressions";
2224 Rem << (
"\n" + linearize(L, Shared, ExprsInSubprogram,
DL));
2235 ExprLinearizer Lin(
DL, Inst2Matrix, Shared, ExprsInSubprogram, L);
2236 Lin.linearizeExpr(L, 0,
false,
false);
2237 return Lin.getResult();
2258 LowerMatrixIntrinsics LMT(
F,
TTI,
AA, DT, LI, ORE);
2273 OS, MapClassName2PassName);
2282 class LowerMatrixIntrinsicsLegacyPass :
public FunctionPass {
2292 auto &
TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
F);
2293 auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
2294 auto &
AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
2295 auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
2296 auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
2297 LowerMatrixIntrinsics LMT(
F,
TTI, &
AA, &DT, &LI, &ORE);
2298 bool C = LMT.Visit();
2314 static const char pass_name[] =
"Lower the matrix intrinsics";
2326 return new LowerMatrixIntrinsicsLegacyPass();
2335 class LowerMatrixIntrinsicsMinimalLegacyPass :
public FunctionPass {
2345 auto &
TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
F);
2346 LowerMatrixIntrinsics LMT(
F,
TTI,
nullptr,
nullptr,
nullptr,
nullptr);
2347 bool C = LMT.Visit();
2368 return new LowerMatrixIntrinsicsMinimalLegacyPass();