56 #include "llvm/IR/IntrinsicsX86.h"
67 using namespace PatternMatch;
69 #define DEBUG_TYPE "lower-amx-type"
73 m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(
m_Value())) ||
74 match(II, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(
m_Value()));
78 auto *II = dyn_cast<IntrinsicInst>(
I);
85 if (II->getType()->isX86_AMXTy())
87 for (
Value *V : II->args()) {
88 if (V->getType()->isX86_AMXTy())
103 unsigned AllocaAS =
DL.getAllocaAddrSpace();
105 new AllocaInst(Ty, AllocaAS,
"", &
F.getEntryBlock().front());
112 if (!isa<AllocaInst>(&
I))
119 Value *Row =
nullptr, *Col =
nullptr;
123 case Intrinsic::x86_tileloadd64_internal:
124 case Intrinsic::x86_tileloaddt164_internal:
125 case Intrinsic::x86_tilestored64_internal: {
132 case Intrinsic::x86_tdpbssd_internal:
133 case Intrinsic::x86_tdpbsud_internal:
134 case Intrinsic::x86_tdpbusd_internal:
135 case Intrinsic::x86_tdpbuud_internal:
136 case Intrinsic::x86_tdpbf16ps_internal: {
149 (cast<ConstantInt>(II->
getOperand(2))->getSExtValue()) / 4);
165 cast<Instruction>(Row)->moveAfter(cast<Instruction>(II->
getOperand(2)));
180 return std::make_pair(Row, Col);
192 if (
isAMXCast(dyn_cast<Instruction>(V))) {
199 return getShape(cast<IntrinsicInst>(V), OpNo);
200 }
else if (isa<PHINode>(V)) {
210 return std::make_pair(
nullptr,
nullptr);
214 class X86LowerAMXType {
220 std::map<Value *, Value *> Col2Row;
236 Value *Row =
nullptr, *Col =
nullptr;
239 auto *II = cast<IntrinsicInst>(U.
getUser());
240 std::tie(Row, Col) =
getShape(II, OpNo);
246 std::array<Value *, 4>
Args = {Row, Col, I8Ptr, Stride};
249 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None,
Args);
250 Bitcast->replaceAllUsesWith(NewInst);
263 auto *II = cast<IntrinsicInst>(Tile);
266 Value *Row = II->getOperand(0);
267 Value *Col = II->getOperand(1);
274 std::array<Value *, 5>
Args = {Row, Col, I8Ptr, Stride, Tile};
275 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None,
Args);
288 Bitcast->replaceAllUsesWith(Vec);
295 Value *I8Ptr, *Stride;
296 auto *Src =
Bitcast->getOperand(0);
298 auto Prepare = [&](
Type *MemTy) {
300 I8Ptr =
Builder.CreateBitCast(AllocaAddr,
Builder.getInt8PtrTy());
304 if (
Bitcast->getType()->isX86_AMXTy()) {
315 auto *II = dyn_cast<IntrinsicInst>(U.
getUser());
318 Prepare(
Bitcast->getOperand(0)->getType());
319 Builder.CreateStore(Src, AllocaAddr);
321 Value *Row =
nullptr, *Col =
nullptr;
322 std::tie(Row, Col) =
getShape(II, OpNo);
323 std::array<Value *, 4>
Args = {Row, Col, I8Ptr, Stride};
325 Intrinsic::x86_tileloadd64_internal, None,
Args);
326 Bitcast->replaceAllUsesWith(NewInst);
335 auto *II = dyn_cast<IntrinsicInst>(Src);
339 Value *Row = II->getOperand(0);
340 Value *Col = II->getOperand(1);
341 std::array<Value *, 5>
Args = {Row, Col, I8Ptr, Stride, Src};
342 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None,
Args);
344 Bitcast->replaceAllUsesWith(NewInst);
350 bool X86LowerAMXType::visit() {
356 auto *
Bitcast = dyn_cast<BitCastInst>(&Inst);
361 if (
Bitcast->getType()->isX86_AMXTy()) {
391 DeadInsts.push_back(
LD);
392 }
else if (Src->getType()->isX86_AMXTy()) {
430 DeadInsts.push_back(
ST);
436 bool C = !DeadInsts.empty();
438 for (
auto *Inst : DeadInsts)
439 Inst->eraseFromParent();
450 unsigned AllocaAS =
DL.getAllocaAddrSpace();
453 new AllocaInst(V256I32Ty, AllocaAS,
"", &
F->getEntryBlock().front());
456 Builder.SetInsertPoint(&*Iter);
463 auto *II = cast<IntrinsicInst>(TileDef);
464 assert(II &&
"Not tile intrinsic!");
465 Value *Row = II->getOperand(0);
466 Value *Col = II->getOperand(1);
472 std::array<Value *, 5>
Args = {Row, Col, Ptr, Stride, TileDef};
475 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal,
None,
Args);
486 Value *PhiOp = dyn_cast<PHINode>(V)->getIncomingValue(0);
487 II = cast<IntrinsicInst>(PhiOp);
489 II = cast<IntrinsicInst>(V);
497 std::array<Value *, 4>
Args = {Row, Col, Ptr, Stride};
500 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal,
None,
Args);
505 for (
Use &U :
I->uses()) {
516 class X86VolatileTileData {
524 bool volatileTileData();
525 void volatileTilePHI(
PHINode *Inst);
529 Value *X86VolatileTileData::updatePhiIncomings(
533 for (
auto *
I : Incomings) {
537 for (
Use &U :
I->uses()) {
539 if (isa<PHINode>(V) || V ==
Store)
547 void X86VolatileTileData::replacePhiDefWithLoad(
Instruction *PHI,
609 void X86VolatileTileData::volatileTilePHI(
PHINode *PHI) {
616 assert(Inst &&
"We shouldn't fold AMX instrution!");
617 Incomings.push_back(Inst);
620 Value *StorePtr = updatePhiIncomings(
BB, Incomings);
621 replacePhiDefWithLoad(PHI, StorePtr);
640 void X86VolatileTileData::volatileTileNonPHI(
Instruction *
I) {
646 for (
Use &U :
I->uses()) {
648 assert(!isa<PHINode>(V) &&
"PHI Nodes should be excluded!");
666 bool X86VolatileTileData::volatileTileData() {
667 bool Changed =
false;
673 if (!
I.getType()->isX86_AMXTy())
675 if (isa<PHINode>(&
I))
676 PHIInsts.push_back(&
I);
678 AMXDefInsts.push_back(&
I);
685 volatileTileNonPHI(
I);
690 volatileTilePHI(dyn_cast<PHINode>(
I));
701 class X86LowerAMXCast {
708 bool transformAllAMXCast();
722 for (
unsigned i = 0,
e =
I->getNumOperands();
i !=
e; ++
i) {
724 I->setOperand(
i,
nullptr);
732 if (
Instruction *OpI = dyn_cast<Instruction>(OpV)) {
738 I->eraseFromParent();
752 bool X86LowerAMXCast::optimizeAMXCastFromPhi(
757 Type *SrcTy = Src->getType();
767 PhiWorklist.push_back(PN);
769 while (!PhiWorklist.empty()) {
771 for (
unsigned I = 0;
I < OldPN->getNumOperands(); ++
I) {
772 Value *IncValue = OldPN->getIncomingValue(
I);
775 if (isa<Constant>(IncValue)) {
776 auto *IncConst = dyn_cast<Constant>(IncValue);
777 if (!isa<UndefValue>(IncValue) && !IncConst->isZeroValue())
779 Value *Row =
nullptr, *Col =
nullptr;
780 std::tie(Row, Col) =
getShape(OldPN);
783 if (!Row || !Col || !isa<Constant>(Row) || !isa<Constant>(Col))
786 auto *
Block = OldPN->getIncomingBlock(
I);
789 Intrinsic::x86_tilezero_internal, None, {Row, Col});
791 NewInst =
Builder.CreateIntrinsic(Intrinsic::x86_cast_tile_to_vector,
792 {IncValue->
getType()}, {NewInst});
795 OldPN->setIncomingValue(
I, NewInst);
799 if (
auto *PNode = dyn_cast<PHINode>(IncValue)) {
800 if (OldPhiNodes.
insert(PNode))
801 PhiWorklist.push_back(PNode);
804 Instruction *ACI = dyn_cast<Instruction>(IncValue);
809 if (TyA != DestTy || TyB != SrcTy)
819 for (
auto *OldPN : OldPhiNodes) {
826 if (TyA != DestTy || TyB != SrcTy)
828 }
else if (
auto *PHI = dyn_cast<PHINode>(V)) {
847 if (OldPhiNodes.count(PHI) == 0)
856 for (
auto *OldPN : OldPhiNodes) {
858 PHINode *NewPN =
Builder.CreatePHI(DestTy, OldPN->getNumOperands());
859 NewPNodes[OldPN] = NewPN;
863 for (
auto *OldPN : OldPhiNodes) {
864 PHINode *NewPN = NewPNodes[OldPN];
865 for (
unsigned j = 0,
e = OldPN->getNumOperands();
j !=
e; ++
j) {
867 Value *NewV =
nullptr;
872 else if (
auto *PrevPN = dyn_cast<PHINode>(V))
873 NewV = NewPNodes[PrevPN];
887 for (
auto *OldPN : OldPhiNodes) {
888 PHINode *NewPN = NewPNodes[OldPN];
894 assert(TyA == DestTy && TyB == SrcTy);
899 }
else if (
auto *PHI = dyn_cast<PHINode>(V)) {
902 assert(OldPhiNodes.contains(PHI));
922 auto *II = cast<IntrinsicInst>(Tile);
925 Value *Row = II->getOperand(0);
926 Value *Col = II->getOperand(1);
933 std::array<Value *, 5>
Args = {Row, Col, I8Ptr, Stride, Tile};
934 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None,
Args);
943 Value *Row =
nullptr, *Col =
nullptr;
946 auto *II = cast<IntrinsicInst>(U.
getUser());
951 std::tie(Row, Col) =
getShape(II, OpNo);
957 std::array<Value *, 4>
Args = {Row, Col, I8Ptr, Stride};
960 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None,
Args);
966 for (
auto *Cast : Casts) {
967 auto *II = cast<IntrinsicInst>(Cast);
973 if (II->getIntrinsicID() == Intrinsic::x86_cast_tile_to_vector) {
979 combineCastStore(cast<IntrinsicInst>(Cast),
Store);
980 DeadStores.push_back(
Store);
983 for (
auto *
Store : DeadStores)
984 Store->eraseFromParent();
995 combineLoadCast(cast<IntrinsicInst>(Cast),
Load);
998 Load->eraseFromParent();
1005 bool Change =
false;
1015 m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(
m_Value(Vec))))
1016 Vec2TileInsts.push_back(&
I);
1017 else if (
match(&
I, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(
1019 Tile2VecInsts.push_back(&
I);
1024 for (
auto *Inst : Insts) {
1042 Convert(Vec2TileInsts, Intrinsic::x86_cast_tile_to_vector);
1043 Convert(Tile2VecInsts, Intrinsic::x86_cast_vector_to_tile);
1047 for (
auto *Inst : Insts) {
1048 if (Inst->use_empty()) {
1049 Inst->eraseFromParent();
1052 LiveCasts.push_back(Inst);
1057 EraseInst(Vec2TileInsts);
1058 EraseInst(Tile2VecInsts);
1059 Change |= combineLdSt(LiveCasts);
1060 EraseInst(LiveCasts);
1066 if (isa<PHINode>(
I.getOperand(0)))
1067 PhiCastWorkList.push_back(&
I);
1071 for (
auto *
I : PhiCastWorkList) {
1075 PHINode *PN = cast<PHINode>(
I->getOperand(0));
1076 if (optimizeAMXCastFromPhi(cast<IntrinsicInst>(
I), PN, DeadInst)) {
1084 while (!DeadInst.
empty()) {
1093 bool X86LowerAMXCast::transformAMXCast(
IntrinsicInst *AMXCast) {
1096 Value *I8Ptr, *Stride;
1099 auto Prepare = [&](
Type *MemTy) {
1101 I8Ptr =
Builder.CreateBitCast(AllocaAddr,
Builder.getInt8PtrTy());
1102 Stride =
Builder.getInt64(64);
1124 auto *II = dyn_cast<IntrinsicInst>(U.
getUser());
1128 Builder.CreateStore(Src, AllocaAddr);
1130 Value *Row =
nullptr, *Col =
nullptr;
1131 std::tie(Row, Col) =
getShape(II, OpNo);
1132 std::array<Value *, 4>
Args = {
1135 Intrinsic::x86_tileloadd64_internal, None,
Args);
1146 auto *II = dyn_cast<IntrinsicInst>(Src);
1150 Value *Row = II->getOperand(0);
1151 Value *Col = II->getOperand(1);
1152 std::array<Value *, 5>
Args = {
1153 Row, Col, I8Ptr,
Builder.CreateSExt(Col,
Builder.getInt64Ty()), Src};
1154 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None,
Args);
1163 bool X86LowerAMXCast::transformAllAMXCast() {
1164 bool Change =
false;
1170 WorkLists.push_back(&
I);
1174 for (
auto *Inst : WorkLists) {
1175 Change |= transformAMXCast(cast<IntrinsicInst>(Inst));
1185 class X86LowerAMXTypeLegacyPass :
public FunctionPass {
1197 &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(
F);
1198 X86LowerAMXCast LAC(
F);
1199 C |= LAC.combineAMXcast(TLI);
1202 C |= LAC.transformAllAMXCast();
1204 X86LowerAMXType LAT(
F);
1215 if (!
F.hasFnAttribute(Attribute::OptimizeNone)) {
1216 X86VolatileTileData VTD(
F);
1217 C = VTD.volatileTileData() ||
C;
1233 static const char PassName[] =
"Lower AMX type for load/store";
1243 return new X86LowerAMXTypeLegacyPass();