56#include "llvm/IR/IntrinsicsX86.h"
67using namespace PatternMatch;
69#define DEBUG_TYPE "lower-amx-type"
73 m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(
m_Value())) ||
74 match(II, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(
m_Value()));
78 auto *II = dyn_cast<IntrinsicInst>(
I);
85 if (II->getType()->isX86_AMXTy())
87 for (
Value *V : II->args()) {
88 if (V->getType()->isX86_AMXTy())
103 unsigned AllocaAS =
DL.getAllocaAddrSpace();
105 new AllocaInst(Ty, AllocaAS,
"", &
F.getEntryBlock().front());
112 if (!isa<AllocaInst>(&
I))
119 Value *Row =
nullptr, *Col =
nullptr;
123 case Intrinsic::x86_tileloadd64_internal:
124 case Intrinsic::x86_tileloaddt164_internal:
125 case Intrinsic::x86_tilestored64_internal: {
132 case Intrinsic::x86_tcmmimfp16ps_internal:
133 case Intrinsic::x86_tcmmrlfp16ps_internal:
134 case Intrinsic::x86_tdpbssd_internal:
135 case Intrinsic::x86_tdpbsud_internal:
136 case Intrinsic::x86_tdpbusd_internal:
137 case Intrinsic::x86_tdpbuud_internal:
138 case Intrinsic::x86_tdpbf16ps_internal:
139 case Intrinsic::x86_tdpfp16ps_internal: {
152 (cast<ConstantInt>(II->
getOperand(2))->getSExtValue()) / 4);
168 cast<Instruction>(Row)->moveAfter(cast<Instruction>(II->
getOperand(2)));
183 return std::make_pair(Row, Col);
187 Use &U = *(Phi->use_begin());
188 unsigned OpNo = U.getOperandNo();
189 User *V = U.getUser();
195 if (
isAMXCast(dyn_cast<Instruction>(V))) {
198 Use &U = *(V->use_begin());
199 OpNo = U.getOperandNo();
202 return getShape(cast<IntrinsicInst>(V), OpNo);
203 }
else if (isa<PHINode>(V)) {
206 Use &U = *(V->use_begin());
213 return std::make_pair(
nullptr,
nullptr);
217class X86LowerAMXType {
223 std::map<Value *, Value *> Col2Row;
239 Value *Row =
nullptr, *Col =
nullptr;
241 unsigned OpNo =
U.getOperandNo();
242 auto *II = cast<IntrinsicInst>(
U.getUser());
243 std::tie(Row, Col) =
getShape(II, OpNo);
246 Value *Stride = Builder.getInt64(64);
247 Value *I8Ptr =
LD->getOperand(0);
248 std::array<Value *, 4>
Args = {Row, Col, I8Ptr, Stride};
250 Value *NewInst = Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal,
252 Bitcast->replaceAllUsesWith(NewInst);
265 auto *II = cast<IntrinsicInst>(Tile);
268 Value *Row = II->getOperand(0);
269 Value *Col = II->getOperand(1);
273 Value *Stride = Builder.getInt64(64);
274 Value *I8Ptr =
ST->getOperand(1);
275 std::array<Value *, 5>
Args = {Row, Col, I8Ptr, Stride, Tile};
276 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt,
289 Value *Vec = Builder.CreateLoad(
Bitcast->getType(),
ST->getOperand(1));
290 Bitcast->replaceAllUsesWith(Vec);
294bool X86LowerAMXType::transformBitcast(
BitCastInst *Bitcast) {
297 Value *I8Ptr, *Stride;
298 auto *Src =
Bitcast->getOperand(0);
300 auto Prepare = [&](
Type *MemTy) {
303 Stride = Builder.getInt64(64);
306 if (
Bitcast->getType()->isX86_AMXTy()) {
316 unsigned OpNo =
U.getOperandNo();
317 auto *II = dyn_cast<IntrinsicInst>(
U.getUser());
320 Prepare(
Bitcast->getOperand(0)->getType());
321 Builder.CreateStore(Src, AllocaAddr);
323 Value *Row =
nullptr, *Col =
nullptr;
324 std::tie(Row, Col) =
getShape(II, OpNo);
325 std::array<Value *, 4>
Args = {Row, Col, I8Ptr, Stride};
326 Value *NewInst = Builder.CreateIntrinsic(
327 Intrinsic::x86_tileloadd64_internal, std::nullopt, Args);
328 Bitcast->replaceAllUsesWith(NewInst);
337 auto *II = dyn_cast<IntrinsicInst>(Src);
341 Value *Row = II->getOperand(0);
342 Value *Col = II->getOperand(1);
343 std::array<Value *, 5>
Args = {Row, Col, I8Ptr, Stride, Src};
344 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt,
346 Value *NewInst = Builder.CreateLoad(
Bitcast->getType(), AllocaAddr);
347 Bitcast->replaceAllUsesWith(NewInst);
353bool X86LowerAMXType::visit() {
359 auto *
Bitcast = dyn_cast<BitCastInst>(&Inst);
364 if (
Bitcast->getType()->isX86_AMXTy()) {
371 if (transformBitcast(Bitcast))
391 combineLoadBitcast(LD, Bitcast);
395 }
else if (Src->getType()->isX86_AMXTy()) {
402 ST = dyn_cast<StoreInst>(
U.getUser());
407 if (transformBitcast(Bitcast))
431 combineBitcastStore(Bitcast, ST);
439 bool C = !DeadInsts.
empty();
441 for (
auto *Inst : DeadInsts)
442 Inst->eraseFromParent();
453 unsigned AllocaAS =
DL.getAllocaAddrSpace();
454 Type *V256I32Ty = VectorType::get(Builder.
getInt32Ty(), 256,
false);
456 new AllocaInst(V256I32Ty, AllocaAS,
"", &
F->getEntryBlock().front());
466 auto *II = cast<IntrinsicInst>(TileDef);
467 assert(II &&
"Not tile intrinsic!");
468 Value *Row = II->getOperand(0);
469 Value *Col = II->getOperand(1);
475 std::array<Value *, 5> Args = {Row, Col,
Ptr, Stride, TileDef};
478 Intrinsic::x86_tilestored64_internal, std::nullopt, Args);
484 assert(V->getType()->isX86_AMXTy() &&
"Not define tile!");
489 Value *PhiOp = cast<PHINode>(V)->getIncomingValue(0);
490 II = cast<IntrinsicInst>(PhiOp);
492 II = cast<IntrinsicInst>(V);
497 Instruction *UserI = cast<Instruction>(U.getUser());
500 std::array<Value *, 4> Args = {Row, Col,
Ptr, Stride};
508 for (
Use &U :
I->uses()) {
509 User *V = U.getUser();
519class X86VolatileTileData {
527 bool volatileTileData();
532Value *X86VolatileTileData::updatePhiIncomings(
536 for (
auto *
I : Incomings) {
540 for (
Use &U :
I->uses()) {
542 if (isa<PHINode>(V) || V == Store)
552 for (
Use &U :
PHI->uses())
554 PHI->eraseFromParent();
612void X86VolatileTileData::volatileTilePHI(
PHINode *
PHI) {
616 for (
unsigned I = 0,
E =
PHI->getNumIncomingValues();
I !=
E; ++
I) {
619 assert(Inst &&
"We shouldn't fold AMX instrution!");
623 Value *StorePtr = updatePhiIncomings(BB, Incomings);
624 replacePhiDefWithLoad(
PHI, StorePtr);
643void X86VolatileTileData::volatileTileNonPHI(
Instruction *
I) {
649 for (
Use &U :
I->uses()) {
651 assert(!isa<PHINode>(V) &&
"PHI Nodes should be excluded!");
669bool X86VolatileTileData::volatileTileData() {
670 bool Changed =
false;
676 if (!
I.getType()->isX86_AMXTy())
678 if (isa<PHINode>(&
I))
688 volatileTileNonPHI(
I);
693 volatileTilePHI(dyn_cast<PHINode>(
I));
704class X86LowerAMXCast {
706 std::unique_ptr<DominatorTree> DT;
715 bool transformAllAMXCast();
729 for (
unsigned i = 0, e =
I->getNumOperands(); i != e; ++i) {
730 Value *OpV =
I->getOperand(i);
731 I->setOperand(i,
nullptr);
739 if (
Instruction *OpI = dyn_cast<Instruction>(OpV)) {
745 I->eraseFromParent();
759bool X86LowerAMXCast::optimizeAMXCastFromPhi(
764 Type *SrcTy = Src->getType();
776 while (!PhiWorklist.
empty()) {
778 for (
unsigned I = 0;
I < OldPN->getNumOperands(); ++
I) {
779 Value *IncValue = OldPN->getIncomingValue(
I);
782 if (isa<Constant>(IncValue)) {
783 auto *IncConst = dyn_cast<Constant>(IncValue);
784 if (!isa<UndefValue>(IncValue) && !IncConst->isZeroValue())
786 Value *Row =
nullptr, *Col =
nullptr;
787 std::tie(Row, Col) =
getShape(OldPN);
790 if (!Row || !Col || !isa<Constant>(Row) || !isa<Constant>(Col))
793 auto *
Block = OldPN->getIncomingBlock(
I);
796 Intrinsic::x86_tilezero_internal, std::nullopt, {Row, Col});
798 NewInst = Builder.CreateIntrinsic(Intrinsic::x86_cast_tile_to_vector,
799 {IncValue->
getType()}, {NewInst});
802 OldPN->setIncomingValue(
I, NewInst);
806 if (
auto *PNode = dyn_cast<PHINode>(IncValue)) {
807 if (OldPhiNodes.
insert(PNode))
811 Instruction *ACI = dyn_cast<Instruction>(IncValue);
816 if (TyA != DestTy || TyB != SrcTy)
826 for (
auto *OldPN : OldPhiNodes) {
833 if (TyA != DestTy || TyB != SrcTy)
835 }
else if (
auto *
PHI = dyn_cast<PHINode>(V)) {
854 if (OldPhiNodes.count(
PHI) == 0)
863 for (
auto *OldPN : OldPhiNodes) {
864 Builder.SetInsertPoint(OldPN);
865 PHINode *NewPN = Builder.CreatePHI(DestTy, OldPN->getNumOperands());
866 NewPNodes[OldPN] = NewPN;
870 for (
auto *OldPN : OldPhiNodes) {
871 PHINode *NewPN = NewPNodes[OldPN];
872 for (
unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) {
873 Value *
V = OldPN->getOperand(j);
874 Value *NewV =
nullptr;
879 else if (
auto *PrevPN = dyn_cast<PHINode>(V))
880 NewV = NewPNodes[PrevPN];
882 NewPN->
addIncoming(NewV, OldPN->getIncomingBlock(j));
894 for (
auto *OldPN : OldPhiNodes) {
895 PHINode *NewPN = NewPNodes[OldPN];
901 assert(TyA == DestTy && TyB == SrcTy);
906 }
else if (
auto *
PHI = dyn_cast<PHINode>(V)) {
929 auto *II = cast<IntrinsicInst>(Tile);
932 Value *Row = II->getOperand(0);
933 Value *Col = II->getOperand(1);
936 Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty());
937 Value *I8Ptr = Builder.CreateBitCast(
ST->getOperand(1), Builder.getPtrTy());
938 std::array<Value *, 5>
Args = {Row, Col, I8Ptr, Stride, Tile};
939 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt,
950 bool EraseLoad =
true;
951 Value *Row =
nullptr, *Col =
nullptr;
953 unsigned OpNo =
U.getOperandNo();
954 auto *II = cast<IntrinsicInst>(
U.getUser());
959 std::tie(Row, Col) =
getShape(II, OpNo);
962 Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty());
969 if (!DT->dominates(Row, LD) || !DT->dominates(Col, LD)) {
973 Builder.SetInsertPoint(&*std::next(
LD->getIterator()));
974 Builder.CreateStore(LD, AllocaAddr);
976 Builder.SetInsertPoint(Cast);
977 I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getPtrTy());
980 I8Ptr = Builder.CreateBitCast(
LD->getOperand(0), Builder.getPtrTy());
982 std::array<Value *, 4>
Args = {Row, Col, I8Ptr, Stride};
984 Value *NewInst = Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal,
993 for (
auto *Cast : Casts) {
994 auto *II = cast<IntrinsicInst>(Cast);
1000 if (II->getIntrinsicID() == Intrinsic::x86_cast_tile_to_vector) {
1006 if (combineCastStore(cast<IntrinsicInst>(Cast), Store)) {
1011 for (
auto *Store : DeadStores)
1012 Store->eraseFromParent();
1016 if (!Load || !
Load->hasOneUse())
1023 if (combineLoadCast(cast<IntrinsicInst>(Cast), Load)) {
1026 Load->eraseFromParent();
1034 bool Change =
false;
1044 m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(
m_Value(Vec))))
1046 else if (
match(&
I, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(
1053 for (
auto *Inst : Insts) {
1071 Convert(Vec2TileInsts, Intrinsic::x86_cast_tile_to_vector);
1072 Convert(Tile2VecInsts, Intrinsic::x86_cast_vector_to_tile);
1076 for (
auto *Inst : Insts) {
1077 if (Inst->use_empty()) {
1078 Inst->eraseFromParent();
1086 EraseInst(Vec2TileInsts);
1087 EraseInst(Tile2VecInsts);
1088 LLVM_DEBUG(
dbgs() <<
"[LowerAMXTYpe][combineAMXcast] IR dump after combine "
1089 "Vec2Tile and Tile2Vec:\n";
1091 Change |= combineLdSt(LiveCasts);
1092 EraseInst(LiveCasts);
1093 LLVM_DEBUG(
dbgs() <<
"[LowerAMXTYpe][combineAMXcast] IR dump after combine "
1094 "AMXCast and load/store:\n";
1101 if (isa<PHINode>(
I.getOperand(0)))
1106 for (
auto *
I : PhiCastWorkList) {
1110 PHINode *PN = cast<PHINode>(
I->getOperand(0));
1111 if (optimizeAMXCastFromPhi(cast<IntrinsicInst>(
I), PN, DeadInst)) {
1119 while (!DeadInst.
empty()) {
1123 LLVM_DEBUG(
dbgs() <<
"[LowerAMXTYpe][combineAMXcast] IR dump after "
1124 "optimizeAMXCastFromPhi:\n";
1131bool X86LowerAMXCast::transformAMXCast(
IntrinsicInst *AMXCast) {
1134 Value *I8Ptr, *Stride;
1137 auto Prepare = [&](
Type *MemTy) {
1139 I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getPtrTy());
1140 Stride = Builder.getInt64(64);
1161 unsigned OpNo =
U.getOperandNo();
1162 auto *II = dyn_cast<IntrinsicInst>(
U.getUser());
1166 Builder.CreateStore(Src, AllocaAddr);
1168 Value *Row =
nullptr, *Col =
nullptr;
1169 std::tie(Row, Col) =
getShape(II, OpNo);
1170 std::array<Value *, 4>
Args = {
1171 Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty())};
1172 Value *NewInst = Builder.CreateIntrinsic(
1173 Intrinsic::x86_tileloadd64_internal, std::nullopt, Args);
1184 auto *II = dyn_cast<IntrinsicInst>(Src);
1188 Value *Row = II->getOperand(0);
1189 Value *Col = II->getOperand(1);
1190 std::array<Value *, 5>
Args = {
1191 Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty()), Src};
1192 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt,
1194 Value *NewInst = Builder.CreateLoad(AMXCast->
getType(), AllocaAddr);
1202bool X86LowerAMXCast::transformAllAMXCast() {
1203 bool Change =
false;
1213 for (
auto *Inst : WorkLists) {
1214 Change |= transformAMXCast(cast<IntrinsicInst>(Inst));
1236 &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(
F);
1238 X86LowerAMXCast LAC(
F);
1239 C |= LAC.combineAMXcast(TLI);
1242 C |= LAC.transformAllAMXCast();
1244 X86LowerAMXType LAT(
F);
1250 if (
TM->getOptLevel() == CodeGenOptLevel::None) {
1255 if (!
F.hasFnAttribute(Attribute::OptimizeNone)) {
1256 X86VolatileTileData VTD(
F);
1257 C = VTD.volatileTileData() ||
C;
1273static const char PassName[] =
"Lower AMX type for load/store";
1274char X86LowerAMXTypeLegacyPass::ID = 0;
1283 return new X86LowerAMXTypeLegacyPass();
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
static bool DCEInstruction(Instruction *I, SmallSetVector< Instruction *, 16 > &WorkList, const TargetLibraryInfo *TLI)
const char LLVMTargetMachineRef TM
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file implements a set that has insertion order iteration characteristics.
This file defines the SmallSet class.
Target-Independent Code Generator Pass Configuration Options pass.
static bool isAMXCast(Instruction *II)
static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI=false)
static Instruction * createTileStore(Instruction *TileDef, Value *Ptr)
static std::pair< Value *, Value * > getShape(IntrinsicInst *II, unsigned OpNo)
static Value * getAllocaPos(BasicBlock *BB)
static bool isIncomingOfPHI(Instruction *I)
static bool isAMXIntrinsic(Value *I)
static const char PassName[]
static Instruction * getFirstNonAllocaInTheEntryBlock(Function &F)
static AllocaInst * createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB, Type *Ty)
an instruction to allocate memory on the stack
void setAlignment(Align Align)
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
void setPreservesCFG()
This function should be called by the pass, iff they do not:
LLVM Basic Block Representation.
const Function * getParent() const
Return the enclosing method, or null if none.
InstListType::iterator iterator
Instruction iterators...
const Module * getModule() const
Return the module owning the function this basic block belongs to, or nullptr if the function does no...
This class represents a no-op cast from one type to another.
Value * getArgOperand(unsigned i) const
This class represents an Operation in the Expression.
A parsed version of the target data layout string in and methods for querying it.
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
FunctionPass class - This class is used to implement most global optimizations.
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, Instruction *FMFSource=nullptr, const Twine &Name="")
Create a call to intrinsic ID with Args, mangled using Types.
IntegerType * getInt32Ty()
Fetch the type representing a 32-bit integer.
Value * CreateUDiv(Value *LHS, Value *RHS, const Twine &Name="", bool isExact=false)
ConstantInt * getInt64(uint64_t C)
Get a constant 64-bit value.
Value * CreateBitCast(Value *V, Type *DestTy, const Twine &Name="")
LLVMContext & getContext() const
PointerType * getPtrTy(unsigned AddrSpace=0)
Fetch the type representing a pointer.
ConstantInt * getInt16(uint16_t C)
Get a constant 16-bit value.
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
const BasicBlock * getParent() const
InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
const Function * getFunction() const
Return the function this instruction belongs to.
void moveBefore(Instruction *MovePos)
Unlink this instruction from its current basic block and insert it into the basic block that MovePos ...
A wrapper class for inspecting calls to intrinsic functions.
Intrinsic::ID getIntrinsicID() const
Return the intrinsic ID of this intrinsic.
This is an important class for using LLVM in a threaded context.
An instruction for reading from memory.
A Module instance is used to store all the information related to an LLVM module.
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
bool empty() const
Determine if the SetVector is empty or not.
bool insert(const value_type &X)
Insert a new element into the SetVector.
value_type pop_back_val()
bool contains(const key_type &key) const
Check if the SetVector contains the given key.
A SetVector that performs no allocations if smaller than a certain size.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
Provides information about what library functions are available for the current target.
Primary interface to the complete machine description for the target machine.
Target-Independent Code Generator Pass Configuration Options.
The instances of the Type class are immutable: once they are created, they are never changed.
static Type * getX86_AMXTy(LLVMContext &C)
bool isX86_AMXTy() const
Return true if this is X86 AMX.
A Use represents the edge between a Value definition and its users.
bool replaceUsesOfWith(Value *From, Value *To)
Replace uses of one Value with another.
void setOperand(unsigned i, Value *Val)
Value * getOperand(unsigned i) const
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
iterator_range< user_iterator > users()
self_iterator getIterator()
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
@ C
The default llvm calling convention, compatible with C.
@ Bitcast
Perform the operation on a different, but equivalently sized type.
bool match(Val *V, const Pattern &P)
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
NodeAddr< FuncNode * > Func
This is an optimization pass for GlobalISel generic memory operations.
void salvageDebugInfo(const MachineRegisterInfo &MRI, MachineInstr &MI)
Assuming the instruction MI is going to be deleted, attempt to salvage debug users of MI by writing t...
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
iterator_range< po_iterator< T > > post_order(const T &G)
bool isInstructionTriviallyDead(Instruction *I, const TargetLibraryInfo *TLI=nullptr)
Return true if the result produced by the instruction is not used, and the instruction will return.
auto reverse(ContainerTy &&C)
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
void initializeX86LowerAMXTypeLegacyPassPass(PassRegistry &)
bool salvageKnowledge(Instruction *I, AssumptionCache *AC=nullptr, DominatorTree *DT=nullptr)
Calls BuildAssumeFromInst and if the resulting llvm.assume is valid insert if before I.
FunctionPass * createX86LowerAMXTypePass()
The pass transforms load/store <256 x i32> to AMX load/store intrinsics or split the data to two <128...