49using namespace PatternMatch;
51#define DEBUG_TYPE "lower-matrix-intrinsics"
55 cl::desc(
"Enable/disable fusing matrix instructions."));
60 "Tile size for matrix instruction fusion using square-shaped tiles."));
63 cl::desc(
"Generate loop nest for tiling."));
66 cl::desc(
"Force matrix instruction fusion even if not profitable."));
69 cl::desc(
"Allow the use of FMAs if available and profitable. This may "
70 "result in different results, due to less rounding error."));
74 cl::desc(
"Enable/disable matrix shape verification."),
81 cl::desc(
"Sets the default matrix layout"),
83 "Use column-major layout"),
85 "Use row-major layout")));
93 if (
auto *Subprogram = dyn_cast<DISubprogram>(Scope))
102 auto *Inst = cast<Instruction>(V);
104 if (!Inst->use_empty())
106 if (II != BB.
rend() && Inst == &*II)
108 Inst->eraseFromParent();
114 if (
auto *SV = dyn_cast<ShuffleVectorInst>(V))
115 return SV->isZeroEltSplat();
120template <
typename LTy,
typename RTy>
126template <
typename LTy,
typename RTy>
174 unsigned NumElements,
Type *EltType,
177 assert((!isa<ConstantInt>(Stride) ||
178 cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) &&
179 "Stride must be >= the number of elements in the result vector.");
182 Value *VecStart =
Builder.CreateMul(VecIdx, Stride,
"vec.start");
186 if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->
isZero())
189 VecStart =
Builder.CreateGEP(EltType, BasePtr, VecStart,
"vec.gep");
217class LowerMatrixIntrinsics {
229 unsigned NumStores = 0;
231 unsigned NumLoads = 0;
233 unsigned NumComputeOps = 0;
237 unsigned NumExposedTransposes = 0;
240 NumStores +=
RHS.NumStores;
241 NumLoads +=
RHS.NumLoads;
242 NumComputeOps +=
RHS.NumComputeOps;
243 NumExposedTransposes +=
RHS.NumExposedTransposes;
255 bool IsColumnMajor =
true;
260 : Vectors(Vectors.
begin(), Vectors.
end()),
262 MatrixTy(
unsigned NumRows,
unsigned NumColumns,
Type *EltTy)
265 unsigned D = isColumnMajor() ? NumColumns : NumRows;
266 for (
unsigned J = 0; J <
D; ++J)
268 EltTy, isColumnMajor() ? NumRows : NumColumns)));
271 Value *getVector(
unsigned i)
const {
return Vectors[i]; }
272 Value *getColumn(
unsigned i)
const {
273 assert(isColumnMajor() &&
"only supported for column-major matrixes");
276 Value *getRow(
unsigned i)
const {
277 assert(!isColumnMajor() &&
"only supported for row-major matrixes");
281 void setVector(
unsigned i,
Value *V) { Vectors[i] =
V; }
283 Type *getElementType()
const {
return getVectorTy()->getElementType(); }
285 unsigned getNumVectors()
const {
287 return getNumColumns();
291 unsigned getNumColumns()
const {
293 return Vectors.
size();
295 assert(Vectors.
size() > 0 &&
"Cannot call getNumRows without columns");
296 return cast<FixedVectorType>(Vectors[0]->
getType())->getNumElements();
299 unsigned getNumRows()
const {
300 if (isColumnMajor()) {
301 assert(Vectors.
size() > 0 &&
"Cannot call getNumRows without columns");
302 return cast<FixedVectorType>(Vectors[0]->
getType())->getNumElements();
304 return Vectors.
size();
309 assert(isColumnMajor() &&
"only supported for column-major matrixes");
310 return getVectorTy();
314 return cast<VectorType>(Vectors[0]->
getType());
319 "columns() only supported for column-major matrixes");
330 return Vectors.
size() == 1 ? Vectors[0]
334 MatrixTy &addNumLoads(
unsigned N) {
335 OpInfo.NumLoads +=
N;
339 void setNumLoads(
unsigned N) { OpInfo.NumLoads =
N; }
341 MatrixTy &addNumStores(
unsigned N) {
342 OpInfo.NumStores +=
N;
346 MatrixTy &addNumExposedTransposes(
unsigned N) {
347 OpInfo.NumExposedTransposes +=
N;
351 MatrixTy &addNumComputeOps(
unsigned N) {
352 OpInfo.NumComputeOps +=
N;
356 unsigned getNumStores()
const {
return OpInfo.NumStores; }
357 unsigned getNumLoads()
const {
return OpInfo.NumLoads; }
358 unsigned getNumComputeOps()
const {
return OpInfo.NumComputeOps; }
360 const OpInfoTy &getOpInfo()
const {
return OpInfo; }
362 bool isColumnMajor()
const {
return IsColumnMajor; }
364 unsigned getStride()
const {
367 return getNumColumns();
375 Value *Vec = isColumnMajor() ? getColumn(J) : getRow(
I);
376 assert(cast<FixedVectorType>(Vec->
getType())->getNumElements() >=
378 "Extracted vector will contain poison values");
379 return Builder.CreateShuffleVector(
391 ShapeInfo(
unsigned NumRows = 0,
unsigned NumColumns = 0)
392 : NumRows(NumRows), NumColumns(NumColumns),
400 return NumRows == other.NumRows && NumColumns == other.NumColumns;
402 bool operator!=(
const ShapeInfo &other) {
return !(*
this == other); }
406 operator bool()
const {
407 assert(NumRows == 0 || NumColumns != 0);
411 unsigned getStride()
const {
417 unsigned getNumVectors()
const {
424 ShapeInfo t()
const {
return ShapeInfo(NumColumns, NumRows); }
449 if (isa<FPMathOperator>(*Inst))
464 unsigned getNumOps(
Type *VT) {
465 assert(isa<VectorType>(VT) &&
"Expected vector type");
467 cast<FixedVectorType>(VT)->getNumElements());
471 bool isMinimal()
const {
477 unsigned getNumOps(
Type *ST,
unsigned N) {
478 return std::ceil((
ST->getPrimitiveSizeInBits() *
N).getFixedValue() /
489 MatrixTy getMatrix(
Value *MatrixVal,
const ShapeInfo &SI,
492 assert(VType &&
"MatrixVal must be a vector type");
493 assert(cast<FixedVectorType>(VType)->getNumElements() ==
494 SI.NumRows *
SI.NumColumns &&
495 "The vector size must match the number of matrix elements");
501 auto Found = Inst2ColumnMatrix.
find(MatrixVal);
502 if (Found != Inst2ColumnMatrix.
end()) {
503 MatrixTy &
M = Found->second;
506 if (
SI.NumRows ==
M.getNumRows() &&
SI.NumColumns ==
M.getNumColumns())
509 MatrixVal =
M.embedInVector(Builder);
514 for (
unsigned MaskStart = 0;
515 MaskStart < cast<FixedVectorType>(VType)->getNumElements();
516 MaskStart +=
SI.getStride()) {
528 bool setShapeInfo(
Value *V, ShapeInfo Shape) {
529 assert(Shape &&
"Shape not set");
530 if (isa<UndefValue>(V) || !supportsShapeInfo(V))
533 auto SIter = ShapeMap.
find(V);
534 if (SIter != ShapeMap.
end()) {
536 SIter->second.NumColumns != Shape.NumColumns)) {
537 errs() <<
"Conflicting shapes (" << SIter->second.NumRows <<
"x"
538 << SIter->second.NumColumns <<
" vs " << Shape.NumRows <<
"x"
539 << Shape.NumColumns <<
") for " << *
V <<
"\n";
541 "Matrix shape verification failed, compilation aborted!");
545 << SIter->second.NumRows <<
" "
546 << SIter->second.NumColumns <<
" for " << *V <<
"\n");
551 LLVM_DEBUG(
dbgs() <<
" " << Shape.NumRows <<
" x " << Shape.NumColumns
552 <<
" for " << *V <<
"\n");
556 bool isUniformShape(
Value *V) {
561 switch (
I->getOpcode()) {
562 case Instruction::FAdd:
563 case Instruction::FSub:
564 case Instruction::FMul:
565 case Instruction::FNeg:
566 case Instruction::Add:
567 case Instruction::Mul:
568 case Instruction::Sub:
577 bool supportsShapeInfo(
Value *V) {
585 case Intrinsic::matrix_multiply:
586 case Intrinsic::matrix_transpose:
587 case Intrinsic::matrix_column_major_load:
588 case Intrinsic::matrix_column_major_store:
593 return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
607 while (!WorkList.
empty()) {
611 bool Propagate =
false;
618 if (
match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>(
621 Propagate = setShapeInfo(Inst, {
M,
K});
622 }
else if (
match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>(
625 Propagate = setShapeInfo(Inst, {
N,
M});
626 }
else if (
match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>(
629 Propagate = setShapeInfo(Inst, {
N,
M});
630 }
else if (
match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>(
633 Propagate = setShapeInfo(Inst, {
M,
N});
635 auto OpShape = ShapeMap.
find(MatrixA);
636 if (OpShape != ShapeMap.
end())
637 setShapeInfo(Inst, OpShape->second);
639 }
else if (isUniformShape(Inst)) {
642 auto OpShape = ShapeMap.
find(
Op.get());
643 if (OpShape != ShapeMap.
end()) {
644 Propagate |= setShapeInfo(Inst, OpShape->second);
667 auto pushInstruction = [](
Value *
V,
677 while (!WorkList.
empty()) {
680 size_t BeforeProcessingV = WorkList.
size();
681 if (!isa<Instruction>(V))
689 if (
match(V, m_Intrinsic<Intrinsic::matrix_multiply>(
692 if (setShapeInfo(MatrixA, {
M,
N}))
693 pushInstruction(MatrixA, WorkList);
695 if (setShapeInfo(MatrixB, {
N,
K}))
696 pushInstruction(MatrixB, WorkList);
698 }
else if (
match(V, m_Intrinsic<Intrinsic::matrix_transpose>(
701 if (setShapeInfo(MatrixA, {
M,
N}))
702 pushInstruction(MatrixA, WorkList);
703 }
else if (
match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>(
706 if (setShapeInfo(MatrixA, {
M,
N})) {
707 pushInstruction(MatrixA, WorkList);
709 }
else if (isa<LoadInst>(V) ||
710 match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) {
712 }
else if (isa<StoreInst>(V)) {
715 }
else if (isUniformShape(V)) {
717 ShapeInfo Shape = ShapeMap[
V];
718 for (
Use &U : cast<Instruction>(V)->operands()) {
719 if (setShapeInfo(
U.get(), Shape))
720 pushInstruction(
U.get(), WorkList);
726 for (
size_t I = BeforeProcessingV;
I != WorkList.
size();
I++)
728 if (isa<Instruction>(U) && V != U)
729 NewWorkList.
push_back(cast<Instruction>(U));
738 Value *Op0, ShapeInfo Shape0,
Value *Op1, ShapeInfo Shape1,
743 Op0, Shape0.NumRows, Shape0.NumColumns, Op0->
getName() +
"_t");
746 setShapeInfo(T0, Shape0.t());
748 Op1, Shape1.NumRows, Shape1.NumColumns, Op1->
getName() +
"_t");
749 setShapeInfo(T1, Shape1.t());
750 return Operation(T0, Shape0.t(), T1, Shape1.t());
757 auto S = ShapeMap.
find(&Old);
758 if (S != ShapeMap.
end()) {
760 if (supportsShapeInfo(New))
777 if (!
match(&
I, m_Intrinsic<Intrinsic::matrix_transpose>(
783 if (
match(TA, m_Intrinsic<Intrinsic::matrix_transpose>(
m_Value(TATA)))) {
784 updateShapeAndReplaceAllUsesWith(
I, TATA);
792 updateShapeAndReplaceAllUsesWith(
I, TA);
799 if (
match(TA, m_Intrinsic<Intrinsic::matrix_multiply>(
802 auto NewInst = distributeTransposes(
804 [&](
Value *T0, ShapeInfo Shape0,
Value *
T1, ShapeInfo Shape1) {
805 return Builder.CreateMatrixMultiply(T0, T1, Shape0.NumRows,
807 Shape1.NumColumns,
"mmul");
809 updateShapeAndReplaceAllUsesWith(
I, NewInst);
824 auto NewInst = distributeTransposes(
826 [&](
Value *T0, ShapeInfo Shape0,
Value *
T1, ShapeInfo Shape1) {
827 bool IsFP =
I.getType()->isFPOrFPVectorTy();
828 auto *
Mul = IsFP ? LocalBuilder.CreateFMul(T0, T1,
"mmul")
829 : LocalBuilder.CreateMul(T0, T1,
"mmul");
831 setShapeInfo(Result, Shape0);
834 updateShapeAndReplaceAllUsesWith(
I, NewInst);
844 auto NewInst = distributeTransposes(
846 [&](
Value *T0, ShapeInfo Shape0,
Value *
T1, ShapeInfo Shape1) {
847 bool IsFP =
I.getType()->isFPOrFPVectorTy();
848 auto *
Add = IsFP ? LocalBuilder.CreateFAdd(T0, T1,
"madd")
849 : LocalBuilder.CreateAdd(T0, T1,
"madd");
852 setShapeInfo(Result, Shape0);
855 updateShapeAndReplaceAllUsesWith(
I, NewInst);
870 cast<Instruction>(
A)->eraseFromParent();
871 if (
A !=
B &&
B->use_empty())
872 cast<Instruction>(
B)->eraseFromParent();
878 if (
match(&
I, m_Intrinsic<Intrinsic::matrix_multiply>(
881 match(
A, m_Intrinsic<Intrinsic::matrix_transpose>(
m_Value(AT))) &&
886 BT, AT,
C->getZExtValue(),
K->getZExtValue(),
R->getZExtValue());
887 setShapeInfo(M, {
C,
R});
890 updateShapeAndReplaceAllUsesWith(
I, NewInst);
891 CleanupBinOp(
I,
A,
B);
895 match(
A, m_Intrinsic<Intrinsic::matrix_transpose>(
897 match(
B, m_Intrinsic<Intrinsic::matrix_transpose>(
901 setShapeInfo(
Add, {
C,
R});
903 Instruction *NewInst = MBuilder.CreateMatrixTranspose(
904 Add,
C->getZExtValue(),
R->getZExtValue(),
"mfadd_t");
905 updateShapeAndReplaceAllUsesWith(
I, NewInst);
906 CleanupBinOp(
I,
A,
B);
911 void optimizeTransposes() {
915 for (
auto II = BB.
rbegin(); II != BB.
rend();) {
945 case Intrinsic::matrix_multiply:
946 case Intrinsic::matrix_transpose:
947 case Intrinsic::matrix_column_major_load:
948 case Intrinsic::matrix_column_major_store:
957 if (WorkList.
empty())
961 while (!WorkList.
empty()) {
962 WorkList = propagateShapeForward(WorkList);
963 WorkList = propagateShapeBackward(WorkList);
967 optimizeTransposes();
969 dbgs() <<
"Dump after matrix transpose optimization:\n";
974 bool Changed =
false;
981 for (
auto *BB : RPOT)
983 if (ShapeMap.
find(&
I) == ShapeMap.
end())
985 if (
match(&
I, m_Intrinsic<Intrinsic::matrix_multiply>()))
986 MaybeFusableInsts.
push_back(cast<CallInst>(&
I));
992 for (
CallInst *CI : MaybeFusableInsts)
993 lowerDotProduct(CI, FusedInsts, getFastMathFlags(CI));
996 for (
CallInst *CI : MaybeFusableInsts)
997 LowerMatrixMultiplyFused(CI, FusedInsts);
999 Changed = !FusedInsts.
empty();
1003 if (FusedInsts.
count(Inst))
1008 if (
CallInst *CInst = dyn_cast<CallInst>(Inst))
1009 Changed |= VisitCallInst(CInst);
1013 if (
auto *BinOp = dyn_cast<BinaryOperator>(Inst))
1014 Changed |= VisitBinaryOperator(BinOp);
1015 if (
auto *UnOp = dyn_cast<UnaryOperator>(Inst))
1016 Changed |= VisitUnaryOperator(UnOp);
1018 Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder);
1020 Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);
1024 RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func);
1025 RemarkGen.emitRemarks();
1038 for (
auto *Inst :
reverse(ToRemove)) {
1040 if (
auto *Poisoned = dyn_cast<Instruction>(
U.getUser()))
1041 PoisonedInsts.
insert(Poisoned);
1045 PoisonedInsts.
erase(Inst);
1047 if (!PoisonedInsts.
empty()) {
1049 dbgs() <<
"Poisoned but present instructions:\n";
1050 for (
auto *
I : PoisonedInsts)
1051 dbgs() << *
I <<
"\n";
1059 bool VisitCallInst(
CallInst *Inst) {
1064 case Intrinsic::matrix_multiply:
1065 LowerMultiply(Inst);
1067 case Intrinsic::matrix_transpose:
1068 LowerTranspose(Inst);
1070 case Intrinsic::matrix_column_major_load:
1071 LowerColumnMajorLoad(Inst);
1073 case Intrinsic::matrix_column_major_store:
1074 LowerColumnMajorStore(Inst);
1089 Align InitialAlign =
DL.getValueOrABITypeAlignment(
A, ElementTy);
1091 return InitialAlign;
1093 TypeSize ElementSizeInBits =
DL.getTypeSizeInBits(ElementTy);
1094 if (
auto *ConstStride = dyn_cast<ConstantInt>(Stride)) {
1096 ConstStride->getZExtValue() * ElementSizeInBits / 8;
1105 bool IsVolatile, ShapeInfo Shape,
IRBuilder<> &Builder) {
1106 auto *VType = cast<VectorType>(Ty);
1107 Type *EltTy = VType->getElementType();
1111 for (
unsigned I = 0,
E = Shape.getNumVectors();
I <
E; ++
I) {
1114 Stride, Shape.getStride(), EltTy, Builder);
1116 VecTy,
GEP, getAlignForIndex(
I, Stride, EltTy, MAlign),
1117 IsVolatile,
"col.load");
1121 return Result.addNumLoads(getNumOps(
Result.getVectorTy()) *
1129 ShapeInfo ResultShape,
Type *EltTy,
1137 ResultShape.NumColumns);
1139 return loadMatrix(TileTy, TileStart,
Align,
1140 Builder.getInt64(MatrixShape.getStride()), IsVolatile,
1141 ResultShape, Builder);
1146 bool IsVolatile, ShapeInfo Shape) {
1148 finalizeLowering(Inst,
1157 void LowerColumnMajorLoad(
CallInst *Inst) {
1159 "Intrinsic only supports column-major layout!");
1164 {Inst->getArgOperand(3), Inst->getArgOperand(4)});
1169 void storeMatrix(
const MatrixTy &StoreVal,
Value *MatrixPtr,
1170 MaybeAlign MAlign,
bool IsVolatile, ShapeInfo MatrixShape,
1177 StoreVal.getNumColumns());
1179 storeMatrix(TileTy, StoreVal, TileStart, MAlign,
1180 Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder);
1185 MatrixTy storeMatrix(
Type *Ty, MatrixTy StoreVal,
Value *
Ptr,
1188 auto VType = cast<VectorType>(Ty);
1190 for (
auto Vec :
enumerate(StoreVal.vectors())) {
1195 Stride, StoreVal.getStride(), VType->getElementType(), Builder);
1197 getAlignForIndex(Vec.index(), Stride,
1198 VType->getElementType(),
1202 return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) *
1203 StoreVal.getNumVectors());
1208 Value *Stride,
bool IsVolatile, ShapeInfo Shape) {
1210 auto StoreVal = getMatrix(
Matrix, Shape, Builder);
1211 finalizeLowering(Inst,
1212 storeMatrix(
Matrix->getType(), StoreVal,
Ptr,
A, Stride,
1213 IsVolatile, Builder),
1220 void LowerColumnMajorStore(
CallInst *Inst) {
1222 "Intrinsic only supports column-major layout!");
1228 {Inst->getArgOperand(4), Inst->getArgOperand(5)});
1236 unsigned BlockNumElts =
1237 cast<FixedVectorType>(
Block->getType())->getNumElements();
1238 unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements();
1239 assert(NumElts >= BlockNumElts &&
"Too few elements for current block");
1248 for (i = 0; i <
I; i++)
1251 unsigned VecNumElts =
1252 cast<FixedVectorType>(Col->getType())->getNumElements();
1253 for (; i <
I + BlockNumElts; i++)
1254 Mask.push_back(i -
I + VecNumElts);
1256 for (; i < VecNumElts; i++)
1264 unsigned &NumComputeOps) {
1265 NumComputeOps += getNumOps(
A->getType());
1270 if (AllowContraction) {
1274 Func.getParent(), Intrinsic::fmuladd,
A->getType());
1277 NumComputeOps += getNumOps(
A->getType());
1282 NumComputeOps += getNumOps(
A->getType());
1294 auto inserted = Inst2ColumnMatrix.
insert(std::make_pair(Inst,
Matrix));
1296 assert(inserted.second &&
"multiple matrix lowering mapping");
1299 Value *Flattened =
nullptr;
1301 if (ShapeMap.
find(
U.getUser()) == ShapeMap.
end()) {
1303 Flattened =
Matrix.embedInVector(Builder);
1312 void lowerDotProduct(
CallInst *MatMul,
1321 if (LShape.NumRows != 1 || RShape.NumColumns != 1)
1327 Type *ElementType = cast<VectorType>(
LHS->
getType())->getElementType();
1334 auto CanBeFlattened = [
this](
Value *
Op) {
1340 m_CombineOr(m_Intrinsic<Intrinsic::matrix_transpose>(),
1341 m_Intrinsic<Intrinsic::matrix_column_major_load>(
1347 auto GetCostForArg = [
this, &CanBeFlattened](
Value *
Op,
unsigned N) {
1348 if (!isa<Instruction>(
Op))
1354 if (!CanBeFlattened(
Op)) {
1357 for (
unsigned I = 1;
I <
N; ++
I)
1371 return NewCost - OriginalCost;
1374 if (
match(
Op, m_Intrinsic<Intrinsic::matrix_transpose>())) {
1379 for (
unsigned I = 1;
I <
N; ++
I)
1393 auto LHSCost = GetCostForArg(LHS, LShape.NumColumns);
1396 int AddOpCode = IsIntVec ? Instruction::Add : Instruction::FAdd;
1397 int MulOpCode = IsIntVec ? Instruction::Mul : Instruction::FMul;
1401 IsIntVec ? std::nullopt : std::optional(FMF)) +
1405 (LShape.NumColumns - 1) +
1407 (LShape.NumColumns);
1408 if ((LHSCost + ReductionCost - SequentialAddCost) >
InstructionCost(0))
1411 FusedInsts.
insert(MatMul);
1413 auto FlattenArg = [&
Builder, &FusedInsts, &CanBeFlattened,
1418 if (!CanBeFlattened(
Op))
1422 ShapeMap[
Op] = ShapeMap[
Op].t();
1426 FusedInsts.insert(cast<Instruction>(
Op));
1429 if (
match(
Op, m_Intrinsic<Intrinsic::matrix_column_major_load>(
1431 auto *NewLoad =
Builder.CreateLoad(
Op->getType(), Arg);
1432 Op->replaceAllUsesWith(NewLoad);
1433 cast<Instruction>(
Op)->eraseFromParent();
1435 }
else if (
match(
Op, m_Intrinsic<Intrinsic::matrix_transpose>(
1443 LHS = FlattenArg(LHS);
1447 IsIntVec ?
Builder.CreateMul(LHS, RHS) :
Builder.CreateFMul(LHS, RHS);
1457 cast<Instruction>(Result)->setFastMathFlags(FMF);
1464 FusedInsts.insert(MatMul);
1475 void emitMatrixMultiply(MatrixTy &Result,
const MatrixTy &
A,
1478 const unsigned VF = std::max<unsigned>(
1481 Result.getElementType()->getPrimitiveSizeInBits().getFixedValue(),
1483 unsigned R =
Result.getNumRows();
1484 unsigned C =
Result.getNumColumns();
1485 unsigned M =
A.getNumColumns();
1487 bool IsFP =
Result.getElementType()->isFloatingPointTy();
1488 assert(
A.isColumnMajor() ==
B.isColumnMajor() &&
1489 Result.isColumnMajor() ==
A.isColumnMajor() &&
1490 "operands must agree on matrix layout");
1491 unsigned NumComputeOps = 0;
1493 Builder.setFastMathFlags(FMF);
1495 if (
A.isColumnMajor()) {
1499 for (
unsigned J = 0; J <
C; ++J) {
1502 bool isSumZero = isa<ConstantAggregateZero>(
Result.getColumn(J));
1511 for (
unsigned K = 0;
K <
M; ++
K) {
1514 B.getColumn(IsScalarMatrixTransposed ? K : J),
1515 IsScalarMatrixTransposed ? J : K);
1518 createMulAdd(isSumZero && K == 0 ?
nullptr : Sum, L,
Splat,
1529 for (
unsigned I = 0;
I <
R; ++
I) {
1531 bool isSumZero = isa<ConstantAggregateZero>(
Result.getRow(
I));
1532 for (
unsigned J = 0; J <
C; J +=
BlockSize) {
1537 Value *Sum =
nullptr;
1538 for (
unsigned K = 0;
K <
M; ++
K) {
1541 A.getVector(IsScalarMatrixTransposed ? K :
I),
1542 IsScalarMatrixTransposed ?
I : K);
1545 createMulAdd(isSumZero && K == 0 ?
nullptr : Sum,
Splat, R,
1553 Result.addNumComputeOps(NumComputeOps);
1566 return Load->getPointerOperand();
1582 nullptr,
"alias_cont");
1588 nullptr,
"no_alias");
1595 Builder.SetInsertPoint(Check0);
1596 Type *IntPtrTy =
Builder.getIntPtrTy(
Load->getModule()->getDataLayout());
1598 const_cast<Value *
>(StoreLoc.
Ptr), IntPtrTy,
"store.begin");
1601 "store.end",
true,
true);
1603 IntPtrTy,
"load.begin");
1604 Builder.CreateCondBr(
Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1,
1614 "load.end",
true,
true);
1615 Builder.CreateCondBr(
Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy,
1620 auto *VT = cast<FixedVectorType>(
Load->getType());
1623 auto *ArrayTy = ArrayType::get(VT->getElementType(), VT->getNumElements());
1625 Builder.CreateAlloca(ArrayTy,
Load->getPointerAddressSpace());
1631 PHI->addIncoming(
Load->getPointerOperand(), Check0);
1632 PHI->addIncoming(
Load->getPointerOperand(), Check1);
1633 PHI->addIncoming(Alloca, Copy);
1644 bool isFusionProfitable(
CallInst *MatMul) {
1651 const unsigned R = LShape.NumRows;
1652 const unsigned C = RShape.NumColumns;
1653 const unsigned M = LShape.NumColumns;
1654 auto *EltType = cast<VectorType>(MatMul->
getType())->getElementType();
1656 const unsigned VF = std::max<unsigned>(
1668 if (R <= VF &&
C == 1)
1674 unsigned Op0Regs = (
R + VF - 1) / VF * M;
1675 unsigned Op1Regs = (
M + VF - 1) / VF *
C;
1676 return Op0Regs + Op1Regs >
1680 MatrixTy getZeroMatrix(
Type *EltType,
unsigned R,
unsigned C) {
1683 for (
unsigned I = 0;
I <
C; ++
I)
1688 void createTiledLoops(
CallInst *MatMul,
Value *LPtr, ShapeInfo LShape,
1690 auto *EltType = cast<VectorType>(MatMul->
getType())->getElementType();
1693 TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns,
TileSize);
1700 BasicBlock *InnerBody = TI.CreateTiledLoops(Start,
End, Builder, DTU, *LI);
1704 MatrixTy TileResult;
1706 Builder.SetInsertPoint(TI.KLoop.Header->getTerminator());
1712 TI.RowLoop.Header->getSingleSuccessor());
1713 TileResult.addVector(Phi);
1722 loadMatrix(LPtr, {},
false, LShape, TI.RowLoop.Index, TI.KLoop.Index,
1725 loadMatrix(RPtr, {},
false, RShape, TI.KLoop.Index, TI.ColumnLoop.Index,
1727 emitMatrixMultiply(TileResult,
A,
B, Builder,
true,
false,
1728 getFastMathFlags(MatMul));
1730 Builder.SetInsertPoint(TI.RowLoop.Latch->getTerminator());
1731 storeMatrix(TileResult,
Store->getPointerOperand(),
Store->getAlign(),
1732 Store->isVolatile(), {LShape.NumRows, RShape.NumColumns},
1733 TI.RowLoop.Index, TI.ColumnLoop.Index, EltType, Builder);
1735 for (
unsigned I = 0;
I < TileResult.getNumVectors();
I++)
1736 ColumnPhis[
I]->addIncoming(TileResult.getVector(
I), TI.KLoop.Latch);
1742 unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns /
TileSize);
1744 "llvm.loop.unroll.count", InnerLoopUnrollCount);
1751 "Tiling only supported for column-major matrixes at the moment!");
1752 if (!isFusionProfitable(MatMul))
1758 const unsigned R = LShape.NumRows;
1759 const unsigned C = RShape.NumColumns;
1760 const unsigned M = LShape.NumColumns;
1761 auto *EltType = cast<VectorType>(MatMul->
getType())->getElementType();
1763 Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul);
1764 Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul);
1768 createTiledLoops(MatMul, APtr, LShape, BPtr, RShape, Store);
1771 for (
unsigned J = 0; J <
C; J +=
TileSize)
1773 const unsigned TileR = std::min(R -
I,
unsigned(
TileSize));
1774 const unsigned TileC = std::min(
C - J,
unsigned(
TileSize));
1775 MatrixTy Res = getZeroMatrix(EltType, TileR, TileC);
1778 const unsigned TileM = std::min(M - K,
unsigned(
TileSize));
1782 {TileR, TileM}, EltType, Builder);
1786 {TileM, TileC}, EltType, Builder);
1787 emitMatrixMultiply(Res,
A,
B, Builder,
true,
false,
1788 getFastMathFlags(MatMul));
1790 storeMatrix(Res, CPtr,
Store->getAlign(),
Store->isVolatile(), {R, M},
1797 FusedInsts.
insert(Store);
1798 FusedInsts.
insert(MatMul);
1799 Store->eraseFromParent();
1802 FusedInsts.
insert(LoadOp0);
1805 if (LoadOp1 != LoadOp0 && LoadOp1->
hasNUses(0)) {
1806 FusedInsts.
insert(LoadOp1);
1815 void LowerMatrixMultiplyFused(
CallInst *MatMul,
1820 assert(AA && LI &&
"Analyses should be available");
1829 :
match(
A, m_Intrinsic<Intrinsic::matrix_transpose>(
m_Value(
T)))) {
1831 auto *EltType = cast<VectorType>(MatMul->
getType())->getElementType();
1834 const unsigned R = LShape.NumRows;
1835 const unsigned M = LShape.NumColumns;
1836 const unsigned C = RShape.NumColumns;
1843 MA = getMatrix(
A, ShapeInfo(R, M), Builder);
1844 MB = getMatrix(
T, ShapeInfo(
C, M), Builder);
1847 MA = getMatrix(
T, ShapeInfo(R, M), Builder);
1848 MB = getMatrix(
B, ShapeInfo(
C, M), Builder);
1853 MatrixTy
Result(R,
C, EltType);
1855 emitMatrixMultiply(Result, MA, MB, Builder,
false,
true,
1856 getFastMathFlags(MatMul));
1858 FusedInsts.
insert(MatMul);
1860 FusedInsts.
insert(cast<Instruction>(Transpose));
1861 ToRemove.push_back(cast<Instruction>(Transpose));
1864 Inst2ColumnMatrix[Transpose] = MatrixTy(M,
C, EltType);
1866 finalizeLowering(MatMul, Result, Builder);
1875 auto *LoadOp0 = dyn_cast<LoadInst>(
A);
1876 auto *LoadOp1 = dyn_cast<LoadInst>(
B);
1878 if (LoadOp0 && LoadOp1 && Store) {
1884 for (
unsigned I = 0;
I != WorkList.
size(); ++
I) {
1885 Value *Current = WorkList[
I];
1886 auto *CurrI = dyn_cast<Instruction>(Current);
1889 if (isa<PHINode>(CurrI))
1893 if (CurrI->mayHaveSideEffects() || CurrI->mayReadFromMemory())
1896 WorkList.
insert(CurrI->op_begin(), CurrI->op_end());
1903 I->moveBefore(MatMul);
1905 emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts);
1911 void LowerMultiply(
CallInst *MatMul) {
1913 auto *EltType = cast<VectorType>(MatMul->
getType())->getElementType();
1917 const MatrixTy &Lhs = getMatrix(MatMul->
getArgOperand(0), LShape, Builder);
1918 const MatrixTy &Rhs = getMatrix(MatMul->
getArgOperand(1), RShape, Builder);
1919 assert(Lhs.getElementType() == Rhs.getElementType() &&
1920 "Matrix multiply argument element types do not match.");
1922 const unsigned R = LShape.NumRows;
1923 const unsigned C = RShape.NumColumns;
1924 assert(LShape.NumColumns == RShape.NumRows);
1927 MatrixTy
Result(R,
C, EltType);
1928 assert(Lhs.getElementType() ==
Result.getElementType() &&
1929 "Matrix multiply result element type does not match arguments.");
1931 emitMatrixMultiply(Result, Lhs, Rhs, Builder,
false,
false,
1932 getFastMathFlags(MatMul));
1933 finalizeLowering(MatMul, Result, Builder);
1937 void LowerTranspose(
CallInst *Inst) {
1943 MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);
1945 const unsigned NewNumVecs =
1946 InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns;
1947 const unsigned NewNumElts =
1948 InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows;
1950 for (
unsigned I = 0;
I < NewNumVecs; ++
I) {
1955 for (
auto J :
enumerate(InputMatrix.vectors())) {
1959 Builder.CreateInsertElement(ResultVector, Elt, J.index());
1961 Result.addVector(ResultVector);
1969 Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns)
1970 .addNumExposedTransposes(1),
1976 auto I = ShapeMap.
find(Inst);
1977 if (
I == ShapeMap.
end())
1988 auto I = ShapeMap.
find(StoredVal);
1989 if (
I == ShapeMap.
end())
2000 auto I = ShapeMap.
find(Inst);
2001 if (
I == ShapeMap.
end())
2008 ShapeInfo &Shape =
I->second;
2011 MatrixTy
A = getMatrix(Lhs, Shape, Builder);
2012 MatrixTy
B = getMatrix(Rhs, Shape, Builder);
2013 assert(
A.isColumnMajor() ==
B.isColumnMajor() &&
2014 Result.isColumnMajor() ==
A.isColumnMajor() &&
2015 "operands must agree on matrix layout");
2017 Builder.setFastMathFlags(getFastMathFlags(Inst));
2022 case Instruction::Add:
2023 return Builder.CreateAdd(LHS, RHS);
2024 case Instruction::Mul:
2025 return Builder.CreateMul(LHS, RHS);
2026 case Instruction::Sub:
2027 return Builder.CreateSub(LHS, RHS);
2028 case Instruction::FAdd:
2029 return Builder.CreateFAdd(LHS, RHS);
2030 case Instruction::FMul:
2031 return Builder.CreateFMul(LHS, RHS);
2032 case Instruction::FSub:
2033 return Builder.CreateFSub(LHS, RHS);
2039 for (
unsigned I = 0;
I < Shape.getNumVectors(); ++
I)
2040 Result.addVector(BuildVectorOp(
A.getVector(
I),
B.getVector(
I)));
2042 finalizeLowering(Inst,
2043 Result.addNumComputeOps(getNumOps(
Result.getVectorTy()) *
2051 auto I = ShapeMap.
find(Inst);
2052 if (
I == ShapeMap.
end())
2058 ShapeInfo &Shape =
I->second;
2061 MatrixTy
M = getMatrix(
Op, Shape, Builder);
2063 Builder.setFastMathFlags(getFastMathFlags(Inst));
2068 case Instruction::FNeg:
2075 for (
unsigned I = 0;
I < Shape.getNumVectors(); ++
I)
2076 Result.addVector(BuildVectorOp(
M.getVector(
I)));
2078 finalizeLowering(Inst,
2079 Result.addNumComputeOps(getNumOps(
Result.getVectorTy()) *
2088 struct ExprLinearizer {
2089 unsigned LengthToBreak = 100;
2092 unsigned LineLength = 0;
2118 : Stream(Str),
DL(
DL), Inst2Matrix(Inst2Matrix), Shared(Shared),
2119 ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {}
2121 void indent(
unsigned N) {
2123 for (
unsigned i = 0; i <
N; i++)
2132 void maybeIndent(
unsigned Indent) {
2133 if (LineLength >= LengthToBreak)
2136 if (LineLength == 0)
2141 LineLength += S.
size();
2145 Value *getUnderlyingObjectThroughLoads(
Value *V) {
2147 return getUnderlyingObjectThroughLoads(
Ptr);
2148 else if (
V->getType()->isPointerTy())
2154 bool isMatrix(
Value *V)
const {
return ExprsInSubprogram.
count(V); }
2159 auto M = Inst2Matrix.
find(V);
2160 if (M == Inst2Matrix.
end())
2163 SS <<
M->second.getNumRows();
2165 SS <<
M->second.getNumColumns();
2174 write(
"<no called fn>");
2177 if (!
Name.startswith(
"llvm.matrix")) {
2181 auto *II = cast<IntrinsicInst>(CI);
2189 case Intrinsic::matrix_multiply:
2190 prettyPrintMatrixType(II->
getOperand(0), SS);
2192 prettyPrintMatrixType(II->
getOperand(1), SS);
2195 case Intrinsic::matrix_transpose:
2196 prettyPrintMatrixType(II->
getOperand(0), SS);
2199 case Intrinsic::matrix_column_major_load:
2200 prettyPrintMatrixType(II, SS);
2203 case Intrinsic::matrix_column_major_store:
2204 prettyPrintMatrixType(II->
getOperand(0), SS);
2215 unsigned getNumShapeArgs(
CallInst *CI)
const {
2218 case Intrinsic::matrix_multiply:
2220 case Intrinsic::matrix_transpose:
2222 case Intrinsic::matrix_column_major_load:
2223 case Intrinsic::matrix_column_major_store:
2236 V = getUnderlyingObjectThroughLoads(V);
2237 if (
V->getType()->isPointerTy()) {
2238 if (isa<AllocaInst>(V)) {
2239 Stream <<
"stack addr";
2245 if (!
V->getName().empty()) {
2246 Stream <<
" %" <<
V->getName() <<
"";
2247 LineLength +=
V->getName().size() + 2;
2255 if (
auto *CI = dyn_cast<ConstantInt>(V))
2256 TmpStream << CI->getValue();
2257 else if (isa<Constant>(V))
2258 TmpStream <<
"constant";
2261 TmpStream <<
"matrix";
2263 TmpStream <<
"scalar";
2266 Tmp = std::string(
StringRef(Tmp).trim());
2267 LineLength += Tmp.size();
2274 void linearizeExpr(
Value *Expr,
unsigned Indent,
bool ParentReused,
2275 bool ParentShared) {
2276 auto *
I = cast<Instruction>(Expr);
2277 maybeIndent(Indent);
2281 bool ExprShared =
false;
2284 if (!ParentShared) {
2285 auto SI = Shared.find(Expr);
2286 assert(SI != Shared.end() &&
SI->second.count(Leaf));
2291 DebugLoc DL = cast<Instruction>(S)->getDebugLoc();
2292 write(
"shared with remark at line " + std::to_string(
DL.getLine()) +
2293 " column " + std::to_string(
DL.getCol()) +
" (");
2295 ExprShared =
SI->second.size() > 1;
2298 bool Reused = !ReusedExprs.
insert(Expr).second;
2299 if (Reused && !ParentReused)
2302 if (
auto *CI = dyn_cast<CallInst>(
I)) {
2306 }
else if (isa<BitCastInst>(Expr)) {
2312 Ops.
append(
I->value_op_begin(),
I->value_op_end());
2313 write(std::string(
I->getOpcodeName()));
2316 write(std::string(
"("));
2318 unsigned NumOpsToBreak = 1;
2319 if (
match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>()))
2323 if (Ops.size() > NumOpsToBreak)
2326 maybeIndent(Indent + 1);
2328 linearizeExpr(
Op, Indent + 1, Reused, ExprShared);
2331 if (
Op != Ops.back())
2338 const std::string &getResult() {
2357 struct RemarkGenerator {
2365 : Inst2Matrix(Inst2Matrix), ORE(ORE),
Func(
Func),
2374 for (
auto *Expr : ExprsInSubprogram)
2377 return ExprsInSubprogram.count(U);
2386 void collectSharedInfo(
Value *Leaf,
Value *V,
2390 if (!ExprsInSubprogram.
count(V))
2393 auto I = Shared.insert({
V, {}});
2394 I.first->second.insert(Leaf);
2396 for (
Value *
Op : cast<Instruction>(V)->operand_values())
2397 collectSharedInfo(Leaf,
Op, ExprsInSubprogram, Shared);
2403 std::pair<OpInfoTy, OpInfoTy>
2407 if (!ExprsInSubprogram.
count(Root))
2411 if (!ReusedExprs.
insert(Root).second)
2414 OpInfoTy SharedCount;
2417 auto I = Shared.find(Root);
2418 auto CM = Inst2Matrix.
find(Root);
2419 if (
I->second.size() == 1)
2420 Count = CM->second.getOpInfo();
2422 SharedCount = CM->second.getOpInfo();
2424 for (
Value *
Op : cast<Instruction>(Root)->operand_values()) {
2425 auto C = sumOpInfos(
Op, ReusedExprs, ExprsInSubprogram, Shared);
2427 SharedCount +=
C.second;
2429 return {Count, SharedCount};
2432 void emitRemarks() {
2440 for (
const auto &KV : Inst2Matrix) {
2441 if (
Func.getSubprogram()) {
2442 auto *
I = cast<Instruction>(KV.first);
2447 I.first->second.push_back(KV.first);
2451 auto I = Subprog2Exprs.
insert({
nullptr, {}});
2452 I.first->second.push_back(KV.first);
2455 for (
auto &KV : Subprog2Exprs) {
2458 auto Leaves = getExpressionLeaves(ExprsInSubprogram);
2461 for (
Value *Leaf : Leaves)
2462 collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared);
2465 for (
auto *L : Leaves) {
2467 DebugLoc Loc = cast<Instruction>(L)->getDebugLoc();
2478 OpInfoTy Counts, SharedCounts;
2479 std::tie(Counts, SharedCounts) =
2480 sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared);
2485 Rem <<
"Lowered with ";
2486 Rem <<
ore::NV(
"NumStores", Counts.NumStores) <<
" stores, "
2487 <<
ore::NV(
"NumLoads", Counts.NumLoads) <<
" loads, "
2488 <<
ore::NV(
"NumComputeOps", Counts.NumComputeOps)
2490 <<
ore::NV(
"NumExposedTransposes", Counts.NumExposedTransposes)
2491 <<
" exposed transposes";
2493 if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 ||
2494 SharedCounts.NumComputeOps > 0) {
2495 Rem <<
",\nadditionally "
2496 <<
ore::NV(
"NumStores", SharedCounts.NumStores) <<
" stores, "
2497 <<
ore::NV(
"NumLoads", SharedCounts.NumLoads) <<
" loads, "
2498 <<
ore::NV(
"NumFPOps", SharedCounts.NumComputeOps)
2500 <<
" are shared with other expressions";
2503 Rem << (
"\n" + linearize(L, Shared, ExprsInSubprogram, DL));
2514 ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L);
2515 Lin.linearizeExpr(L, 0,
false,
false);
2516 return Lin.getResult();
2537 LowerMatrixIntrinsics LMT(
F,
TTI, AA, DT, LI, ORE);
2552 OS, MapClassName2PassName);
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
ReachingDefAnalysis InstSet & ToRemove
static const Function * getParent(const Value *V)
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
#define clEnumValN(ENUMVAL, FLAGNAME, DESC)
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
hexagon Hexagon specific predictive commoning for HVX vectors
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
static DISubprogram * getSubprogram(DIScope *Scope)
Helper function to either return Scope, if it is a subprogram or the attached subprogram for a local ...
static cl::opt< bool > ForceFusion("force-fuse-matrix", cl::init(false), cl::Hidden, cl::desc("Force matrix instruction fusion even if not profitable."))
static void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II, BasicBlock &BB)
Erase V from BB and move \II forward to avoid invalidating iterators.
static cl::opt< bool > VerifyShapeInfo("verify-matrix-shapes", cl::Hidden, cl::desc("Enable/disable matrix shape verification."), cl::init(false))
static bool isSplat(Value *V)
Return true if V is a splat of a value (which is used when multiplying a matrix with a scalar).
static cl::opt< bool > TileUseLoops("fuse-matrix-use-loops", cl::init(false), cl::Hidden, cl::desc("Generate loop nest for tiling."))
static cl::opt< bool > FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden, cl::desc("Enable/disable fusing matrix instructions."))
auto m_AnyAdd(const LTy &L, const RTy &R)
Match any add operation (fp or integer).
static cl::opt< bool > AllowContractEnabled("matrix-allow-contract", cl::init(false), cl::Hidden, cl::desc("Allow the use of FMAs if available and profitable. This may " "result in different results, due to less rounding error."))
auto m_AnyMul(const LTy &L, const RTy &R)
Match any mul operation (fp or integer).
static cl::opt< bool > PrintAfterTransposeOpt("matrix-print-after-transpose-opt", cl::init(false))
static cl::opt< unsigned > TileSize("fuse-matrix-tile-size", cl::init(4), cl::Hidden, cl::desc("Tile size for matrix instruction fusion using square-shaped tiles."))
static cl::opt< MatrixLayoutTy > MatrixLayout("matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor), cl::desc("Sets the default matrix layout"), cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major", "Use column-major layout"), clEnumValN(MatrixLayoutTy::RowMajor, "row-major", "Use row-major layout")))
PowerPC Reduce CR logical Operation
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
static Value * extractVector(IRBuilderTy &IRB, Value *V, unsigned BeginIndex, unsigned EndIndex, const Twine &Name)
static Value * insertVector(IRBuilderTy &IRB, Value *Old, Value *V, unsigned BeginIndex, const Twine &Name)
This file defines the SmallVector class.
static SymbolRef::Type getType(const Symbol *Sym)
static const int BlockSize
static std::optional< unsigned > getOpcode(ArrayRef< VPValue * > Values)
Returns the opcode of Values or ~0 if they do not all agree.
static SDValue LowerStore(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)
static SDValue LowerLoad(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)
A manager for alias analyses.
bool isNoAlias(const MemoryLocation &LocA, const MemoryLocation &LocB)
A trivial helper function to check to see if the specified pointers are no-alias.
an instruction to allocate memory on the stack
Align getAlign() const
Return the alignment of the memory that is being allocated by the instruction.
A container for analyses that lazily runs them and caches their results.
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
LLVM Basic Block Representation.
iterator begin()
Instruction iterator methods.
reverse_iterator rbegin()
InstListType::reverse_iterator reverse_iterator
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
BinaryOps getOpcode() const
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
User::op_iterator arg_begin()
Return the iterator pointing to the beginning of the argument list.
MaybeAlign getParamAlign(unsigned ArgNo) const
Extract the alignment for a call or parameter (0=unknown).
Value * getArgOperand(unsigned i) const
User::op_iterator arg_end()
Return the iterator pointing to the end of the argument list.
This class represents a function call, abstracting a target machine's calling convention.
static ConstantAggregateZero * get(Type *Ty)
static Constant * get(Type *Ty, double V)
This returns a ConstantFP, or a vector containing a splat of a ConstantFP, for the specified value in...
This is the shared class of boolean and integer constants.
static Constant * get(Type *Ty, uint64_t V, bool IsSigned=false)
If Ty is a vector type, return a Constant with a splat of the given value.
DISubprogram * getSubprogram() const
Get the subprogram for this scope.
Base class for scope-like contexts.
This class represents an Operation in the Expression.
A parsed version of the target data layout string in and methods for querying it.
DILocation * getInlinedAt() const
Analysis pass which computes a DominatorTree.
void applyUpdates(ArrayRef< UpdateType > Updates)
Inform the dominator tree about a sequence of CFG edge insertions and deletions and perform a batch u...
static constexpr UpdateKind Delete
static constexpr UpdateKind Insert
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
Convenience struct for specifying and reasoning about fast-math flags.
void setAllowContract(bool B=true)
bool allowReassoc() const
Flag queries.
bool allowContract() const
Class to represent fixed width SIMD vectors.
static FixedVectorType * get(Type *ElementType, unsigned NumElts)
Intrinsic::ID getIntrinsicID() const LLVM_READONLY
getIntrinsicID - This method returns the ID number of the specified function, or Intrinsic::not_intri...
bool isIntrinsic() const
isIntrinsic - Returns true if the function's name starts with "llvm.".
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
const BasicBlock * getParent() const
FastMathFlags getFastMathFlags() const LLVM_READONLY
Convenience function for getting all the fast-math flags, which must be an operator which supports th...
SymbolTableList< Instruction >::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
A wrapper class for inspecting calls to intrinsic functions.
Intrinsic::ID getIntrinsicID() const
Return the intrinsic ID of this intrinsic.
An instruction for reading from memory.
bool isVolatile() const
Return true if this is a load from a volatile memory location.
Align getAlign() const
Return the alignment of the access that is being performed.
uint64_t getValue() const
Analysis pass that exposes the LoopInfo for a function.
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void printPipeline(raw_ostream &OS, function_ref< StringRef(StringRef)> MapClassName2PassName)
This class implements a map that also provides access to all stored values in a deterministic order.
iterator find(const KeyT &Key)
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Representation for a specific memory location.
static MemoryLocation get(const LoadInst *LI)
Return a location with information about the memory reference by the given instruction.
LocationSize Size
The maximum size of the location, in address-units, or UnknownSize if the size is not known.
const Value * Ptr
The address of the start of the location.
static PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
void preserve()
Mark an analysis as preserved.
A vector that has set insertion semantics.
size_type size() const
Determine the number of elements in the SetVector.
size_type count(const key_type &key) const
Count the number of elements of a given key in the SetVector.
bool insert(const value_type &X)
Insert a new element into the SetVector.
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
bool contains(ConstPtrType Ptr) const
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
A SetVector that performs no allocations if smaller than a certain size.
SmallSet - This maintains a set of unique values, optimizing for the case when the set is small (less...
std::pair< const_iterator, bool > insert(const T &V)
insert - Insert an element into the set if it isn't already there.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
bool isVolatile() const
Return true if this is a store to a volatile memory location.
StringRef - Represent a constant reference to a string, i.e.
StringRef drop_front(size_t N=1) const
Return a StringRef equal to 'this' but with the first N elements dropped.
constexpr size_t size() const
size - Get the string size.
Analysis pass providing the TargetTransformInfo.
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
The instances of the Type class are immutable: once they are created, they are never changed.
unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
bool isIntegerTy() const
True if this is an instance of IntegerType.
TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
bool isVoidTy() const
Return true if this is 'void'.
Type * getScalarType() const
If this is a vector type, return the element type, otherwise return 'this'.
UnaryOps getOpcode() const
A Use represents the edge between a Value definition and its users.
Value * getOperand(unsigned i) const
size_type count(const KeyT &Val) const
Return 1 if the specified key is in the map, 0 otherwise.
iterator find(const KeyT &Val)
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
bool erase(const KeyT &Val)
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
user_iterator user_begin()
bool hasOneUse() const
Return true if there is exactly one use of this value.
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
iterator_range< user_iterator > users()
bool hasNUses(unsigned N) const
Return true if this Value has exactly N uses.
iterator_range< use_iterator > uses()
StringRef getName() const
Return a constant reference to the value's name.
Type * getElementType() const
constexpr ScalarTy getFixedValue() const
An efficient, type-erasing, non-owning reference to a callable.
A range adaptor for a pair of iterators.
This class implements an extremely fast bulk output stream that can only output to a stream.
A raw_ostream that writes to an std::string.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
@ C
The default llvm calling convention, compatible with C.
StringRef getBaseName(ID id)
Return the LLVM name for an intrinsic, without encoded types for overloading, such as "llvm....
Function * getDeclaration(Module *M, ID id, ArrayRef< Type * > Tys=std::nullopt)
Create or insert an LLVM Function declaration for an intrinsic, and return it.
TwoOps_match< ValueOpTy, PointerOpTy, Instruction::Store > m_Store(const ValueOpTy &ValueOp, const PointerOpTy &PointerOp)
Matches StoreInst.
specific_intval< false > m_SpecificInt(APInt V)
Match a specific integer value or vector with all elements equal to the value.
BinaryOp_match< LHS, RHS, Instruction::Add > m_Add(const LHS &L, const RHS &R)
class_match< BinaryOperator > m_BinOp()
Match an arbitrary binary operation and ignore it.
BinaryOp_match< LHS, RHS, Instruction::FMul > m_FMul(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
BinaryOp_match< LHS, RHS, Instruction::FAdd > m_FAdd(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
OneUse_match< T > m_OneUse(const T &SubPattern)
OneOps_match< OpTy, Instruction::Load > m_Load(const OpTy &Op)
Matches LoadInst.
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
match_combine_or< LTy, RTy > m_CombineOr(const LTy &L, const RTy &R)
Combine two pattern matchers matching L || R.
ValuesClass values(OptsTy... Options)
Helper to build a ValuesClass by forwarding a variable number of arguments as an initializer list to ...
initializer< Ty > init(const Ty &Val)
DiagnosticInfoOptimizationBase::Argument NV
NodeAddr< PhiNode * > Phi
NodeAddr< FuncNode * > Func
const_iterator begin(StringRef path, Style style=Style::native)
Get begin iterator over path.
const_iterator end(StringRef path)
Get end iterator over path.
This is an optimization pass for GlobalISel generic memory operations.
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
auto enumerate(FirstRange &&First, RestRanges &&...Rest)
Given two or more input ranges, returns a new range whose values are are tuples (A,...
auto successors(const MachineBasicBlock *BB)
bool operator!=(uint64_t V1, const APInt &V2)
iterator_range< T > make_range(T x, T y)
Convenience function for iterating over sub-ranges.
std::string & operator+=(std::string &buffer, StringRef string)
Error write(MCStreamer &Out, ArrayRef< std::string > Inputs, bool ContinueOnCuIndexOverflow)
const Value * getUnderlyingObject(const Value *V, unsigned MaxLookup=6)
This method strips off any GEP address adjustments and pointer casts from the specified value,...
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Value * concatenateVectors(IRBuilderBase &Builder, ArrayRef< Value * > Vecs)
Concatenate a list of vectors.
bool operator==(const AddressRangeValuePair &LHS, const AddressRangeValuePair &RHS)
const Value * getPointerOperand(const Value *V)
A helper function that returns the pointer operand of a load, store or GEP instruction.
void addStringMetadataToLoop(Loop *TheLoop, const char *MDString, unsigned V=0)
Set input string into loop metadata by keeping other values intact.
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
auto reverse(ContainerTy &&C)
void sort(IteratorTy Start, IteratorTy End)
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
void report_fatal_error(Error Err, bool gen_crash_diag=true)
Report a serious error, calling any installed error handler.
raw_fd_ostream & errs()
This returns a reference to a raw_ostream for standard error.
@ FMulAdd
Sum of float products with llvm.fmuladd(a * b + sum).
DWARFExpression::Operation Op
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
BasicBlock * SplitBlock(BasicBlock *Old, BasicBlock::iterator SplitPt, DominatorTree *DT, LoopInfo *LI=nullptr, MemorySSAUpdater *MSSAU=nullptr, const Twine &BBName="", bool Before=false)
Split the specified block at the specified instruction.
Align commonAlignment(Align A, uint64_t Offset)
Returns the alignment that satisfies both alignments.
llvm::SmallVector< int, 16 > createSequentialMask(unsigned Start, unsigned NumInts, unsigned NumUndefs)
Create a sequential shuffle mask.
This struct is a compact representation of a valid (non-zero power of two) alignment.
This struct is a compact representation of a valid (power of two) or undefined (0) alignment.
A CRTP mix-in to automatically provide informational APIs needed for passes.
A helper struct to create IR loop nests for tiling in IR of the following form: for ColumnLoop....