33#include "llvm/IR/IntrinsicsX86.h"
43using namespace PatternMatch;
45#define DEBUG_TYPE "lower-amx-intrinsics"
49 if (
auto *FVT = dyn_cast<FixedVectorType>(Ty))
50 return FVT->getNumElements() == 256 &&
51 FVT->getElementType()->isIntegerTy(32);
58 cl::desc(
"X86: enable AMX scalarizition."));
61class X86LowerAMXIntrinsics {
66 : Func(
F), DTU(DomTU), LI(LoopI) {}
75 template <
bool IsTileLoad>
79 template <Intrinsic::ID IntrID>
80 std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal ||
81 IntrID == Intrinsic::x86_tdpbsud_internal ||
82 IntrID == Intrinsic::x86_tdpbusd_internal ||
83 IntrID == Intrinsic::x86_tdpbuud_internal ||
84 IntrID == Intrinsic::x86_tdpbf16ps_internal,
89 template <
bool IsTileLoad>
90 bool lowerTileLoadStore(
Instruction *TileLoadStore);
91 template <Intrinsic::ID IntrID>
92 std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal ||
93 IntrID == Intrinsic::x86_tdpbsud_internal ||
94 IntrID == Intrinsic::x86_tdpbusd_internal ||
95 IntrID == Intrinsic::x86_tdpbuud_internal ||
96 IntrID == Intrinsic::x86_tdpbf16ps_internal,
120 IV->addIncoming(ConstantInt::get(I16Ty, 0), Preheader);
122 B.SetInsertPoint(Latch);
126 IV->addIncoming(Inc, Latch);
131 DTU.applyUpdatesPermissive({
132 {DominatorTree::Delete, Preheader, Tmp},
133 {DominatorTree::Insert, Header, Body},
134 {DominatorTree::Insert, Body, Latch},
135 {DominatorTree::Insert, Latch, Header},
136 {DominatorTree::Insert, Latch,
Exit},
137 {DominatorTree::Insert, Preheader, Header},
140 L->addBasicBlockToLoop(Header, *LI);
141 L->addBasicBlockToLoop(Body, *LI);
142 L->addBasicBlockToLoop(Latch, *LI);
147template <
bool IsTileLoad>
148Value *X86LowerAMXIntrinsics::createTileLoadStoreLoops(
151 std::string IntrinName = IsTileLoad ?
"tileload" :
"tilestore";
152 Loop *RowLoop =
nullptr;
153 Loop *ColLoop =
nullptr;
155 RowLoop = LI->AllocateLoop();
156 ColLoop = LI->AllocateLoop();
158 if (
Loop *ParentL = LI->getLoopFor(Start))
159 ParentL->addChildLoop(RowLoop);
161 LI->addTopLevelLoop(RowLoop);
165 IntrinName +
".scalarize.rows",
B, RowLoop);
168 BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col,
B.getInt16(1),
169 IntrinName +
".scalarize.cols",
B, ColLoop);
176 Type *EltTy =
B.getInt32Ty();
183 Value *CurrentRowZExt =
B.CreateZExt(CurrentRow, Stride->
getType());
184 Value *CurrentColZExt =
B.CreateZExt(CurrentCol, Stride->
getType());
186 B.CreateAdd(
B.CreateMul(CurrentRowZExt, Stride), CurrentColZExt);
188 Value *
Idx =
B.CreateAdd(
B.CreateMul(CurrentRow,
B.getInt16(16)), CurrentCol);
195 PHINode *VecCPhiRowLoop =
B.CreatePHI(V256I32Ty, 2,
"vec.phi.row");
202 PHINode *VecPhi =
B.CreatePHI(V256I32Ty, 2,
"vec.phi");
211 Value *Elt =
B.CreateLoad(EltTy, EltPtr);
212 Value *ResVec =
B.CreateInsertElement(VecPhi, Elt,
Idx);
218 auto *BitCast = cast<BitCastInst>(Tile);
219 Value *Vec = BitCast->getOperand(0);
227 Value *Elt =
B.CreateExtractElement(Vec,
Idx);
229 B.CreateStore(Elt, EltPtr);
234template <Intrinsic::ID IntrID>
235std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal ||
236 IntrID == Intrinsic::x86_tdpbsud_internal ||
237 IntrID == Intrinsic::x86_tdpbusd_internal ||
238 IntrID == Intrinsic::x86_tdpbuud_internal ||
239 IntrID == Intrinsic::x86_tdpbf16ps_internal,
245 std::string IntrinName;
247 case Intrinsic::x86_tdpbssd_internal:
248 IntrinName =
"tiledpbssd";
250 case Intrinsic::x86_tdpbsud_internal:
251 IntrinName =
"tiledpbsud";
253 case Intrinsic::x86_tdpbusd_internal:
254 IntrinName =
"tiledpbusd";
256 case Intrinsic::x86_tdpbuud_internal:
257 IntrinName =
"tiledpbuud";
259 case Intrinsic::x86_tdpbf16ps_internal:
260 IntrinName =
"tiledpbf16ps";
263 Loop *RowLoop =
nullptr;
264 Loop *ColLoop =
nullptr;
265 Loop *InnerLoop =
nullptr;
267 RowLoop = LI->AllocateLoop();
268 ColLoop = LI->AllocateLoop();
269 InnerLoop = LI->AllocateLoop();
272 if (
Loop *ParentL = LI->getLoopFor(Start))
273 ParentL->addChildLoop(RowLoop);
275 LI->addTopLevelLoop(RowLoop);
279 IntrinName +
".scalarize.rows",
B, RowLoop);
282 BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col,
B.getInt16(1),
283 IntrinName +
".scalarize.cols",
B, ColLoop);
289 createLoop(ColBody, ColLoopLatch, K,
B.getInt16(1),
290 IntrinName +
".scalarize.inner",
B, InnerLoop);
298 Value *CurrentInner = &*InnerLoopHeader->
begin();
301 auto *BitCastAcc = cast<BitCastInst>(Acc);
302 Value *VecC = BitCastAcc->getOperand(0);
307 auto *BitCastLHS = cast<BitCastInst>(LHS);
308 Value *VecA = BitCastLHS->getOperand(0);
310 auto *BitCastRHS = cast<BitCastInst>(RHS);
311 Value *VecB = BitCastRHS->getOperand(0);
321 PHINode *VecCPhiRowLoop =
B.CreatePHI(V256I32Ty, 2,
"vec.c.phi.row");
324 PHINode *VecDPhiRowLoop =
B.CreatePHI(V256I32Ty, 2,
"vec.d.phi.row");
338 PHINode *VecCPhiColLoop =
B.CreatePHI(V256I32Ty, 2,
"vec.c.phi.col");
339 VecCPhiColLoop->
addIncoming(VecCPhiRowLoop, RowBody);
340 PHINode *VecDPhiColLoop =
B.CreatePHI(V256I32Ty, 2,
"vec.d.phi.col");
341 VecDPhiColLoop->
addIncoming(VecDPhiRowLoop, RowBody);
343 B.CreateAdd(
B.CreateMul(CurrentRow,
B.getInt16(16)), CurrentCol);
351 PHINode *VecCPhi =
B.CreatePHI(V256I32Ty, 2,
"vec.c.inner.phi");
356 B.CreateAdd(
B.CreateMul(CurrentRow,
B.getInt16(16)), CurrentInner);
358 B.CreateAdd(
B.CreateMul(CurrentInner,
B.getInt16(16)), CurrentCol);
359 Value *NewVecC =
nullptr;
361 if (IntrID != Intrinsic::x86_tdpbf16ps_internal) {
378 Value *EltC =
B.CreateExtractElement(VecCPhi, IdxC);
379 Value *EltA =
B.CreateExtractElement(VecA, IdxA);
380 Value *SubVecA =
B.CreateBitCast(EltA, V4I8Ty);
381 Value *EltB =
B.CreateExtractElement(VecB, IdxB);
382 Value *SubVecB =
B.CreateBitCast(EltB, V4I8Ty);
383 Value *SEXTSubVecB =
nullptr;
384 Value *SEXTSubVecA =
nullptr;
386 case Intrinsic::x86_tdpbssd_internal:
387 SEXTSubVecB =
B.CreateSExt(SubVecB, V4I32Ty);
388 SEXTSubVecA =
B.CreateSExt(SubVecA, V4I32Ty);
390 case Intrinsic::x86_tdpbsud_internal:
391 SEXTSubVecB =
B.CreateZExt(SubVecB, V4I32Ty);
392 SEXTSubVecA =
B.CreateSExt(SubVecA, V4I32Ty);
394 case Intrinsic::x86_tdpbusd_internal:
395 SEXTSubVecB =
B.CreateSExt(SubVecB, V4I32Ty);
396 SEXTSubVecA =
B.CreateZExt(SubVecA, V4I32Ty);
398 case Intrinsic::x86_tdpbuud_internal:
399 SEXTSubVecB =
B.CreateZExt(SubVecB, V4I32Ty);
400 SEXTSubVecA =
B.CreateZExt(SubVecA, V4I32Ty);
405 Value *SubVecR =
B.CreateAddReduce(
B.CreateMul(SEXTSubVecA, SEXTSubVecB));
406 Value *ResElt =
B.CreateAdd(EltC, SubVecR);
407 NewVecC =
B.CreateInsertElement(VecCPhi, ResElt, IdxC);
433 Value *EltC =
B.CreateExtractElement(VecCPhi, IdxC);
434 Value *EltCF32 =
B.CreateBitCast(EltC,
B.getFloatTy());
435 Value *EltA =
B.CreateExtractElement(VecA, IdxA);
436 Value *SubVecA =
B.CreateBitCast(EltA, V2I16Ty);
437 Value *EltB =
B.CreateExtractElement(VecB, IdxB);
438 Value *SubVecB =
B.CreateBitCast(EltB, V2I16Ty);
440 int ShuffleMask[4] = {2, 0, 3, 1};
441 auto ShuffleArray =
ArrayRef(ShuffleMask);
442 Value *AV2F32 =
B.CreateBitCast(
443 B.CreateShuffleVector(SubVecA, ZeroV2I16, ShuffleArray), V2F32Ty);
444 Value *BV2F32 =
B.CreateBitCast(
445 B.CreateShuffleVector(SubVecB, ZeroV2I16, ShuffleArray), V2F32Ty);
446 Value *SubVecR =
B.CreateFAddReduce(EltCF32,
B.CreateFMul(AV2F32, BV2F32));
447 Value *ResElt =
B.CreateBitCast(SubVecR,
B.getInt32Ty());
448 NewVecC =
B.CreateInsertElement(VecCPhi, ResElt, IdxC);
456 Value *NewEltC =
B.CreateExtractElement(NewVecC, IdxC);
457 Value *NewVecD =
B.CreateInsertElement(VecDPhiColLoop, NewEltC, IdxC);
461 VecCPhiColLoop->
addIncoming(NewVecC, ColLoopLatch);
463 VecDPhiColLoop->
addIncoming(NewVecD, ColLoopLatch);
468template <Intrinsic::ID IntrID>
469std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal ||
470 IntrID == Intrinsic::x86_tdpbsud_internal ||
471 IntrID == Intrinsic::x86_tdpbusd_internal ||
472 IntrID == Intrinsic::x86_tdpbuud_internal ||
473 IntrID == Intrinsic::x86_tdpbf16ps_internal,
475X86LowerAMXIntrinsics::lowerTileDP(
Instruction *TileDP) {
481 PreBuilder.SetInsertPoint(TileDP);
485 Value *NDWord = PreBuilder.CreateLShr(
N, PreBuilder.getInt16(2));
486 Value *KDWord = PreBuilder.CreateLShr(K, PreBuilder.getInt16(2));
491 Value *ResVec = createTileDPLoops<IntrID>(Start,
End, Builder, M, NDWord,
495 Builder.SetInsertPoint(
End,
End->getFirstNonPHIIt());
503 I->replaceAllUsesWith(ResVec);
504 I->eraseFromParent();
512template <
bool IsTileLoad>
513bool X86LowerAMXIntrinsics::lowerTileLoadStore(
Instruction *TileLoadStore) {
517 m_Intrinsic<Intrinsic::x86_tileloadd64_internal>(
520 match(TileLoadStore, m_Intrinsic<Intrinsic::x86_tilestored64_internal>(
526 PreBuilder.SetInsertPoint(TileLoadStore);
527 Value *NDWord = PreBuilder.CreateLShr(
N, PreBuilder.getInt16(2));
528 Value *StrideDWord = PreBuilder.CreateLShr(Stride, PreBuilder.getInt64(2));
533 Value *ResVec = createTileLoadStoreLoops<IsTileLoad>(
534 Start,
End, Builder, M, NDWord,
Ptr, StrideDWord,
535 IsTileLoad ?
nullptr : Tile);
539 Builder.SetInsertPoint(
End,
End->getFirstNonPHIIt());
547 I->replaceAllUsesWith(ResVec);
548 I->eraseFromParent();
557bool X86LowerAMXIntrinsics::lowerTileZero(
Instruction *TileZero) {
565 I->replaceAllUsesWith(VecZero);
566 I->eraseFromParent();
573bool X86LowerAMXIntrinsics::visit() {
578 if (
auto *Inst = dyn_cast<IntrinsicInst>(&*
II++)) {
579 switch (Inst->getIntrinsicID()) {
580 case Intrinsic::x86_tdpbssd_internal:
581 case Intrinsic::x86_tdpbsud_internal:
582 case Intrinsic::x86_tdpbusd_internal:
583 case Intrinsic::x86_tdpbuud_internal:
584 case Intrinsic::x86_tileloadd64_internal:
585 case Intrinsic::x86_tilestored64_internal:
586 case Intrinsic::x86_tilezero_internal:
587 case Intrinsic::x86_tdpbf16ps_internal:
597 for (
auto *Inst : WorkList) {
598 switch (Inst->getIntrinsicID()) {
599 case Intrinsic::x86_tdpbssd_internal:
600 C = lowerTileDP<Intrinsic::x86_tdpbssd_internal>(Inst) ||
C;
602 case Intrinsic::x86_tdpbsud_internal:
603 C = lowerTileDP<Intrinsic::x86_tdpbsud_internal>(Inst) ||
C;
605 case Intrinsic::x86_tdpbusd_internal:
606 C = lowerTileDP<Intrinsic::x86_tdpbusd_internal>(Inst) ||
C;
608 case Intrinsic::x86_tdpbuud_internal:
609 C = lowerTileDP<Intrinsic::x86_tdpbuud_internal>(Inst) ||
C;
611 case Intrinsic::x86_tdpbf16ps_internal:
612 C = lowerTileDP<Intrinsic::x86_tdpbf16ps_internal>(Inst) ||
C;
614 case Intrinsic::x86_tileloadd64_internal:
615 C = lowerTileLoadStore<true>(Inst) ||
C;
617 case Intrinsic::x86_tilestored64_internal:
618 C = lowerTileLoadStore<false>(Inst) ||
C;
620 case Intrinsic::x86_tilezero_internal:
621 C = lowerTileZero(Inst) ||
C;
632class X86LowerAMXIntrinsicsLegacyPass :
public FunctionPass {
645 if (!
F.hasFnAttribute(Attribute::OptimizeNone) &&
649 auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
650 auto *DT = DTWP ? &DTWP->getDomTree() :
nullptr;
651 auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>();
652 auto *LI = LIWP ? &LIWP->getLoopInfo() :
nullptr;
655 X86LowerAMXIntrinsics LAT(
F, DTU, LI);
668static const char PassName[] =
"Lower AMX intrinsics";
669char X86LowerAMXIntrinsicsLegacyPass::ID = 0;
677 return new X86LowerAMXIntrinsicsLegacyPass();
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
uint64_t IntrinsicInst * II
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
const SmallVectorImpl< MachineOperand > & Cond
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
void visit(MachineFunction &MF, MachineBasicBlock &Start, std::function< void(MachineBasicBlock *)> op)
Target-Independent Code Generator Pass Configuration Options pass.
static cl::opt< bool > X86ScalarizeAMX("enable-x86-scalar-amx", cl::init(false), cl::Hidden, cl::desc("X86: enable AMX scalarizition."))
static bool isV256I32Ty(Type *Ty)
static const char PassName[]
static const uint32_t IV[8]
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
LLVM Basic Block Representation.
iterator begin()
Instruction iterator methods.
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
const BasicBlock * getSingleSuccessor() const
Return the successor of this block if it has a single successor.
const Function * getParent() const
Return the enclosing method, or null if none.
InstListType::iterator iterator
Instruction iterators...
LLVMContext & getContext() const
Get the context in which this basic block lives.
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Conditional or Unconditional Branch instruction.
static BranchInst * Create(BasicBlock *IfTrue, InsertPosition InsertBefore=nullptr)
BasicBlock * getSuccessor(unsigned i) const
void setSuccessor(unsigned idx, BasicBlock *NewSucc)
static Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
Legacy analysis pass which computes a DominatorTree.
Class to represent fixed width SIMD vectors.
static FixedVectorType * get(Type *ElementType, unsigned NumElts)
FunctionPass class - This class is used to implement most global optimizations.
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
Common base class shared among various IRBuilders.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
This is an important class for using LLVM in a threaded context.
void addChildLoop(LoopT *NewChild)
Add the specified loop to be a child of this loop.
The legacy pass manager's analysis pass to compute loop information.
Represents a single loop in the control flow graph.
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
static PHINode * Create(Type *Ty, unsigned NumReservedValues, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
Constructors - NumReservedValues is a hint for the number of incoming edges that this phi node will h...
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
virtual StringRef getPassName() const
getPassName - Return a nice clean name for a pass.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
StringRef - Represent a constant reference to a string, i.e.
Primary interface to the complete machine description for the target machine.
Target-Independent Code Generator Pass Configuration Options.
The instances of the Type class are immutable: once they are created, they are never changed.
static Type * getX86_AMXTy(LLVMContext &C)
static IntegerType * getInt16Ty(LLVMContext &C)
A Use represents the edge between a Value definition and its users.
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
iterator_range< use_iterator > uses()
const ParentTy * getParent() const
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
@ C
The default llvm calling convention, compatible with C.
bool match(Val *V, const Pattern &P)
CastOperator_match< OpTy, Instruction::BitCast > m_BitCast(const OpTy &Op)
Matches BitCast.
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
initializer< Ty > init(const Ty &Val)
This is an optimization pass for GlobalISel generic memory operations.
FunctionPass * createX86LowerAMXIntrinsicsPass()
The pass transforms amx intrinsics to scalar operation if the function has optnone attribute or it is...
void initializeX86LowerAMXIntrinsicsLegacyPassPass(PassRegistry &)
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
BasicBlock * SplitBlock(BasicBlock *Old, BasicBlock::iterator SplitPt, DominatorTree *DT, LoopInfo *LI=nullptr, MemorySSAUpdater *MSSAU=nullptr, const Twine &BBName="", bool Before=false)
Split the specified block at the specified instruction.
iterator_range< df_iterator< T > > depth_first(const T &G)