LLVM  16.0.0git
X86LowerAMXType.cpp
Go to the documentation of this file.
1 //===- Target/X86/X86LowerAMXType.cpp - -------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Pass to transform <256 x i32> load/store
10 /// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only
11 /// provides simple operation on x86_amx. The basic elementwise operation
12 /// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32>
13 /// and only AMX intrinsics can operate on the type, we need transform
14 /// load/store <256 x i32> instruction to AMX load/store. If the bitcast can
15 /// not be combined with load/store, we transform the bitcast to amx load/store
16 /// and <256 x i32> store/load.
17 ///
18 /// If Front End not use O0 but the Mid/Back end use O0, (e.g. "Clang -O2 -S
19 /// -emit-llvm t.c" + "llc t.ll") we should make sure the amx data is volatile,
20 /// because that is necessary for AMX fast register allocation. (In Fast
21 /// registera allocation, register will be allocated before spill/reload, so
22 /// there is no additional register for amx to identify the step in spill.)
23 /// The volatileTileData() will handle this case.
24 /// e.g.
25 /// ----------------------------------------------------------
26 /// | def %td = ... |
27 /// | ... |
28 /// | "use %td" |
29 /// ----------------------------------------------------------
30 /// will transfer to -->
31 /// ----------------------------------------------------------
32 /// | def %td = ... |
33 /// | call void @llvm.x86.tilestored64.internal(mem, %td) |
34 /// | ... |
35 /// | %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)|
36 /// | "use %td2" |
37 /// ----------------------------------------------------------
38 //
39 //===----------------------------------------------------------------------===//
40 //
41 #include "X86.h"
43 #include "llvm/ADT/SetVector.h"
44 #include "llvm/ADT/SmallSet.h"
48 #include "llvm/CodeGen/Passes.h"
51 #include "llvm/IR/DataLayout.h"
52 #include "llvm/IR/Function.h"
53 #include "llvm/IR/IRBuilder.h"
54 #include "llvm/IR/Instructions.h"
55 #include "llvm/IR/IntrinsicInst.h"
56 #include "llvm/IR/IntrinsicsX86.h"
57 #include "llvm/IR/PatternMatch.h"
58 #include "llvm/InitializePasses.h"
59 #include "llvm/Pass.h"
63 
64 #include <map>
65 
66 using namespace llvm;
67 using namespace PatternMatch;
68 
69 #define DEBUG_TYPE "lower-amx-type"
70 
71 static bool isAMXCast(Instruction *II) {
72  return match(II,
73  m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value())) ||
74  match(II, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(m_Value()));
75 }
76 
77 static bool isAMXIntrinsic(Value *I) {
78  auto *II = dyn_cast<IntrinsicInst>(I);
79  if (!II)
80  return false;
81  if (isAMXCast(II))
82  return false;
83  // Check if return type or parameter is x86_amx. If it is x86_amx
84  // the intrinsic must be x86 amx intrinsics.
85  if (II->getType()->isX86_AMXTy())
86  return true;
87  for (Value *V : II->args()) {
88  if (V->getType()->isX86_AMXTy())
89  return true;
90  }
91 
92  return false;
93 }
94 
96  Type *Ty) {
97  Function &F = *BB->getParent();
98  Module *M = BB->getModule();
99  const DataLayout &DL = M->getDataLayout();
100 
101  LLVMContext &Ctx = Builder.getContext();
102  auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx));
103  unsigned AllocaAS = DL.getAllocaAddrSpace();
104  AllocaInst *AllocaRes =
105  new AllocaInst(Ty, AllocaAS, "", &F.getEntryBlock().front());
106  AllocaRes->setAlignment(AllocaAlignment);
107  return AllocaRes;
108 }
109 
111  for (Instruction &I : F.getEntryBlock())
112  if (!isa<AllocaInst>(&I))
113  return &I;
114  llvm_unreachable("No terminator in the entry block!");
115 }
116 
117 static std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) {
118  IRBuilder<> Builder(II);
119  Value *Row = nullptr, *Col = nullptr;
120  switch (II->getIntrinsicID()) {
121  default:
122  llvm_unreachable("Expect amx intrinsics");
123  case Intrinsic::x86_tileloadd64_internal:
124  case Intrinsic::x86_tileloaddt164_internal:
125  case Intrinsic::x86_tilestored64_internal: {
126  Row = II->getArgOperand(0);
127  Col = II->getArgOperand(1);
128  break;
129  }
130  // a * b + c
131  // The shape depends on which operand.
132  case Intrinsic::x86_tdpbssd_internal:
133  case Intrinsic::x86_tdpbsud_internal:
134  case Intrinsic::x86_tdpbusd_internal:
135  case Intrinsic::x86_tdpbuud_internal:
136  case Intrinsic::x86_tdpbf16ps_internal:
137  case Intrinsic::x86_tdpfp16ps_internal: {
138  switch (OpNo) {
139  case 3:
140  Row = II->getArgOperand(0);
141  Col = II->getArgOperand(1);
142  break;
143  case 4:
144  Row = II->getArgOperand(0);
145  Col = II->getArgOperand(2);
146  break;
147  case 5:
148  if (isa<ConstantInt>(II->getArgOperand(2)))
149  Row = Builder.getInt16(
150  (cast<ConstantInt>(II->getOperand(2))->getSExtValue()) / 4);
151  else if (isa<Instruction>(II->getArgOperand(2))) {
152  // When it is not a const value and it is not a function argument, we
153  // create Row after the definition of II->getOperand(2) instead of
154  // before II. For example, II is %118, we try to getshape for %117:
155  // %117 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x
156  // i32> %115).
157  // %118 = call x86_amx @llvm.x86.tdpbf16ps.internal(i16
158  // %104, i16 %105, i16 %106, x86_amx %110, x86_amx %114, x86_amx
159  // %117).
160  // If we create %row = udiv i16 %106, 4 before %118(aka. II), then its
161  // definition is after its user(new tileload for %117).
162  // So, the best choice is to create %row right after the definition of
163  // %106.
164  Builder.SetInsertPoint(cast<Instruction>(II->getOperand(2)));
165  Row = Builder.CreateUDiv(II->getOperand(2), Builder.getInt16(4));
166  cast<Instruction>(Row)->moveAfter(cast<Instruction>(II->getOperand(2)));
167  } else {
168  // When it is not a const value and it is a function argument, we create
169  // Row at the entry bb.
170  IRBuilder<> NewBuilder(
172  Row = NewBuilder.CreateUDiv(II->getOperand(2), NewBuilder.getInt16(4));
173  }
174  Col = II->getArgOperand(1);
175  break;
176  }
177  break;
178  }
179  }
180 
181  return std::make_pair(Row, Col);
182 }
183 
184 static std::pair<Value *, Value *> getShape(PHINode *Phi) {
185  Use &U = *(Phi->use_begin());
186  unsigned OpNo = U.getOperandNo();
187  User *V = U.getUser();
188  // TODO We don't traverse all users. To make the algorithm simple, here we
189  // just traverse the first user. If we can find shape, then return the shape,
190  // otherwise just return nullptr and the optimization for undef/zero will be
191  // abandoned.
192  while (V) {
193  if (isAMXCast(dyn_cast<Instruction>(V))) {
194  if (V->use_empty())
195  break;
196  Use &U = *(V->use_begin());
197  OpNo = U.getOperandNo();
198  V = U.getUser();
199  } else if (isAMXIntrinsic(V)) {
200  return getShape(cast<IntrinsicInst>(V), OpNo);
201  } else if (isa<PHINode>(V)) {
202  if (V->use_empty())
203  break;
204  Use &U = *(V->use_begin());
205  V = U.getUser();
206  } else {
207  break;
208  }
209  }
210 
211  return std::make_pair(nullptr, nullptr);
212 }
213 
214 namespace {
215 class X86LowerAMXType {
216  Function &Func;
217 
218  // In AMX intrinsics we let Shape = {Row, Col}, but the
219  // RealCol = Col / ElementSize. We may use the RealCol
220  // as a new Row for other new created AMX intrinsics.
221  std::map<Value *, Value *> Col2Row;
222 
223 public:
224  X86LowerAMXType(Function &F) : Func(F) {}
225  bool visit();
226  void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast);
227  void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST);
228  bool transformBitcast(BitCastInst *Bitcast);
229 };
230 
231 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
232 // %2 = bitcast <256 x i32> %src to x86_amx
233 // -->
234 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
235 // i8* %addr, i64 %stride64)
236 void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) {
237  Value *Row = nullptr, *Col = nullptr;
238  Use &U = *(Bitcast->use_begin());
239  unsigned OpNo = U.getOperandNo();
240  auto *II = cast<IntrinsicInst>(U.getUser());
241  std::tie(Row, Col) = getShape(II, OpNo);
243  // Use the maximun column as stride.
244  Value *Stride = Builder.getInt64(64);
245  Value *I8Ptr =
246  Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy());
247  std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
248 
249  Value *NewInst = Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal,
250  std::nullopt, Args);
251  Bitcast->replaceAllUsesWith(NewInst);
252 }
253 
254 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
255 // %stride);
256 // %13 = bitcast x86_amx %src to <256 x i32>
257 // store <256 x i32> %13, <256 x i32>* %addr, align 64
258 // -->
259 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
260 // %stride64, %13)
261 void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) {
262 
263  Value *Tile = Bitcast->getOperand(0);
264  auto *II = cast<IntrinsicInst>(Tile);
265  // Tile is output from AMX intrinsic. The first operand of the
266  // intrinsic is row, the second operand of the intrinsic is column.
267  Value *Row = II->getOperand(0);
268  Value *Col = II->getOperand(1);
270  // Use the maximum column as stride. It must be the same with load
271  // stride.
272  Value *Stride = Builder.getInt64(64);
273  Value *I8Ptr =
274  Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy());
275  std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
276  Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt,
277  Args);
278  if (Bitcast->hasOneUse())
279  return;
280  // %13 = bitcast x86_amx %src to <256 x i32>
281  // store <256 x i32> %13, <256 x i32>* %addr, align 64
282  // %add = <256 x i32> %13, <256 x i32> %src2
283  // -->
284  // %13 = bitcast x86_amx %src to <256 x i32>
285  // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
286  // %stride64, %13)
287  // %14 = load <256 x i32>, %addr
288  // %add = <256 x i32> %14, <256 x i32> %src2
289  Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1));
290  Bitcast->replaceAllUsesWith(Vec);
291 }
292 
293 // transform bitcast to <store, load> instructions.
294 bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) {
296  AllocaInst *AllocaAddr;
297  Value *I8Ptr, *Stride;
298  auto *Src = Bitcast->getOperand(0);
299 
300  auto Prepare = [&](Type *MemTy) {
301  AllocaAddr = createAllocaInstAtEntry(Builder, Bitcast->getParent(), MemTy);
302  I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy());
303  Stride = Builder.getInt64(64);
304  };
305 
306  if (Bitcast->getType()->isX86_AMXTy()) {
307  // %2 = bitcast <256 x i32> %src to x86_amx
308  // -->
309  // %addr = alloca <256 x i32>, align 64
310  // store <256 x i32> %src, <256 x i32>* %addr, align 64
311  // %addr2 = bitcast <256 x i32>* to i8*
312  // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
313  // i8* %addr2,
314  // i64 64)
315  Use &U = *(Bitcast->use_begin());
316  unsigned OpNo = U.getOperandNo();
317  auto *II = dyn_cast<IntrinsicInst>(U.getUser());
318  if (!II)
319  return false; // May be bitcast from x86amx to <256 x i32>.
320  Prepare(Bitcast->getOperand(0)->getType());
321  Builder.CreateStore(Src, AllocaAddr);
322  // TODO we can pick an constant operand for the shape.
323  Value *Row = nullptr, *Col = nullptr;
324  std::tie(Row, Col) = getShape(II, OpNo);
325  std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
326  Value *NewInst = Builder.CreateIntrinsic(
327  Intrinsic::x86_tileloadd64_internal, std::nullopt, Args);
328  Bitcast->replaceAllUsesWith(NewInst);
329  } else {
330  // %2 = bitcast x86_amx %src to <256 x i32>
331  // -->
332  // %addr = alloca <256 x i32>, align 64
333  // %addr2 = bitcast <256 x i32>* to i8*
334  // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
335  // i8* %addr2, i64 %stride)
336  // %2 = load <256 x i32>, <256 x i32>* %addr, align 64
337  auto *II = dyn_cast<IntrinsicInst>(Src);
338  if (!II)
339  return false; // May be bitcast from <256 x i32> to x86amx.
340  Prepare(Bitcast->getType());
341  Value *Row = II->getOperand(0);
342  Value *Col = II->getOperand(1);
343  std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src};
344  Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt,
345  Args);
346  Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr);
347  Bitcast->replaceAllUsesWith(NewInst);
348  }
349 
350  return true;
351 }
352 
353 bool X86LowerAMXType::visit() {
355  Col2Row.clear();
356 
357  for (BasicBlock *BB : post_order(&Func)) {
359  auto *Bitcast = dyn_cast<BitCastInst>(&Inst);
360  if (!Bitcast)
361  continue;
362 
363  Value *Src = Bitcast->getOperand(0);
364  if (Bitcast->getType()->isX86_AMXTy()) {
365  if (Bitcast->user_empty()) {
366  DeadInsts.push_back(Bitcast);
367  continue;
368  }
369  LoadInst *LD = dyn_cast<LoadInst>(Src);
370  if (!LD) {
371  if (transformBitcast(Bitcast))
372  DeadInsts.push_back(Bitcast);
373  continue;
374  }
375  // If load has mutli-user, duplicate a vector load.
376  // %src = load <256 x i32>, <256 x i32>* %addr, align 64
377  // %2 = bitcast <256 x i32> %src to x86_amx
378  // %add = add <256 x i32> %src, <256 x i32> %src2
379  // -->
380  // %src = load <256 x i32>, <256 x i32>* %addr, align 64
381  // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
382  // i8* %addr, i64 %stride64)
383  // %add = add <256 x i32> %src, <256 x i32> %src2
384 
385  // If load has one user, the load will be eliminated in DAG ISel.
386  // %src = load <256 x i32>, <256 x i32>* %addr, align 64
387  // %2 = bitcast <256 x i32> %src to x86_amx
388  // -->
389  // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
390  // i8* %addr, i64 %stride64)
391  combineLoadBitcast(LD, Bitcast);
392  DeadInsts.push_back(Bitcast);
393  if (LD->hasOneUse())
394  DeadInsts.push_back(LD);
395  } else if (Src->getType()->isX86_AMXTy()) {
396  if (Bitcast->user_empty()) {
397  DeadInsts.push_back(Bitcast);
398  continue;
399  }
400  StoreInst *ST = nullptr;
401  for (Use &U : Bitcast->uses()) {
402  ST = dyn_cast<StoreInst>(U.getUser());
403  if (ST)
404  break;
405  }
406  if (!ST) {
407  if (transformBitcast(Bitcast))
408  DeadInsts.push_back(Bitcast);
409  continue;
410  }
411  // If bitcast (%13) has one use, combine bitcast and store to amx store.
412  // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
413  // %stride);
414  // %13 = bitcast x86_amx %src to <256 x i32>
415  // store <256 x i32> %13, <256 x i32>* %addr, align 64
416  // -->
417  // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
418  // %stride64, %13)
419  //
420  // If bitcast (%13) has multi-use, transform as below.
421  // %13 = bitcast x86_amx %src to <256 x i32>
422  // store <256 x i32> %13, <256 x i32>* %addr, align 64
423  // %add = <256 x i32> %13, <256 x i32> %src2
424  // -->
425  // %13 = bitcast x86_amx %src to <256 x i32>
426  // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
427  // %stride64, %13)
428  // %14 = load <256 x i32>, %addr
429  // %add = <256 x i32> %14, <256 x i32> %src2
430  //
431  combineBitcastStore(Bitcast, ST);
432  // Delete user first.
433  DeadInsts.push_back(ST);
434  DeadInsts.push_back(Bitcast);
435  }
436  }
437  }
438 
439  bool C = !DeadInsts.empty();
440 
441  for (auto *Inst : DeadInsts)
442  Inst->eraseFromParent();
443 
444  return C;
445 }
446 } // anonymous namespace
447 
449  Module *M = BB->getModule();
450  Function *F = BB->getParent();
451  IRBuilder<> Builder(&F->getEntryBlock().front());
452  const DataLayout &DL = M->getDataLayout();
453  unsigned AllocaAS = DL.getAllocaAddrSpace();
454  Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false);
455  AllocaInst *AllocaRes =
456  new AllocaInst(V256I32Ty, AllocaAS, "", &F->getEntryBlock().front());
457  BasicBlock::iterator Iter = AllocaRes->getIterator();
458  ++Iter;
459  Builder.SetInsertPoint(&*Iter);
460  Value *I8Ptr = Builder.CreateBitCast(AllocaRes, Builder.getInt8PtrTy());
461  return I8Ptr;
462 }
463 
465  assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!");
466  auto *II = cast<IntrinsicInst>(TileDef);
467  assert(II && "Not tile intrinsic!");
468  Value *Row = II->getOperand(0);
469  Value *Col = II->getOperand(1);
470 
471  BasicBlock *BB = TileDef->getParent();
472  BasicBlock::iterator Iter = TileDef->getIterator();
473  IRBuilder<> Builder(BB, ++Iter);
474  Value *Stride = Builder.getInt64(64);
475  std::array<Value *, 5> Args = {Row, Col, Ptr, Stride, TileDef};
476 
477  Instruction *TileStore = Builder.CreateIntrinsic(
478  Intrinsic::x86_tilestored64_internal, std::nullopt, Args);
479  return TileStore;
480 }
481 
482 static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) {
483  Value *V = U.get();
484  assert(V->getType()->isX86_AMXTy() && "Not define tile!");
485 
486  // Get tile shape.
487  IntrinsicInst *II = nullptr;
488  if (IsPHI) {
489  Value *PhiOp = dyn_cast<PHINode>(V)->getIncomingValue(0);
490  II = cast<IntrinsicInst>(PhiOp);
491  } else {
492  II = cast<IntrinsicInst>(V);
493  }
494  Value *Row = II->getOperand(0);
495  Value *Col = II->getOperand(1);
496 
497  Instruction *UserI = dyn_cast<Instruction>(U.getUser());
498  IRBuilder<> Builder(UserI);
499  Value *Stride = Builder.getInt64(64);
500  std::array<Value *, 4> Args = {Row, Col, Ptr, Stride};
501 
502  Value *TileLoad = Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal,
503  std::nullopt, Args);
504  UserI->replaceUsesOfWith(V, TileLoad);
505 }
506 
508  for (Use &U : I->uses()) {
509  User *V = U.getUser();
510  if (isa<PHINode>(V))
511  return true;
512  }
513  return false;
514 }
515 
516 // Let all AMX tile data become volatile data, shorten the life range
517 // of each tile register before fast register allocation.
518 namespace {
519 class X86VolatileTileData {
520  Function &F;
521 
522 public:
523  X86VolatileTileData(Function &Func) : F(Func) {}
524  Value *updatePhiIncomings(BasicBlock *BB,
525  SmallVector<Instruction *, 2> &Incomings);
526  void replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr);
527  bool volatileTileData();
528  void volatileTilePHI(PHINode *Inst);
529  void volatileTileNonPHI(Instruction *I);
530 };
531 
532 Value *X86VolatileTileData::updatePhiIncomings(
534  Value *I8Ptr = getAllocaPos(BB);
535 
536  for (auto *I : Incomings) {
537  User *Store = createTileStore(I, I8Ptr);
538 
539  // All its uses (except phi) should load from stored mem.
540  for (Use &U : I->uses()) {
541  User *V = U.getUser();
542  if (isa<PHINode>(V) || V == Store)
543  continue;
544  replaceWithTileLoad(U, I8Ptr);
545  }
546  }
547  return I8Ptr;
548 }
549 
550 void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI,
551  Value *StorePtr) {
552  for (Use &U : PHI->uses())
553  replaceWithTileLoad(U, StorePtr, true);
554  PHI->eraseFromParent();
555 }
556 
557 // Smilar with volatileTileNonPHI, this function only handle PHI Nodes
558 // and their related AMX intrinsics.
559 // 1) PHI Def should change to tileload.
560 // 2) PHI Incoming Values should tilestored in just after their def.
561 // 3) The mem of these tileload and tilestores should be same.
562 // e.g.
563 // ------------------------------------------------------
564 // bb_dom:
565 // ...
566 // br i1 %bool.cond, label %if.else, label %if.then
567 //
568 // if.then:
569 // def %t0 = ...
570 // ...
571 // use %t0
572 // ...
573 // br label %if.end
574 //
575 // if.else:
576 // def %t1 = ...
577 // br label %if.end
578 //
579 // if.end:
580 // %td = phi x86_amx [ %t1, %if.else ], [ %t0, %if.then ]
581 // ...
582 // use %td
583 // ------------------------------------------------------
584 // -->
585 // ------------------------------------------------------
586 // bb_entry:
587 // %mem = alloca <256 x i32>, align 1024 *
588 // ...
589 // bb_dom:
590 // ...
591 // br i1 %bool.cond, label %if.else, label %if.then
592 //
593 // if.then:
594 // def %t0 = ...
595 // call void @llvm.x86.tilestored64.internal(mem, %t0) *
596 // ...
597 // %t0` = call x86_amx @llvm.x86.tileloadd64.internal(mem)*
598 // use %t0` *
599 // ...
600 // br label %if.end
601 //
602 // if.else:
603 // def %t1 = ...
604 // call void @llvm.x86.tilestored64.internal(mem, %t1) *
605 // br label %if.end
606 //
607 // if.end:
608 // ...
609 // %td = call x86_amx @llvm.x86.tileloadd64.internal(mem) *
610 // use %td
611 // ------------------------------------------------------
612 void X86VolatileTileData::volatileTilePHI(PHINode *PHI) {
613  BasicBlock *BB = PHI->getParent();
615 
616  for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) {
617  Value *Op = PHI->getIncomingValue(I);
618  Instruction *Inst = dyn_cast<Instruction>(Op);
619  assert(Inst && "We shouldn't fold AMX instrution!");
620  Incomings.push_back(Inst);
621  }
622 
623  Value *StorePtr = updatePhiIncomings(BB, Incomings);
624  replacePhiDefWithLoad(PHI, StorePtr);
625 }
626 
627 // Store the defined tile and load it before use.
628 // All its users are not PHI.
629 // e.g.
630 // ------------------------------------------------------
631 // def %td = ...
632 // ...
633 // "use %td"
634 // ------------------------------------------------------
635 // -->
636 // ------------------------------------------------------
637 // def %td = ...
638 // call void @llvm.x86.tilestored64.internal(mem, %td)
639 // ...
640 // %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)
641 // "use %td2"
642 // ------------------------------------------------------
643 void X86VolatileTileData::volatileTileNonPHI(Instruction *I) {
644  BasicBlock *BB = I->getParent();
645  Value *I8Ptr = getAllocaPos(BB);
646  User *Store = createTileStore(I, I8Ptr);
647 
648  // All its uses should load from stored mem.
649  for (Use &U : I->uses()) {
650  User *V = U.getUser();
651  assert(!isa<PHINode>(V) && "PHI Nodes should be excluded!");
652  if (V != Store)
653  replaceWithTileLoad(U, I8Ptr);
654  }
655 }
656 
657 // Volatile Tile Model:
658 // 1) All the uses of tile data comes from tileload in time.
659 // 2) All the defs of tile data tilestore into mem immediately.
660 // For example:
661 // --------------------------------------------------------------------------
662 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key
663 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
664 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx
665 // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
666 // call void @llvm.x86.tilestored64.internal(... td) area
667 // --------------------------------------------------------------------------
668 // 3) No terminator, call or other amx instructions in the key amx area.
669 bool X86VolatileTileData::volatileTileData() {
670  bool Changed = false;
671  for (BasicBlock &BB : F) {
673  SmallVector<Instruction *, 8> AMXDefInsts;
674 
675  for (Instruction &I : BB) {
676  if (!I.getType()->isX86_AMXTy())
677  continue;
678  if (isa<PHINode>(&I))
679  PHIInsts.push_back(&I);
680  else
681  AMXDefInsts.push_back(&I);
682  }
683 
684  // First we "volatile" the non-phi related amx intrinsics.
685  for (Instruction *I : AMXDefInsts) {
686  if (isIncomingOfPHI(I))
687  continue;
688  volatileTileNonPHI(I);
689  Changed = true;
690  }
691 
692  for (Instruction *I : PHIInsts) {
693  volatileTilePHI(dyn_cast<PHINode>(I));
694  Changed = true;
695  }
696  }
697  return Changed;
698 }
699 
700 } // anonymous namespace
701 
702 namespace {
703 
704 class X86LowerAMXCast {
705  Function &Func;
706  std::unique_ptr<DominatorTree> DT;
707 
708 public:
709  X86LowerAMXCast(Function &F) : Func(F), DT(nullptr) {}
710  void combineCastStore(IntrinsicInst *Cast, StoreInst *ST);
711  bool combineLoadCast(IntrinsicInst *Cast, LoadInst *LD);
712  bool combineLdSt(SmallVectorImpl<Instruction *> &Casts);
713  bool combineAMXcast(TargetLibraryInfo *TLI);
714  bool transformAMXCast(IntrinsicInst *AMXCast);
715  bool transformAllAMXCast();
716  bool optimizeAMXCastFromPhi(IntrinsicInst *CI, PHINode *PN,
718 };
719 
720 static bool DCEInstruction(Instruction *I,
722  const TargetLibraryInfo *TLI) {
723  if (isInstructionTriviallyDead(I, TLI)) {
726 
727  // Null out all of the instruction's operands to see if any operand becomes
728  // dead as we go.
729  for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
730  Value *OpV = I->getOperand(i);
731  I->setOperand(i, nullptr);
732 
733  if (!OpV->use_empty() || I == OpV)
734  continue;
735 
736  // If the operand is an instruction that became dead as we nulled out the
737  // operand, and if it is 'trivially' dead, delete it in a future loop
738  // iteration.
739  if (Instruction *OpI = dyn_cast<Instruction>(OpV)) {
740  if (isInstructionTriviallyDead(OpI, TLI)) {
741  WorkList.insert(OpI);
742  }
743  }
744  }
745  I->eraseFromParent();
746  return true;
747  }
748  return false;
749 }
750 
751 /// This function handles following case
752 ///
753 /// A -> B amxcast
754 /// PHI
755 /// B -> A amxcast
756 ///
757 /// All the related PHI nodes can be replaced by new PHI nodes with type A.
758 /// The uses of \p CI can be changed to the new PHI node corresponding to \p PN.
759 bool X86LowerAMXCast::optimizeAMXCastFromPhi(
760  IntrinsicInst *CI, PHINode *PN,
762  IRBuilder<> Builder(CI);
763  Value *Src = CI->getOperand(0);
764  Type *SrcTy = Src->getType(); // Type B
765  Type *DestTy = CI->getType(); // Type A
766 
767  SmallVector<PHINode *, 4> PhiWorklist;
768  SmallSetVector<PHINode *, 4> OldPhiNodes;
769 
770  // Find all of the A->B casts and PHI nodes.
771  // We need to inspect all related PHI nodes, but PHIs can be cyclic, so
772  // OldPhiNodes is used to track all known PHI nodes, before adding a new
773  // PHI to PhiWorklist, it is checked against and added to OldPhiNodes first.
774  PhiWorklist.push_back(PN);
775  OldPhiNodes.insert(PN);
776  while (!PhiWorklist.empty()) {
777  auto *OldPN = PhiWorklist.pop_back_val();
778  for (unsigned I = 0; I < OldPN->getNumOperands(); ++I) {
779  Value *IncValue = OldPN->getIncomingValue(I);
780  // TODO: currently, We ignore cases where it is a const. In the future, we
781  // might support const.
782  if (isa<Constant>(IncValue)) {
783  auto *IncConst = dyn_cast<Constant>(IncValue);
784  if (!isa<UndefValue>(IncValue) && !IncConst->isZeroValue())
785  return false;
786  Value *Row = nullptr, *Col = nullptr;
787  std::tie(Row, Col) = getShape(OldPN);
788  // TODO: If it is not constant the Row and Col must domoniate tilezero
789  // that we are going to create.
790  if (!Row || !Col || !isa<Constant>(Row) || !isa<Constant>(Col))
791  return false;
792  // Create tilezero at the end of incoming block.
793  auto *Block = OldPN->getIncomingBlock(I);
794  BasicBlock::iterator Iter = Block->getTerminator()->getIterator();
795  Instruction *NewInst = Builder.CreateIntrinsic(
796  Intrinsic::x86_tilezero_internal, std::nullopt, {Row, Col});
797  NewInst->moveBefore(&*Iter);
798  NewInst = Builder.CreateIntrinsic(Intrinsic::x86_cast_tile_to_vector,
799  {IncValue->getType()}, {NewInst});
800  NewInst->moveBefore(&*Iter);
801  // Replace InValue with new Value.
802  OldPN->setIncomingValue(I, NewInst);
803  IncValue = NewInst;
804  }
805 
806  if (auto *PNode = dyn_cast<PHINode>(IncValue)) {
807  if (OldPhiNodes.insert(PNode))
808  PhiWorklist.push_back(PNode);
809  continue;
810  }
811  Instruction *ACI = dyn_cast<Instruction>(IncValue);
812  if (ACI && isAMXCast(ACI)) {
813  // Verify it's a A->B cast.
814  Type *TyA = ACI->getOperand(0)->getType();
815  Type *TyB = ACI->getType();
816  if (TyA != DestTy || TyB != SrcTy)
817  return false;
818  continue;
819  }
820  return false;
821  }
822  }
823 
824  // Check that each user of each old PHI node is something that we can
825  // rewrite, so that all of the old PHI nodes can be cleaned up afterwards.
826  for (auto *OldPN : OldPhiNodes) {
827  for (User *V : OldPN->users()) {
828  Instruction *ACI = dyn_cast<Instruction>(V);
829  if (ACI && isAMXCast(ACI)) {
830  // Verify it's a B->A cast.
831  Type *TyB = ACI->getOperand(0)->getType();
832  Type *TyA = ACI->getType();
833  if (TyA != DestTy || TyB != SrcTy)
834  return false;
835  } else if (auto *PHI = dyn_cast<PHINode>(V)) {
836  // As long as the user is another old PHI node, then even if we don't
837  // rewrite it, the PHI web we're considering won't have any users
838  // outside itself, so it'll be dead.
839  // example:
840  // bb.0:
841  // %0 = amxcast ...
842  // bb.1:
843  // %1 = amxcast ...
844  // bb.2:
845  // %goodphi = phi %0, %1
846  // %3 = amxcast %goodphi
847  // bb.3:
848  // %goodphi2 = phi %0, %goodphi
849  // %4 = amxcast %goodphi2
850  // When optimizeAMXCastFromPhi process %3 and %goodphi, %goodphi2 is
851  // outside the phi-web, so the combination stop When
852  // optimizeAMXCastFromPhi process %4 and %goodphi2, the optimization
853  // will be done.
854  if (OldPhiNodes.count(PHI) == 0)
855  return false;
856  } else
857  return false;
858  }
859  }
860 
861  // For each old PHI node, create a corresponding new PHI node with a type A.
863  for (auto *OldPN : OldPhiNodes) {
864  Builder.SetInsertPoint(OldPN);
865  PHINode *NewPN = Builder.CreatePHI(DestTy, OldPN->getNumOperands());
866  NewPNodes[OldPN] = NewPN;
867  }
868 
869  // Fill in the operands of new PHI nodes.
870  for (auto *OldPN : OldPhiNodes) {
871  PHINode *NewPN = NewPNodes[OldPN];
872  for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) {
873  Value *V = OldPN->getOperand(j);
874  Value *NewV = nullptr;
875  Instruction *ACI = dyn_cast<Instruction>(V);
876  // There should not be a AMXcast from a const.
877  if (ACI && isAMXCast(ACI))
878  NewV = ACI->getOperand(0);
879  else if (auto *PrevPN = dyn_cast<PHINode>(V))
880  NewV = NewPNodes[PrevPN];
881  assert(NewV);
882  NewPN->addIncoming(NewV, OldPN->getIncomingBlock(j));
883  }
884  }
885 
886  // Traverse all accumulated PHI nodes and process its users,
887  // which are Stores and BitcCasts. Without this processing
888  // NewPHI nodes could be replicated and could lead to extra
889  // moves generated after DeSSA.
890  // If there is a store with type B, change it to type A.
891 
892  // Replace users of BitCast B->A with NewPHI. These will help
893  // later to get rid of a closure formed by OldPHI nodes.
894  for (auto *OldPN : OldPhiNodes) {
895  PHINode *NewPN = NewPNodes[OldPN];
896  for (User *V : make_early_inc_range(OldPN->users())) {
897  Instruction *ACI = dyn_cast<Instruction>(V);
898  if (ACI && isAMXCast(ACI)) {
899  Type *TyB = ACI->getOperand(0)->getType();
900  Type *TyA = ACI->getType();
901  assert(TyA == DestTy && TyB == SrcTy);
902  (void)TyA;
903  (void)TyB;
904  ACI->replaceAllUsesWith(NewPN);
905  DeadInst.insert(ACI);
906  } else if (auto *PHI = dyn_cast<PHINode>(V)) {
907  // We don't need to push PHINode into DeadInst since they are operands
908  // of rootPN DCE can safely delete rootPN's operands if rootPN is dead.
909  assert(OldPhiNodes.contains(PHI));
910  (void)PHI;
911  } else
912  llvm_unreachable("all uses should be handled");
913  }
914  }
915  return true;
916 }
917 
918 // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %42)
919 // store <256 x i32> %43, <256 x i32>* %p, align 64
920 // -->
921 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p,
922 // i64 64, x86_amx %42)
923 void X86LowerAMXCast::combineCastStore(IntrinsicInst *Cast, StoreInst *ST) {
924  Value *Tile = Cast->getOperand(0);
925  // TODO: If it is cast intrinsic or phi node, we can propagate the
926  // shape information through def-use chain.
927  if (!isAMXIntrinsic(Tile))
928  return;
929  auto *II = cast<IntrinsicInst>(Tile);
930  // Tile is output from AMX intrinsic. The first operand of the
931  // intrinsic is row, the second operand of the intrinsic is column.
932  Value *Row = II->getOperand(0);
933  Value *Col = II->getOperand(1);
935  // Use the maximum column as stride. It must be the same with load
936  // stride.
937  Value *Stride = Builder.getInt64(64);
938  Value *I8Ptr =
939  Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy());
940  std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
941  Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt,
942  Args);
943 }
944 
945 // %65 = load <256 x i32>, <256 x i32>* %p, align 64
946 // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65)
947 // -->
948 // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
949 // i8* %p, i64 64)
950 bool X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) {
951  bool EraseLoad = true;
952  Value *Row = nullptr, *Col = nullptr;
953  Use &U = *(Cast->use_begin());
954  unsigned OpNo = U.getOperandNo();
955  auto *II = cast<IntrinsicInst>(U.getUser());
956  // TODO: If it is cast intrinsic or phi node, we can propagate the
957  // shape information through def-use chain.
958  if (!isAMXIntrinsic(II))
959  return false;
960  std::tie(Row, Col) = getShape(II, OpNo);
962  // Use the maximun column as stride.
963  Value *Stride = Builder.getInt64(64);
964  Value *I8Ptr;
965 
966  // To save compiling time, we create doninator tree when it is really
967  // needed.
968  if (!DT)
969  DT.reset(new DominatorTree(Func));
970  if (!DT->dominates(Row, LD) || !DT->dominates(Col, LD)) {
971  // store the value to stack and reload it from stack before cast.
972  auto *AllocaAddr =
973  createAllocaInstAtEntry(Builder, Cast->getParent(), LD->getType());
974  Builder.SetInsertPoint(&*std::next(LD->getIterator()));
975  Builder.CreateStore(LD, AllocaAddr);
976 
977  Builder.SetInsertPoint(Cast);
978  I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy());
979  EraseLoad = false;
980  } else {
981  I8Ptr = Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy());
982  }
983  std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
984 
985  Value *NewInst = Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal,
986  std::nullopt, Args);
987  Cast->replaceAllUsesWith(NewInst);
988 
989  return EraseLoad;
990 }
991 
992 bool X86LowerAMXCast::combineLdSt(SmallVectorImpl<Instruction *> &Casts) {
993  bool Change = false;
994  for (auto *Cast : Casts) {
995  auto *II = cast<IntrinsicInst>(Cast);
996  // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector(x86_amx %42)
997  // store <256 x i32> %43, <256 x i32>* %p, align 64
998  // -->
999  // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p,
1000  // i64 64, x86_amx %42)
1001  if (II->getIntrinsicID() == Intrinsic::x86_cast_tile_to_vector) {
1002  SmallVector<Instruction *, 2> DeadStores;
1003  for (User *U : Cast->users()) {
1004  StoreInst *Store = dyn_cast<StoreInst>(U);
1005  if (!Store)
1006  continue;
1007  combineCastStore(cast<IntrinsicInst>(Cast), Store);
1008  DeadStores.push_back(Store);
1009  Change = true;
1010  }
1011  for (auto *Store : DeadStores)
1012  Store->eraseFromParent();
1013  } else { // x86_cast_vector_to_tile
1015  auto *Load = dyn_cast<LoadInst>(Cast->getOperand(0));
1016  if (!Load || !Load->hasOneUse())
1017  continue;
1018  // %65 = load <256 x i32>, <256 x i32>* %p, align 64
1019  // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65)
1020  // -->
1021  // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
1022  // i8* %p, i64 64)
1023  if (combineLoadCast(cast<IntrinsicInst>(Cast), Load)) {
1024  // Set the operand is null so that load instruction can be erased.
1025  Cast->setOperand(0, nullptr);
1026  Load->eraseFromParent();
1027  }
1028  }
1029  }
1030  return Change;
1031 }
1032 
1033 bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo *TLI) {
1034  bool Change = false;
1035  // Collect tile cast instruction.
1036  SmallVector<Instruction *, 8> Vec2TileInsts;
1037  SmallVector<Instruction *, 8> Tile2VecInsts;
1038  SmallVector<Instruction *, 8> PhiCastWorkList;
1040  for (BasicBlock &BB : Func) {
1041  for (Instruction &I : BB) {
1042  Value *Vec;
1043  if (match(&I,
1044  m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value(Vec))))
1045  Vec2TileInsts.push_back(&I);
1046  else if (match(&I, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(
1047  m_Value(Vec))))
1048  Tile2VecInsts.push_back(&I);
1049  }
1050  }
1051 
1052  auto Convert = [&](SmallVectorImpl<Instruction *> &Insts, Intrinsic::ID IID) {
1053  for (auto *Inst : Insts) {
1054  for (User *U : Inst->users()) {
1055  IntrinsicInst *II = dyn_cast<IntrinsicInst>(U);
1056  if (!II || II->getIntrinsicID() != IID)
1057  continue;
1058  // T1 = vec2tile V0
1059  // V2 = tile2vec T1
1060  // V3 = OP V2
1061  // -->
1062  // T1 = vec2tile V0
1063  // V2 = tile2vec T1
1064  // V3 = OP V0
1065  II->replaceAllUsesWith(Inst->getOperand(0));
1066  Change = true;
1067  }
1068  }
1069  };
1070 
1071  Convert(Vec2TileInsts, Intrinsic::x86_cast_tile_to_vector);
1072  Convert(Tile2VecInsts, Intrinsic::x86_cast_vector_to_tile);
1073 
1075  auto EraseInst = [&](SmallVectorImpl<Instruction *> &Insts) {
1076  for (auto *Inst : Insts) {
1077  if (Inst->use_empty()) {
1078  Inst->eraseFromParent();
1079  Change = true;
1080  } else {
1081  LiveCasts.push_back(Inst);
1082  }
1083  }
1084  };
1085 
1086  EraseInst(Vec2TileInsts);
1087  EraseInst(Tile2VecInsts);
1088  Change |= combineLdSt(LiveCasts);
1089  EraseInst(LiveCasts);
1090 
1091  // Handle the A->B->A cast, and there is an intervening PHI node.
1092  for (BasicBlock &BB : Func) {
1093  for (Instruction &I : BB) {
1094  if (isAMXCast(&I)) {
1095  if (isa<PHINode>(I.getOperand(0)))
1096  PhiCastWorkList.push_back(&I);
1097  }
1098  }
1099  }
1100  for (auto *I : PhiCastWorkList) {
1101  // We skip the dead Amxcast.
1102  if (DeadInst.contains(I))
1103  continue;
1104  PHINode *PN = cast<PHINode>(I->getOperand(0));
1105  if (optimizeAMXCastFromPhi(cast<IntrinsicInst>(I), PN, DeadInst)) {
1106  DeadInst.insert(PN);
1107  Change = true;
1108  }
1109  }
1110 
1111  // Since we create new phi and merge AMXCast, some old phis and AMXCast might
1112  // have no uses. We do some DeadCodeElimination for them.
1113  while (!DeadInst.empty()) {
1114  Instruction *I = DeadInst.pop_back_val();
1115  Change |= DCEInstruction(I, DeadInst, TLI);
1116  }
1117  return Change;
1118 }
1119 
1120 // There might be remaining AMXcast after combineAMXcast and they should be
1121 // handled elegantly.
1122 bool X86LowerAMXCast::transformAMXCast(IntrinsicInst *AMXCast) {
1123  IRBuilder<> Builder(AMXCast);
1124  AllocaInst *AllocaAddr;
1125  Value *I8Ptr, *Stride;
1126  auto *Src = AMXCast->getOperand(0);
1127 
1128  auto Prepare = [&](Type *MemTy) {
1129  AllocaAddr = createAllocaInstAtEntry(Builder, AMXCast->getParent(), MemTy);
1130  I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy());
1131  Stride = Builder.getInt64(64);
1132  };
1133 
1134  if (AMXCast->getType()->isX86_AMXTy()) {
1135  // %2 = amxcast <225 x i32> %src to x86_amx
1136  // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
1137  // i8* %addr3, i64 60, x86_amx %2)
1138  // -->
1139  // %addr = alloca <225 x i32>, align 64
1140  // store <225 x i32> %src, <225 x i32>* %addr, align 64
1141  // %addr2 = bitcast <225 x i32>* %addr to i8*
1142  // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 15, i16 60,
1143  // i8* %addr2,
1144  // i64 60)
1145  // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
1146  // i8* %addr3, i64 60, x86_amx %2)
1147  if (AMXCast->use_empty()) {
1148  AMXCast->eraseFromParent();
1149  return true;
1150  }
1151  Use &U = *(AMXCast->use_begin());
1152  unsigned OpNo = U.getOperandNo();
1153  auto *II = dyn_cast<IntrinsicInst>(U.getUser());
1154  if (!II)
1155  return false; // May be bitcast from x86amx to <256 x i32>.
1156  Prepare(AMXCast->getOperand(0)->getType());
1157  Builder.CreateStore(Src, AllocaAddr);
1158  // TODO we can pick an constant operand for the shape.
1159  Value *Row = nullptr, *Col = nullptr;
1160  std::tie(Row, Col) = getShape(II, OpNo);
1161  std::array<Value *, 4> Args = {
1162  Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty())};
1163  Value *NewInst = Builder.CreateIntrinsic(
1164  Intrinsic::x86_tileloadd64_internal, std::nullopt, Args);
1165  AMXCast->replaceAllUsesWith(NewInst);
1166  AMXCast->eraseFromParent();
1167  } else {
1168  // %2 = amxcast x86_amx %src to <225 x i32>
1169  // -->
1170  // %addr = alloca <225 x i32>, align 64
1171  // %addr2 = bitcast <225 x i32>* to i8*
1172  // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
1173  // i8* %addr2, i64 %stride)
1174  // %2 = load <225 x i32>, <225 x i32>* %addr, align 64
1175  auto *II = dyn_cast<IntrinsicInst>(Src);
1176  if (!II)
1177  return false; // May be bitcast from <256 x i32> to x86amx.
1178  Prepare(AMXCast->getType());
1179  Value *Row = II->getOperand(0);
1180  Value *Col = II->getOperand(1);
1181  std::array<Value *, 5> Args = {
1182  Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty()), Src};
1183  Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt,
1184  Args);
1185  Value *NewInst = Builder.CreateLoad(AMXCast->getType(), AllocaAddr);
1186  AMXCast->replaceAllUsesWith(NewInst);
1187  AMXCast->eraseFromParent();
1188  }
1189 
1190  return true;
1191 }
1192 
1193 bool X86LowerAMXCast::transformAllAMXCast() {
1194  bool Change = false;
1195  // Collect tile cast instruction.
1197  for (BasicBlock &BB : Func) {
1198  for (Instruction &I : BB) {
1199  if (isAMXCast(&I))
1200  WorkLists.push_back(&I);
1201  }
1202  }
1203 
1204  for (auto *Inst : WorkLists) {
1205  Change |= transformAMXCast(cast<IntrinsicInst>(Inst));
1206  }
1207 
1208  return Change;
1209 }
1210 
1211 } // anonymous namespace
1212 
1213 namespace {
1214 
1215 class X86LowerAMXTypeLegacyPass : public FunctionPass {
1216 public:
1217  static char ID;
1218 
1219  X86LowerAMXTypeLegacyPass() : FunctionPass(ID) {
1221  }
1222 
1223  bool runOnFunction(Function &F) override {
1224  bool C = false;
1225  TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
1226  TargetLibraryInfo *TLI =
1227  &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
1228 
1229  X86LowerAMXCast LAC(F);
1230  C |= LAC.combineAMXcast(TLI);
1231  // There might be remaining AMXcast after combineAMXcast and they should be
1232  // handled elegantly.
1233  C |= LAC.transformAllAMXCast();
1234 
1235  X86LowerAMXType LAT(F);
1236  C |= LAT.visit();
1237 
1238  // Prepare for fast register allocation at O0.
1239  // Todo: May better check the volatile model of AMX code, not just
1240  // by checking Attribute::OptimizeNone and CodeGenOpt::None.
1241  if (TM->getOptLevel() == CodeGenOpt::None) {
1242  // If Front End not use O0 but the Mid/Back end use O0, (e.g.
1243  // "Clang -O2 -S -emit-llvm t.c" + "llc t.ll") we should make
1244  // sure the amx data is volatile, that is nessary for AMX fast
1245  // register allocation.
1246  if (!F.hasFnAttribute(Attribute::OptimizeNone)) {
1247  X86VolatileTileData VTD(F);
1248  C = VTD.volatileTileData() || C;
1249  }
1250  }
1251 
1252  return C;
1253  }
1254 
1255  void getAnalysisUsage(AnalysisUsage &AU) const override {
1256  AU.setPreservesCFG();
1259  }
1260 };
1261 
1262 } // anonymous namespace
1263 
1264 static const char PassName[] = "Lower AMX type for load/store";
1266 INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
1267  false)
1270 INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
1271  false)
1272 
1274  return new X86LowerAMXTypeLegacyPass();
1275 }
i
i
Definition: README.txt:29
llvm::createX86LowerAMXTypePass
FunctionPass * createX86LowerAMXTypePass()
The pass transforms load/store <256 x i32> to AMX load/store intrinsics or split the data to two <128...
Definition: X86LowerAMXType.cpp:1273
llvm::SPII::Load
@ Load
Definition: SparcInstrInfo.h:32
ValueTypes.h
DCEInstruction
static bool DCEInstruction(Instruction *I, SmallSetVector< Instruction *, 16 > &WorkList, const TargetLibraryInfo *TLI)
Definition: DCE.cpp:88
INITIALIZE_PASS_BEGIN
INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false, false) INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass
llvm
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
M
We currently emits eax Perhaps this is what we really should generate is Is imull three or four cycles eax eax The current instruction priority is based on pattern complexity The former is more complex because it folds a load so the latter will not be emitted Perhaps we should use AddedComplexity to give LEA32r a higher priority We should always try to match LEA first since the LEA matching code does some estimate to determine whether the match is profitable if we care more about code then imull is better It s two bytes shorter than movl leal On a Pentium M
Definition: README.txt:252
llvm::DataLayout
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:113
llvm::AArch64PACKey::ID
ID
Definition: AArch64BaseInfo.h:818
llvm::BasicBlock::iterator
InstListType::iterator iterator
Instruction iterators...
Definition: BasicBlock.h:87
PHI
Rewrite undef for PHI
Definition: AMDGPURewriteUndefForPHI.cpp:101
IntrinsicInst.h
llvm::Function
Definition: Function.h:60
Pass.h
llvm::IntrinsicInst::getIntrinsicID
Intrinsic::ID getIntrinsicID() const
Return the intrinsic ID of this intrinsic.
Definition: IntrinsicInst.h:54
llvm::ARM_MB::LD
@ LD
Definition: ARMBaseInfo.h:72
llvm::BitCastInst
This class represents a no-op cast from one type to another.
Definition: Instructions.h:5253
llvm::SmallVector< Instruction *, 8 >
llvm::LegacyLegalizeActions::Bitcast
@ Bitcast
Perform the operation on a different, but equivalently sized type.
Definition: LegacyLegalizerInfo.h:54
llvm::IRBuilder<>
llvm::Use::get
Value * get() const
Definition: Use.h:66
llvm::SmallDenseMap
Definition: DenseMap.h:880
Local.h
OptimizationRemarkEmitter.h
llvm::DominatorTree
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:166
createAllocaInstAtEntry
static AllocaInst * createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB, Type *Ty)
Definition: X86LowerAMXType.cpp:95
createTileStore
static Instruction * createTileStore(Instruction *TileDef, Value *Ptr)
Definition: X86LowerAMXType.cpp:464
isAMXCast
static bool isAMXCast(Instruction *II)
Definition: X86LowerAMXType.cpp:71
llvm::Type
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
isAMXIntrinsic
static bool isAMXIntrinsic(Value *I)
Definition: X86LowerAMXType.cpp:77
llvm::Use::getOperandNo
unsigned getOperandNo() const
Return the operand # of this use in its User.
Definition: Use.cpp:31
replaceWithTileLoad
static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI=false)
Definition: X86LowerAMXType.cpp:482
F
#define F(x, y, z)
Definition: MD5.cpp:55
DEBUG_TYPE
#define DEBUG_TYPE
Definition: X86LowerAMXType.cpp:69
llvm::BasicBlock
LLVM Basic Block Representation.
Definition: BasicBlock.h:55
X86.h
llvm::Type::getX86_AMXTy
static Type * getX86_AMXTy(LLVMContext &C)
Definition: Type.cpp:234
TargetMachine.h
llvm::PassRegistry::getPassRegistry
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
Definition: PassRegistry.cpp:24
llvm::PatternMatch::match
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
E
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
llvm::User
Definition: User.h:44
C
(vector float) vec_cmpeq(*A, *B) C
Definition: README_ALTIVEC.txt:86
llvm::AnalysisUsage
Represent the analysis usage information of a pass.
Definition: PassAnalysisSupport.h:47
TargetLibraryInfo.h
AssumeBundleBuilder.h
false
Definition: StackSlotColoring.cpp:141
llvm::Instruction
Definition: Instruction.h:42
llvm::salvageDebugInfo
void salvageDebugInfo(const MachineRegisterInfo &MRI, MachineInstr &MI)
Assuming the instruction MI is going to be deleted, attempt to salvage debug users of MI by writing t...
Definition: Utils.cpp:1361
llvm::Use::getUser
User * getUser() const
Returns the User that contains this Use.
Definition: Use.h:72
PatternMatch.h
llvm::SetVector< T, SmallVector< T, N >, SmallDenseSet< T, N > >::empty
bool empty() const
Determine if the SetVector is empty or not.
Definition: SetVector.h:72
llvm::Value::use_empty
bool use_empty() const
Definition: Value.h:344
llvm::CallingConv::ID
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
INITIALIZE_PASS_END
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:58
Passes.h
llvm::TargetPassConfig
Target-Independent Code Generator Pass Configuration Options.
Definition: TargetPassConfig.h:84
llvm::StoreInst
An instruction for storing to memory.
Definition: Instructions.h:298
llvm::SetVector< T, SmallVector< T, N >, SmallDenseSet< T, N > >::contains
bool contains(const key_type &key) const
Check if the SetVector contains the given key.
Definition: SetVector.h:209
llvm::Instruction::eraseFromParent
SymbolTableList< Instruction >::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
Definition: Instruction.cpp:81
llvm::TargetLibraryInfoWrapperPass
Definition: TargetLibraryInfo.h:475
llvm::ARM_MB::ST
@ ST
Definition: ARMBaseInfo.h:73
INITIALIZE_PASS_DEPENDENCY
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
llvm::PHINode::addIncoming
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
Definition: Instructions.h:2847
llvm::LLVMContext
This is an important class for using LLVM in a threaded context.
Definition: LLVMContext.h:67
llvm::numbers::e
constexpr double e
Definition: MathExtras.h:53
I
#define I(x, y, z)
Definition: MD5.cpp:58
llvm::make_early_inc_range
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition: STLExtras.h:716
getAllocaPos
static Value * getAllocaPos(BasicBlock *BB)
Definition: X86LowerAMXType.cpp:448
TargetPassConfig.h
llvm::SPII::Store
@ Store
Definition: SparcInstrInfo.h:33
llvm::initializeX86LowerAMXTypeLegacyPassPass
void initializeX86LowerAMXTypeLegacyPassPass(PassRegistry &)
IRBuilder.h
assert
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
llvm::TargetMachine
Primary interface to the complete machine description for the target machine.
Definition: TargetMachine.h:77
Ptr
@ Ptr
Definition: TargetLibraryInfo.cpp:60
llvm::Value::use_begin
use_iterator use_begin()
Definition: Value.h:360
llvm::salvageKnowledge
void salvageKnowledge(Instruction *I, AssumptionCache *AC=nullptr, DominatorTree *DT=nullptr)
Calls BuildAssumeFromInst and if the resulting llvm.assume is valid insert if before I.
Definition: AssumeBundleBuilder.cpp:293
llvm::Module
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:66
PassName
static const char PassName[]
Definition: X86LowerAMXType.cpp:1264
Builder
assume Assume Builder
Definition: AssumeBundleBuilder.cpp:651
llvm::PatternMatch::m_Value
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
Definition: PatternMatch.h:76
llvm::User::setOperand
void setOperand(unsigned i, Value *Val)
Definition: User.h:174
llvm::SetVector< T, SmallVector< T, N >, SmallDenseSet< T, N > >::insert
bool insert(const value_type &X)
Insert a new element into the SetVector.
Definition: SetVector.h:141
llvm::CodeGenOpt::None
@ None
Definition: CodeGen.h:53
getFirstNonAllocaInTheEntryBlock
static Instruction * getFirstNonAllocaInTheEntryBlock(Function &F)
Definition: X86LowerAMXType.cpp:110
llvm::isInstructionTriviallyDead
bool isInstructionTriviallyDead(Instruction *I, const TargetLibraryInfo *TLI=nullptr)
Return true if the result produced by the instruction is not used, and the instruction will return.
Definition: Local.cpp:396
llvm::SetVector< T, SmallVector< T, N >, SmallDenseSet< T, N > >::pop_back_val
T pop_back_val()
Definition: SetVector.h:232
DataLayout.h
llvm::AnalysisUsage::setPreservesCFG
void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition: Pass.cpp:265
llvm_unreachable
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
Definition: ErrorHandling.h:143
llvm::Value::getType
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
llvm::Instruction::getFunction
const Function * getFunction() const
Return the function this instruction belongs to.
Definition: Instruction.cpp:73
llvm::Value::replaceAllUsesWith
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:532
llvm::ilist_node_impl::getIterator
self_iterator getIterator()
Definition: ilist_node.h:82
DL
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Definition: AArch64SLSHardening.cpp:76
isIncomingOfPHI
static bool isIncomingOfPHI(Instruction *I)
Definition: X86LowerAMXType.cpp:507
llvm::ifs::IFSSymbolType::Func
@ Func
llvm::LoadInst
An instruction for reading from memory.
Definition: Instructions.h:174
llvm::User::replaceUsesOfWith
bool replaceUsesOfWith(Value *From, Value *To)
Replace uses of one Value with another.
Definition: User.cpp:21
runOnFunction
static bool runOnFunction(Function &F, bool PostInlining)
Definition: EntryExitInstrumenter.cpp:85
j
return j(j<< 16)
getShape
static std::pair< Value *, Value * > getShape(IntrinsicInst *II, unsigned OpNo)
Definition: X86LowerAMXType.cpp:117
llvm::AMDGPU::SendMsg::Op
Op
Definition: SIDefines.h:348
llvm::post_order
iterator_range< po_iterator< T > > post_order(const T &G)
Definition: PostOrderIterator.h:189
Function.h
llvm::TargetLibraryInfo
Provides information about what library functions are available for the current target.
Definition: TargetLibraryInfo.h:226
llvm::SmallVectorImpl::clear
void clear()
Definition: SmallVector.h:614
llvm::IntrinsicInst
A wrapper class for inspecting calls to intrinsic functions.
Definition: IntrinsicInst.h:47
llvm::AllocaInst::setAlignment
void setAlignment(Align Align)
Definition: Instructions.h:126
Instructions.h
PostOrderIterator.h
llvm::IRBuilderBase::getInt16
ConstantInt * getInt16(uint16_t C)
Get a constant 16-bit value.
Definition: IRBuilder.h:467
llvm::CallBase::getArgOperand
Value * getArgOperand(unsigned i) const
Definition: InstrTypes.h:1342
llvm::Instruction::getParent
const BasicBlock * getParent() const
Definition: Instruction.h:91
llvm::SmallVectorImpl::pop_back_val
T pop_back_val()
Definition: SmallVector.h:677
TargetTransformInfo.h
llvm::PHINode
Definition: Instructions.h:2697
llvm::SmallVectorImpl< Instruction * >
llvm::reverse
auto reverse(ContainerTy &&C)
Definition: STLExtras.h:485
llvm::SmallSetVector
A SetVector that performs no allocations if smaller than a certain size.
Definition: SetVector.h:307
TM
const char LLVMTargetMachineRef TM
Definition: PassBuilderBindings.cpp:47
llvm::FunctionPass
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:308
BB
Common register allocation spilling lr str ldr sxth r3 ldr mla r4 can lr mov lr str ldr sxth r3 mla r4 and then merge mul and lr str ldr sxth r3 mla r4 It also increase the likelihood the store may become dead bb27 Successors according to LLVM BB
Definition: README.txt:39
llvm::IRBuilderBase::CreateUDiv
Value * CreateUDiv(Value *LHS, Value *RHS, const Twine &Name="", bool isExact=false)
Definition: IRBuilder.h:1290
llvm::AnalysisUsage::addRequired
AnalysisUsage & addRequired()
Definition: PassAnalysisSupport.h:75
llvm::AMDGPU::HSAMD::Kernel::Key::Args
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
Definition: AMDGPUMetadata.h:394
llvm::AllocaInst
an instruction to allocate memory on the stack
Definition: Instructions.h:59
llvm::User::getOperand
Value * getOperand(unsigned i) const
Definition: User.h:169
llvm::pdb::PDB_SymType::Block
@ Block
InitializePasses.h
llvm::Value
LLVM Value Representation.
Definition: Value.h:74
llvm::VectorType::get
static VectorType * get(Type *ElementType, ElementCount EC)
This static method is the primary way to construct an VectorType.
Definition: Type.cpp:668
llvm::Value::users
iterator_range< user_iterator > users()
Definition: Value.h:421
llvm::Type::isX86_AMXTy
bool isX86_AMXTy() const
Return true if this is X86 AMX.
Definition: Type.h:195
SetVector.h
llvm::Instruction::moveBefore
void moveBefore(Instruction *MovePos)
Unlink this instruction from its current basic block and insert it into the basic block that MovePos ...
Definition: Instruction.cpp:107
llvm::Use
A Use represents the edge between a Value definition and its users.
Definition: Use.h:43
SmallSet.h
llvm::Intrinsic::ID
unsigned ID
Definition: TargetTransformInfo.h:39