52using namespace PatternMatch;
54#define DEBUG_TYPE "lower-matrix-intrinsics"
58 cl::desc(
"Enable/disable fusing matrix instructions."));
63 "Tile size for matrix instruction fusion using square-shaped tiles."));
66 cl::desc(
"Generate loop nest for tiling."));
69 cl::desc(
"Force matrix instruction fusion even if not profitable."));
72 cl::desc(
"Allow the use of FMAs if available and profitable. This may "
73 "result in different results, due to less rounding error."));
79 cl::desc(
"Sets the default matrix layout"),
81 "Use column-major layout"),
83 "Use row-major layout")));
91 if (
auto *Subprogram = dyn_cast<DISubprogram>(Scope))
100 auto *Inst = cast<Instruction>(V);
102 if (!Inst->use_empty())
104 if (II != BB.
rend() && Inst == &*II)
106 Inst->eraseFromParent();
112 if (
auto *SV = dyn_cast<ShuffleVectorInst>(V))
113 return SV->isZeroEltSplat();
118template <
typename LTy,
typename RTy>
124template <
typename LTy,
typename RTy>
172 unsigned NumElements,
Type *EltType,
175 assert((!isa<ConstantInt>(Stride) ||
176 cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) &&
177 "Stride must be >= the number of elements in the result vector.");
178 unsigned AS = cast<PointerType>(
BasePtr->getType())->getAddressSpace();
181 Value *VecStart =
Builder.CreateMul(VecIdx, Stride,
"vec.start");
185 if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->
isZero())
188 VecStart =
Builder.CreateGEP(EltType, BasePtr, VecStart,
"vec.gep");
193 Type *VecPtrType = PointerType::get(VecType, AS);
194 return Builder.CreatePointerCast(VecStart, VecPtrType,
"vec.cast");
220class LowerMatrixIntrinsics {
232 unsigned NumStores = 0;
234 unsigned NumLoads = 0;
236 unsigned NumComputeOps = 0;
240 unsigned NumExposedTransposes = 0;
243 NumStores +=
RHS.NumStores;
244 NumLoads +=
RHS.NumLoads;
245 NumComputeOps +=
RHS.NumComputeOps;
246 NumExposedTransposes +=
RHS.NumExposedTransposes;
258 bool IsColumnMajor =
true;
263 : Vectors(Vectors.
begin(), Vectors.
end()),
265 MatrixTy(
unsigned NumRows,
unsigned NumColumns,
Type *EltTy)
268 unsigned D = isColumnMajor() ? NumColumns : NumRows;
269 for (
unsigned J = 0; J <
D; ++J)
271 EltTy, isColumnMajor() ? NumRows : NumColumns)));
274 Value *getVector(
unsigned i)
const {
return Vectors[i]; }
275 Value *getColumn(
unsigned i)
const {
276 assert(isColumnMajor() &&
"only supported for column-major matrixes");
279 Value *getRow(
unsigned i)
const {
280 assert(!isColumnMajor() &&
"only supported for row-major matrixes");
284 void setVector(
unsigned i,
Value *V) { Vectors[i] =
V; }
286 Type *getElementType()
const {
return getVectorTy()->getElementType(); }
288 unsigned getNumVectors()
const {
290 return getNumColumns();
294 unsigned getNumColumns()
const {
296 return Vectors.
size();
298 assert(Vectors.
size() > 0 &&
"Cannot call getNumRows without columns");
299 return cast<FixedVectorType>(Vectors[0]->
getType())->getNumElements();
302 unsigned getNumRows()
const {
303 if (isColumnMajor()) {
304 assert(Vectors.
size() > 0 &&
"Cannot call getNumRows without columns");
305 return cast<FixedVectorType>(Vectors[0]->
getType())->getNumElements();
307 return Vectors.
size();
312 assert(isColumnMajor() &&
"only supported for column-major matrixes");
313 return getVectorTy();
317 return cast<VectorType>(Vectors[0]->
getType());
322 "columns() only supported for column-major matrixes");
333 return Vectors.
size() == 1 ? Vectors[0]
337 MatrixTy &addNumLoads(
unsigned N) {
338 OpInfo.NumLoads +=
N;
342 void setNumLoads(
unsigned N) { OpInfo.NumLoads =
N; }
344 MatrixTy &addNumStores(
unsigned N) {
345 OpInfo.NumStores +=
N;
349 MatrixTy &addNumExposedTransposes(
unsigned N) {
350 OpInfo.NumExposedTransposes +=
N;
354 MatrixTy &addNumComputeOps(
unsigned N) {
355 OpInfo.NumComputeOps +=
N;
359 unsigned getNumStores()
const {
return OpInfo.NumStores; }
360 unsigned getNumLoads()
const {
return OpInfo.NumLoads; }
361 unsigned getNumComputeOps()
const {
return OpInfo.NumComputeOps; }
363 const OpInfoTy &getOpInfo()
const {
return OpInfo; }
365 bool isColumnMajor()
const {
return IsColumnMajor; }
367 unsigned getStride()
const {
370 return getNumColumns();
378 Value *Vec = isColumnMajor() ? getColumn(J) : getRow(
I);
379 assert(cast<FixedVectorType>(Vec->
getType())->getNumElements() >=
381 "Extracted vector will contain poison values");
382 return Builder.CreateShuffleVector(
394 ShapeInfo(
unsigned NumRows = 0,
unsigned NumColumns = 0)
395 : NumRows(NumRows), NumColumns(NumColumns),
403 return NumRows == other.NumRows && NumColumns == other.NumColumns;
405 bool operator!=(
const ShapeInfo &other) {
return !(*
this == other); }
409 operator bool()
const {
410 assert(NumRows == 0 || NumColumns != 0);
414 unsigned getStride()
const {
420 unsigned getNumVectors()
const {
427 ShapeInfo t()
const {
return ShapeInfo(NumColumns, NumRows); }
452 if (isa<FPMathOperator>(*Inst))
467 unsigned getNumOps(
Type *VT) {
468 assert(isa<VectorType>(VT) &&
"Expected vector type");
470 cast<FixedVectorType>(VT)->getNumElements());
474 bool isMinimal()
const {
480 unsigned getNumOps(
Type *ST,
unsigned N) {
481 return std::ceil((
ST->getPrimitiveSizeInBits() *
N).getFixedValue() /
492 MatrixTy getMatrix(
Value *MatrixVal,
const ShapeInfo &SI,
495 assert(VType &&
"MatrixVal must be a vector type");
496 assert(cast<FixedVectorType>(VType)->getNumElements() ==
497 SI.NumRows *
SI.NumColumns &&
498 "The vector size must match the number of matrix elements");
504 auto Found = Inst2ColumnMatrix.
find(MatrixVal);
505 if (Found != Inst2ColumnMatrix.
end()) {
506 MatrixTy &
M = Found->second;
509 if (
SI.NumRows ==
M.getNumRows() &&
SI.NumColumns ==
M.getNumColumns())
512 MatrixVal =
M.embedInVector(Builder);
517 for (
unsigned MaskStart = 0;
518 MaskStart < cast<FixedVectorType>(VType)->getNumElements();
519 MaskStart +=
SI.getStride()) {
531 bool setShapeInfo(
Value *V, ShapeInfo Shape) {
532 assert(Shape &&
"Shape not set");
533 if (isa<UndefValue>(V) || !supportsShapeInfo(V))
536 auto SIter = ShapeMap.
find(V);
537 if (SIter != ShapeMap.
end()) {
539 << SIter->second.NumRows <<
" "
540 << SIter->second.NumColumns <<
" for " << *V <<
"\n");
545 LLVM_DEBUG(
dbgs() <<
" " << Shape.NumRows <<
" x " << Shape.NumColumns
546 <<
" for " << *V <<
"\n");
550 bool isUniformShape(
Value *V) {
555 switch (
I->getOpcode()) {
556 case Instruction::FAdd:
557 case Instruction::FSub:
558 case Instruction::FMul:
559 case Instruction::FNeg:
560 case Instruction::Add:
561 case Instruction::Mul:
562 case Instruction::Sub:
571 bool supportsShapeInfo(
Value *V) {
579 case Intrinsic::matrix_multiply:
580 case Intrinsic::matrix_transpose:
581 case Intrinsic::matrix_column_major_load:
582 case Intrinsic::matrix_column_major_store:
587 return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
601 while (!WorkList.
empty()) {
605 bool Propagate =
false;
612 if (
match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>(
615 Propagate = setShapeInfo(Inst, {
M,
K});
616 }
else if (
match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>(
619 Propagate = setShapeInfo(Inst, {
N,
M});
620 }
else if (
match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>(
623 Propagate = setShapeInfo(Inst, {
N,
M});
624 }
else if (
match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>(
627 Propagate = setShapeInfo(Inst, {
M,
N});
629 auto OpShape = ShapeMap.
find(MatrixA);
630 if (OpShape != ShapeMap.
end())
631 setShapeInfo(Inst, OpShape->second);
633 }
else if (isUniformShape(Inst)) {
636 auto OpShape = ShapeMap.
find(
Op.get());
637 if (OpShape != ShapeMap.
end()) {
638 Propagate |= setShapeInfo(Inst, OpShape->second);
661 auto pushInstruction = [](
Value *
V,
671 while (!WorkList.
empty()) {
674 size_t BeforeProcessingV = WorkList.
size();
675 if (!isa<Instruction>(V))
683 if (
match(V, m_Intrinsic<Intrinsic::matrix_multiply>(
686 if (setShapeInfo(MatrixA, {
M,
N}))
687 pushInstruction(MatrixA, WorkList);
689 if (setShapeInfo(MatrixB, {
N,
K}))
690 pushInstruction(MatrixB, WorkList);
692 }
else if (
match(V, m_Intrinsic<Intrinsic::matrix_transpose>(
695 if (setShapeInfo(MatrixA, {
M,
N}))
696 pushInstruction(MatrixA, WorkList);
697 }
else if (
match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>(
700 if (setShapeInfo(MatrixA, {
M,
N})) {
701 pushInstruction(MatrixA, WorkList);
703 }
else if (isa<LoadInst>(V) ||
704 match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) {
706 }
else if (isa<StoreInst>(V)) {
709 }
else if (isUniformShape(V)) {
711 ShapeInfo Shape = ShapeMap[
V];
712 for (
Use &U : cast<Instruction>(V)->operands()) {
713 if (setShapeInfo(
U.get(), Shape))
714 pushInstruction(
U.get(), WorkList);
720 for (
size_t I = BeforeProcessingV;
I != WorkList.
size();
I++)
722 if (isa<Instruction>(U) && V != U)
723 NewWorkList.
push_back(cast<Instruction>(U));
732 Value *Op0, ShapeInfo Shape0,
Value *Op1, ShapeInfo Shape1,
737 Op0, Shape0.NumRows, Shape0.NumColumns, Op0->
getName() +
"_t");
740 setShapeInfo(T0, Shape0.t());
742 Op1, Shape1.NumRows, Shape1.NumColumns, Op1->
getName() +
"_t");
743 setShapeInfo(T1, Shape1.t());
744 return Operation(T0, Shape0.t(), T1, Shape1.t());
751 auto S = ShapeMap.
find(&Old);
752 if (S != ShapeMap.
end()) {
754 if (supportsShapeInfo(New))
771 if (!
match(&
I, m_Intrinsic<Intrinsic::matrix_transpose>(
777 if (
match(TA, m_Intrinsic<Intrinsic::matrix_transpose>(
m_Value(TATA)))) {
778 updateShapeAndReplaceAllUsesWith(
I, TATA);
786 updateShapeAndReplaceAllUsesWith(
I, TA);
793 if (
match(TA, m_Intrinsic<Intrinsic::matrix_multiply>(
796 auto NewInst = distributeTransposes(
798 [&](
Value *T0, ShapeInfo Shape0,
Value *
T1, ShapeInfo Shape1) {
799 return Builder.CreateMatrixMultiply(T0, T1, Shape0.NumRows,
801 Shape1.NumColumns,
"mmul");
803 updateShapeAndReplaceAllUsesWith(
I, NewInst);
818 auto NewInst = distributeTransposes(
820 [&](
Value *T0, ShapeInfo Shape0,
Value *
T1, ShapeInfo Shape1) {
821 bool IsFP =
I.getType()->isFPOrFPVectorTy();
822 auto *
Mul = IsFP ? LocalBuilder.CreateFMul(T0, T1,
"mmul")
823 : LocalBuilder.CreateMul(T0, T1,
"mmul");
825 setShapeInfo(Result, Shape0);
828 updateShapeAndReplaceAllUsesWith(
I, NewInst);
838 auto NewInst = distributeTransposes(
840 [&](
Value *T0, ShapeInfo Shape0,
Value *
T1, ShapeInfo Shape1) {
842 cast<Instruction>(LocalBuilder.CreateFAdd(T0, T1,
"mfadd"));
843 setShapeInfo(
FAdd, Shape0);
846 updateShapeAndReplaceAllUsesWith(
I, NewInst);
861 cast<Instruction>(
A)->eraseFromParent();
862 if (
A !=
B &&
B->use_empty())
863 cast<Instruction>(
B)->eraseFromParent();
869 if (
match(&
I, m_Intrinsic<Intrinsic::matrix_multiply>(
872 match(
A, m_Intrinsic<Intrinsic::matrix_transpose>(
m_Value(AT))) &&
877 BT, AT,
C->getZExtValue(),
K->getZExtValue(),
R->getZExtValue());
878 setShapeInfo(M, {
C,
R});
881 updateShapeAndReplaceAllUsesWith(
I, NewInst);
882 CleanupBinOp(
I,
A,
B);
886 match(
A, m_Intrinsic<Intrinsic::matrix_transpose>(
888 match(
B, m_Intrinsic<Intrinsic::matrix_transpose>(
892 setShapeInfo(
Add, {
C,
R});
894 Instruction *NewInst = MBuilder.CreateMatrixTranspose(
895 Add,
C->getZExtValue(),
R->getZExtValue(),
"mfadd_t");
896 updateShapeAndReplaceAllUsesWith(
I, NewInst);
897 CleanupBinOp(
I,
A,
B);
902 void optimizeTransposes() {
906 for (
auto II = BB.
rbegin(); II != BB.
rend();) {
936 case Intrinsic::matrix_multiply:
937 case Intrinsic::matrix_transpose:
938 case Intrinsic::matrix_column_major_load:
939 case Intrinsic::matrix_column_major_store:
948 if (WorkList.
empty())
952 while (!WorkList.
empty()) {
953 WorkList = propagateShapeForward(WorkList);
954 WorkList = propagateShapeBackward(WorkList);
958 optimizeTransposes();
960 dbgs() <<
"Dump after matrix transpose optimization:\n";
965 bool Changed =
false;
972 for (
auto *BB : RPOT)
974 if (ShapeMap.
find(&
I) == ShapeMap.
end())
976 if (
match(&
I, m_Intrinsic<Intrinsic::matrix_multiply>()))
977 MaybeFusableInsts.
push_back(cast<CallInst>(&
I));
983 for (
CallInst *CI : MaybeFusableInsts)
984 LowerMatrixMultiplyFused(CI, FusedInsts);
985 Changed = !FusedInsts.
empty();
989 if (FusedInsts.
count(Inst))
994 if (
CallInst *CInst = dyn_cast<CallInst>(Inst))
995 Changed |= VisitCallInst(CInst);
999 if (
auto *BinOp = dyn_cast<BinaryOperator>(Inst))
1000 Changed |= VisitBinaryOperator(BinOp);
1001 if (
auto *UnOp = dyn_cast<UnaryOperator>(Inst))
1002 Changed |= VisitUnaryOperator(UnOp);
1004 Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder);
1006 Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);
1010 RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func);
1011 RemarkGen.emitRemarks();
1024 for (
auto *Inst :
reverse(ToRemove)) {
1026 if (
auto *Poisoned = dyn_cast<Instruction>(
U.getUser()))
1027 PoisonedInsts.
insert(Poisoned);
1031 PoisonedInsts.
erase(Inst);
1033 if (!PoisonedInsts.
empty()) {
1035 dbgs() <<
"Poisoned but present instructions:\n";
1036 for (
auto *
I : PoisonedInsts)
1037 dbgs() << *
I <<
"\n";
1046 unsigned AS = cast<PointerType>(
BasePtr->getType())->getAddressSpace();
1047 Type *EltPtrType = PointerType::get(EltType, AS);
1048 return Builder.CreatePointerCast(BasePtr, EltPtrType);
1052 bool VisitCallInst(
CallInst *Inst) {
1057 case Intrinsic::matrix_multiply:
1058 LowerMultiply(Inst);
1060 case Intrinsic::matrix_transpose:
1061 LowerTranspose(Inst);
1063 case Intrinsic::matrix_column_major_load:
1064 LowerColumnMajorLoad(Inst);
1066 case Intrinsic::matrix_column_major_store:
1067 LowerColumnMajorStore(Inst);
1082 Align InitialAlign =
DL.getValueOrABITypeAlignment(
A, ElementTy);
1084 return InitialAlign;
1086 TypeSize ElementSizeInBits =
DL.getTypeSizeInBits(ElementTy);
1087 if (
auto *ConstStride = dyn_cast<ConstantInt>(Stride)) {
1089 ConstStride->getZExtValue() * ElementSizeInBits / 8;
1098 bool IsVolatile, ShapeInfo Shape,
IRBuilder<> &Builder) {
1099 auto *VType = cast<VectorType>(Ty);
1100 Type *EltTy = VType->getElementType();
1102 Value *EltPtr = createElementPtr(
Ptr, EltTy, Builder);
1104 for (
unsigned I = 0,
E = Shape.getNumVectors();
I <
E; ++
I) {
1107 Stride, Shape.getStride(), EltTy, Builder);
1109 VecTy,
GEP, getAlignForIndex(
I, Stride, EltTy, MAlign),
1110 IsVolatile,
"col.load");
1114 return Result.addNumLoads(getNumOps(
Result.getVectorTy()) *
1122 ShapeInfo ResultShape,
Type *EltTy,
1128 unsigned AS = cast<PointerType>(MatrixPtr->
getType())->getAddressSpace();
1130 Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS));
1133 ResultShape.NumColumns);
1134 Type *TilePtrTy = PointerType::get(TileTy, AS);
1136 Builder.CreatePointerCast(TileStart, TilePtrTy,
"col.cast");
1138 return loadMatrix(TileTy, TilePtr,
Align,
1139 Builder.getInt64(MatrixShape.getStride()), IsVolatile,
1140 ResultShape, Builder);
1145 bool IsVolatile, ShapeInfo Shape) {
1147 finalizeLowering(Inst,
1156 void LowerColumnMajorLoad(
CallInst *Inst) {
1158 "Intrinsic only supports column-major layout!");
1163 {Inst->getArgOperand(3), Inst->getArgOperand(4)});
1168 void storeMatrix(
const MatrixTy &StoreVal,
Value *MatrixPtr,
1169 MaybeAlign MAlign,
bool IsVolatile, ShapeInfo MatrixShape,
1174 unsigned AS = cast<PointerType>(MatrixPtr->
getType())->getAddressSpace();
1176 Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS));
1179 StoreVal.getNumColumns());
1180 Type *TilePtrTy = PointerType::get(TileTy, AS);
1182 Builder.CreatePointerCast(TileStart, TilePtrTy,
"col.cast");
1184 storeMatrix(TileTy, StoreVal, TilePtr, MAlign,
1185 Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder);
1190 MatrixTy storeMatrix(
Type *Ty, MatrixTy StoreVal,
Value *
Ptr,
1193 auto VType = cast<VectorType>(Ty);
1194 Value *EltPtr = createElementPtr(
Ptr, VType->getElementType(), Builder);
1195 for (
auto Vec :
enumerate(StoreVal.vectors())) {
1200 Stride, StoreVal.getStride(), VType->getElementType(), Builder);
1202 getAlignForIndex(Vec.index(), Stride,
1203 VType->getElementType(),
1207 return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) *
1208 StoreVal.getNumVectors());
1213 Value *Stride,
bool IsVolatile, ShapeInfo Shape) {
1215 auto StoreVal = getMatrix(
Matrix, Shape, Builder);
1216 finalizeLowering(Inst,
1217 storeMatrix(
Matrix->getType(), StoreVal,
Ptr,
A, Stride,
1218 IsVolatile, Builder),
1225 void LowerColumnMajorStore(
CallInst *Inst) {
1227 "Intrinsic only supports column-major layout!");
1233 {Inst->getArgOperand(4), Inst->getArgOperand(5)});
1241 unsigned BlockNumElts =
1242 cast<FixedVectorType>(
Block->getType())->getNumElements();
1243 unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements();
1244 assert(NumElts >= BlockNumElts &&
"Too few elements for current block");
1253 for (i = 0; i <
I; i++)
1256 unsigned VecNumElts =
1257 cast<FixedVectorType>(Col->getType())->getNumElements();
1258 for (; i <
I + BlockNumElts; i++)
1259 Mask.push_back(i -
I + VecNumElts);
1261 for (; i < VecNumElts; i++)
1264 return Builder.CreateShuffleVector(Col, Block, Mask);
1269 unsigned &NumComputeOps) {
1270 NumComputeOps += getNumOps(
A->getType());
1275 if (AllowContraction) {
1279 Func.getParent(), Intrinsic::fmuladd,
A->getType());
1282 NumComputeOps += getNumOps(
A->getType());
1287 NumComputeOps += getNumOps(
A->getType());
1299 auto inserted = Inst2ColumnMatrix.
insert(std::make_pair(Inst,
Matrix));
1301 assert(inserted.second &&
"multiple matrix lowering mapping");
1304 Value *Flattened =
nullptr;
1306 if (ShapeMap.
find(
U.getUser()) == ShapeMap.
end()) {
1308 Flattened =
Matrix.embedInVector(Builder);
1321 void emitMatrixMultiply(MatrixTy &Result,
const MatrixTy &
A,
1324 const unsigned VF = std::max<unsigned>(
1327 Result.getElementType()->getPrimitiveSizeInBits().getFixedValue(),
1329 unsigned R =
Result.getNumRows();
1330 unsigned C =
Result.getNumColumns();
1331 unsigned M =
A.getNumColumns();
1333 bool IsFP =
Result.getElementType()->isFloatingPointTy();
1334 assert(
A.isColumnMajor() ==
B.isColumnMajor() &&
1335 Result.isColumnMajor() ==
A.isColumnMajor() &&
1336 "operands must agree on matrix layout");
1337 unsigned NumComputeOps = 0;
1339 Builder.setFastMathFlags(FMF);
1341 if (
A.isColumnMajor()) {
1345 for (
unsigned J = 0; J <
C; ++J) {
1348 bool isSumZero = isa<ConstantAggregateZero>(
Result.getColumn(J));
1357 for (
unsigned K = 0;
K <
M; ++
K) {
1360 B.getColumn(IsScalarMatrixTransposed ? K : J),
1361 IsScalarMatrixTransposed ? J : K);
1364 createMulAdd(isSumZero && K == 0 ?
nullptr : Sum, L, Splat,
1375 for (
unsigned I = 0;
I <
R; ++
I) {
1377 bool isSumZero = isa<ConstantAggregateZero>(
Result.getRow(
I));
1378 for (
unsigned J = 0; J <
C; J +=
BlockSize) {
1383 Value *Sum =
nullptr;
1384 for (
unsigned K = 0;
K <
M; ++
K) {
1387 A.getVector(IsScalarMatrixTransposed ? K :
I),
1388 IsScalarMatrixTransposed ?
I : K);
1391 createMulAdd(isSumZero && K == 0 ?
nullptr : Sum, Splat, R,
1399 Result.addNumComputeOps(NumComputeOps);
1412 return Load->getPointerOperand();
1428 nullptr,
"alias_cont");
1434 nullptr,
"no_alias");
1441 Builder.SetInsertPoint(Check0);
1442 Type *IntPtrTy =
Builder.getIntPtrTy(
Load->getModule()->getDataLayout());
1444 const_cast<Value *
>(StoreLoc.
Ptr), IntPtrTy,
"store.begin");
1447 "store.end",
true,
true);
1449 IntPtrTy,
"load.begin");
1450 Builder.CreateCondBr(
Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1,
1460 "load.end",
true,
true);
1461 Builder.CreateCondBr(
Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy,
1466 auto *VT = cast<FixedVectorType>(
Load->getType());
1469 auto *ArrayTy = ArrayType::get(VT->getElementType(), VT->getNumElements());
1471 Builder.CreateAlloca(ArrayTy,
Load->getPointerAddressSpace());
1478 PHI->addIncoming(
Load->getPointerOperand(), Check0);
1479 PHI->addIncoming(
Load->getPointerOperand(), Check1);
1480 PHI->addIncoming(BC, Copy);
1491 bool isFusionProfitable(
CallInst *MatMul) {
1498 const unsigned R = LShape.NumRows;
1499 const unsigned C = RShape.NumColumns;
1500 const unsigned M = LShape.NumColumns;
1501 auto *EltType = cast<VectorType>(MatMul->
getType())->getElementType();
1503 const unsigned VF = std::max<unsigned>(
1515 if (R <= VF &&
C == 1)
1521 unsigned Op0Regs = (
R + VF - 1) / VF * M;
1522 unsigned Op1Regs = (
M + VF - 1) / VF *
C;
1523 return Op0Regs + Op1Regs >
1527 MatrixTy getZeroMatrix(
Type *EltType,
unsigned R,
unsigned C) {
1530 for (
unsigned I = 0;
I <
C; ++
I)
1535 void createTiledLoops(
CallInst *MatMul,
Value *LPtr, ShapeInfo LShape,
1537 auto *EltType = cast<VectorType>(MatMul->
getType())->getElementType();
1540 TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns,
TileSize);
1547 BasicBlock *InnerBody = TI.CreateTiledLoops(Start, End, Builder, DTU, *LI);
1551 MatrixTy TileResult;
1553 Builder.SetInsertPoint(TI.KLoop.Header->getTerminator());
1557 auto *Phi =
Builder.CreatePHI(TileVecTy, 2,
"result.vec." +
Twine(
I));
1559 TI.RowLoop.Header->getSingleSuccessor());
1560 TileResult.addVector(Phi);
1569 loadMatrix(LPtr, {},
false, LShape, TI.RowLoop.Index, TI.KLoop.Index,
1572 loadMatrix(RPtr, {},
false, RShape, TI.KLoop.Index, TI.ColumnLoop.Index,
1574 emitMatrixMultiply(TileResult,
A,
B, Builder,
true,
false,
1575 getFastMathFlags(MatMul));
1577 Builder.SetInsertPoint(TI.RowLoop.Latch->getTerminator());
1578 storeMatrix(TileResult,
Store->getPointerOperand(),
Store->getAlign(),
1579 Store->isVolatile(), {LShape.NumRows, RShape.NumColumns},
1580 TI.RowLoop.Index, TI.ColumnLoop.Index, EltType, Builder);
1582 for (
unsigned I = 0;
I < TileResult.getNumVectors();
I++)
1583 ColumnPhis[
I]->addIncoming(TileResult.getVector(
I), TI.KLoop.Latch);
1589 unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns /
TileSize);
1591 "llvm.loop.unroll.count", InnerLoopUnrollCount);
1598 "Tiling only supported for column-major matrixes at the moment!");
1599 if (!isFusionProfitable(MatMul))
1605 const unsigned R = LShape.NumRows;
1606 const unsigned C = RShape.NumColumns;
1607 const unsigned M = LShape.NumColumns;
1608 auto *EltType = cast<VectorType>(MatMul->
getType())->getElementType();
1610 Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul);
1611 Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul);
1615 createTiledLoops(MatMul, APtr, LShape, BPtr, RShape, Store);
1618 for (
unsigned J = 0; J <
C; J +=
TileSize)
1620 const unsigned TileR = std::min(R -
I,
unsigned(
TileSize));
1621 const unsigned TileC = std::min(
C - J,
unsigned(
TileSize));
1622 MatrixTy Res = getZeroMatrix(EltType, TileR, TileC);
1625 const unsigned TileM = std::min(M - K,
unsigned(
TileSize));
1629 {TileR, TileM}, EltType, Builder);
1633 {TileM, TileC}, EltType, Builder);
1634 emitMatrixMultiply(Res,
A,
B, Builder,
true,
false,
1635 getFastMathFlags(MatMul));
1637 storeMatrix(Res, CPtr,
Store->getAlign(),
Store->isVolatile(), {R, M},
1644 FusedInsts.
insert(Store);
1645 FusedInsts.
insert(MatMul);
1646 Store->eraseFromParent();
1649 FusedInsts.
insert(LoadOp0);
1652 if (LoadOp1 != LoadOp0 && LoadOp1->
hasNUses(0)) {
1653 FusedInsts.
insert(LoadOp1);
1662 void LowerMatrixMultiplyFused(
CallInst *MatMul,
1667 assert(AA && LI &&
"Analyses should be available");
1676 :
match(
A, m_Intrinsic<Intrinsic::matrix_transpose>(
m_Value(
T)))) {
1678 auto *EltType = cast<VectorType>(MatMul->
getType())->getElementType();
1681 const unsigned R = LShape.NumRows;
1682 const unsigned M = LShape.NumColumns;
1683 const unsigned C = RShape.NumColumns;
1690 MA = getMatrix(
A, ShapeInfo(R, M), Builder);
1691 MB = getMatrix(
T, ShapeInfo(
C, M), Builder);
1694 MA = getMatrix(
T, ShapeInfo(R, M), Builder);
1695 MB = getMatrix(
B, ShapeInfo(
C, M), Builder);
1700 MatrixTy
Result(R,
C, EltType);
1702 emitMatrixMultiply(Result, MA, MB, Builder,
false,
true,
1703 getFastMathFlags(MatMul));
1705 FusedInsts.
insert(MatMul);
1707 FusedInsts.
insert(cast<Instruction>(Transpose));
1708 ToRemove.push_back(cast<Instruction>(Transpose));
1711 Inst2ColumnMatrix[Transpose] = MatrixTy(M,
C, EltType);
1713 finalizeLowering(MatMul, Result, Builder);
1722 auto *LoadOp0 = dyn_cast<LoadInst>(
A);
1723 auto *LoadOp1 = dyn_cast<LoadInst>(
B);
1725 if (LoadOp0 && LoadOp1 && Store) {
1731 for (
unsigned I = 0;
I != WorkList.
size(); ++
I) {
1732 Value *Current = WorkList[
I];
1733 auto *CurrI = dyn_cast<Instruction>(Current);
1736 if (isa<PHINode>(CurrI))
1740 if (CurrI->mayHaveSideEffects() || CurrI->mayReadFromMemory())
1743 WorkList.
insert(CurrI->op_begin(), CurrI->op_end());
1750 I->moveBefore(MatMul);
1752 emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts);
1758 void LowerMultiply(
CallInst *MatMul) {
1760 auto *EltType = cast<VectorType>(MatMul->
getType())->getElementType();
1764 const MatrixTy &Lhs = getMatrix(MatMul->
getArgOperand(0), LShape, Builder);
1765 const MatrixTy &Rhs = getMatrix(MatMul->
getArgOperand(1), RShape, Builder);
1766 assert(Lhs.getElementType() == Rhs.getElementType() &&
1767 "Matrix multiply argument element types do not match.");
1769 const unsigned R = LShape.NumRows;
1770 const unsigned C = RShape.NumColumns;
1771 assert(LShape.NumColumns == RShape.NumRows);
1774 MatrixTy
Result(R,
C, EltType);
1775 assert(Lhs.getElementType() ==
Result.getElementType() &&
1776 "Matrix multiply result element type does not match arguments.");
1778 emitMatrixMultiply(Result, Lhs, Rhs, Builder,
false,
false,
1779 getFastMathFlags(MatMul));
1780 finalizeLowering(MatMul, Result, Builder);
1784 void LowerTranspose(
CallInst *Inst) {
1790 MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);
1792 const unsigned NewNumVecs =
1793 InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns;
1794 const unsigned NewNumElts =
1795 InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows;
1797 for (
unsigned I = 0;
I < NewNumVecs; ++
I) {
1802 for (
auto J :
enumerate(InputMatrix.vectors())) {
1806 Builder.CreateInsertElement(ResultVector, Elt, J.index());
1808 Result.addVector(ResultVector);
1816 Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns)
1817 .addNumExposedTransposes(1),
1823 auto I = ShapeMap.
find(Inst);
1824 if (
I == ShapeMap.
end())
1835 auto I = ShapeMap.
find(StoredVal);
1836 if (
I == ShapeMap.
end())
1847 auto I = ShapeMap.
find(Inst);
1848 if (
I == ShapeMap.
end())
1855 ShapeInfo &Shape =
I->second;
1858 MatrixTy
A = getMatrix(Lhs, Shape, Builder);
1859 MatrixTy
B = getMatrix(Rhs, Shape, Builder);
1860 assert(
A.isColumnMajor() ==
B.isColumnMajor() &&
1861 Result.isColumnMajor() ==
A.isColumnMajor() &&
1862 "operands must agree on matrix layout");
1864 Builder.setFastMathFlags(getFastMathFlags(Inst));
1869 case Instruction::Add:
1870 return Builder.CreateAdd(LHS, RHS);
1871 case Instruction::Mul:
1872 return Builder.CreateMul(LHS, RHS);
1873 case Instruction::Sub:
1874 return Builder.CreateSub(LHS, RHS);
1875 case Instruction::FAdd:
1876 return Builder.CreateFAdd(LHS, RHS);
1877 case Instruction::FMul:
1878 return Builder.CreateFMul(LHS, RHS);
1879 case Instruction::FSub:
1880 return Builder.CreateFSub(LHS, RHS);
1886 for (
unsigned I = 0;
I < Shape.getNumVectors(); ++
I)
1887 Result.addVector(BuildVectorOp(
A.getVector(
I),
B.getVector(
I)));
1889 finalizeLowering(Inst,
1890 Result.addNumComputeOps(getNumOps(
Result.getVectorTy()) *
1898 auto I = ShapeMap.
find(Inst);
1899 if (
I == ShapeMap.
end())
1905 ShapeInfo &Shape =
I->second;
1908 MatrixTy
M = getMatrix(Op, Shape, Builder);
1910 Builder.setFastMathFlags(getFastMathFlags(Inst));
1915 case Instruction::FNeg:
1916 return Builder.CreateFNeg(Op);
1922 for (
unsigned I = 0;
I < Shape.getNumVectors(); ++
I)
1923 Result.addVector(BuildVectorOp(
M.getVector(
I)));
1925 finalizeLowering(Inst,
1926 Result.addNumComputeOps(getNumOps(
Result.getVectorTy()) *
1935 struct ExprLinearizer {
1936 unsigned LengthToBreak = 100;
1939 unsigned LineLength = 0;
1965 : Stream(Str),
DL(
DL), Inst2Matrix(Inst2Matrix), Shared(Shared),
1966 ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {}
1968 void indent(
unsigned N) {
1970 for (
unsigned i = 0; i <
N; i++)
1979 void maybeIndent(
unsigned Indent) {
1980 if (LineLength >= LengthToBreak)
1983 if (LineLength == 0)
1988 LineLength += S.
size();
1992 Value *getUnderlyingObjectThroughLoads(
Value *V) {
1994 return getUnderlyingObjectThroughLoads(
Ptr);
1995 else if (
V->getType()->isPointerTy())
2001 bool isMatrix(
Value *V)
const {
return ExprsInSubprogram.
count(V); }
2006 auto M = Inst2Matrix.
find(V);
2007 if (M == Inst2Matrix.
end())
2010 SS <<
M->second.getNumRows();
2012 SS <<
M->second.getNumColumns();
2021 write(
"<no called fn>");
2024 if (!
Name.startswith(
"llvm.matrix")) {
2028 auto *II = cast<IntrinsicInst>(CI);
2036 case Intrinsic::matrix_multiply:
2037 prettyPrintMatrixType(II->
getOperand(0), SS);
2039 prettyPrintMatrixType(II->
getOperand(1), SS);
2042 case Intrinsic::matrix_transpose:
2043 prettyPrintMatrixType(II->
getOperand(0), SS);
2046 case Intrinsic::matrix_column_major_load:
2047 prettyPrintMatrixType(II, SS);
2050 case Intrinsic::matrix_column_major_store:
2051 prettyPrintMatrixType(II->
getOperand(0), SS);
2062 unsigned getNumShapeArgs(
CallInst *CI)
const {
2065 case Intrinsic::matrix_multiply:
2067 case Intrinsic::matrix_transpose:
2069 case Intrinsic::matrix_column_major_load:
2070 case Intrinsic::matrix_column_major_store:
2083 V = getUnderlyingObjectThroughLoads(V);
2084 if (
V->getType()->isPointerTy()) {
2085 if (isa<AllocaInst>(V)) {
2086 Stream <<
"stack addr";
2092 if (!
V->getName().empty()) {
2093 Stream <<
" %" <<
V->getName() <<
"";
2094 LineLength +=
V->getName().size() + 2;
2102 if (
auto *CI = dyn_cast<ConstantInt>(V))
2103 TmpStream << CI->getValue();
2104 else if (isa<Constant>(V))
2105 TmpStream <<
"constant";
2108 TmpStream <<
"matrix";
2110 TmpStream <<
"scalar";
2113 Tmp = std::string(
StringRef(Tmp).trim());
2114 LineLength += Tmp.size();
2121 void linearizeExpr(
Value *Expr,
unsigned Indent,
bool ParentReused,
2122 bool ParentShared) {
2123 auto *
I = cast<Instruction>(Expr);
2124 maybeIndent(Indent);
2128 bool ExprShared =
false;
2131 if (!ParentShared) {
2132 auto SI = Shared.find(Expr);
2133 assert(SI != Shared.end() &&
SI->second.count(Leaf));
2138 DebugLoc DL = cast<Instruction>(S)->getDebugLoc();
2139 write(
"shared with remark at line " + std::to_string(
DL.getLine()) +
2140 " column " + std::to_string(
DL.getCol()) +
" (");
2142 ExprShared =
SI->second.size() > 1;
2145 bool Reused = !ReusedExprs.
insert(Expr).second;
2146 if (Reused && !ParentReused)
2149 if (
auto *CI = dyn_cast<CallInst>(
I)) {
2153 }
else if (isa<BitCastInst>(Expr)) {
2159 Ops.
append(
I->value_op_begin(),
I->value_op_end());
2160 write(std::string(
I->getOpcodeName()));
2163 write(std::string(
"("));
2165 unsigned NumOpsToBreak = 1;
2166 if (
match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>()))
2169 for (
Value *Op : Ops) {
2170 if (Ops.size() > NumOpsToBreak)
2173 maybeIndent(Indent + 1);
2175 linearizeExpr(Op, Indent + 1, Reused, ExprShared);
2178 if (Op != Ops.back())
2185 const std::string &getResult() {
2204 struct RemarkGenerator {
2212 : Inst2Matrix(Inst2Matrix), ORE(ORE),
Func(
Func),
2221 for (
auto *Expr : ExprsInSubprogram)
2224 return ExprsInSubprogram.count(U);
2233 void collectSharedInfo(
Value *Leaf,
Value *V,
2237 if (!ExprsInSubprogram.
count(V))
2240 auto I = Shared.insert({
V, {}});
2241 I.first->second.insert(Leaf);
2243 for (
Value *Op : cast<Instruction>(V)->operand_values())
2244 collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared);
2250 std::pair<OpInfoTy, OpInfoTy>
2254 if (!ExprsInSubprogram.
count(Root))
2258 if (!ReusedExprs.
insert(Root).second)
2261 OpInfoTy SharedCount;
2264 auto I = Shared.find(Root);
2265 auto CM = Inst2Matrix.
find(Root);
2266 if (
I->second.size() == 1)
2267 Count = CM->second.getOpInfo();
2269 SharedCount = CM->second.getOpInfo();
2271 for (
Value *Op : cast<Instruction>(Root)->operand_values()) {
2272 auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared);
2274 SharedCount +=
C.second;
2276 return {Count, SharedCount};
2279 void emitRemarks() {
2287 for (
const auto &KV : Inst2Matrix) {
2288 if (
Func.getSubprogram()) {
2289 auto *
I = cast<Instruction>(KV.first);
2294 I.first->second.push_back(KV.first);
2298 auto I = Subprog2Exprs.
insert({
nullptr, {}});
2299 I.first->second.push_back(KV.first);
2302 for (
auto &KV : Subprog2Exprs) {
2305 auto Leaves = getExpressionLeaves(ExprsInSubprogram);
2308 for (
Value *Leaf : Leaves)
2309 collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared);
2312 for (
auto *L : Leaves) {
2314 DebugLoc Loc = cast<Instruction>(L)->getDebugLoc();
2325 OpInfoTy Counts, SharedCounts;
2326 std::tie(Counts, SharedCounts) =
2327 sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared);
2332 Rem <<
"Lowered with ";
2333 Rem <<
ore::NV(
"NumStores", Counts.NumStores) <<
" stores, "
2334 <<
ore::NV(
"NumLoads", Counts.NumLoads) <<
" loads, "
2335 <<
ore::NV(
"NumComputeOps", Counts.NumComputeOps)
2337 <<
ore::NV(
"NumExposedTransposes", Counts.NumExposedTransposes)
2338 <<
" exposed transposes";
2340 if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 ||
2341 SharedCounts.NumComputeOps > 0) {
2342 Rem <<
",\nadditionally "
2343 <<
ore::NV(
"NumStores", SharedCounts.NumStores) <<
" stores, "
2344 <<
ore::NV(
"NumLoads", SharedCounts.NumLoads) <<
" loads, "
2345 <<
ore::NV(
"NumFPOps", SharedCounts.NumComputeOps)
2347 <<
" are shared with other expressions";
2350 Rem << (
"\n" + linearize(L, Shared, ExprsInSubprogram, DL));
2361 ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L);
2362 Lin.linearizeExpr(L, 0,
false,
false);
2363 return Lin.getResult();
2384 LowerMatrixIntrinsics LMT(
F,
TTI, AA, DT, LI, ORE);
2399 OS, MapClassName2PassName);
2408class LowerMatrixIntrinsicsLegacyPass :
public FunctionPass {
2418 auto &
TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
F);
2419 auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
2420 auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
2421 auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
2422 auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
2423 LowerMatrixIntrinsics LMT(
F,
TTI, &AA, &DT, &LI, &ORE);
2424 bool C = LMT.Visit();
2440static const char pass_name[] =
"Lower the matrix intrinsics";
2441char LowerMatrixIntrinsicsLegacyPass::ID = 0;
2452 return new LowerMatrixIntrinsicsLegacyPass();
2461class LowerMatrixIntrinsicsMinimalLegacyPass :
public FunctionPass {
2471 auto &
TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
F);
2472 LowerMatrixIntrinsics LMT(
F,
TTI,
nullptr,
nullptr,
nullptr,
nullptr);
2473 bool C = LMT.Visit();
2485char LowerMatrixIntrinsicsMinimalLegacyPass::ID = 0;
2494 return new LowerMatrixIntrinsicsMinimalLegacyPass();
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
ReachingDefAnalysis InstSet & ToRemove
static const Function * getParent(const Value *V)
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
#define clEnumValN(ENUMVAL, FLAGNAME, DESC)
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
static bool runOnFunction(Function &F, bool PostInlining)
expand Expand reduction intrinsics
hexagon Hexagon specific predictive commoning for HVX vectors
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
static DISubprogram * getSubprogram(DIScope *Scope)
Helper function to either return Scope, if it is a subprogram or the attached subprogram for a local ...
static cl::opt< bool > ForceFusion("force-fuse-matrix", cl::init(false), cl::Hidden, cl::desc("Force matrix instruction fusion even if not profitable."))
static void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II, BasicBlock &BB)
Erase V from BB and move \II forward to avoid invalidating iterators.
static bool isSplat(Value *V)
Return true if V is a splat of a value (which is used when multiplying a matrix with a scalar).
static cl::opt< bool > TileUseLoops("fuse-matrix-use-loops", cl::init(false), cl::Hidden, cl::desc("Generate loop nest for tiling."))
static cl::opt< bool > FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden, cl::desc("Enable/disable fusing matrix instructions."))
auto m_AnyAdd(const LTy &L, const RTy &R)
Match any add operation (fp or integer).
lower matrix intrinsics minimal
static const char pass_name[]
static cl::opt< bool > AllowContractEnabled("matrix-allow-contract", cl::init(false), cl::Hidden, cl::desc("Allow the use of FMAs if available and profitable. This may " "result in different results, due to less rounding error."))
static const char pass_name_minimal[]
auto m_AnyMul(const LTy &L, const RTy &R)
Match any mul operation (fp or integer).
static cl::opt< bool > PrintAfterTransposeOpt("matrix-print-after-transpose-opt", cl::init(false))
static cl::opt< unsigned > TileSize("fuse-matrix-tile-size", cl::init(4), cl::Hidden, cl::desc("Tile size for matrix instruction fusion using square-shaped tiles."))
static cl::opt< MatrixLayoutTy > MatrixLayout("matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor), cl::desc("Sets the default matrix layout"), cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major", "Use column-major layout"), clEnumValN(MatrixLayoutTy::RowMajor, "row-major", "Use row-major layout")))
return ToRemove size() > 0
PowerPC Reduce CR logical Operation
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
static Value * extractVector(IRBuilderTy &IRB, Value *V, unsigned BeginIndex, unsigned EndIndex, const Twine &Name)
static Value * insertVector(IRBuilderTy &IRB, Value *Old, Value *V, unsigned BeginIndex, const Twine &Name)
This file defines the SmallVector class.
static SymbolRef::Type getType(const Symbol *Sym)
static const int BlockSize
static SDValue LowerStore(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)
static SDValue LowerLoad(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)
A manager for alias analyses.
A wrapper pass to provide the legacy pass manager access to a suitably prepared AAResults object.
bool isNoAlias(const MemoryLocation &LocA, const MemoryLocation &LocB)
A trivial helper function to check to see if the specified pointers are no-alias.
an instruction to allocate memory on the stack
Align getAlign() const
Return the alignment of the memory that is being allocated by the instruction.
A container for analyses that lazily runs them and caches their results.
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
void setPreservesCFG()
This function should be called by the pass, iff they do not:
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
LLVM Basic Block Representation.
iterator begin()
Instruction iterator methods.
reverse_iterator rbegin()
InstListType::reverse_iterator reverse_iterator
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
BinaryOps getOpcode() const
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
User::op_iterator arg_begin()
Return the iterator pointing to the beginning of the argument list.
MaybeAlign getParamAlign(unsigned ArgNo) const
Extract the alignment for a call or parameter (0=unknown).
Value * getArgOperand(unsigned i) const
User::op_iterator arg_end()
Return the iterator pointing to the end of the argument list.
This class represents a function call, abstracting a target machine's calling convention.
static ConstantAggregateZero * get(Type *Ty)
This is the shared class of boolean and integer constants.
static Constant * get(Type *Ty, uint64_t V, bool IsSigned=false)
If Ty is a vector type, return a Constant with a splat of the given value.
DISubprogram * getSubprogram() const
Get the subprogram for this scope.
Base class for scope-like contexts.
A parsed version of the target data layout string in and methods for querying it.
DILocation * getInlinedAt() const
Analysis pass which computes a DominatorTree.
void applyUpdates(ArrayRef< UpdateType > Updates)
Inform the dominator tree about a sequence of CFG edge insertions and deletions and perform a batch u...
static constexpr UpdateKind Delete
static constexpr UpdateKind Insert
Legacy analysis pass which computes a DominatorTree.
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
Convenience struct for specifying and reasoning about fast-math flags.
void setAllowContract(bool B=true)
bool allowContract() const
static FixedVectorType * get(Type *ElementType, unsigned NumElts)
FunctionPass class - This class is used to implement most global optimizations.
Intrinsic::ID getIntrinsicID() const LLVM_READONLY
getIntrinsicID - This method returns the ID number of the specified function, or Intrinsic::not_intri...
bool isIntrinsic() const
isIntrinsic - Returns true if the function's name starts with "llvm.".
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
const BasicBlock * getParent() const
FastMathFlags getFastMathFlags() const LLVM_READONLY
Convenience function for getting all the fast-math flags, which must be an operator which supports th...
SymbolTableList< Instruction >::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
A wrapper class for inspecting calls to intrinsic functions.
Intrinsic::ID getIntrinsicID() const
Return the intrinsic ID of this intrinsic.
An instruction for reading from memory.
bool isVolatile() const
Return true if this is a load from a volatile memory location.
Align getAlign() const
Return the alignment of the access that is being performed.
uint64_t getValue() const
Analysis pass that exposes the LoopInfo for a function.
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
The legacy pass manager's analysis pass to compute loop information.
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void printPipeline(raw_ostream &OS, function_ref< StringRef(StringRef)> MapClassName2PassName)
This class implements a map that also provides access to all stored values in a deterministic order.
iterator find(const KeyT &Key)
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Representation for a specific memory location.
static MemoryLocation get(const LoadInst *LI)
Return a location with information about the memory reference by the given instruction.
LocationSize Size
The maximum size of the location, in address-units, or UnknownSize if the size is not known.
const Value * Ptr
The address of the start of the location.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
Pass interface - Implemented by all 'passes'.
static PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
void preserve()
Mark an analysis as preserved.
A vector that has set insertion semantics.
size_type size() const
Determine the number of elements in the SetVector.
size_type count(const key_type &key) const
Count the number of elements of a given key in the SetVector.
bool insert(const value_type &X)
Insert a new element into the SetVector.
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
A SetVector that performs no allocations if smaller than a certain size.
SmallSet - This maintains a set of unique values, optimizing for the case when the set is small (less...
std::pair< const_iterator, bool > insert(const T &V)
insert - Insert an element into the set if it isn't already there.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
bool isVolatile() const
Return true if this is a store to a volatile memory location.
StringRef - Represent a constant reference to a string, i.e.
StringRef drop_front(size_t N=1) const
Return a StringRef equal to 'this' but with the first N elements dropped.
constexpr size_t size() const
size - Get the string size.
Analysis pass providing the TargetTransformInfo.
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
The instances of the Type class are immutable: once they are created, they are never changed.
PointerType * getPointerTo(unsigned AddrSpace=0) const
Return a pointer to the current type.
unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
bool isVoidTy() const
Return true if this is 'void'.
Type * getScalarType() const
If this is a vector type, return the element type, otherwise return 'this'.
UnaryOps getOpcode() const
static UndefValue * get(Type *T)
Static factory methods - Return an 'undef' object of the specified type.
A Use represents the edge between a Value definition and its users.
Value * getOperand(unsigned i) const
size_type count(const KeyT &Val) const
Return 1 if the specified key is in the map, 0 otherwise.
iterator find(const KeyT &Val)
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
bool erase(const KeyT &Val)
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
user_iterator user_begin()
bool hasOneUse() const
Return true if there is exactly one use of this value.
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
iterator_range< user_iterator > users()
bool hasNUses(unsigned N) const
Return true if this Value has exactly N uses.
iterator_range< use_iterator > uses()
StringRef getName() const
Return a constant reference to the value's name.
constexpr ScalarTy getFixedValue() const
An efficient, type-erasing, non-owning reference to a callable.
A range adaptor for a pair of iterators.
This class implements an extremely fast bulk output stream that can only output to a stream.
A raw_ostream that writes to an std::string.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
@ C
The default llvm calling convention, compatible with C.
StringRef getBaseName(ID id)
Return the LLVM name for an intrinsic, without encoded types for overloading, such as "llvm....
Function * getDeclaration(Module *M, ID id, ArrayRef< Type * > Tys=std::nullopt)
Create or insert an LLVM Function declaration for an intrinsic, and return it.
TwoOps_match< ValueOpTy, PointerOpTy, Instruction::Store > m_Store(const ValueOpTy &ValueOp, const PointerOpTy &PointerOp)
Matches StoreInst.
BinaryOp_match< LHS, RHS, Instruction::Add > m_Add(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::FMul > m_FMul(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
BinaryOp_match< LHS, RHS, Instruction::FAdd > m_FAdd(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
OneOps_match< OpTy, Instruction::Load > m_Load(const OpTy &Op)
Matches LoadInst.
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
match_combine_or< LTy, RTy > m_CombineOr(const LTy &L, const RTy &R)
Combine two pattern matchers matching L || R.
ValuesClass values(OptsTy... Options)
Helper to build a ValuesClass by forwarding a variable number of arguments as an initializer list to ...
initializer< Ty > init(const Ty &Val)
DiagnosticInfoOptimizationBase::Argument NV
const_iterator begin(StringRef path, Style style=Style::native)
Get begin iterator over path.
const_iterator end(StringRef path)
Get end iterator over path.
This is an optimization pass for GlobalISel generic memory operations.
Error write(MCStreamer &Out, ArrayRef< std::string > Inputs)
auto enumerate(FirstRange &&First, RestRanges &&...Rest)
Given two or more input ranges, returns a new range whose values are are tuples (A,...
auto successors(const MachineBasicBlock *BB)
bool operator!=(uint64_t V1, const APInt &V2)
iterator_range< T > make_range(T x, T y)
Convenience function for iterating over sub-ranges.
std::string & operator+=(std::string &buffer, StringRef string)
const Value * getUnderlyingObject(const Value *V, unsigned MaxLookup=6)
This method strips off any GEP address adjustments and pointer casts from the specified value,...
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Value * concatenateVectors(IRBuilderBase &Builder, ArrayRef< Value * > Vecs)
Concatenate a list of vectors.
void initializeLowerMatrixIntrinsicsLegacyPassPass(PassRegistry &)
bool operator==(const AddressRangeValuePair &LHS, const AddressRangeValuePair &RHS)
const Value * getPointerOperand(const Value *V)
A helper function that returns the pointer operand of a load, store or GEP instruction.
void addStringMetadataToLoop(Loop *TheLoop, const char *MDString, unsigned V=0)
Set input string into loop metadata by keeping other values intact.
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
auto reverse(ContainerTy &&C)
void sort(IteratorTy Start, IteratorTy End)
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Pass * createLowerMatrixIntrinsicsPass()
void initializeLowerMatrixIntrinsicsMinimalLegacyPassPass(PassRegistry &)
@ FMulAdd
Fused multiply-add of floats (a * b + c).
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Align commonAlignment(Align A, uint64_t Offset)
Returns the alignment that satisfies both alignments.
BasicBlock * SplitBlock(BasicBlock *Old, Instruction *SplitPt, DominatorTree *DT, LoopInfo *LI=nullptr, MemorySSAUpdater *MSSAU=nullptr, const Twine &BBName="", bool Before=false)
Split the specified block at the specified instruction.
Pass * createLowerMatrixIntrinsicsMinimalPass()
llvm::SmallVector< int, 16 > createSequentialMask(unsigned Start, unsigned NumInts, unsigned NumUndefs)
Create a sequential shuffle mask.
This struct is a compact representation of a valid (non-zero power of two) alignment.
This struct is a compact representation of a valid (power of two) or undefined (0) alignment.
A CRTP mix-in to automatically provide informational APIs needed for passes.
A helper struct to create IR loop nests for tiling in IR of the following form: for ColumnLoop....