LLVM 22.0.0git
X86LowerAMXType.cpp
Go to the documentation of this file.
1//===- Target/X86/X86LowerAMXType.cpp - -------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// \file Pass to transform <256 x i32> load/store
10/// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only
11/// provides simple operation on x86_amx. The basic elementwise operation
12/// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32>
13/// and only AMX intrinsics can operate on the type, we need transform
14/// load/store <256 x i32> instruction to AMX load/store. If the bitcast can
15/// not be combined with load/store, we transform the bitcast to amx load/store
16/// and <256 x i32> store/load.
17///
18/// If Front End not use O0 but the Mid/Back end use O0, (e.g. "Clang -O2 -S
19/// -emit-llvm t.c" + "llc t.ll") we should make sure the amx data is volatile,
20/// because that is necessary for AMX fast register allocation. (In Fast
21/// registera allocation, register will be allocated before spill/reload, so
22/// there is no additional register for amx to identify the step in spill.)
23/// The volatileTileData() will handle this case.
24/// e.g.
25/// ----------------------------------------------------------
26/// | def %td = ... |
27/// | ... |
28/// | "use %td" |
29/// ----------------------------------------------------------
30/// will transfer to -->
31/// ----------------------------------------------------------
32/// | def %td = ... |
33/// | call void @llvm.x86.tilestored64.internal(mem, %td) |
34/// | ... |
35/// | %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)|
36/// | "use %td2" |
37/// ----------------------------------------------------------
38//
39//===----------------------------------------------------------------------===//
40//
41#include "X86.h"
43#include "llvm/ADT/SetVector.h"
46#include "llvm/CodeGen/Passes.h"
49#include "llvm/IR/Analysis.h"
50#include "llvm/IR/DataLayout.h"
51#include "llvm/IR/Function.h"
52#include "llvm/IR/IRBuilder.h"
55#include "llvm/IR/IntrinsicsX86.h"
56#include "llvm/IR/PassManager.h"
59#include "llvm/Pass.h"
63
64#include <map>
65
66using namespace llvm;
67using namespace PatternMatch;
68
69#define DEBUG_TYPE "x86-lower-amx-type"
70
76
77static bool isAMXIntrinsic(Value *I) {
79 if (!II)
80 return false;
81 if (isAMXCast(II))
82 return false;
83 // Check if return type or parameter is x86_amx. If it is x86_amx
84 // the intrinsic must be x86 amx intrinsics.
85 if (II->getType()->isX86_AMXTy())
86 return true;
87 for (Value *V : II->args()) {
88 if (V->getType()->isX86_AMXTy())
89 return true;
90 }
91
92 return false;
93}
94
95static bool containsAMXCode(Function &F) {
96 for (BasicBlock &BB : F)
97 for (Instruction &I : BB)
98 if (I.getType()->isX86_AMXTy())
99 return true;
100 return false;
101}
102
104 Type *Ty) {
105 Function &F = *BB->getParent();
106 const DataLayout &DL = F.getDataLayout();
107
108 LLVMContext &Ctx = Builder.getContext();
109 auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx));
110 unsigned AllocaAS = DL.getAllocaAddrSpace();
111 AllocaInst *AllocaRes =
112 new AllocaInst(Ty, AllocaAS, "", F.getEntryBlock().begin());
113 AllocaRes->setAlignment(AllocaAlignment);
114 return AllocaRes;
115}
116
118 for (Instruction &I : F.getEntryBlock())
119 if (!isa<AllocaInst>(&I))
120 return &I;
121 llvm_unreachable("No terminator in the entry block!");
122}
123
124static Value *getRowFromCol(Instruction *II, Value *V, unsigned Granularity) {
125 IRBuilder<> Builder(II);
126 Value *RealRow = nullptr;
127 if (isa<ConstantInt>(V))
128 RealRow =
129 Builder.getInt16((cast<ConstantInt>(V)->getSExtValue()) / Granularity);
130 else if (isa<Instruction>(V)) {
131 // When it is not a const value and it is not a function argument, we
132 // create Row after the definition of V instead of
133 // before II. For example, II is %118, we try to getshape for %117:
134 // %117 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x
135 // i32> %115).
136 // %118 = call x86_amx @llvm.x86.tdpbf16ps.internal(i16
137 // %104, i16 %105, i16 %106, x86_amx %110, x86_amx %114, x86_amx
138 // %117).
139 // If we create %row = udiv i16 %106, 4 before %118(aka. II), then its
140 // definition is after its user(new tileload for %117).
141 // So, the best choice is to create %row right after the definition of
142 // %106.
143 Builder.SetInsertPoint(cast<Instruction>(V));
144 RealRow = Builder.CreateUDiv(V, Builder.getInt16(4));
145 cast<Instruction>(RealRow)->moveAfter(cast<Instruction>(V));
146 } else {
147 // When it is not a const value and it is a function argument, we create
148 // Row at the entry bb.
149 IRBuilder<> NewBuilder(
150 getFirstNonAllocaInTheEntryBlock(*II->getFunction()));
151 RealRow = NewBuilder.CreateUDiv(V, NewBuilder.getInt16(Granularity));
152 }
153 return RealRow;
154}
155
156// TODO: Refine the row and col-in-bytes of tile to row and col of matrix.
157std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) {
158 IRBuilder<> Builder(II);
159 Value *Row = nullptr, *Col = nullptr;
160 switch (II->getIntrinsicID()) {
161 default:
162 llvm_unreachable("Expect amx intrinsics");
163 case Intrinsic::x86_tileloadd64_internal:
164 case Intrinsic::x86_tileloaddt164_internal:
165 case Intrinsic::x86_tilestored64_internal:
166 case Intrinsic::x86_tileloaddrs64_internal:
167 case Intrinsic::x86_tileloaddrst164_internal: {
168 Row = II->getArgOperand(0);
169 Col = II->getArgOperand(1);
170 break;
171 }
172 // a * b + c
173 // The shape depends on which operand.
174 case Intrinsic::x86_tcmmimfp16ps_internal:
175 case Intrinsic::x86_tcmmrlfp16ps_internal:
176 case Intrinsic::x86_tdpbssd_internal:
177 case Intrinsic::x86_tdpbsud_internal:
178 case Intrinsic::x86_tdpbusd_internal:
179 case Intrinsic::x86_tdpbuud_internal:
180 case Intrinsic::x86_tdpbf16ps_internal:
181 case Intrinsic::x86_tdpfp16ps_internal:
182 case Intrinsic::x86_tmmultf32ps_internal:
183 case Intrinsic::x86_tdpbf8ps_internal:
184 case Intrinsic::x86_tdpbhf8ps_internal:
185 case Intrinsic::x86_tdphbf8ps_internal:
186 case Intrinsic::x86_tdphf8ps_internal: {
187 switch (OpNo) {
188 case 3:
189 Row = II->getArgOperand(0);
190 Col = II->getArgOperand(1);
191 break;
192 case 4:
193 Row = II->getArgOperand(0);
194 Col = II->getArgOperand(2);
195 break;
196 case 5:
197 Row = getRowFromCol(II, II->getArgOperand(2), 4);
198 Col = II->getArgOperand(1);
199 break;
200 }
201 break;
202 }
203 case Intrinsic::x86_tcvtrowd2ps_internal:
204 case Intrinsic::x86_tcvtrowps2bf16h_internal:
205 case Intrinsic::x86_tcvtrowps2bf16l_internal:
206 case Intrinsic::x86_tcvtrowps2phh_internal:
207 case Intrinsic::x86_tcvtrowps2phl_internal:
208 case Intrinsic::x86_tilemovrow_internal: {
209 assert(OpNo == 2 && "Illegal Operand Number.");
210 Row = II->getArgOperand(0);
211 Col = II->getArgOperand(1);
212 break;
213 }
214 }
215
216 return std::make_pair(Row, Col);
217}
218
219static std::pair<Value *, Value *> getShape(PHINode *Phi) {
220 Use &U = *(Phi->use_begin());
221 unsigned OpNo = U.getOperandNo();
222 User *V = U.getUser();
223 // TODO We don't traverse all users. To make the algorithm simple, here we
224 // just traverse the first user. If we can find shape, then return the shape,
225 // otherwise just return nullptr and the optimization for undef/zero will be
226 // abandoned.
227 while (V) {
229 if (V->use_empty())
230 break;
231 Use &U = *(V->use_begin());
232 OpNo = U.getOperandNo();
233 V = U.getUser();
234 } else if (isAMXIntrinsic(V)) {
235 return getShape(cast<IntrinsicInst>(V), OpNo);
236 } else if (isa<PHINode>(V)) {
237 if (V->use_empty())
238 break;
239 Use &U = *(V->use_begin());
240 V = U.getUser();
241 } else {
242 break;
243 }
244 }
245
246 return std::make_pair(nullptr, nullptr);
247}
248
249namespace {
250class X86LowerAMXType {
251 Function &Func;
252
253 // In AMX intrinsics we let Shape = {Row, Col}, but the
254 // RealCol = Col / ElementSize. We may use the RealCol
255 // as a new Row for other new created AMX intrinsics.
256 std::map<Value *, Value *> Col2Row;
257
258public:
259 X86LowerAMXType(Function &F) : Func(F) {}
260 bool visit();
261 void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast);
262 void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST);
263 bool transformBitcast(BitCastInst *Bitcast);
264};
265
266// %src = load <256 x i32>, <256 x i32>* %addr, align 64
267// %2 = bitcast <256 x i32> %src to x86_amx
268// -->
269// %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
270// i8* %addr, i64 %stride64)
271void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) {
272 Value *Row = nullptr, *Col = nullptr;
273 Use &U = *(Bitcast->use_begin());
274 unsigned OpNo = U.getOperandNo();
275 auto *II = cast<IntrinsicInst>(U.getUser());
276 std::tie(Row, Col) = getShape(II, OpNo);
277 IRBuilder<> Builder(Bitcast);
278 // Use the maximun column as stride.
279 Value *Stride = Builder.getInt64(64);
280 Value *I8Ptr = LD->getOperand(0);
281 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
282
283 Value *NewInst =
284 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, Args);
285 Bitcast->replaceAllUsesWith(NewInst);
286}
287
288// %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
289// %stride);
290// %13 = bitcast x86_amx %src to <256 x i32>
291// store <256 x i32> %13, <256 x i32>* %addr, align 64
292// -->
293// call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
294// %stride64, %13)
295void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) {
296
297 Value *Tile = Bitcast->getOperand(0);
298 auto *II = cast<IntrinsicInst>(Tile);
299 // Tile is output from AMX intrinsic. The first operand of the
300 // intrinsic is row, the second operand of the intrinsic is column.
301 Value *Row = II->getOperand(0);
302 Value *Col = II->getOperand(1);
303 IRBuilder<> Builder(ST);
304 // Use the maximum column as stride. It must be the same with load
305 // stride.
306 Value *Stride = Builder.getInt64(64);
307 Value *I8Ptr = ST->getOperand(1);
308 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
309 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, Args);
310 if (Bitcast->hasOneUse())
311 return;
312 // %13 = bitcast x86_amx %src to <256 x i32>
313 // store <256 x i32> %13, <256 x i32>* %addr, align 64
314 // %add = <256 x i32> %13, <256 x i32> %src2
315 // -->
316 // %13 = bitcast x86_amx %src to <256 x i32>
317 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
318 // %stride64, %13)
319 // %14 = load <256 x i32>, %addr
320 // %add = <256 x i32> %14, <256 x i32> %src2
321 Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1));
322 Bitcast->replaceAllUsesWith(Vec);
323}
324
325// transform bitcast to <store, load> instructions.
326bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) {
327 IRBuilder<> Builder(Bitcast);
328 AllocaInst *AllocaAddr;
329 Value *I8Ptr, *Stride;
330 auto *Src = Bitcast->getOperand(0);
331
332 auto Prepare = [&](Type *MemTy) {
333 AllocaAddr = createAllocaInstAtEntry(Builder, Bitcast->getParent(), MemTy);
334 I8Ptr = AllocaAddr;
335 Stride = Builder.getInt64(64);
336 };
337
338 if (Bitcast->getType()->isX86_AMXTy()) {
339 // %2 = bitcast <256 x i32> %src to x86_amx
340 // -->
341 // %addr = alloca <256 x i32>, align 64
342 // store <256 x i32> %src, <256 x i32>* %addr, align 64
343 // %addr2 = bitcast <256 x i32>* to i8*
344 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
345 // i8* %addr2,
346 // i64 64)
347 Use &U = *(Bitcast->use_begin());
348 unsigned OpNo = U.getOperandNo();
349 auto *II = dyn_cast<IntrinsicInst>(U.getUser());
350 if (!II)
351 return false; // May be bitcast from x86amx to <256 x i32>.
352 Prepare(Bitcast->getOperand(0)->getType());
353 Builder.CreateStore(Src, AllocaAddr);
354 // TODO we can pick an constant operand for the shape.
355 Value *Row = nullptr, *Col = nullptr;
356 std::tie(Row, Col) = getShape(II, OpNo);
357 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
358 Value *NewInst =
359 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, Args);
360 Bitcast->replaceAllUsesWith(NewInst);
361 } else {
362 // %2 = bitcast x86_amx %src to <256 x i32>
363 // -->
364 // %addr = alloca <256 x i32>, align 64
365 // %addr2 = bitcast <256 x i32>* to i8*
366 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
367 // i8* %addr2, i64 %stride)
368 // %2 = load <256 x i32>, <256 x i32>* %addr, align 64
369 auto *II = dyn_cast<IntrinsicInst>(Src);
370 if (!II)
371 return false; // May be bitcast from <256 x i32> to x86amx.
372 Prepare(Bitcast->getType());
373 Value *Row = II->getOperand(0);
374 Value *Col = II->getOperand(1);
375 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src};
376 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, Args);
377 Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr);
378 Bitcast->replaceAllUsesWith(NewInst);
379 }
380
381 return true;
382}
383
384bool X86LowerAMXType::visit() {
385 SmallVector<Instruction *, 8> DeadInsts;
386 Col2Row.clear();
387
388 for (BasicBlock *BB : post_order(&Func)) {
389 for (Instruction &Inst : llvm::make_early_inc_range(llvm::reverse(*BB))) {
390 auto *Bitcast = dyn_cast<BitCastInst>(&Inst);
391 if (!Bitcast)
392 continue;
393
394 Value *Src = Bitcast->getOperand(0);
395 if (Bitcast->getType()->isX86_AMXTy()) {
396 if (Bitcast->user_empty()) {
397 DeadInsts.push_back(Bitcast);
398 continue;
399 }
400 LoadInst *LD = dyn_cast<LoadInst>(Src);
401 if (!LD) {
402 if (transformBitcast(Bitcast))
403 DeadInsts.push_back(Bitcast);
404 continue;
405 }
406 // If load has multi-user, duplicate a vector load.
407 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
408 // %2 = bitcast <256 x i32> %src to x86_amx
409 // %add = add <256 x i32> %src, <256 x i32> %src2
410 // -->
411 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
412 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
413 // i8* %addr, i64 %stride64)
414 // %add = add <256 x i32> %src, <256 x i32> %src2
415
416 // If load has one user, the load will be eliminated in DAG ISel.
417 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
418 // %2 = bitcast <256 x i32> %src to x86_amx
419 // -->
420 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
421 // i8* %addr, i64 %stride64)
422 combineLoadBitcast(LD, Bitcast);
423 DeadInsts.push_back(Bitcast);
424 if (LD->hasOneUse())
425 DeadInsts.push_back(LD);
426 } else if (Src->getType()->isX86_AMXTy()) {
427 if (Bitcast->user_empty()) {
428 DeadInsts.push_back(Bitcast);
429 continue;
430 }
431 StoreInst *ST = nullptr;
432 for (Use &U : Bitcast->uses()) {
433 ST = dyn_cast<StoreInst>(U.getUser());
434 if (ST)
435 break;
436 }
437 if (!ST) {
438 if (transformBitcast(Bitcast))
439 DeadInsts.push_back(Bitcast);
440 continue;
441 }
442 // If bitcast (%13) has one use, combine bitcast and store to amx store.
443 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
444 // %stride);
445 // %13 = bitcast x86_amx %src to <256 x i32>
446 // store <256 x i32> %13, <256 x i32>* %addr, align 64
447 // -->
448 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
449 // %stride64, %13)
450 //
451 // If bitcast (%13) has multi-use, transform as below.
452 // %13 = bitcast x86_amx %src to <256 x i32>
453 // store <256 x i32> %13, <256 x i32>* %addr, align 64
454 // %add = <256 x i32> %13, <256 x i32> %src2
455 // -->
456 // %13 = bitcast x86_amx %src to <256 x i32>
457 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
458 // %stride64, %13)
459 // %14 = load <256 x i32>, %addr
460 // %add = <256 x i32> %14, <256 x i32> %src2
461 //
462 combineBitcastStore(Bitcast, ST);
463 // Delete user first.
464 DeadInsts.push_back(ST);
465 DeadInsts.push_back(Bitcast);
466 }
467 }
468 }
469
470 bool C = !DeadInsts.empty();
471
472 for (auto *Inst : DeadInsts)
473 Inst->eraseFromParent();
474
475 return C;
476}
477} // anonymous namespace
478
480 Function *F = BB->getParent();
481 IRBuilder<> Builder(&F->getEntryBlock().front());
482 const DataLayout &DL = F->getDataLayout();
483 unsigned AllocaAS = DL.getAllocaAddrSpace();
484 Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false);
485 AllocaInst *AllocaRes =
486 new AllocaInst(V256I32Ty, AllocaAS, "", F->getEntryBlock().begin());
487 BasicBlock::iterator Iter = AllocaRes->getIterator();
488 ++Iter;
489 Builder.SetInsertPoint(&*Iter);
490 Value *I8Ptr = Builder.CreateBitCast(AllocaRes, Builder.getPtrTy());
491 return I8Ptr;
492}
493
495 assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!");
496 auto *II = cast<IntrinsicInst>(TileDef);
497
498 assert(II && "Not tile intrinsic!");
499 Value *Row = II->getOperand(0);
500 Value *Col = II->getOperand(1);
501
502 BasicBlock *BB = TileDef->getParent();
503 BasicBlock::iterator Iter = TileDef->getIterator();
504 IRBuilder<> Builder(BB, ++Iter);
505 Value *Stride = Builder.getInt64(64);
506 std::array<Value *, 5> Args = {Row, Col, Ptr, Stride, TileDef};
507
508 Instruction *TileStore =
509 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, Args);
510 return TileStore;
511}
512
513static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) {
514 Value *V = U.get();
515 assert(V->getType()->isX86_AMXTy() && "Not define tile!");
516
517 // Get tile shape.
518 IntrinsicInst *II = nullptr;
519 if (IsPHI) {
520 Value *PhiOp = cast<PHINode>(V)->getIncomingValue(0);
521 II = cast<IntrinsicInst>(PhiOp);
522 } else {
524 }
525 Value *Row = II->getOperand(0);
526 Value *Col = II->getOperand(1);
527
528 Instruction *UserI = cast<Instruction>(U.getUser());
529 IRBuilder<> Builder(UserI);
530 Value *Stride = Builder.getInt64(64);
531 std::array<Value *, 4> Args = {Row, Col, Ptr, Stride};
532
533 Value *TileLoad =
534 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, Args);
535 UserI->replaceUsesOfWith(V, TileLoad);
536}
537
539 for (Use &U : I->uses()) {
540 User *V = U.getUser();
541 if (isa<PHINode>(V))
542 return true;
543 }
544 return false;
545}
546
547// Let all AMX tile data become volatile data, shorten the life range
548// of each tile register before fast register allocation.
549namespace {
550class X86VolatileTileData {
551 Function &F;
552
553public:
554 X86VolatileTileData(Function &Func) : F(Func) {}
555 Value *updatePhiIncomings(BasicBlock *BB,
556 SmallVector<Instruction *, 2> &Incomings);
557 void replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr);
558 bool volatileTileData();
559 void volatileTilePHI(PHINode *PHI);
560 void volatileTileNonPHI(Instruction *I);
561};
562
563Value *X86VolatileTileData::updatePhiIncomings(
564 BasicBlock *BB, SmallVector<Instruction *, 2> &Incomings) {
565 Value *I8Ptr = getAllocaPos(BB);
566
567 for (auto *I : Incomings) {
568 User *Store = createTileStore(I, I8Ptr);
569
570 // All its uses (except phi) should load from stored mem.
571 for (Use &U : I->uses()) {
572 User *V = U.getUser();
573 if (isa<PHINode>(V) || V == Store)
574 continue;
575 replaceWithTileLoad(U, I8Ptr);
576 }
577 }
578 return I8Ptr;
579}
580
581void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI,
582 Value *StorePtr) {
583 for (Use &U : PHI->uses())
584 replaceWithTileLoad(U, StorePtr, true);
585 PHI->eraseFromParent();
586}
587
588// Smilar with volatileTileNonPHI, this function only handle PHI Nodes
589// and their related AMX intrinsics.
590// 1) PHI Def should change to tileload.
591// 2) PHI Incoming Values should tilestored in just after their def.
592// 3) The mem of these tileload and tilestores should be same.
593// e.g.
594// ------------------------------------------------------
595// bb_dom:
596// ...
597// br i1 %bool.cond, label %if.else, label %if.then
598//
599// if.then:
600// def %t0 = ...
601// ...
602// use %t0
603// ...
604// br label %if.end
605//
606// if.else:
607// def %t1 = ...
608// br label %if.end
609//
610// if.end:
611// %td = phi x86_amx [ %t1, %if.else ], [ %t0, %if.then ]
612// ...
613// use %td
614// ------------------------------------------------------
615// -->
616// ------------------------------------------------------
617// bb_entry:
618// %mem = alloca <256 x i32>, align 1024 *
619// ...
620// bb_dom:
621// ...
622// br i1 %bool.cond, label %if.else, label %if.then
623//
624// if.then:
625// def %t0 = ...
626// call void @llvm.x86.tilestored64.internal(mem, %t0) *
627// ...
628// %t0` = call x86_amx @llvm.x86.tileloadd64.internal(mem)*
629// use %t0` *
630// ...
631// br label %if.end
632//
633// if.else:
634// def %t1 = ...
635// call void @llvm.x86.tilestored64.internal(mem, %t1) *
636// br label %if.end
637//
638// if.end:
639// ...
640// %td = call x86_amx @llvm.x86.tileloadd64.internal(mem) *
641// use %td
642// ------------------------------------------------------
643void X86VolatileTileData::volatileTilePHI(PHINode *PHI) {
644 BasicBlock *BB = PHI->getParent();
645 SmallVector<Instruction *, 2> Incomings;
646
647 for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) {
648 Value *Op = PHI->getIncomingValue(I);
650 assert(Inst && "We shouldn't fold AMX instrution!");
651 Incomings.push_back(Inst);
652 }
653
654 Value *StorePtr = updatePhiIncomings(BB, Incomings);
655 replacePhiDefWithLoad(PHI, StorePtr);
656}
657
658// Store the defined tile and load it before use.
659// All its users are not PHI.
660// e.g.
661// ------------------------------------------------------
662// def %td = ...
663// ...
664// "use %td"
665// ------------------------------------------------------
666// -->
667// ------------------------------------------------------
668// def %td = ...
669// call void @llvm.x86.tilestored64.internal(mem, %td)
670// ...
671// %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)
672// "use %td2"
673// ------------------------------------------------------
674void X86VolatileTileData::volatileTileNonPHI(Instruction *I) {
675 BasicBlock *BB = I->getParent();
676 Value *I8Ptr = getAllocaPos(BB);
677 User *Store = createTileStore(I, I8Ptr);
678
679 // All its uses should load from stored mem.
680 for (Use &U : I->uses()) {
681 User *V = U.getUser();
682 assert(!isa<PHINode>(V) && "PHI Nodes should be excluded!");
683 if (V != Store)
684 replaceWithTileLoad(U, I8Ptr);
685 }
686}
687
688// Volatile Tile Model:
689// 1) All the uses of tile data comes from tileload in time.
690// 2) All the defs of tile data tilestore into mem immediately.
691// For example:
692// --------------------------------------------------------------------------
693// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key
694// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
695// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx
696// %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
697// call void @llvm.x86.tilestored64.internal(... td) area
698// --------------------------------------------------------------------------
699// 3) No terminator, call or other amx instructions in the key amx area.
700bool X86VolatileTileData::volatileTileData() {
701 bool Changed = false;
702 for (BasicBlock &BB : F) {
703 SmallVector<Instruction *, 2> PHIInsts;
704 SmallVector<Instruction *, 8> AMXDefInsts;
705
706 for (Instruction &I : BB) {
707 if (!I.getType()->isX86_AMXTy())
708 continue;
709 if (isa<PHINode>(&I))
710 PHIInsts.push_back(&I);
711 else
712 AMXDefInsts.push_back(&I);
713 }
714
715 // First we "volatile" the non-phi related amx intrinsics.
716 for (Instruction *I : AMXDefInsts) {
717 if (isIncomingOfPHI(I))
718 continue;
719 volatileTileNonPHI(I);
720 Changed = true;
721 }
722
723 for (Instruction *I : PHIInsts) {
724 volatileTilePHI(dyn_cast<PHINode>(I));
725 Changed = true;
726 }
727 }
728 return Changed;
729}
730
731} // anonymous namespace
732
733namespace {
734
735class X86LowerAMXCast {
736 Function &Func;
737 std::unique_ptr<DominatorTree> DT;
738
739public:
740 X86LowerAMXCast(Function &F) : Func(F), DT(nullptr) {}
741 bool combineCastStore(IntrinsicInst *Cast, StoreInst *ST);
742 bool combineLoadCast(IntrinsicInst *Cast, LoadInst *LD);
743 bool combineTilezero(IntrinsicInst *Cast);
744 bool combineLdSt(SmallVectorImpl<Instruction *> &Casts);
745 bool combineAMXcast(TargetLibraryInfo *TLI);
746 bool transformAMXCast(IntrinsicInst *AMXCast);
747 bool transformAllAMXCast();
748 bool optimizeAMXCastFromPhi(IntrinsicInst *CI, PHINode *PN,
749 SmallSetVector<Instruction *, 16> &DeadInst);
750};
751
752static bool DCEInstruction(Instruction *I,
753 SmallSetVector<Instruction *, 16> &WorkList,
754 const TargetLibraryInfo *TLI) {
755 if (isInstructionTriviallyDead(I, TLI)) {
758
759 // Null out all of the instruction's operands to see if any operand becomes
760 // dead as we go.
761 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
762 Value *OpV = I->getOperand(i);
763 I->setOperand(i, nullptr);
764
765 if (!OpV->use_empty() || I == OpV)
766 continue;
767
768 // If the operand is an instruction that became dead as we nulled out the
769 // operand, and if it is 'trivially' dead, delete it in a future loop
770 // iteration.
771 if (Instruction *OpI = dyn_cast<Instruction>(OpV)) {
772 if (isInstructionTriviallyDead(OpI, TLI)) {
773 WorkList.insert(OpI);
774 }
775 }
776 }
777 I->eraseFromParent();
778 return true;
779 }
780 return false;
781}
782
783/// This function handles following case
784///
785/// A -> B amxcast
786/// PHI
787/// B -> A amxcast
788///
789/// All the related PHI nodes can be replaced by new PHI nodes with type A.
790/// The uses of \p CI can be changed to the new PHI node corresponding to \p PN.
791bool X86LowerAMXCast::optimizeAMXCastFromPhi(
792 IntrinsicInst *CI, PHINode *PN,
793 SmallSetVector<Instruction *, 16> &DeadInst) {
794 IRBuilder<> Builder(CI);
795 Value *Src = CI->getOperand(0);
796 Type *SrcTy = Src->getType(); // Type B
797 Type *DestTy = CI->getType(); // Type A
798
799 SmallVector<PHINode *, 4> PhiWorklist;
800 SmallSetVector<PHINode *, 4> OldPhiNodes;
801
802 // Find all of the A->B casts and PHI nodes.
803 // We need to inspect all related PHI nodes, but PHIs can be cyclic, so
804 // OldPhiNodes is used to track all known PHI nodes, before adding a new
805 // PHI to PhiWorklist, it is checked against and added to OldPhiNodes first.
806 PhiWorklist.push_back(PN);
807 OldPhiNodes.insert(PN);
808 while (!PhiWorklist.empty()) {
809 auto *OldPN = PhiWorklist.pop_back_val();
810 for (unsigned I = 0; I < OldPN->getNumOperands(); ++I) {
811 Value *IncValue = OldPN->getIncomingValue(I);
812 // TODO: currently, We ignore cases where it is a const. In the future, we
813 // might support const.
814 if (isa<Constant>(IncValue)) {
815 auto *IncConst = dyn_cast<Constant>(IncValue);
816 if (!isa<UndefValue>(IncValue) && !IncConst->isZeroValue())
817 return false;
818 Value *Row = nullptr, *Col = nullptr;
819 std::tie(Row, Col) = getShape(OldPN);
820 // TODO: If it is not constant the Row and Col must domoniate tilezero
821 // that we are going to create.
822 if (!Row || !Col || !isa<Constant>(Row) || !isa<Constant>(Col))
823 return false;
824 // Create tilezero at the end of incoming block.
825 auto *Block = OldPN->getIncomingBlock(I);
826 BasicBlock::iterator Iter = Block->getTerminator()->getIterator();
827 Instruction *NewInst = Builder.CreateIntrinsic(
828 Intrinsic::x86_tilezero_internal, {}, {Row, Col});
829 NewInst->moveBefore(Iter);
830 NewInst = Builder.CreateIntrinsic(Intrinsic::x86_cast_tile_to_vector,
831 {IncValue->getType()}, {NewInst});
832 NewInst->moveBefore(Iter);
833 // Replace InValue with new Value.
834 OldPN->setIncomingValue(I, NewInst);
835 IncValue = NewInst;
836 }
837
838 if (auto *PNode = dyn_cast<PHINode>(IncValue)) {
839 if (OldPhiNodes.insert(PNode))
840 PhiWorklist.push_back(PNode);
841 continue;
842 }
843 Instruction *ACI = dyn_cast<Instruction>(IncValue);
844 if (ACI && isAMXCast(ACI)) {
845 // Verify it's a A->B cast.
846 Type *TyA = ACI->getOperand(0)->getType();
847 Type *TyB = ACI->getType();
848 if (TyA != DestTy || TyB != SrcTy)
849 return false;
850 continue;
851 }
852 return false;
853 }
854 }
855
856 // Check that each user of each old PHI node is something that we can
857 // rewrite, so that all of the old PHI nodes can be cleaned up afterwards.
858 for (auto *OldPN : OldPhiNodes) {
859 for (User *V : OldPN->users()) {
861 if (ACI && isAMXCast(ACI)) {
862 // Verify it's a B->A cast.
863 Type *TyB = ACI->getOperand(0)->getType();
864 Type *TyA = ACI->getType();
865 if (TyA != DestTy || TyB != SrcTy)
866 return false;
867 } else if (auto *PHI = dyn_cast<PHINode>(V)) {
868 // As long as the user is another old PHI node, then even if we don't
869 // rewrite it, the PHI web we're considering won't have any users
870 // outside itself, so it'll be dead.
871 // example:
872 // bb.0:
873 // %0 = amxcast ...
874 // bb.1:
875 // %1 = amxcast ...
876 // bb.2:
877 // %goodphi = phi %0, %1
878 // %3 = amxcast %goodphi
879 // bb.3:
880 // %goodphi2 = phi %0, %goodphi
881 // %4 = amxcast %goodphi2
882 // When optimizeAMXCastFromPhi process %3 and %goodphi, %goodphi2 is
883 // outside the phi-web, so the combination stop When
884 // optimizeAMXCastFromPhi process %4 and %goodphi2, the optimization
885 // will be done.
886 if (OldPhiNodes.count(PHI) == 0)
887 return false;
888 } else
889 return false;
890 }
891 }
892
893 // For each old PHI node, create a corresponding new PHI node with a type A.
894 SmallDenseMap<PHINode *, PHINode *> NewPNodes;
895 for (auto *OldPN : OldPhiNodes) {
896 Builder.SetInsertPoint(OldPN);
897 PHINode *NewPN = Builder.CreatePHI(DestTy, OldPN->getNumOperands());
898 NewPNodes[OldPN] = NewPN;
899 }
900
901 // Fill in the operands of new PHI nodes.
902 for (auto *OldPN : OldPhiNodes) {
903 PHINode *NewPN = NewPNodes[OldPN];
904 for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) {
905 Value *V = OldPN->getOperand(j);
906 Value *NewV = nullptr;
908 // There should not be a AMXcast from a const.
909 if (ACI && isAMXCast(ACI))
910 NewV = ACI->getOperand(0);
911 else if (auto *PrevPN = dyn_cast<PHINode>(V))
912 NewV = NewPNodes[PrevPN];
913 assert(NewV);
914 NewPN->addIncoming(NewV, OldPN->getIncomingBlock(j));
915 }
916 }
917
918 // Traverse all accumulated PHI nodes and process its users,
919 // which are Stores and BitcCasts. Without this processing
920 // NewPHI nodes could be replicated and could lead to extra
921 // moves generated after DeSSA.
922 // If there is a store with type B, change it to type A.
923
924 // Replace users of BitCast B->A with NewPHI. These will help
925 // later to get rid of a closure formed by OldPHI nodes.
926 for (auto *OldPN : OldPhiNodes) {
927 PHINode *NewPN = NewPNodes[OldPN];
928 for (User *V : make_early_inc_range(OldPN->users())) {
930 if (ACI && isAMXCast(ACI)) {
931 Type *TyB = ACI->getOperand(0)->getType();
932 Type *TyA = ACI->getType();
933 assert(TyA == DestTy && TyB == SrcTy);
934 (void)TyA;
935 (void)TyB;
936 ACI->replaceAllUsesWith(NewPN);
937 DeadInst.insert(ACI);
938 } else if (auto *PHI = dyn_cast<PHINode>(V)) {
939 // We don't need to push PHINode into DeadInst since they are operands
940 // of rootPN DCE can safely delete rootPN's operands if rootPN is dead.
941 assert(OldPhiNodes.contains(PHI));
942 (void)PHI;
943 } else
944 llvm_unreachable("all uses should be handled");
945 }
946 }
947 return true;
948}
949
950// %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %42)
951// store <256 x i32> %43, <256 x i32>* %p, align 64
952// -->
953// call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p,
954// i64 64, x86_amx %42)
955bool X86LowerAMXCast::combineCastStore(IntrinsicInst *Cast, StoreInst *ST) {
956 Value *Tile = Cast->getOperand(0);
957
958 assert(Tile->getType()->isX86_AMXTy() && "Not Tile Operand!");
959
960 // TODO: Specially handle the multi-use case.
961 if (!Tile->hasOneUse())
962 return false;
963
964 auto *II = cast<IntrinsicInst>(Tile);
965 // Tile is output from AMX intrinsic. The first operand of the
966 // intrinsic is row, the second operand of the intrinsic is column.
967 Value *Row = II->getOperand(0);
968 Value *Col = II->getOperand(1);
969
970 IRBuilder<> Builder(ST);
971
972 // Stride should be equal to col(measured by bytes)
973 Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty());
974 Value *I8Ptr = Builder.CreateBitCast(ST->getOperand(1), Builder.getPtrTy());
975 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
976 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, Args);
977 return true;
978}
979
980// %65 = load <256 x i32>, <256 x i32>* %p, align 64
981// %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65)
982// -->
983// %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
984// i8* %p, i64 64)
985bool X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) {
986 bool EraseLoad = true;
987 Value *Row = nullptr, *Col = nullptr;
988 Use &U = *(Cast->use_begin());
989 unsigned OpNo = U.getOperandNo();
990 auto *II = cast<IntrinsicInst>(U.getUser());
991 // TODO: If it is cast intrinsic or phi node, we can propagate the
992 // shape information through def-use chain.
993 if (!isAMXIntrinsic(II))
994 return false;
995 std::tie(Row, Col) = getShape(II, OpNo);
996 IRBuilder<> Builder(LD);
997 // Stride should be equal to col(measured by bytes)
998 Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty());
999 Value *I8Ptr;
1000
1001 // To save compiling time, we create doninator tree when it is really
1002 // needed.
1003 if (!DT)
1004 DT.reset(new DominatorTree(Func));
1005 if (!DT->dominates(Row, LD) || !DT->dominates(Col, LD)) {
1006 // store the value to stack and reload it from stack before cast.
1007 auto *AllocaAddr =
1008 createAllocaInstAtEntry(Builder, Cast->getParent(), LD->getType());
1009 Builder.SetInsertPoint(&*std::next(LD->getIterator()));
1010 Builder.CreateStore(LD, AllocaAddr);
1011
1012 Builder.SetInsertPoint(Cast);
1013 I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getPtrTy());
1014 EraseLoad = false;
1015 } else {
1016 I8Ptr = Builder.CreateBitCast(LD->getOperand(0), Builder.getPtrTy());
1017 }
1018 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
1019
1020 Value *NewInst =
1021 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, Args);
1022 Cast->replaceAllUsesWith(NewInst);
1023
1024 return EraseLoad;
1025}
1026
1027// %19 = tail call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> zeroinitializer)
1028// -->
1029// %19 = tail call x86_amx @llvm.x86.tilezero.internal(i16 %row, i16 %col)
1030bool X86LowerAMXCast::combineTilezero(IntrinsicInst *Cast) {
1031 Value *Row = nullptr, *Col = nullptr;
1032 Use &U = *(Cast->use_begin());
1033 unsigned OpNo = U.getOperandNo();
1034 auto *II = cast<IntrinsicInst>(U.getUser());
1035 if (!isAMXIntrinsic(II))
1036 return false;
1037
1038 std::tie(Row, Col) = getShape(II, OpNo);
1039
1040 IRBuilder<> Builder(Cast);
1041 Value *NewInst =
1042 Builder.CreateIntrinsic(Intrinsic::x86_tilezero_internal, {}, {Row, Col});
1043 Cast->replaceAllUsesWith(NewInst);
1044 return true;
1045}
1046
1047bool X86LowerAMXCast::combineLdSt(SmallVectorImpl<Instruction *> &Casts) {
1048 bool Change = false;
1049 for (auto *Cast : Casts) {
1050 auto *II = cast<IntrinsicInst>(Cast);
1051 // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector(x86_amx %42)
1052 // store <256 x i32> %43, <256 x i32>* %p, align 64
1053 // -->
1054 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p,
1055 // i64 64, x86_amx %42)
1056 if (II->getIntrinsicID() == Intrinsic::x86_cast_tile_to_vector) {
1057 SmallVector<Instruction *, 2> DeadStores;
1058 for (User *U : Cast->users()) {
1059 StoreInst *Store = dyn_cast<StoreInst>(U);
1060 if (!Store)
1061 continue;
1062 if (combineCastStore(cast<IntrinsicInst>(Cast), Store)) {
1063 DeadStores.push_back(Store);
1064 Change = true;
1065 }
1066 }
1067 for (auto *Store : DeadStores)
1068 Store->eraseFromParent();
1069 } else { // x86_cast_vector_to_tile
1070 // %19 = tail call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> zeroinitializer)
1071 // -->
1072 // %19 = tail call x86_amx @llvm.x86.tilezero.internal(i16 %row, i16 %col)
1073 if (isa<ConstantAggregateZero>(Cast->getOperand(0))) {
1074 Change |= combineTilezero(cast<IntrinsicInst>(Cast));
1075 continue;
1076 }
1077
1078 auto *Load = dyn_cast<LoadInst>(Cast->getOperand(0));
1079 if (!Load || !Load->hasOneUse())
1080 continue;
1081 // %65 = load <256 x i32>, <256 x i32>* %p, align 64
1082 // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65)
1083 // -->
1084 // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
1085 // i8* %p, i64 64)
1086 if (combineLoadCast(cast<IntrinsicInst>(Cast), Load)) {
1087 // Set the operand is null so that load instruction can be erased.
1088 Cast->setOperand(0, nullptr);
1089 Load->eraseFromParent();
1090 Change = true;
1091 }
1092 }
1093 }
1094 return Change;
1095}
1096
1097bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo *TLI) {
1098 bool Change = false;
1099 // Collect tile cast instruction.
1100 SmallVector<Instruction *, 8> Vec2TileInsts;
1101 SmallVector<Instruction *, 8> Tile2VecInsts;
1102 SmallVector<Instruction *, 8> PhiCastWorkList;
1103 SmallSetVector<Instruction *, 16> DeadInst;
1104 for (BasicBlock &BB : Func) {
1105 for (Instruction &I : BB) {
1106 Value *Vec;
1107 if (match(&I,
1109 Vec2TileInsts.push_back(&I);
1111 m_Value(Vec))))
1112 Tile2VecInsts.push_back(&I);
1113 }
1114 }
1115
1116 auto Convert = [&](SmallVectorImpl<Instruction *> &Insts, Intrinsic::ID IID) {
1117 for (auto *Inst : Insts) {
1118 for (User *U : Inst->users()) {
1119 IntrinsicInst *II = dyn_cast<IntrinsicInst>(U);
1120 if (!II || II->getIntrinsicID() != IID)
1121 continue;
1122 // T1 = vec2tile V0
1123 // V2 = tile2vec T1
1124 // V3 = OP V2
1125 // -->
1126 // T1 = vec2tile V0
1127 // V2 = tile2vec T1
1128 // V3 = OP V0
1129 II->replaceAllUsesWith(Inst->getOperand(0));
1130 Change = true;
1131 }
1132 }
1133 };
1134
1135 Convert(Vec2TileInsts, Intrinsic::x86_cast_tile_to_vector);
1136 Convert(Tile2VecInsts, Intrinsic::x86_cast_vector_to_tile);
1137
1138 SmallVector<Instruction *, 8> LiveCasts;
1139 auto EraseInst = [&](SmallVectorImpl<Instruction *> &Insts) {
1140 for (auto *Inst : Insts) {
1141 if (Inst->use_empty()) {
1142 Inst->eraseFromParent();
1143 Change = true;
1144 } else {
1145 LiveCasts.push_back(Inst);
1146 }
1147 }
1148 };
1149
1150 EraseInst(Vec2TileInsts);
1151 EraseInst(Tile2VecInsts);
1152 LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine "
1153 "Vec2Tile and Tile2Vec:\n";
1154 Func.dump());
1155 Change |= combineLdSt(LiveCasts);
1156 EraseInst(LiveCasts);
1157 LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine "
1158 "AMXCast and load/store:\n";
1159 Func.dump());
1160
1161 // Handle the A->B->A cast, and there is an intervening PHI node.
1162 for (BasicBlock &BB : Func) {
1163 for (Instruction &I : BB) {
1164 if (isAMXCast(&I)) {
1165 if (isa<PHINode>(I.getOperand(0)))
1166 PhiCastWorkList.push_back(&I);
1167 }
1168 }
1169 }
1170 for (auto *I : PhiCastWorkList) {
1171 // We skip the dead Amxcast.
1172 if (DeadInst.contains(I))
1173 continue;
1174 PHINode *PN = cast<PHINode>(I->getOperand(0));
1175 if (optimizeAMXCastFromPhi(cast<IntrinsicInst>(I), PN, DeadInst)) {
1176 DeadInst.insert(PN);
1177 Change = true;
1178 }
1179 }
1180
1181 // Since we create new phi and merge AMXCast, some old phis and AMXCast might
1182 // have no uses. We do some DeadCodeElimination for them.
1183 while (!DeadInst.empty()) {
1184 Instruction *I = DeadInst.pop_back_val();
1185 Change |= DCEInstruction(I, DeadInst, TLI);
1186 }
1187 LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after "
1188 "optimizeAMXCastFromPhi:\n";
1189 Func.dump());
1190 return Change;
1191}
1192
1193// There might be remaining AMXcast after combineAMXcast and they should be
1194// handled elegantly.
1195bool X86LowerAMXCast::transformAMXCast(IntrinsicInst *AMXCast) {
1196 IRBuilder<> Builder(AMXCast);
1197 AllocaInst *AllocaAddr;
1198 Value *I8Ptr, *Stride;
1199 auto *Src = AMXCast->getOperand(0);
1200
1201 auto Prepare = [&](Type *MemTy) {
1202 AllocaAddr = createAllocaInstAtEntry(Builder, AMXCast->getParent(), MemTy);
1203 I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getPtrTy());
1204 Stride = Builder.getInt64(64);
1205 };
1206
1207 if (AMXCast->getType()->isX86_AMXTy()) {
1208 // %2 = amxcast <225 x i32> %src to x86_amx
1209 // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
1210 // i8* %addr3, i64 60, x86_amx %2)
1211 // -->
1212 // %addr = alloca <225 x i32>, align 64
1213 // store <225 x i32> %src, <225 x i32>* %addr, align 64
1214 // %addr2 = bitcast <225 x i32>* %addr to i8*
1215 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 15, i16 60,
1216 // i8* %addr2,
1217 // i64 60)
1218 // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
1219 // i8* %addr3, i64 60, x86_amx %2)
1220 if (AMXCast->use_empty()) {
1221 AMXCast->eraseFromParent();
1222 return true;
1223 }
1224 Use &U = *(AMXCast->use_begin());
1225 unsigned OpNo = U.getOperandNo();
1226 auto *II = dyn_cast<IntrinsicInst>(U.getUser());
1227 if (!II)
1228 return false; // May be bitcast from x86amx to <256 x i32>.
1229 Prepare(AMXCast->getOperand(0)->getType());
1230 Builder.CreateStore(Src, AllocaAddr);
1231 // TODO we can pick an constant operand for the shape.
1232 Value *Row = nullptr, *Col = nullptr;
1233 std::tie(Row, Col) = getShape(II, OpNo);
1234 std::array<Value *, 4> Args = {
1235 Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty())};
1236 Value *NewInst =
1237 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, Args);
1238 AMXCast->replaceAllUsesWith(NewInst);
1239 AMXCast->eraseFromParent();
1240 } else {
1241 // %2 = amxcast x86_amx %src to <225 x i32>
1242 // -->
1243 // %addr = alloca <225 x i32>, align 64
1244 // %addr2 = bitcast <225 x i32>* to i8*
1245 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
1246 // i8* %addr2, i64 %stride)
1247 // %2 = load <225 x i32>, <225 x i32>* %addr, align 64
1248 auto *II = dyn_cast<IntrinsicInst>(Src);
1249 if (!II)
1250 return false; // May be bitcast from <256 x i32> to x86amx.
1251 Prepare(AMXCast->getType());
1252 Value *Row = II->getOperand(0);
1253 Value *Col = II->getOperand(1);
1254 std::array<Value *, 5> Args = {
1255 Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty()), Src};
1256 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, Args);
1257 Value *NewInst = Builder.CreateLoad(AMXCast->getType(), AllocaAddr);
1258 AMXCast->replaceAllUsesWith(NewInst);
1259 AMXCast->eraseFromParent();
1260 }
1261
1262 return true;
1263}
1264
1265bool X86LowerAMXCast::transformAllAMXCast() {
1266 bool Change = false;
1267 // Collect tile cast instruction.
1268 SmallVector<Instruction *, 8> WorkLists;
1269 for (BasicBlock &BB : Func) {
1270 for (Instruction &I : BB) {
1271 if (isAMXCast(&I))
1272 WorkLists.push_back(&I);
1273 }
1274 }
1275
1276 for (auto *Inst : WorkLists) {
1277 Change |= transformAMXCast(cast<IntrinsicInst>(Inst));
1278 }
1279
1280 return Change;
1281}
1282
1283bool lowerAmxType(Function &F, const TargetMachine *TM,
1284 TargetLibraryInfo *TLI) {
1285 // Performance optimization: most code doesn't use AMX, so return early if
1286 // there are no instructions that produce AMX values. This is sufficient, as
1287 // AMX arguments and constants are not allowed -- so any producer of an AMX
1288 // value must be an instruction.
1289 // TODO: find a cheaper way for this, without looking at all instructions.
1290 if (!containsAMXCode(F))
1291 return false;
1292
1293 bool C = false;
1294 X86LowerAMXCast LAC(F);
1295 C |= LAC.combineAMXcast(TLI);
1296 // There might be remaining AMXcast after combineAMXcast and they should be
1297 // handled elegantly.
1298 C |= LAC.transformAllAMXCast();
1299
1300 X86LowerAMXType LAT(F);
1301 C |= LAT.visit();
1302
1303 // Prepare for fast register allocation at O0.
1304 // Todo: May better check the volatile model of AMX code, not just
1305 // by checking Attribute::OptimizeNone and CodeGenOptLevel::None.
1306 if (TM->getOptLevel() == CodeGenOptLevel::None) {
1307 // If Front End not use O0 but the Mid/Back end use O0, (e.g.
1308 // "Clang -O2 -S -emit-llvm t.c" + "llc t.ll") we should make
1309 // sure the amx data is volatile, that is necessary for AMX fast
1310 // register allocation.
1311 if (!F.hasFnAttribute(Attribute::OptimizeNone)) {
1312 X86VolatileTileData VTD(F);
1313 C = VTD.volatileTileData() || C;
1314 }
1315 }
1316
1317 return C;
1318}
1319
1320} // anonymous namespace
1321
1324 TargetLibraryInfo &TLI = FAM.getResult<TargetLibraryAnalysis>(F);
1325 bool Changed = lowerAmxType(F, TM, &TLI);
1326 if (!Changed)
1327 return PreservedAnalyses::all();
1328
1331 return PA;
1332}
1333
1334namespace {
1335
1336class X86LowerAMXTypeLegacyPass : public FunctionPass {
1337public:
1338 static char ID;
1339
1340 X86LowerAMXTypeLegacyPass() : FunctionPass(ID) {}
1341
1342 bool runOnFunction(Function &F) override {
1343 TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
1344 TargetLibraryInfo *TLI =
1345 &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
1346 return lowerAmxType(F, TM, TLI);
1347 }
1348
1349 void getAnalysisUsage(AnalysisUsage &AU) const override {
1350 AU.setPreservesCFG();
1351 AU.addRequired<TargetPassConfig>();
1352 AU.addRequired<TargetLibraryInfoWrapperPass>();
1353 }
1354};
1355
1356} // anonymous namespace
1357
1358static const char PassName[] = "Lower AMX type for load/store";
1359char X86LowerAMXTypeLegacyPass::ID = 0;
1360INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
1361 false)
1364INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
1365 false)
1366
1368 return new X86LowerAMXTypeLegacyPass();
1369}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
Rewrite undef for PHI
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
static bool DCEInstruction(Instruction *I, SmallSetVector< Instruction *, 16 > &WorkList, const TargetLibraryInfo *TLI)
Definition DCE.cpp:55
static bool runOnFunction(Function &F, bool PostInlining)
#define DEBUG_TYPE
This header defines various interfaces for pass management in LLVM.
#define F(x, y, z)
Definition MD5.cpp:54
#define I(x, y, z)
Definition MD5.cpp:57
uint64_t IntrinsicInst * II
FunctionAnalysisManager FAM
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition PassSupport.h:42
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition PassSupport.h:39
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
void visit(MachineFunction &MF, MachineBasicBlock &Start, std::function< void(MachineBasicBlock *)> op)
This file implements a set that has insertion order iteration characteristics.
#define LLVM_DEBUG(...)
Definition Debug.h:114
Target-Independent Code Generator Pass Configuration Options pass.
This pass exposes codegen information to IR-level passes.
static ShapeT getShape(MachineRegisterInfo *MRI, Register TileReg)
static const char PassName[]
static bool isAMXCast(Instruction *II)
static Value * getRowFromCol(Instruction *II, Value *V, unsigned Granularity)
static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI=false)
static Instruction * createTileStore(Instruction *TileDef, Value *Ptr)
static Value * getAllocaPos(BasicBlock *BB)
static bool containsAMXCode(Function &F)
std::pair< Value *, Value * > getShape(IntrinsicInst *II, unsigned OpNo)
static bool isIncomingOfPHI(Instruction *I)
static bool isAMXIntrinsic(Value *I)
static Instruction * getFirstNonAllocaInTheEntryBlock(Function &F)
static AllocaInst * createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB, Type *Ty)
an instruction to allocate memory on the stack
void setAlignment(Align Align)
AnalysisUsage & addRequired()
LLVM_ABI void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition Pass.cpp:270
LLVM Basic Block Representation.
Definition BasicBlock.h:62
const Function * getParent() const
Return the enclosing method, or null if none.
Definition BasicBlock.h:213
InstListType::iterator iterator
Instruction iterators...
Definition BasicBlock.h:170
This class represents a no-op cast from one type to another.
Represents analyses that only rely on functions' control flow.
Definition Analysis.h:73
A parsed version of the target data layout string in and methods for querying it.
Definition DataLayout.h:64
FunctionPass class - This class is used to implement most global optimizations.
Definition Pass.h:314
Value * CreateUDiv(Value *LHS, Value *RHS, const Twine &Name="", bool isExact=false)
Definition IRBuilder.h:1454
ConstantInt * getInt16(uint16_t C)
Get a constant 16-bit value.
Definition IRBuilder.h:517
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition IRBuilder.h:2788
LLVM_ABI void moveBefore(InstListType::iterator InsertPos)
Unlink this instruction from its current basic block and insert it into the basic block that MovePos ...
LLVM_ABI InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
A wrapper class for inspecting calls to intrinsic functions.
This is an important class for using LLVM in a threaded context.
Definition LLVMContext.h:68
An instruction for reading from memory.
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses none()
Convenience factory function for the empty preserved set.
Definition Analysis.h:115
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalyses & preserveSet()
Mark an analysis set as preserved.
Definition Analysis.h:151
bool contains(const_arg_type key) const
Check if the SetVector contains the given key.
Definition SetVector.h:252
bool empty() const
Determine if the SetVector is empty or not.
Definition SetVector.h:100
bool insert(const value_type &X)
Insert a new element into the SetVector.
Definition SetVector.h:151
value_type pop_back_val()
Definition SetVector.h:279
void push_back(const T &Elt)
Analysis pass providing the TargetLibraryInfo.
Provides information about what library functions are available for the current target.
Primary interface to the complete machine description for the target machine.
CodeGenOptLevel getOptLevel() const
Returns the optimization level: None, Less, Default, or Aggressive.
Target-Independent Code Generator Pass Configuration Options.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:45
static LLVM_ABI Type * getX86_AMXTy(LLVMContext &C)
Definition Type.cpp:291
bool isX86_AMXTy() const
Return true if this is X86 AMX.
Definition Type.h:200
A Use represents the edge between a Value definition and its users.
Definition Use.h:35
LLVM_ABI unsigned getOperandNo() const
Return the operand # of this use in its User.
Definition Use.cpp:35
User * getUser() const
Returns the User that contains this Use.
Definition Use.h:61
void setOperand(unsigned i, Value *Val)
Definition User.h:237
LLVM_ABI bool replaceUsesOfWith(Value *From, Value *To)
Replace uses of one Value with another.
Definition User.cpp:24
Value * getOperand(unsigned i) const
Definition User.h:232
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:256
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition Value.cpp:546
iterator_range< user_iterator > users()
Definition Value.h:426
use_iterator use_begin()
Definition Value.h:364
bool use_empty() const
Definition Value.h:346
static LLVM_ABI VectorType * get(Type *ElementType, ElementCount EC)
This static method is the primary way to construct an VectorType.
PreservedAnalyses run(Function &F, FunctionAnalysisManager &FAM)
const ParentTy * getParent() const
Definition ilist_node.h:34
self_iterator getIterator()
Definition ilist_node.h:123
Changed
Pass manager infrastructure for declaring and invalidating analyses.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
@ BasicBlock
Various leaf nodes.
Definition ISDOpcodes.h:81
@ Bitcast
Perform the operation on a different, but equivalently sized type.
bool match(Val *V, const Pattern &P)
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
@ User
could "use" a pointer
NodeAddr< UseNode * > Use
Definition RDFGraph.h:385
NodeAddr< FuncNode * > Func
Definition RDFGraph.h:393
friend class Instruction
Iterator for Instructions in a `BasicBlock.
Definition BasicBlock.h:73
This is an optimization pass for GlobalISel generic memory operations.
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
LLVM_ABI void salvageDebugInfo(const MachineRegisterInfo &MRI, MachineInstr &MI)
Assuming the instruction MI is going to be deleted, attempt to salvage debug users of MI by writing t...
Definition Utils.cpp:1729
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition STLExtras.h:632
iterator_range< po_iterator< T > > post_order(const T &G)
LLVM_ABI bool isInstructionTriviallyDead(Instruction *I, const TargetLibraryInfo *TLI=nullptr)
Return true if the result produced by the instruction is not used, and the instruction will return.
Definition Local.cpp:402
auto reverse(ContainerTy &&C)
Definition STLExtras.h:406
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:207
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:547
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
LLVM_ABI bool salvageKnowledge(Instruction *I, AssumptionCache *AC=nullptr, DominatorTree *DT=nullptr)
Calls BuildAssumeFromInst and if the resulting llvm.assume is valid insert if before I.
DWARFExpression::Operation Op
FunctionPass * createX86LowerAMXTypeLegacyPass()
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.