LLVM  15.0.0git
X86LowerAMXType.cpp
Go to the documentation of this file.
1 //===- Target/X86/X86LowerAMXType.cpp - -------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Pass to transform <256 x i32> load/store
10 /// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only
11 /// provides simple operation on x86_amx. The basic elementwise operation
12 /// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32>
13 /// and only AMX intrinsics can operate on the type, we need transform
14 /// load/store <256 x i32> instruction to AMX load/store. If the bitcast can
15 /// not be combined with load/store, we transform the bitcast to amx load/store
16 /// and <256 x i32> store/load.
17 ///
18 /// If Front End not use O0 but the Mid/Back end use O0, (e.g. "Clang -O2 -S
19 /// -emit-llvm t.c" + "llc t.ll") we should make sure the amx data is volatile,
20 /// because that is necessary for AMX fast register allocation. (In Fast
21 /// registera allocation, register will be allocated before spill/reload, so
22 /// there is no additional register for amx to identify the step in spill.)
23 /// The volatileTileData() will handle this case.
24 /// e.g.
25 /// ----------------------------------------------------------
26 /// | def %td = ... |
27 /// | ... |
28 /// | "use %td" |
29 /// ----------------------------------------------------------
30 /// will transfer to -->
31 /// ----------------------------------------------------------
32 /// | def %td = ... |
33 /// | call void @llvm.x86.tilestored64.internal(mem, %td) |
34 /// | ... |
35 /// | %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)|
36 /// | "use %td2" |
37 /// ----------------------------------------------------------
38 //
39 //===----------------------------------------------------------------------===//
40 //
41 #include "X86.h"
43 #include "llvm/ADT/SetVector.h"
44 #include "llvm/ADT/SmallSet.h"
48 #include "llvm/CodeGen/Passes.h"
51 #include "llvm/IR/DataLayout.h"
52 #include "llvm/IR/Function.h"
53 #include "llvm/IR/IRBuilder.h"
54 #include "llvm/IR/Instructions.h"
55 #include "llvm/IR/IntrinsicInst.h"
56 #include "llvm/IR/IntrinsicsX86.h"
57 #include "llvm/IR/PatternMatch.h"
58 #include "llvm/InitializePasses.h"
59 #include "llvm/Pass.h"
63 
64 #include <map>
65 
66 using namespace llvm;
67 using namespace PatternMatch;
68 
69 #define DEBUG_TYPE "lower-amx-type"
70 
71 static bool isAMXCast(Instruction *II) {
72  return match(II,
73  m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value())) ||
74  match(II, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(m_Value()));
75 }
76 
77 static bool isAMXIntrinsic(Value *I) {
78  auto *II = dyn_cast<IntrinsicInst>(I);
79  if (!II)
80  return false;
81  if (isAMXCast(II))
82  return false;
83  // Check if return type or parameter is x86_amx. If it is x86_amx
84  // the intrinsic must be x86 amx intrinsics.
85  if (II->getType()->isX86_AMXTy())
86  return true;
87  for (Value *V : II->args()) {
88  if (V->getType()->isX86_AMXTy())
89  return true;
90  }
91 
92  return false;
93 }
94 
96  Type *Ty) {
97  Function &F = *BB->getParent();
98  Module *M = BB->getModule();
99  const DataLayout &DL = M->getDataLayout();
100 
101  LLVMContext &Ctx = Builder.getContext();
102  auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx));
103  unsigned AllocaAS = DL.getAllocaAddrSpace();
104  AllocaInst *AllocaRes =
105  new AllocaInst(Ty, AllocaAS, "", &F.getEntryBlock().front());
106  AllocaRes->setAlignment(AllocaAlignment);
107  return AllocaRes;
108 }
109 
111  for (Instruction &I : F.getEntryBlock())
112  if (!isa<AllocaInst>(&I))
113  return &I;
114  llvm_unreachable("No terminator in the entry block!");
115 }
116 
117 static std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) {
118  IRBuilder<> Builder(II);
119  Value *Row = nullptr, *Col = nullptr;
120  switch (II->getIntrinsicID()) {
121  default:
122  llvm_unreachable("Expect amx intrinsics");
123  case Intrinsic::x86_tileloadd64_internal:
124  case Intrinsic::x86_tileloaddt164_internal:
125  case Intrinsic::x86_tilestored64_internal: {
126  Row = II->getArgOperand(0);
127  Col = II->getArgOperand(1);
128  break;
129  }
130  // a * b + c
131  // The shape depends on which operand.
132  case Intrinsic::x86_tdpbssd_internal:
133  case Intrinsic::x86_tdpbsud_internal:
134  case Intrinsic::x86_tdpbusd_internal:
135  case Intrinsic::x86_tdpbuud_internal:
136  case Intrinsic::x86_tdpbf16ps_internal: {
137  switch (OpNo) {
138  case 3:
139  Row = II->getArgOperand(0);
140  Col = II->getArgOperand(1);
141  break;
142  case 4:
143  Row = II->getArgOperand(0);
144  Col = II->getArgOperand(2);
145  break;
146  case 5:
147  if (isa<ConstantInt>(II->getArgOperand(2)))
148  Row = Builder.getInt16(
149  (cast<ConstantInt>(II->getOperand(2))->getSExtValue()) / 4);
150  else if (isa<Instruction>(II->getArgOperand(2))) {
151  // When it is not a const value and it is not a function argument, we
152  // create Row after the definition of II->getOperand(2) instead of
153  // before II. For example, II is %118, we try to getshape for %117:
154  // %117 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x
155  // i32> %115).
156  // %118 = call x86_amx @llvm.x86.tdpbf16ps.internal(i16
157  // %104, i16 %105, i16 %106, x86_amx %110, x86_amx %114, x86_amx
158  // %117).
159  // If we create %row = udiv i16 %106, 4 before %118(aka. II), then its
160  // definition is after its user(new tileload for %117).
161  // So, the best choice is to create %row right after the definition of
162  // %106.
163  Builder.SetInsertPoint(cast<Instruction>(II->getOperand(2)));
164  Row = Builder.CreateUDiv(II->getOperand(2), Builder.getInt16(4));
165  cast<Instruction>(Row)->moveAfter(cast<Instruction>(II->getOperand(2)));
166  } else {
167  // When it is not a const value and it is a function argument, we create
168  // Row at the entry bb.
169  IRBuilder<> NewBuilder(
171  Row = NewBuilder.CreateUDiv(II->getOperand(2), NewBuilder.getInt16(4));
172  }
173  Col = II->getArgOperand(1);
174  break;
175  }
176  break;
177  }
178  }
179 
180  return std::make_pair(Row, Col);
181 }
182 
183 static std::pair<Value *, Value *> getShape(PHINode *Phi) {
184  Use &U = *(Phi->use_begin());
185  unsigned OpNo = U.getOperandNo();
186  User *V = U.getUser();
187  // TODO We don't traverse all users. To make the algorithm simple, here we
188  // just traverse the first user. If we can find shape, then return the shape,
189  // otherwise just return nullptr and the optimization for undef/zero will be
190  // abandoned.
191  while (V) {
192  if (isAMXCast(dyn_cast<Instruction>(V))) {
193  if (V->use_empty())
194  break;
195  Use &U = *(V->use_begin());
196  OpNo = U.getOperandNo();
197  V = U.getUser();
198  } else if (isAMXIntrinsic(V)) {
199  return getShape(cast<IntrinsicInst>(V), OpNo);
200  } else if (isa<PHINode>(V)) {
201  if (V->use_empty())
202  break;
203  Use &U = *(V->use_begin());
204  V = U.getUser();
205  } else {
206  break;
207  }
208  }
209 
210  return std::make_pair(nullptr, nullptr);
211 }
212 
213 namespace {
214 class X86LowerAMXType {
215  Function &Func;
216 
217  // In AMX intrinsics we let Shape = {Row, Col}, but the
218  // RealCol = Col / ElementSize. We may use the RealCol
219  // as a new Row for other new created AMX intrinsics.
220  std::map<Value *, Value *> Col2Row;
221 
222 public:
223  X86LowerAMXType(Function &F) : Func(F) {}
224  bool visit();
225  void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast);
226  void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST);
227  bool transformBitcast(BitCastInst *Bitcast);
228 };
229 
230 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
231 // %2 = bitcast <256 x i32> %src to x86_amx
232 // -->
233 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
234 // i8* %addr, i64 %stride64)
235 void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) {
236  Value *Row = nullptr, *Col = nullptr;
237  Use &U = *(Bitcast->use_begin());
238  unsigned OpNo = U.getOperandNo();
239  auto *II = cast<IntrinsicInst>(U.getUser());
240  std::tie(Row, Col) = getShape(II, OpNo);
242  // Use the maximun column as stride.
243  Value *Stride = Builder.getInt64(64);
244  Value *I8Ptr =
245  Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy());
246  std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
247 
248  Value *NewInst =
249  Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args);
250  Bitcast->replaceAllUsesWith(NewInst);
251 }
252 
253 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
254 // %stride);
255 // %13 = bitcast x86_amx %src to <256 x i32>
256 // store <256 x i32> %13, <256 x i32>* %addr, align 64
257 // -->
258 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
259 // %stride64, %13)
260 void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) {
261 
262  Value *Tile = Bitcast->getOperand(0);
263  auto *II = cast<IntrinsicInst>(Tile);
264  // Tile is output from AMX intrinsic. The first operand of the
265  // intrinsic is row, the second operand of the intrinsic is column.
266  Value *Row = II->getOperand(0);
267  Value *Col = II->getOperand(1);
269  // Use the maximum column as stride. It must be the same with load
270  // stride.
271  Value *Stride = Builder.getInt64(64);
272  Value *I8Ptr =
273  Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy());
274  std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
275  Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
276  if (Bitcast->hasOneUse())
277  return;
278  // %13 = bitcast x86_amx %src to <256 x i32>
279  // store <256 x i32> %13, <256 x i32>* %addr, align 64
280  // %add = <256 x i32> %13, <256 x i32> %src2
281  // -->
282  // %13 = bitcast x86_amx %src to <256 x i32>
283  // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
284  // %stride64, %13)
285  // %14 = load <256 x i32>, %addr
286  // %add = <256 x i32> %14, <256 x i32> %src2
287  Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1));
288  Bitcast->replaceAllUsesWith(Vec);
289 }
290 
291 // transform bitcast to <store, load> instructions.
292 bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) {
294  AllocaInst *AllocaAddr;
295  Value *I8Ptr, *Stride;
296  auto *Src = Bitcast->getOperand(0);
297 
298  auto Prepare = [&](Type *MemTy) {
299  AllocaAddr = createAllocaInstAtEntry(Builder, Bitcast->getParent(), MemTy);
300  I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy());
301  Stride = Builder.getInt64(64);
302  };
303 
304  if (Bitcast->getType()->isX86_AMXTy()) {
305  // %2 = bitcast <256 x i32> %src to x86_amx
306  // -->
307  // %addr = alloca <256 x i32>, align 64
308  // store <256 x i32> %src, <256 x i32>* %addr, align 64
309  // %addr2 = bitcast <256 x i32>* to i8*
310  // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
311  // i8* %addr2,
312  // i64 64)
313  Use &U = *(Bitcast->use_begin());
314  unsigned OpNo = U.getOperandNo();
315  auto *II = dyn_cast<IntrinsicInst>(U.getUser());
316  if (!II)
317  return false; // May be bitcast from x86amx to <256 x i32>.
318  Prepare(Bitcast->getOperand(0)->getType());
319  Builder.CreateStore(Src, AllocaAddr);
320  // TODO we can pick an constant operand for the shape.
321  Value *Row = nullptr, *Col = nullptr;
322  std::tie(Row, Col) = getShape(II, OpNo);
323  std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
324  Value *NewInst = Builder.CreateIntrinsic(
325  Intrinsic::x86_tileloadd64_internal, None, Args);
326  Bitcast->replaceAllUsesWith(NewInst);
327  } else {
328  // %2 = bitcast x86_amx %src to <256 x i32>
329  // -->
330  // %addr = alloca <256 x i32>, align 64
331  // %addr2 = bitcast <256 x i32>* to i8*
332  // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
333  // i8* %addr2, i64 %stride)
334  // %2 = load <256 x i32>, <256 x i32>* %addr, align 64
335  auto *II = dyn_cast<IntrinsicInst>(Src);
336  if (!II)
337  return false; // May be bitcast from <256 x i32> to x86amx.
338  Prepare(Bitcast->getType());
339  Value *Row = II->getOperand(0);
340  Value *Col = II->getOperand(1);
341  std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src};
342  Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
343  Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr);
344  Bitcast->replaceAllUsesWith(NewInst);
345  }
346 
347  return true;
348 }
349 
350 bool X86LowerAMXType::visit() {
352  Col2Row.clear();
353 
354  for (BasicBlock *BB : post_order(&Func)) {
356  auto *Bitcast = dyn_cast<BitCastInst>(&Inst);
357  if (!Bitcast)
358  continue;
359 
360  Value *Src = Bitcast->getOperand(0);
361  if (Bitcast->getType()->isX86_AMXTy()) {
362  if (Bitcast->user_empty()) {
363  DeadInsts.push_back(Bitcast);
364  continue;
365  }
366  LoadInst *LD = dyn_cast<LoadInst>(Src);
367  if (!LD) {
368  if (transformBitcast(Bitcast))
369  DeadInsts.push_back(Bitcast);
370  continue;
371  }
372  // If load has mutli-user, duplicate a vector load.
373  // %src = load <256 x i32>, <256 x i32>* %addr, align 64
374  // %2 = bitcast <256 x i32> %src to x86_amx
375  // %add = add <256 x i32> %src, <256 x i32> %src2
376  // -->
377  // %src = load <256 x i32>, <256 x i32>* %addr, align 64
378  // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
379  // i8* %addr, i64 %stride64)
380  // %add = add <256 x i32> %src, <256 x i32> %src2
381 
382  // If load has one user, the load will be eliminated in DAG ISel.
383  // %src = load <256 x i32>, <256 x i32>* %addr, align 64
384  // %2 = bitcast <256 x i32> %src to x86_amx
385  // -->
386  // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
387  // i8* %addr, i64 %stride64)
388  combineLoadBitcast(LD, Bitcast);
389  DeadInsts.push_back(Bitcast);
390  if (LD->hasOneUse())
391  DeadInsts.push_back(LD);
392  } else if (Src->getType()->isX86_AMXTy()) {
393  if (Bitcast->user_empty()) {
394  DeadInsts.push_back(Bitcast);
395  continue;
396  }
397  StoreInst *ST = nullptr;
398  for (Use &U : Bitcast->uses()) {
399  ST = dyn_cast<StoreInst>(U.getUser());
400  if (ST)
401  break;
402  }
403  if (!ST) {
404  if (transformBitcast(Bitcast))
405  DeadInsts.push_back(Bitcast);
406  continue;
407  }
408  // If bitcast (%13) has one use, combine bitcast and store to amx store.
409  // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
410  // %stride);
411  // %13 = bitcast x86_amx %src to <256 x i32>
412  // store <256 x i32> %13, <256 x i32>* %addr, align 64
413  // -->
414  // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
415  // %stride64, %13)
416  //
417  // If bitcast (%13) has multi-use, transform as below.
418  // %13 = bitcast x86_amx %src to <256 x i32>
419  // store <256 x i32> %13, <256 x i32>* %addr, align 64
420  // %add = <256 x i32> %13, <256 x i32> %src2
421  // -->
422  // %13 = bitcast x86_amx %src to <256 x i32>
423  // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
424  // %stride64, %13)
425  // %14 = load <256 x i32>, %addr
426  // %add = <256 x i32> %14, <256 x i32> %src2
427  //
428  combineBitcastStore(Bitcast, ST);
429  // Delete user first.
430  DeadInsts.push_back(ST);
431  DeadInsts.push_back(Bitcast);
432  }
433  }
434  }
435 
436  bool C = !DeadInsts.empty();
437 
438  for (auto *Inst : DeadInsts)
439  Inst->eraseFromParent();
440 
441  return C;
442 }
443 } // anonymous namespace
444 
446  Module *M = BB->getModule();
447  Function *F = BB->getParent();
448  IRBuilder<> Builder(&F->getEntryBlock().front());
449  const DataLayout &DL = M->getDataLayout();
450  unsigned AllocaAS = DL.getAllocaAddrSpace();
451  Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false);
452  AllocaInst *AllocaRes =
453  new AllocaInst(V256I32Ty, AllocaAS, "", &F->getEntryBlock().front());
454  BasicBlock::iterator Iter = AllocaRes->getIterator();
455  ++Iter;
456  Builder.SetInsertPoint(&*Iter);
457  Value *I8Ptr = Builder.CreateBitCast(AllocaRes, Builder.getInt8PtrTy());
458  return I8Ptr;
459 }
460 
461 static Instruction *createTileStore(Instruction *TileDef, Value *Ptr) {
462  assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!");
463  auto *II = cast<IntrinsicInst>(TileDef);
464  assert(II && "Not tile intrinsic!");
465  Value *Row = II->getOperand(0);
466  Value *Col = II->getOperand(1);
467 
468  BasicBlock *BB = TileDef->getParent();
469  BasicBlock::iterator Iter = TileDef->getIterator();
470  IRBuilder<> Builder(BB, ++Iter);
471  Value *Stride = Builder.getInt64(64);
472  std::array<Value *, 5> Args = {Row, Col, Ptr, Stride, TileDef};
473 
474  Instruction *TileStore =
475  Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
476  return TileStore;
477 }
478 
479 static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) {
480  Value *V = U.get();
481  assert(V->getType()->isX86_AMXTy() && "Not define tile!");
482 
483  // Get tile shape.
484  IntrinsicInst *II = nullptr;
485  if (IsPHI) {
486  Value *PhiOp = dyn_cast<PHINode>(V)->getIncomingValue(0);
487  II = cast<IntrinsicInst>(PhiOp);
488  } else {
489  II = cast<IntrinsicInst>(V);
490  }
491  Value *Row = II->getOperand(0);
492  Value *Col = II->getOperand(1);
493 
494  Instruction *UserI = dyn_cast<Instruction>(U.getUser());
495  IRBuilder<> Builder(UserI);
496  Value *Stride = Builder.getInt64(64);
497  std::array<Value *, 4> Args = {Row, Col, Ptr, Stride};
498 
499  Value *TileLoad =
500  Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args);
501  UserI->replaceUsesOfWith(V, TileLoad);
502 }
503 
505  for (Use &U : I->uses()) {
506  User *V = U.getUser();
507  if (isa<PHINode>(V))
508  return true;
509  }
510  return false;
511 }
512 
513 // Let all AMX tile data become volatile data, shorten the life range
514 // of each tile register before fast register allocation.
515 namespace {
516 class X86VolatileTileData {
517  Function &F;
518 
519 public:
520  X86VolatileTileData(Function &Func) : F(Func) {}
521  Value *updatePhiIncomings(BasicBlock *BB,
522  SmallVector<Instruction *, 2> &Incomings);
523  void replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr);
524  bool volatileTileData();
525  void volatileTilePHI(PHINode *Inst);
526  void volatileTileNonPHI(Instruction *I);
527 };
528 
529 Value *X86VolatileTileData::updatePhiIncomings(
531  Value *I8Ptr = getAllocaPos(BB);
532 
533  for (auto *I : Incomings) {
534  User *Store = createTileStore(I, I8Ptr);
535 
536  // All its uses (except phi) should load from stored mem.
537  for (Use &U : I->uses()) {
538  User *V = U.getUser();
539  if (isa<PHINode>(V) || V == Store)
540  continue;
541  replaceWithTileLoad(U, I8Ptr);
542  }
543  }
544  return I8Ptr;
545 }
546 
547 void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI,
548  Value *StorePtr) {
549  for (Use &U : PHI->uses())
550  replaceWithTileLoad(U, StorePtr, true);
551  PHI->eraseFromParent();
552 }
553 
554 // Smilar with volatileTileNonPHI, this function only handle PHI Nodes
555 // and their related AMX intrinsics.
556 // 1) PHI Def should change to tileload.
557 // 2) PHI Incoming Values should tilestored in just after their def.
558 // 3) The mem of these tileload and tilestores should be same.
559 // e.g.
560 // ------------------------------------------------------
561 // bb_dom:
562 // ...
563 // br i1 %bool.cond, label %if.else, label %if.then
564 //
565 // if.then:
566 // def %t0 = ...
567 // ...
568 // use %t0
569 // ...
570 // br label %if.end
571 //
572 // if.else:
573 // def %t1 = ...
574 // br label %if.end
575 //
576 // if.end:
577 // %td = phi x86_amx [ %t1, %if.else ], [ %t0, %if.then ]
578 // ...
579 // use %td
580 // ------------------------------------------------------
581 // -->
582 // ------------------------------------------------------
583 // bb_entry:
584 // %mem = alloca <256 x i32>, align 1024 *
585 // ...
586 // bb_dom:
587 // ...
588 // br i1 %bool.cond, label %if.else, label %if.then
589 //
590 // if.then:
591 // def %t0 = ...
592 // call void @llvm.x86.tilestored64.internal(mem, %t0) *
593 // ...
594 // %t0` = call x86_amx @llvm.x86.tileloadd64.internal(mem)*
595 // use %t0` *
596 // ...
597 // br label %if.end
598 //
599 // if.else:
600 // def %t1 = ...
601 // call void @llvm.x86.tilestored64.internal(mem, %t1) *
602 // br label %if.end
603 //
604 // if.end:
605 // ...
606 // %td = call x86_amx @llvm.x86.tileloadd64.internal(mem) *
607 // use %td
608 // ------------------------------------------------------
609 void X86VolatileTileData::volatileTilePHI(PHINode *PHI) {
610  BasicBlock *BB = PHI->getParent();
612 
613  for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) {
614  Value *Op = PHI->getIncomingValue(I);
615  Instruction *Inst = dyn_cast<Instruction>(Op);
616  assert(Inst && "We shouldn't fold AMX instrution!");
617  Incomings.push_back(Inst);
618  }
619 
620  Value *StorePtr = updatePhiIncomings(BB, Incomings);
621  replacePhiDefWithLoad(PHI, StorePtr);
622 }
623 
624 // Store the defined tile and load it before use.
625 // All its users are not PHI.
626 // e.g.
627 // ------------------------------------------------------
628 // def %td = ...
629 // ...
630 // "use %td"
631 // ------------------------------------------------------
632 // -->
633 // ------------------------------------------------------
634 // def %td = ...
635 // call void @llvm.x86.tilestored64.internal(mem, %td)
636 // ...
637 // %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)
638 // "use %td2"
639 // ------------------------------------------------------
640 void X86VolatileTileData::volatileTileNonPHI(Instruction *I) {
641  BasicBlock *BB = I->getParent();
642  Value *I8Ptr = getAllocaPos(BB);
643  User *Store = createTileStore(I, I8Ptr);
644 
645  // All its uses should load from stored mem.
646  for (Use &U : I->uses()) {
647  User *V = U.getUser();
648  assert(!isa<PHINode>(V) && "PHI Nodes should be excluded!");
649  if (V != Store)
650  replaceWithTileLoad(U, I8Ptr);
651  }
652 }
653 
654 // Volatile Tile Model:
655 // 1) All the uses of tile data comes from tileload in time.
656 // 2) All the defs of tile data tilestore into mem immediately.
657 // For example:
658 // --------------------------------------------------------------------------
659 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key
660 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
661 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx
662 // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
663 // call void @llvm.x86.tilestored64.internal(... td) area
664 // --------------------------------------------------------------------------
665 // 3) No terminator, call or other amx instructions in the key amx area.
666 bool X86VolatileTileData::volatileTileData() {
667  bool Changed = false;
668  for (BasicBlock &BB : F) {
670  SmallVector<Instruction *, 8> AMXDefInsts;
671 
672  for (Instruction &I : BB) {
673  if (!I.getType()->isX86_AMXTy())
674  continue;
675  if (isa<PHINode>(&I))
676  PHIInsts.push_back(&I);
677  else
678  AMXDefInsts.push_back(&I);
679  }
680 
681  // First we "volatile" the non-phi related amx intrinsics.
682  for (Instruction *I : AMXDefInsts) {
683  if (isIncomingOfPHI(I))
684  continue;
685  volatileTileNonPHI(I);
686  Changed = true;
687  }
688 
689  for (Instruction *I : PHIInsts) {
690  volatileTilePHI(dyn_cast<PHINode>(I));
691  Changed = true;
692  }
693  }
694  return Changed;
695 }
696 
697 } // anonymous namespace
698 
699 namespace {
700 
701 class X86LowerAMXCast {
702  Function &Func;
703 
704 public:
705  X86LowerAMXCast(Function &F) : Func(F) {}
706  bool combineAMXcast(TargetLibraryInfo *TLI);
707  bool transformAMXCast(IntrinsicInst *AMXCast);
708  bool transformAllAMXCast();
709  bool optimizeAMXCastFromPhi(IntrinsicInst *CI, PHINode *PN,
711 };
712 
713 static bool DCEInstruction(Instruction *I,
715  const TargetLibraryInfo *TLI) {
716  if (isInstructionTriviallyDead(I, TLI)) {
719 
720  // Null out all of the instruction's operands to see if any operand becomes
721  // dead as we go.
722  for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
723  Value *OpV = I->getOperand(i);
724  I->setOperand(i, nullptr);
725 
726  if (!OpV->use_empty() || I == OpV)
727  continue;
728 
729  // If the operand is an instruction that became dead as we nulled out the
730  // operand, and if it is 'trivially' dead, delete it in a future loop
731  // iteration.
732  if (Instruction *OpI = dyn_cast<Instruction>(OpV)) {
733  if (isInstructionTriviallyDead(OpI, TLI)) {
734  WorkList.insert(OpI);
735  }
736  }
737  }
738  I->eraseFromParent();
739  return true;
740  }
741  return false;
742 }
743 
744 /// This function handles following case
745 ///
746 /// A -> B amxcast
747 /// PHI
748 /// B -> A amxcast
749 ///
750 /// All the related PHI nodes can be replaced by new PHI nodes with type A.
751 /// The uses of \p CI can be changed to the new PHI node corresponding to \p PN.
752 bool X86LowerAMXCast::optimizeAMXCastFromPhi(
753  IntrinsicInst *CI, PHINode *PN,
755  IRBuilder<> Builder(CI);
756  Value *Src = CI->getOperand(0);
757  Type *SrcTy = Src->getType(); // Type B
758  Type *DestTy = CI->getType(); // Type A
759 
760  SmallVector<PHINode *, 4> PhiWorklist;
761  SmallSetVector<PHINode *, 4> OldPhiNodes;
762 
763  // Find all of the A->B casts and PHI nodes.
764  // We need to inspect all related PHI nodes, but PHIs can be cyclic, so
765  // OldPhiNodes is used to track all known PHI nodes, before adding a new
766  // PHI to PhiWorklist, it is checked against and added to OldPhiNodes first.
767  PhiWorklist.push_back(PN);
768  OldPhiNodes.insert(PN);
769  while (!PhiWorklist.empty()) {
770  auto *OldPN = PhiWorklist.pop_back_val();
771  for (unsigned I = 0; I < OldPN->getNumOperands(); ++I) {
772  Value *IncValue = OldPN->getIncomingValue(I);
773  // TODO: currently, We ignore cases where it is a const. In the future, we
774  // might support const.
775  if (isa<Constant>(IncValue)) {
776  auto *IncConst = dyn_cast<Constant>(IncValue);
777  if (!isa<UndefValue>(IncValue) && !IncConst->isZeroValue())
778  return false;
779  Value *Row = nullptr, *Col = nullptr;
780  std::tie(Row, Col) = getShape(OldPN);
781  // TODO: If it is not constant the Row and Col must domoniate tilezero
782  // that we are going to create.
783  if (!Row || !Col || !isa<Constant>(Row) || !isa<Constant>(Col))
784  return false;
785  // Create tilezero at the end of incoming block.
786  auto *Block = OldPN->getIncomingBlock(I);
787  BasicBlock::iterator Iter = Block->getTerminator()->getIterator();
788  Instruction *NewInst = Builder.CreateIntrinsic(
789  Intrinsic::x86_tilezero_internal, None, {Row, Col});
790  NewInst->moveBefore(&*Iter);
791  NewInst = Builder.CreateIntrinsic(Intrinsic::x86_cast_tile_to_vector,
792  {IncValue->getType()}, {NewInst});
793  NewInst->moveBefore(&*Iter);
794  // Replace InValue with new Value.
795  OldPN->setIncomingValue(I, NewInst);
796  IncValue = NewInst;
797  }
798 
799  if (auto *PNode = dyn_cast<PHINode>(IncValue)) {
800  if (OldPhiNodes.insert(PNode))
801  PhiWorklist.push_back(PNode);
802  continue;
803  }
804  Instruction *ACI = dyn_cast<Instruction>(IncValue);
805  if (ACI && isAMXCast(ACI)) {
806  // Verify it's a A->B cast.
807  Type *TyA = ACI->getOperand(0)->getType();
808  Type *TyB = ACI->getType();
809  if (TyA != DestTy || TyB != SrcTy)
810  return false;
811  continue;
812  }
813  return false;
814  }
815  }
816 
817  // Check that each user of each old PHI node is something that we can
818  // rewrite, so that all of the old PHI nodes can be cleaned up afterwards.
819  for (auto *OldPN : OldPhiNodes) {
820  for (User *V : OldPN->users()) {
821  Instruction *ACI = dyn_cast<Instruction>(V);
822  if (ACI && isAMXCast(ACI)) {
823  // Verify it's a B->A cast.
824  Type *TyB = ACI->getOperand(0)->getType();
825  Type *TyA = ACI->getType();
826  if (TyA != DestTy || TyB != SrcTy)
827  return false;
828  } else if (auto *PHI = dyn_cast<PHINode>(V)) {
829  // As long as the user is another old PHI node, then even if we don't
830  // rewrite it, the PHI web we're considering won't have any users
831  // outside itself, so it'll be dead.
832  // example:
833  // bb.0:
834  // %0 = amxcast ...
835  // bb.1:
836  // %1 = amxcast ...
837  // bb.2:
838  // %goodphi = phi %0, %1
839  // %3 = amxcast %goodphi
840  // bb.3:
841  // %goodphi2 = phi %0, %goodphi
842  // %4 = amxcast %goodphi2
843  // When optimizeAMXCastFromPhi process %3 and %goodphi, %goodphi2 is
844  // outside the phi-web, so the combination stop When
845  // optimizeAMXCastFromPhi process %4 and %goodphi2, the optimization
846  // will be done.
847  if (OldPhiNodes.count(PHI) == 0)
848  return false;
849  } else
850  return false;
851  }
852  }
853 
854  // For each old PHI node, create a corresponding new PHI node with a type A.
856  for (auto *OldPN : OldPhiNodes) {
857  Builder.SetInsertPoint(OldPN);
858  PHINode *NewPN = Builder.CreatePHI(DestTy, OldPN->getNumOperands());
859  NewPNodes[OldPN] = NewPN;
860  }
861 
862  // Fill in the operands of new PHI nodes.
863  for (auto *OldPN : OldPhiNodes) {
864  PHINode *NewPN = NewPNodes[OldPN];
865  for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) {
866  Value *V = OldPN->getOperand(j);
867  Value *NewV = nullptr;
868  Instruction *ACI = dyn_cast<Instruction>(V);
869  // There should not be a AMXcast from a const.
870  if (ACI && isAMXCast(ACI))
871  NewV = ACI->getOperand(0);
872  else if (auto *PrevPN = dyn_cast<PHINode>(V))
873  NewV = NewPNodes[PrevPN];
874  assert(NewV);
875  NewPN->addIncoming(NewV, OldPN->getIncomingBlock(j));
876  }
877  }
878 
879  // Traverse all accumulated PHI nodes and process its users,
880  // which are Stores and BitcCasts. Without this processing
881  // NewPHI nodes could be replicated and could lead to extra
882  // moves generated after DeSSA.
883  // If there is a store with type B, change it to type A.
884 
885  // Replace users of BitCast B->A with NewPHI. These will help
886  // later to get rid of a closure formed by OldPHI nodes.
887  for (auto *OldPN : OldPhiNodes) {
888  PHINode *NewPN = NewPNodes[OldPN];
889  for (User *V : make_early_inc_range(OldPN->users())) {
890  Instruction *ACI = dyn_cast<Instruction>(V);
891  if (ACI && isAMXCast(ACI)) {
892  Type *TyB = ACI->getOperand(0)->getType();
893  Type *TyA = ACI->getType();
894  assert(TyA == DestTy && TyB == SrcTy);
895  (void)TyA;
896  (void)TyB;
897  ACI->replaceAllUsesWith(NewPN);
898  DeadInst.insert(ACI);
899  } else if (auto *PHI = dyn_cast<PHINode>(V)) {
900  // We don't need to push PHINode into DeadInst since they are operands
901  // of rootPN DCE can safely delete rootPN's operands if rootPN is dead.
902  assert(OldPhiNodes.contains(PHI));
903  (void)PHI;
904  } else
905  llvm_unreachable("all uses should be handled");
906  }
907  }
908  return true;
909 }
910 
911 // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %42)
912 // store <256 x i32> %43, <256 x i32>* %p, align 64
913 // -->
914 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p,
915 // i64 64, x86_amx %42)
916 static void combineCastStore(IntrinsicInst *Cast, StoreInst *ST) {
917  Value *Tile = Cast->getOperand(0);
918  // TODO: If it is cast intrinsic or phi node, we can propagate the
919  // shape information through def-use chain.
920  if (!isAMXIntrinsic(Tile))
921  return;
922  auto *II = cast<IntrinsicInst>(Tile);
923  // Tile is output from AMX intrinsic. The first operand of the
924  // intrinsic is row, the second operand of the intrinsic is column.
925  Value *Row = II->getOperand(0);
926  Value *Col = II->getOperand(1);
928  // Use the maximum column as stride. It must be the same with load
929  // stride.
930  Value *Stride = Builder.getInt64(64);
931  Value *I8Ptr =
932  Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy());
933  std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
934  Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
935 }
936 
937 // %65 = load <256 x i32>, <256 x i32>* %p, align 64
938 // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65)
939 // -->
940 // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
941 // i8* %p, i64 64)
942 static void combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) {
943  Value *Row = nullptr, *Col = nullptr;
944  Use &U = *(Cast->use_begin());
945  unsigned OpNo = U.getOperandNo();
946  auto *II = cast<IntrinsicInst>(U.getUser());
947  // TODO: If it is cast intrinsic or phi node, we can propagate the
948  // shape information through def-use chain.
949  if (!isAMXIntrinsic(II))
950  return;
951  std::tie(Row, Col) = getShape(II, OpNo);
953  // Use the maximun column as stride.
954  Value *Stride = Builder.getInt64(64);
955  Value *I8Ptr =
956  Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy());
957  std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
958 
959  Value *NewInst =
960  Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args);
961  Cast->replaceAllUsesWith(NewInst);
962 }
963 
964 static bool combineLdSt(SmallVectorImpl<Instruction *> &Casts) {
965  bool Change = false;
966  for (auto *Cast : Casts) {
967  auto *II = cast<IntrinsicInst>(Cast);
968  // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector(x86_amx %42)
969  // store <256 x i32> %43, <256 x i32>* %p, align 64
970  // -->
971  // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p,
972  // i64 64, x86_amx %42)
973  if (II->getIntrinsicID() == Intrinsic::x86_cast_tile_to_vector) {
975  for (User *U : Cast->users()) {
976  StoreInst *Store = dyn_cast<StoreInst>(U);
977  if (!Store)
978  continue;
979  combineCastStore(cast<IntrinsicInst>(Cast), Store);
980  DeadStores.push_back(Store);
981  Change = true;
982  }
983  for (auto *Store : DeadStores)
984  Store->eraseFromParent();
985  } else { // x86_cast_vector_to_tile
987  auto *Load = dyn_cast<LoadInst>(Cast->getOperand(0));
988  if (!Load || !Load->hasOneUse())
989  continue;
990  // %65 = load <256 x i32>, <256 x i32>* %p, align 64
991  // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65)
992  // -->
993  // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
994  // i8* %p, i64 64)
995  combineLoadCast(cast<IntrinsicInst>(Cast), Load);
996  // Set the operand is null so that load instruction can be erased.
997  Cast->setOperand(0, nullptr);
998  Load->eraseFromParent();
999  }
1000  }
1001  return Change;
1002 }
1003 
1004 bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo *TLI) {
1005  bool Change = false;
1006  // Collect tile cast instruction.
1007  SmallVector<Instruction *, 8> Vec2TileInsts;
1008  SmallVector<Instruction *, 8> Tile2VecInsts;
1009  SmallVector<Instruction *, 8> PhiCastWorkList;
1011  for (BasicBlock &BB : Func) {
1012  for (Instruction &I : BB) {
1013  Value *Vec;
1014  if (match(&I,
1015  m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value(Vec))))
1016  Vec2TileInsts.push_back(&I);
1017  else if (match(&I, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(
1018  m_Value(Vec))))
1019  Tile2VecInsts.push_back(&I);
1020  }
1021  }
1022 
1023  auto Convert = [&](SmallVectorImpl<Instruction *> &Insts, Intrinsic::ID IID) {
1024  for (auto *Inst : Insts) {
1025  for (User *U : Inst->users()) {
1026  IntrinsicInst *II = dyn_cast<IntrinsicInst>(U);
1027  if (!II || II->getIntrinsicID() != IID)
1028  continue;
1029  // T1 = vec2tile V0
1030  // V2 = tile2vec T1
1031  // V3 = OP V2
1032  // -->
1033  // T1 = vec2tile V0
1034  // V2 = tile2vec T1
1035  // V3 = OP V0
1036  II->replaceAllUsesWith(Inst->getOperand(0));
1037  Change = true;
1038  }
1039  }
1040  };
1041 
1042  Convert(Vec2TileInsts, Intrinsic::x86_cast_tile_to_vector);
1043  Convert(Tile2VecInsts, Intrinsic::x86_cast_vector_to_tile);
1044 
1046  auto EraseInst = [&](SmallVectorImpl<Instruction *> &Insts) {
1047  for (auto *Inst : Insts) {
1048  if (Inst->use_empty()) {
1049  Inst->eraseFromParent();
1050  Change = true;
1051  } else {
1052  LiveCasts.push_back(Inst);
1053  }
1054  }
1055  };
1056 
1057  EraseInst(Vec2TileInsts);
1058  EraseInst(Tile2VecInsts);
1059  Change |= combineLdSt(LiveCasts);
1060  EraseInst(LiveCasts);
1061 
1062  // Handle the A->B->A cast, and there is an intervening PHI node.
1063  for (BasicBlock &BB : Func) {
1064  for (Instruction &I : BB) {
1065  if (isAMXCast(&I)) {
1066  if (isa<PHINode>(I.getOperand(0)))
1067  PhiCastWorkList.push_back(&I);
1068  }
1069  }
1070  }
1071  for (auto *I : PhiCastWorkList) {
1072  // We skip the dead Amxcast.
1073  if (DeadInst.contains(I))
1074  continue;
1075  PHINode *PN = cast<PHINode>(I->getOperand(0));
1076  if (optimizeAMXCastFromPhi(cast<IntrinsicInst>(I), PN, DeadInst)) {
1077  DeadInst.insert(PN);
1078  Change = true;
1079  }
1080  }
1081 
1082  // Since we create new phi and merge AMXCast, some old phis and AMXCast might
1083  // have no uses. We do some DeadCodeElimination for them.
1084  while (!DeadInst.empty()) {
1085  Instruction *I = DeadInst.pop_back_val();
1086  Change |= DCEInstruction(I, DeadInst, TLI);
1087  }
1088  return Change;
1089 }
1090 
1091 // There might be remaining AMXcast after combineAMXcast and they should be
1092 // handled elegantly.
1093 bool X86LowerAMXCast::transformAMXCast(IntrinsicInst *AMXCast) {
1094  IRBuilder<> Builder(AMXCast);
1095  AllocaInst *AllocaAddr;
1096  Value *I8Ptr, *Stride;
1097  auto *Src = AMXCast->getOperand(0);
1098 
1099  auto Prepare = [&](Type *MemTy) {
1100  AllocaAddr = createAllocaInstAtEntry(Builder, AMXCast->getParent(), MemTy);
1101  I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy());
1102  Stride = Builder.getInt64(64);
1103  };
1104 
1105  if (AMXCast->getType()->isX86_AMXTy()) {
1106  // %2 = amxcast <225 x i32> %src to x86_amx
1107  // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
1108  // i8* %addr3, i64 60, x86_amx %2)
1109  // -->
1110  // %addr = alloca <225 x i32>, align 64
1111  // store <225 x i32> %src, <225 x i32>* %addr, align 64
1112  // %addr2 = bitcast <225 x i32>* %addr to i8*
1113  // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 15, i16 60,
1114  // i8* %addr2,
1115  // i64 60)
1116  // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
1117  // i8* %addr3, i64 60, x86_amx %2)
1118  if (AMXCast->use_empty()) {
1119  AMXCast->eraseFromParent();
1120  return true;
1121  }
1122  Use &U = *(AMXCast->use_begin());
1123  unsigned OpNo = U.getOperandNo();
1124  auto *II = dyn_cast<IntrinsicInst>(U.getUser());
1125  if (!II)
1126  return false; // May be bitcast from x86amx to <256 x i32>.
1127  Prepare(AMXCast->getOperand(0)->getType());
1128  Builder.CreateStore(Src, AllocaAddr);
1129  // TODO we can pick an constant operand for the shape.
1130  Value *Row = nullptr, *Col = nullptr;
1131  std::tie(Row, Col) = getShape(II, OpNo);
1132  std::array<Value *, 4> Args = {
1133  Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty())};
1134  Value *NewInst = Builder.CreateIntrinsic(
1135  Intrinsic::x86_tileloadd64_internal, None, Args);
1136  AMXCast->replaceAllUsesWith(NewInst);
1137  AMXCast->eraseFromParent();
1138  } else {
1139  // %2 = amxcast x86_amx %src to <225 x i32>
1140  // -->
1141  // %addr = alloca <225 x i32>, align 64
1142  // %addr2 = bitcast <225 x i32>* to i8*
1143  // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
1144  // i8* %addr2, i64 %stride)
1145  // %2 = load <225 x i32>, <225 x i32>* %addr, align 64
1146  auto *II = dyn_cast<IntrinsicInst>(Src);
1147  if (!II)
1148  return false; // May be bitcast from <256 x i32> to x86amx.
1149  Prepare(AMXCast->getType());
1150  Value *Row = II->getOperand(0);
1151  Value *Col = II->getOperand(1);
1152  std::array<Value *, 5> Args = {
1153  Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty()), Src};
1154  Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
1155  Value *NewInst = Builder.CreateLoad(AMXCast->getType(), AllocaAddr);
1156  AMXCast->replaceAllUsesWith(NewInst);
1157  AMXCast->eraseFromParent();
1158  }
1159 
1160  return true;
1161 }
1162 
1163 bool X86LowerAMXCast::transformAllAMXCast() {
1164  bool Change = false;
1165  // Collect tile cast instruction.
1167  for (BasicBlock &BB : Func) {
1168  for (Instruction &I : BB) {
1169  if (isAMXCast(&I))
1170  WorkLists.push_back(&I);
1171  }
1172  }
1173 
1174  for (auto *Inst : WorkLists) {
1175  Change |= transformAMXCast(cast<IntrinsicInst>(Inst));
1176  }
1177 
1178  return Change;
1179 }
1180 
1181 } // anonymous namespace
1182 
1183 namespace {
1184 
1185 class X86LowerAMXTypeLegacyPass : public FunctionPass {
1186 public:
1187  static char ID;
1188 
1189  X86LowerAMXTypeLegacyPass() : FunctionPass(ID) {
1191  }
1192 
1193  bool runOnFunction(Function &F) override {
1194  bool C = false;
1195  TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
1196  TargetLibraryInfo *TLI =
1197  &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
1198  X86LowerAMXCast LAC(F);
1199  C |= LAC.combineAMXcast(TLI);
1200  // There might be remaining AMXcast after combineAMXcast and they should be
1201  // handled elegantly.
1202  C |= LAC.transformAllAMXCast();
1203 
1204  X86LowerAMXType LAT(F);
1205  C |= LAT.visit();
1206 
1207  // Prepare for fast register allocation at O0.
1208  // Todo: May better check the volatile model of AMX code, not just
1209  // by checking Attribute::OptimizeNone and CodeGenOpt::None.
1210  if (TM->getOptLevel() == CodeGenOpt::None) {
1211  // If Front End not use O0 but the Mid/Back end use O0, (e.g.
1212  // "Clang -O2 -S -emit-llvm t.c" + "llc t.ll") we should make
1213  // sure the amx data is volatile, that is nessary for AMX fast
1214  // register allocation.
1215  if (!F.hasFnAttribute(Attribute::OptimizeNone)) {
1216  X86VolatileTileData VTD(F);
1217  C = VTD.volatileTileData() || C;
1218  }
1219  }
1220 
1221  return C;
1222  }
1223 
1224  void getAnalysisUsage(AnalysisUsage &AU) const override {
1225  AU.setPreservesCFG();
1228  }
1229 };
1230 
1231 } // anonymous namespace
1232 
1233 static const char PassName[] = "Lower AMX type for load/store";
1235 INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
1236  false)
1239 INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
1240  false)
1241 
1243  return new X86LowerAMXTypeLegacyPass();
1244 }
i
i
Definition: README.txt:29
llvm::createX86LowerAMXTypePass
FunctionPass * createX86LowerAMXTypePass()
The pass transforms load/store <256 x i32> to AMX load/store intrinsics or split the data to two <128...
Definition: X86LowerAMXType.cpp:1242
ValueTypes.h
DCEInstruction
static bool DCEInstruction(Instruction *I, SmallSetVector< Instruction *, 16 > &WorkList, const TargetLibraryInfo *TLI)
Definition: DCE.cpp:88
INITIALIZE_PASS_BEGIN
INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false, false) INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass
llvm
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:17
M
We currently emits eax Perhaps this is what we really should generate is Is imull three or four cycles eax eax The current instruction priority is based on pattern complexity The former is more complex because it folds a load so the latter will not be emitted Perhaps we should use AddedComplexity to give LEA32r a higher priority We should always try to match LEA first since the LEA matching code does some estimate to determine whether the match is profitable if we care more about code then imull is better It s two bytes shorter than movl leal On a Pentium M
Definition: README.txt:252
llvm::DataLayout
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:113
llvm::BasicBlock::iterator
InstListType::iterator iterator
Instruction iterators...
Definition: BasicBlock.h:87
IntrinsicInst.h
llvm::Function
Definition: Function.h:60
Pass.h
llvm::IntrinsicInst::getIntrinsicID
Intrinsic::ID getIntrinsicID() const
Return the intrinsic ID of this intrinsic.
Definition: IntrinsicInst.h:53
llvm::ARM_MB::LD
@ LD
Definition: ARMBaseInfo.h:72
llvm::BitCastInst
This class represents a no-op cast from one type to another.
Definition: Instructions.h:5225
llvm::SmallVector< Instruction *, 8 >
llvm::LegacyLegalizeActions::Bitcast
@ Bitcast
Perform the operation on a different, but equivalently sized type.
Definition: LegacyLegalizerInfo.h:54
llvm::IRBuilder<>
llvm::Use::get
Value * get() const
Definition: Use.h:66
llvm::SmallDenseMap
Definition: DenseMap.h:882
Local.h
OptimizationRemarkEmitter.h
createAllocaInstAtEntry
static AllocaInst * createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB, Type *Ty)
Definition: X86LowerAMXType.cpp:95
createTileStore
static Instruction * createTileStore(Instruction *TileDef, Value *Ptr)
Definition: X86LowerAMXType.cpp:461
isAMXCast
static bool isAMXCast(Instruction *II)
Definition: X86LowerAMXType.cpp:71
llvm::Type
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
llvm::reverse
auto reverse(ContainerTy &&C, std::enable_if_t< has_rbegin< ContainerTy >::value > *=nullptr)
Definition: STLExtras.h:380
isAMXIntrinsic
static bool isAMXIntrinsic(Value *I)
Definition: X86LowerAMXType.cpp:77
llvm::Use::getOperandNo
unsigned getOperandNo() const
Return the operand # of this use in its User.
Definition: Use.cpp:31
llvm::SPII::Load
@ Load
Definition: SparcInstrInfo.h:32
llvm::SmallVectorImpl::pop_back_val
LLVM_NODISCARD T pop_back_val()
Definition: SmallVector.h:654
replaceWithTileLoad
static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI=false)
Definition: X86LowerAMXType.cpp:479
F
#define F(x, y, z)
Definition: MD5.cpp:55
DEBUG_TYPE
#define DEBUG_TYPE
Definition: X86LowerAMXType.cpp:69
llvm::BasicBlock
LLVM Basic Block Representation.
Definition: BasicBlock.h:55
X86.h
llvm::Type::getX86_AMXTy
static Type * getX86_AMXTy(LLVMContext &C)
Definition: Type.cpp:234
TargetMachine.h
llvm::PassRegistry::getPassRegistry
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
Definition: PassRegistry.cpp:31
llvm::PatternMatch::match
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
llvm::PHINode::getIncomingValue
Value * getIncomingValue(unsigned i) const
Return incoming value number x.
Definition: Instructions.h:2760
E
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
llvm::User
Definition: User.h:44
C
(vector float) vec_cmpeq(*A, *B) C
Definition: README_ALTIVEC.txt:86
llvm::AnalysisUsage
Represent the analysis usage information of a pass.
Definition: PassAnalysisSupport.h:47
TargetLibraryInfo.h
AssumeBundleBuilder.h
llvm::Value::uses
iterator_range< use_iterator > uses()
Definition: Value.h:376
false
Definition: StackSlotColoring.cpp:141
llvm::Instruction
Definition: Instruction.h:42
llvm::SPII::Store
@ Store
Definition: SparcInstrInfo.h:33
llvm::Use::getUser
User * getUser() const
Returns the User that contains this Use.
Definition: Use.h:72
PatternMatch.h
llvm::SetVector< T, SmallVector< T, N >, SmallDenseSet< T, N > >::empty
bool empty() const
Determine if the SetVector is empty or not.
Definition: SetVector.h:72
llvm::PHINode::getNumIncomingValues
unsigned getNumIncomingValues() const
Return the number of incoming edges.
Definition: Instructions.h:2756
llvm::None
const NoneType None
Definition: None.h:24
llvm::Value::use_empty
bool use_empty() const
Definition: Value.h:344
llvm::CallingConv::ID
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
INITIALIZE_PASS_END
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:58
Passes.h
llvm::TargetPassConfig
Target-Independent Code Generator Pass Configuration Options.
Definition: TargetPassConfig.h:84
llvm::StoreInst
An instruction for storing to memory.
Definition: Instructions.h:305
llvm::SetVector< T, SmallVector< T, N >, SmallDenseSet< T, N > >::contains
bool contains(const key_type &key) const
Check if the SetVector contains the given key.
Definition: SetVector.h:209
llvm::Instruction::eraseFromParent
SymbolTableList< Instruction >::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
Definition: Instruction.cpp:77
llvm::TargetLibraryInfoWrapperPass
Definition: TargetLibraryInfo.h:468
llvm::ARM_MB::ST
@ ST
Definition: ARMBaseInfo.h:73
INITIALIZE_PASS_DEPENDENCY
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
llvm::PHINode::addIncoming
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
Definition: Instructions.h:2814
llvm::LLVMContext
This is an important class for using LLVM in a threaded context.
Definition: LLVMContext.h:68
llvm::numbers::e
constexpr double e
Definition: MathExtras.h:57
I
#define I(x, y, z)
Definition: MD5.cpp:58
llvm::make_early_inc_range
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition: STLExtras.h:608
getAllocaPos
static Value * getAllocaPos(BasicBlock *BB)
Definition: X86LowerAMXType.cpp:445
TargetPassConfig.h
llvm::initializeX86LowerAMXTypeLegacyPassPass
void initializeX86LowerAMXTypeLegacyPassPass(PassRegistry &)
IRBuilder.h
assert
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
llvm::TargetMachine
Primary interface to the complete machine description for the target machine.
Definition: TargetMachine.h:77
llvm::User::replaceUsesOfWith
void replaceUsesOfWith(Value *From, Value *To)
Replace uses of one Value with another.
Definition: User.cpp:21
llvm::Value::use_begin
use_iterator use_begin()
Definition: Value.h:360
llvm::salvageKnowledge
void salvageKnowledge(Instruction *I, AssumptionCache *AC=nullptr, DominatorTree *DT=nullptr)
Calls BuildAssumeFromInst and if the resulting llvm.assume is valid insert if before I.
Definition: AssumeBundleBuilder.cpp:293
llvm::Module
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:65
PassName
static const char PassName[]
Definition: X86LowerAMXType.cpp:1233
Builder
assume Assume Builder
Definition: AssumeBundleBuilder.cpp:651
llvm::PatternMatch::m_Value
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
Definition: PatternMatch.h:76
llvm::User::setOperand
void setOperand(unsigned i, Value *Val)
Definition: User.h:174
llvm::SetVector< T, SmallVector< T, N >, SmallDenseSet< T, N > >::insert
bool insert(const value_type &X)
Insert a new element into the SetVector.
Definition: SetVector.h:141
llvm::CodeGenOpt::None
@ None
Definition: CodeGen.h:53
getFirstNonAllocaInTheEntryBlock
static Instruction * getFirstNonAllocaInTheEntryBlock(Function &F)
Definition: X86LowerAMXType.cpp:110
llvm::isInstructionTriviallyDead
bool isInstructionTriviallyDead(Instruction *I, const TargetLibraryInfo *TLI=nullptr)
Return true if the result produced by the instruction is not used, and the instruction will return.
Definition: Local.cpp:395
DataLayout.h
llvm::AnalysisUsage::setPreservesCFG
void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition: Pass.cpp:263
llvm_unreachable
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
Definition: ErrorHandling.h:143
llvm::Value::getType
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
llvm::Instruction::getFunction
const Function * getFunction() const
Return the function this instruction belongs to.
Definition: Instruction.cpp:69
llvm::Value::replaceAllUsesWith
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:529
llvm::ilist_node_impl::getIterator
self_iterator getIterator()
Definition: ilist_node.h:82
DL
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Definition: AArch64SLSHardening.cpp:76
isIncomingOfPHI
static bool isIncomingOfPHI(Instruction *I)
Definition: X86LowerAMXType.cpp:504
llvm::ifs::IFSSymbolType::Func
@ Func
llvm::LoadInst
An instruction for reading from memory.
Definition: Instructions.h:176
runOnFunction
static bool runOnFunction(Function &F, bool PostInlining)
Definition: EntryExitInstrumenter.cpp:69
j
return j(j<< 16)
getShape
static std::pair< Value *, Value * > getShape(IntrinsicInst *II, unsigned OpNo)
Definition: X86LowerAMXType.cpp:117
llvm::AMDGPU::SendMsg::Op
Op
Definition: SIDefines.h:326
llvm::post_order
iterator_range< po_iterator< T > > post_order(const T &G)
Definition: PostOrderIterator.h:189
Function.h
llvm::salvageDebugInfo
void salvageDebugInfo(Instruction &I)
Assuming the instruction I is going to be deleted, attempt to salvage debug users of I by writing the...
Definition: Local.cpp:1739
llvm::TargetLibraryInfo
Provides information about what library functions are available for the current target.
Definition: TargetLibraryInfo.h:222
llvm::SmallVectorImpl::clear
void clear()
Definition: SmallVector.h:591
llvm::IntrinsicInst
A wrapper class for inspecting calls to intrinsic functions.
Definition: IntrinsicInst.h:46
llvm::AllocaInst::setAlignment
void setAlignment(Align Align)
Definition: Instructions.h:125
Instructions.h
PostOrderIterator.h
llvm::IRBuilderBase::getInt16
ConstantInt * getInt16(uint16_t C)
Get a constant 16-bit value.
Definition: IRBuilder.h:456
llvm::CallBase::getArgOperand
Value * getArgOperand(unsigned i) const
Definition: InstrTypes.h:1341
llvm::Instruction::getParent
const BasicBlock * getParent() const
Definition: Instruction.h:91
TargetTransformInfo.h
llvm::PHINode
Definition: Instructions.h:2664
llvm::SmallVectorImpl< Instruction * >
llvm::SmallSetVector
A SetVector that performs no allocations if smaller than a certain size.
Definition: SetVector.h:307
TM
const char LLVMTargetMachineRef TM
Definition: PassBuilderBindings.cpp:47
llvm::FunctionPass
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:308
BB
Common register allocation spilling lr str ldr sxth r3 ldr mla r4 can lr mov lr str ldr sxth r3 mla r4 and then merge mul and lr str ldr sxth r3 mla r4 It also increase the likelihood the store may become dead bb27 Successors according to LLVM BB
Definition: README.txt:39
llvm::IRBuilderBase::CreateUDiv
Value * CreateUDiv(Value *LHS, Value *RHS, const Twine &Name="", bool isExact=false)
Definition: IRBuilder.h:1251
llvm::AnalysisUsage::addRequired
AnalysisUsage & addRequired()
Definition: PassAnalysisSupport.h:75
llvm::AMDGPU::HSAMD::Kernel::Key::Args
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
Definition: AMDGPUMetadata.h:394
llvm::AllocaInst
an instruction to allocate memory on the stack
Definition: Instructions.h:58
llvm::User::getOperand
Value * getOperand(unsigned i) const
Definition: User.h:169
llvm::pdb::PDB_SymType::Block
@ Block
InitializePasses.h
llvm::SetVector< T, SmallVector< T, N >, SmallDenseSet< T, N > >::pop_back_val
LLVM_NODISCARD T pop_back_val()
Definition: SetVector.h:232
llvm::Value
LLVM Value Representation.
Definition: Value.h:74
llvm::VectorType::get
static VectorType * get(Type *ElementType, ElementCount EC)
This static method is the primary way to construct an VectorType.
Definition: Type.cpp:668
llvm::Value::users
iterator_range< user_iterator > users()
Definition: Value.h:421
llvm::Type::isX86_AMXTy
bool isX86_AMXTy() const
Return true if this is X86 AMX.
Definition: Type.h:176
SetVector.h
llvm::Instruction::moveBefore
void moveBefore(Instruction *MovePos)
Unlink this instruction from its current basic block and insert it into the basic block that MovePos ...
Definition: Instruction.cpp:96
llvm::Use
A Use represents the edge between a Value definition and its users.
Definition: Use.h:43
SmallSet.h
llvm::Intrinsic::ID
unsigned ID
Definition: TargetTransformInfo.h:37