LLVM 22.0.0git
X86LowerAMXIntrinsics.cpp
Go to the documentation of this file.
1//===-- X86LowerAMXIntrinsics.cpp -X86 Scalarize AMX Intrinsics------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// \file Pass to transform amx intrinsics to scalar operations.
10/// This pass is always enabled and it skips when it is not -O0 and has no
11/// optnone attributes. With -O0 or optnone attribute, the def of shape to amx
12/// intrinsics is near the amx intrinsics code. We are not able to find a
13/// point which post-dominate all the shape and dominate all amx intrinsics.
14/// To decouple the dependency of the shape, we transform amx intrinsics
15/// to scalar operation, so that compiling doesn't fail. In long term, we
16/// should improve fast register allocation to allocate amx register.
17//===----------------------------------------------------------------------===//
18//
19#include "X86.h"
23#include "llvm/CodeGen/Passes.h"
26#include "llvm/IR/Analysis.h"
27#include "llvm/IR/DataLayout.h"
28#include "llvm/IR/Dominators.h"
29#include "llvm/IR/Function.h"
30#include "llvm/IR/IRBuilder.h"
33#include "llvm/IR/IntrinsicsX86.h"
34#include "llvm/IR/PassManager.h"
37#include "llvm/Pass.h"
42
43using namespace llvm;
44using namespace PatternMatch;
45
46#define DEBUG_TYPE "x86-lower-amx-intrinsics"
47
48#ifndef NDEBUG
49static bool isV256I32Ty(Type *Ty) {
50 if (auto *FVT = dyn_cast<FixedVectorType>(Ty))
51 return FVT->getNumElements() == 256 &&
52 FVT->getElementType()->isIntegerTy(32);
53 return false;
54}
55#endif
56
57static cl::opt<bool>
58 X86ScalarizeAMX("enable-x86-scalar-amx", cl::init(false), cl::Hidden,
59 cl::desc("X86: enable AMX scalarizition."));
60
61namespace {
62class X86LowerAMXIntrinsics {
63 Function &Func;
64
65public:
66 X86LowerAMXIntrinsics(Function &F, DomTreeUpdater &DomTU, LoopInfo *LoopI)
67 : Func(F), DTU(DomTU), LI(LoopI) {}
68 bool visit();
69
70private:
71 DomTreeUpdater &DTU;
72 LoopInfo *LI;
73 BasicBlock *createLoop(BasicBlock *Preheader, BasicBlock *Exit, Value *Bound,
74 Value *Step, StringRef Name, IRBuilderBase &B,
75 Loop *L);
76 template <bool IsTileLoad>
77 Value *createTileLoadStoreLoops(BasicBlock *Start, BasicBlock *End,
78 IRBuilderBase &B, Value *Row, Value *Col,
79 Value *Ptr, Value *Stride, Value *Tile);
80 template <Intrinsic::ID IntrID>
81 std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal ||
82 IntrID == Intrinsic::x86_tdpbsud_internal ||
83 IntrID == Intrinsic::x86_tdpbusd_internal ||
84 IntrID == Intrinsic::x86_tdpbuud_internal ||
85 IntrID == Intrinsic::x86_tdpbf16ps_internal,
86 Value *>
87 createTileDPLoops(BasicBlock *Start, BasicBlock *End, IRBuilderBase &B,
88 Value *Row, Value *Col, Value *K, Value *Acc, Value *LHS,
89 Value *RHS);
90 template <bool IsTileLoad>
91 bool lowerTileLoadStore(Instruction *TileLoadStore);
92 template <Intrinsic::ID IntrID>
93 std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal ||
94 IntrID == Intrinsic::x86_tdpbsud_internal ||
95 IntrID == Intrinsic::x86_tdpbusd_internal ||
96 IntrID == Intrinsic::x86_tdpbuud_internal ||
97 IntrID == Intrinsic::x86_tdpbf16ps_internal,
98 bool>
99 lowerTileDP(Instruction *TileDP);
100 bool lowerTileZero(Instruction *TileZero);
101};
102} // anonymous namespace
103
104BasicBlock *X86LowerAMXIntrinsics::createLoop(BasicBlock *Preheader,
105 BasicBlock *Exit, Value *Bound,
106 Value *Step, StringRef Name,
107 IRBuilderBase &B, Loop *L) {
108 LLVMContext &Ctx = Preheader->getContext();
109 BasicBlock *Header =
110 BasicBlock::Create(Ctx, Name + ".header", Preheader->getParent(), Exit);
111 BasicBlock *Body =
112 BasicBlock::Create(Ctx, Name + ".body", Header->getParent(), Exit);
113 BasicBlock *Latch =
114 BasicBlock::Create(Ctx, Name + ".latch", Header->getParent(), Exit);
115
116 Type *I16Ty = Type::getInt16Ty(Ctx);
117 BranchInst::Create(Body, Header);
118 BranchInst::Create(Latch, Body);
119 PHINode *IV =
120 PHINode::Create(I16Ty, 2, Name + ".iv", Header->getTerminator()->getIterator());
121 IV->addIncoming(ConstantInt::get(I16Ty, 0), Preheader);
122
123 B.SetInsertPoint(Latch);
124 Value *Inc = B.CreateAdd(IV, Step, Name + ".step");
125 Value *Cond = B.CreateICmpNE(Inc, Bound, Name + ".cond");
126 BranchInst::Create(Header, Exit, Cond, Latch);
127 IV->addIncoming(Inc, Latch);
128
129 BranchInst *PreheaderBr = cast<BranchInst>(Preheader->getTerminator());
130 BasicBlock *Tmp = PreheaderBr->getSuccessor(0);
131 PreheaderBr->setSuccessor(0, Header);
133 {DominatorTree::Delete, Preheader, Tmp},
134 {DominatorTree::Insert, Header, Body},
135 {DominatorTree::Insert, Body, Latch},
136 {DominatorTree::Insert, Latch, Header},
137 {DominatorTree::Insert, Latch, Exit},
138 {DominatorTree::Insert, Preheader, Header},
139 });
140 if (LI) {
141 L->addBasicBlockToLoop(Header, *LI);
142 L->addBasicBlockToLoop(Body, *LI);
143 L->addBasicBlockToLoop(Latch, *LI);
144 }
145 return Body;
146}
147
148template <bool IsTileLoad>
149Value *X86LowerAMXIntrinsics::createTileLoadStoreLoops(
150 BasicBlock *Start, BasicBlock *End, IRBuilderBase &B, Value *Row,
151 Value *Col, Value *Ptr, Value *Stride, Value *Tile) {
152 std::string IntrinName = IsTileLoad ? "tileload" : "tilestore";
153 Loop *RowLoop = nullptr;
154 Loop *ColLoop = nullptr;
155 if (LI) {
156 RowLoop = LI->AllocateLoop();
157 ColLoop = LI->AllocateLoop();
158 RowLoop->addChildLoop(ColLoop);
159 if (Loop *ParentL = LI->getLoopFor(Start))
160 ParentL->addChildLoop(RowLoop);
161 else
162 LI->addTopLevelLoop(RowLoop);
163 }
164
165 BasicBlock *RowBody = createLoop(Start, End, Row, B.getInt16(1),
166 IntrinName + ".scalarize.rows", B, RowLoop);
167 BasicBlock *RowLatch = RowBody->getSingleSuccessor();
168
169 BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col, B.getInt16(1),
170 IntrinName + ".scalarize.cols", B, ColLoop);
171
172 BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor();
173 BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor();
174 BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor();
175 Value *CurrentRow = &*RowLoopHeader->begin();
176 Value *CurrentCol = &*ColLoopHeader->begin();
177 Type *EltTy = B.getInt32Ty();
178 FixedVectorType *V256I32Ty = FixedVectorType::get(EltTy, 256);
179
180 // Common part for tileload and tilestore
181 // *.scalarize.cols.body:
182 // Calculate %idxmem and %idxvec
183 B.SetInsertPoint(ColBody->getTerminator());
184 Value *CurrentRowZExt = B.CreateZExt(CurrentRow, Stride->getType());
185 Value *CurrentColZExt = B.CreateZExt(CurrentCol, Stride->getType());
186 Value *Offset =
187 B.CreateAdd(B.CreateMul(CurrentRowZExt, Stride), CurrentColZExt);
188 Value *EltPtr = B.CreateGEP(EltTy, Ptr, Offset);
189 Value *Idx = B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol);
190 if (IsTileLoad) {
191 // tileload.scalarize.rows.header:
192 // %vec.phi.row = phi <256 x i32> [ zeroinitializer, %entry ], [ %ResVec,
193 // %tileload.scalarize.rows.latch ]
194 B.SetInsertPoint(RowLoopHeader->getTerminator());
195 Value *VecZero = Constant::getNullValue(V256I32Ty);
196 PHINode *VecCPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.phi.row");
197 VecCPhiRowLoop->addIncoming(VecZero, Start);
198
199 // tileload.scalarize.cols.header:
200 // %vec.phi = phi <256 x i32> [ %vec.phi.row, %tileload.scalarize.rows.body
201 // ], [ %ResVec, %tileload.scalarize.cols.latch ]
202 B.SetInsertPoint(ColLoopHeader->getTerminator());
203 PHINode *VecPhi = B.CreatePHI(V256I32Ty, 2, "vec.phi");
204 VecPhi->addIncoming(VecCPhiRowLoop, RowBody);
205
206 // tileload.scalarize.cols.body:
207 // Calculate %idxmem and %idxvec
208 // %eltptr = getelementptr i32, i32* %base, i64 %idxmem
209 // %elt = load i32, i32* %ptr
210 // %ResVec = insertelement <256 x i32> %vec.phi, i32 %elt, i16 %idxvec
211 B.SetInsertPoint(ColBody->getTerminator());
212 Value *Elt = B.CreateLoad(EltTy, EltPtr);
213 Value *ResVec = B.CreateInsertElement(VecPhi, Elt, Idx);
214 VecPhi->addIncoming(ResVec, ColLoopLatch);
215 VecCPhiRowLoop->addIncoming(ResVec, RowLatch);
216
217 return ResVec;
218 } else {
219 auto *BitCast = cast<BitCastInst>(Tile);
220 Value *Vec = BitCast->getOperand(0);
221 assert(isV256I32Ty(Vec->getType()) && "bitcast from non-v256i32 to x86amx");
222 // tilestore.scalarize.cols.body:
223 // %mul = mul i16 %row.iv, i16 16
224 // %idx = add i16 %mul, i16 %col.iv
225 // %vec = extractelement <16 x i32> %vec, i16 %idx
226 // store i32 %vec, i32* %ptr
227 B.SetInsertPoint(ColBody->getTerminator());
228 Value *Elt = B.CreateExtractElement(Vec, Idx);
229
230 B.CreateStore(Elt, EltPtr);
231 return nullptr;
232 }
233}
234
235template <Intrinsic::ID IntrID>
236std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal ||
237 IntrID == Intrinsic::x86_tdpbsud_internal ||
238 IntrID == Intrinsic::x86_tdpbusd_internal ||
239 IntrID == Intrinsic::x86_tdpbuud_internal ||
240 IntrID == Intrinsic::x86_tdpbf16ps_internal,
241 Value *>
242X86LowerAMXIntrinsics::createTileDPLoops(BasicBlock *Start, BasicBlock *End,
243 IRBuilderBase &B, Value *Row,
244 Value *Col, Value *K, Value *Acc,
245 Value *LHS, Value *RHS) {
246 std::string IntrinName;
247 switch (IntrID) {
248 case Intrinsic::x86_tdpbssd_internal:
249 IntrinName = "tiledpbssd";
250 break;
251 case Intrinsic::x86_tdpbsud_internal:
252 IntrinName = "tiledpbsud";
253 break;
254 case Intrinsic::x86_tdpbusd_internal:
255 IntrinName = "tiledpbusd";
256 break;
257 case Intrinsic::x86_tdpbuud_internal:
258 IntrinName = "tiledpbuud";
259 break;
260 case Intrinsic::x86_tdpbf16ps_internal:
261 IntrinName = "tiledpbf16ps";
262 break;
263 }
264 Loop *RowLoop = nullptr;
265 Loop *ColLoop = nullptr;
266 Loop *InnerLoop = nullptr;
267 if (LI) {
268 RowLoop = LI->AllocateLoop();
269 ColLoop = LI->AllocateLoop();
270 InnerLoop = LI->AllocateLoop();
271 ColLoop->addChildLoop(InnerLoop);
272 RowLoop->addChildLoop(ColLoop);
273 if (Loop *ParentL = LI->getLoopFor(Start))
274 ParentL->addChildLoop(RowLoop);
275 else
276 LI->addTopLevelLoop(RowLoop);
277 }
278
279 BasicBlock *RowBody = createLoop(Start, End, Row, B.getInt16(1),
280 IntrinName + ".scalarize.rows", B, RowLoop);
281 BasicBlock *RowLatch = RowBody->getSingleSuccessor();
282
283 BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col, B.getInt16(1),
284 IntrinName + ".scalarize.cols", B, ColLoop);
285
286 BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor();
287
288 B.SetInsertPoint(ColBody->getTerminator());
289 BasicBlock *InnerBody =
290 createLoop(ColBody, ColLoopLatch, K, B.getInt16(1),
291 IntrinName + ".scalarize.inner", B, InnerLoop);
292
293 BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor();
294 BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor();
295 BasicBlock *InnerLoopHeader = InnerBody->getSinglePredecessor();
296 BasicBlock *InnerLoopLatch = InnerBody->getSingleSuccessor();
297 Value *CurrentRow = &*RowLoopHeader->begin();
298 Value *CurrentCol = &*ColLoopHeader->begin();
299 Value *CurrentInner = &*InnerLoopHeader->begin();
300
301 FixedVectorType *V256I32Ty = FixedVectorType::get(B.getInt32Ty(), 256);
302 auto *BitCastAcc = cast<BitCastInst>(Acc);
303 Value *VecC = BitCastAcc->getOperand(0);
304 assert(isV256I32Ty(VecC->getType()) && "bitcast from non-v256i32 to x86amx");
305 // TODO else create BitCast from x86amx to v256i32.
306 // Store x86amx to memory, and reload from memory
307 // to vector. However with -O0, it doesn't happen.
308 auto *BitCastLHS = cast<BitCastInst>(LHS);
309 Value *VecA = BitCastLHS->getOperand(0);
310 assert(isV256I32Ty(VecA->getType()) && "bitcast from non-v256i32 to x86amx");
311 auto *BitCastRHS = cast<BitCastInst>(RHS);
312 Value *VecB = BitCastRHS->getOperand(0);
313 assert(isV256I32Ty(VecB->getType()) && "bitcast from non-v256i32 to x86amx");
314
315 // tiledpbssd.scalarize.rows.header:
316 // %vec.c.phi.row = phi <256 x i32> [ %VecC, %continue ], [ %NewVecC,
317 // %tiledpbssd.scalarize.rows.latch ]
318
319 // %vec.d.phi.row = phi <256 x i32> [ zeroinitializer, %continue ], [
320 // %NewVecD, %tiledpbssd.scalarize.rows.latch ]
321 B.SetInsertPoint(RowLoopHeader->getTerminator());
322 PHINode *VecCPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.c.phi.row");
323 VecCPhiRowLoop->addIncoming(VecC, Start);
324 Value *VecZero = Constant::getNullValue(V256I32Ty);
325 PHINode *VecDPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.d.phi.row");
326 VecDPhiRowLoop->addIncoming(VecZero, Start);
327
328 // tiledpbssd.scalarize.cols.header:
329 // %vec.c.phi.col = phi <256 x i32> [ %vec.c.phi.row,
330 // %tiledpbssd.scalarize.rows.body ], [ %NewVecC,
331 // %tiledpbssd.scalarize.cols.latch ]
332
333 // %vec.d.phi.col = phi <256 x i32> [
334 // %vec.d.phi.row, %tiledpbssd.scalarize.rows.body ], [ %NewVecD,
335 // %tiledpbssd.scalarize.cols.latch ]
336
337 // calculate idxc.
338 B.SetInsertPoint(ColLoopHeader->getTerminator());
339 PHINode *VecCPhiColLoop = B.CreatePHI(V256I32Ty, 2, "vec.c.phi.col");
340 VecCPhiColLoop->addIncoming(VecCPhiRowLoop, RowBody);
341 PHINode *VecDPhiColLoop = B.CreatePHI(V256I32Ty, 2, "vec.d.phi.col");
342 VecDPhiColLoop->addIncoming(VecDPhiRowLoop, RowBody);
343 Value *IdxC =
344 B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol);
345
346 // tiledpbssd.scalarize.inner.header:
347 // %vec.c.inner.phi = phi <256 x i32> [ %vec.c.phi.col,
348 // %tiledpbssd.scalarize.cols.body ], [ %NewVecC,
349 // %tiledpbssd.scalarize.inner.latch ]
350
351 B.SetInsertPoint(InnerLoopHeader->getTerminator());
352 PHINode *VecCPhi = B.CreatePHI(V256I32Ty, 2, "vec.c.inner.phi");
353 VecCPhi->addIncoming(VecCPhiColLoop, ColBody);
354
355 B.SetInsertPoint(InnerBody->getTerminator());
356 Value *IdxA =
357 B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentInner);
358 Value *IdxB =
359 B.CreateAdd(B.CreateMul(CurrentInner, B.getInt16(16)), CurrentCol);
360 Value *NewVecC = nullptr;
361
362 if (IntrID != Intrinsic::x86_tdpbf16ps_internal) {
363 // tiledpbssd.scalarize.inner.body:
364 // calculate idxa, idxb
365 // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc
366 // %elta = extractelement <256 x i32> %veca, i16 %idxa
367 // %eltav4i8 = bitcast i32 %elta to <4 x i8>
368 // %eltb = extractelement <256 x i32> %vecb, i16 %idxb
369 // %eltbv4i8 = bitcast i32 %eltb to <4 x i8>
370 // %eltav4i32 = sext <4 x i8> %eltav4i8 to <4 x i32>
371 // %eltbv4i32 = sext <4 x i8> %eltbv4i8 to <4 x i32>
372 // %mulab = mul <4 x i32> %eltbv4i32, %eltav4i32
373 // %acc = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %131)
374 // %neweltc = add i32 %elt, %acc
375 // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc,
376 // i16 %idxc
377 FixedVectorType *V4I8Ty = FixedVectorType::get(B.getInt8Ty(), 4);
378 FixedVectorType *V4I32Ty = FixedVectorType::get(B.getInt32Ty(), 4);
379 Value *EltC = B.CreateExtractElement(VecCPhi, IdxC);
380 Value *EltA = B.CreateExtractElement(VecA, IdxA);
381 Value *SubVecA = B.CreateBitCast(EltA, V4I8Ty);
382 Value *EltB = B.CreateExtractElement(VecB, IdxB);
383 Value *SubVecB = B.CreateBitCast(EltB, V4I8Ty);
384 Value *SEXTSubVecB = nullptr;
385 Value *SEXTSubVecA = nullptr;
386 switch (IntrID) {
387 case Intrinsic::x86_tdpbssd_internal:
388 SEXTSubVecB = B.CreateSExt(SubVecB, V4I32Ty);
389 SEXTSubVecA = B.CreateSExt(SubVecA, V4I32Ty);
390 break;
391 case Intrinsic::x86_tdpbsud_internal:
392 SEXTSubVecB = B.CreateZExt(SubVecB, V4I32Ty);
393 SEXTSubVecA = B.CreateSExt(SubVecA, V4I32Ty);
394 break;
395 case Intrinsic::x86_tdpbusd_internal:
396 SEXTSubVecB = B.CreateSExt(SubVecB, V4I32Ty);
397 SEXTSubVecA = B.CreateZExt(SubVecA, V4I32Ty);
398 break;
399 case Intrinsic::x86_tdpbuud_internal:
400 SEXTSubVecB = B.CreateZExt(SubVecB, V4I32Ty);
401 SEXTSubVecA = B.CreateZExt(SubVecA, V4I32Ty);
402 break;
403 default:
404 llvm_unreachable("Invalid intrinsic ID!");
405 }
406 Value *SubVecR = B.CreateAddReduce(B.CreateMul(SEXTSubVecA, SEXTSubVecB));
407 Value *ResElt = B.CreateAdd(EltC, SubVecR);
408 NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC);
409 } else {
410 // tiledpbf16ps.scalarize.inner.body:
411 // calculate idxa, idxb, idxc
412 // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc
413 // %eltcf32 = bitcast i32 %eltc to float
414 // %elta = extractelement <256 x i32> %veca, i16 %idxa
415 // %eltav2i16 = bitcast i32 %elta to <2 x i16>
416 // %eltb = extractelement <256 x i32> %vecb, i16 %idxb
417 // %eltbv2i16 = bitcast i32 %eltb to <2 x i16>
418 // %shufflea = shufflevector <2 x i16> %elta, <2 x i16> zeroinitializer, <4
419 // x i32> <i32 2, i32 0, i32 3, i32 1>
420 // %eltav2f32 = bitcast <4 x i16> %shufflea to <2 x float>
421 // %shuffleb = shufflevector <2 x i16> %eltb, <2 xi16> zeroinitializer, <4 x
422 // i32> <i32 2, i32 0, i32 3, i32 1>
423 // %eltbv2f32 = bitcast <4 x i16> %shuffleb to <2 x float>
424 // %mulab = fmul <2 x float> %eltav2f32, %eltbv2f32
425 // %acc = call float
426 // @llvm.vector.reduce.fadd.v2f32(float %eltcf32, <2 x float> %mulab)
427 // %neweltc = bitcast float %acc to i32
428 // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc,
429 // i16 %idxc
430 // %NewVecD = insertelement <256 x i32> %vec.d.inner.phi, i32 %neweltc,
431 // i16 %idxc
432 FixedVectorType *V2I16Ty = FixedVectorType::get(B.getInt16Ty(), 2);
433 FixedVectorType *V2F32Ty = FixedVectorType::get(B.getFloatTy(), 2);
434 Value *EltC = B.CreateExtractElement(VecCPhi, IdxC);
435 Value *EltCF32 = B.CreateBitCast(EltC, B.getFloatTy());
436 Value *EltA = B.CreateExtractElement(VecA, IdxA);
437 Value *SubVecA = B.CreateBitCast(EltA, V2I16Ty);
438 Value *EltB = B.CreateExtractElement(VecB, IdxB);
439 Value *SubVecB = B.CreateBitCast(EltB, V2I16Ty);
440 Value *ZeroV2I16 = Constant::getNullValue(V2I16Ty);
441 int ShuffleMask[4] = {2, 0, 3, 1};
442 auto ShuffleArray = ArrayRef(ShuffleMask);
443 Value *AV2F32 = B.CreateBitCast(
444 B.CreateShuffleVector(SubVecA, ZeroV2I16, ShuffleArray), V2F32Ty);
445 Value *BV2F32 = B.CreateBitCast(
446 B.CreateShuffleVector(SubVecB, ZeroV2I16, ShuffleArray), V2F32Ty);
447 Value *SubVecR = B.CreateFAddReduce(EltCF32, B.CreateFMul(AV2F32, BV2F32));
448 Value *ResElt = B.CreateBitCast(SubVecR, B.getInt32Ty());
449 NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC);
450 }
451
452 // tiledpbssd.scalarize.cols.latch:
453 // %NewEltC = extractelement <256 x i32> %vec.c.phi.col, i16 %idxc
454 // %NewVecD = insertelement <256 x i32> %vec.d.phi.col, i32 %NewEltC,
455 // i16 %idxc
456 B.SetInsertPoint(ColLoopLatch->getTerminator());
457 Value *NewEltC = B.CreateExtractElement(NewVecC, IdxC);
458 Value *NewVecD = B.CreateInsertElement(VecDPhiColLoop, NewEltC, IdxC);
459
460 VecCPhi->addIncoming(NewVecC, InnerLoopLatch);
461 VecCPhiRowLoop->addIncoming(NewVecC, RowLatch);
462 VecCPhiColLoop->addIncoming(NewVecC, ColLoopLatch);
463 VecDPhiRowLoop->addIncoming(NewVecD, RowLatch);
464 VecDPhiColLoop->addIncoming(NewVecD, ColLoopLatch);
465
466 return NewVecD;
467}
468
469template <Intrinsic::ID IntrID>
470std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal ||
471 IntrID == Intrinsic::x86_tdpbsud_internal ||
472 IntrID == Intrinsic::x86_tdpbusd_internal ||
473 IntrID == Intrinsic::x86_tdpbuud_internal ||
474 IntrID == Intrinsic::x86_tdpbf16ps_internal,
475 bool>
476X86LowerAMXIntrinsics::lowerTileDP(Instruction *TileDP) {
477 Value *M, *N, *K, *C, *A, *B;
479 m_Value(C), m_Value(A), m_Value(B)));
480 Instruction *InsertI = TileDP;
481 IRBuilder<> PreBuilder(TileDP);
482 PreBuilder.SetInsertPoint(TileDP);
483 // We visit the loop with (m, n/4, k/4):
484 // %n_dword = lshr i16 %n, 2
485 // %k_dword = lshr i16 %k, 2
486 Value *NDWord = PreBuilder.CreateLShr(N, PreBuilder.getInt16(2));
487 Value *KDWord = PreBuilder.CreateLShr(K, PreBuilder.getInt16(2));
488 BasicBlock *Start = InsertI->getParent();
489 BasicBlock *End =
490 SplitBlock(InsertI->getParent(), InsertI, &DTU, LI, nullptr, "continue");
491 IRBuilder<> Builder(TileDP);
492 Value *ResVec = createTileDPLoops<IntrID>(Start, End, Builder, M, NDWord,
493 KDWord, C, A, B);
494 // we cannot assume there always be bitcast after tiledpbssd. So we need to
495 // insert one bitcast as required
496 Builder.SetInsertPoint(End, End->getFirstNonPHIIt());
497 Value *ResAMX =
498 Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext()));
499 // Delete TileDP intrinsic and do some clean-up.
500 for (Use &U : llvm::make_early_inc_range(TileDP->uses())) {
501 Instruction *I = cast<Instruction>(U.getUser());
502 Value *Vec;
503 if (match(I, m_BitCast(m_Value(Vec)))) {
504 I->replaceAllUsesWith(ResVec);
505 I->eraseFromParent();
506 }
507 }
508 TileDP->replaceAllUsesWith(ResAMX);
509 TileDP->eraseFromParent();
510 return true;
511}
512
513template <bool IsTileLoad>
514bool X86LowerAMXIntrinsics::lowerTileLoadStore(Instruction *TileLoadStore) {
515 Value *M, *N, *Ptr, *Stride, *Tile;
516 if (IsTileLoad)
517 match(TileLoadStore,
519 m_Value(M), m_Value(N), m_Value(Ptr), m_Value(Stride)));
520 else
522 m_Value(M), m_Value(N), m_Value(Ptr),
523 m_Value(Stride), m_Value(Tile)));
524
525 Instruction *InsertI = TileLoadStore;
526 IRBuilder<> PreBuilder(TileLoadStore);
527 PreBuilder.SetInsertPoint(TileLoadStore);
528 Value *NDWord = PreBuilder.CreateLShr(N, PreBuilder.getInt16(2));
529 Value *StrideDWord = PreBuilder.CreateLShr(Stride, PreBuilder.getInt64(2));
530 BasicBlock *Start = InsertI->getParent();
531 BasicBlock *End =
532 SplitBlock(InsertI->getParent(), InsertI, &DTU, LI, nullptr, "continue");
533 IRBuilder<> Builder(TileLoadStore);
534 Value *ResVec = createTileLoadStoreLoops<IsTileLoad>(
535 Start, End, Builder, M, NDWord, Ptr, StrideDWord,
536 IsTileLoad ? nullptr : Tile);
537 if (IsTileLoad) {
538 // we cannot assume there always be bitcast after tileload. So we need to
539 // insert one bitcast as required
540 Builder.SetInsertPoint(End, End->getFirstNonPHIIt());
541 Value *ResAMX =
542 Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext()));
543 // Delete tileloadd6 intrinsic and do some clean-up
544 for (Use &U : llvm::make_early_inc_range(TileLoadStore->uses())) {
545 Instruction *I = cast<Instruction>(U.getUser());
546 Value *Vec;
547 if (match(I, m_BitCast(m_Value(Vec)))) {
548 I->replaceAllUsesWith(ResVec);
549 I->eraseFromParent();
550 }
551 }
552 TileLoadStore->replaceAllUsesWith(ResAMX);
553 }
554 TileLoadStore->eraseFromParent();
555 return true;
556}
557
558bool X86LowerAMXIntrinsics::lowerTileZero(Instruction *TileZero) {
559 IRBuilder<> Builder(TileZero);
560 FixedVectorType *V256I32Ty = FixedVectorType::get(Builder.getInt32Ty(), 256);
561 Value *VecZero = Constant::getNullValue(V256I32Ty);
562 for (Use &U : llvm::make_early_inc_range(TileZero->uses())) {
563 Instruction *I = cast<Instruction>(U.getUser());
564 Value *Vec;
565 if (match(I, m_BitCast(m_Value(Vec)))) {
566 I->replaceAllUsesWith(VecZero);
567 I->eraseFromParent();
568 }
569 }
570 TileZero->eraseFromParent();
571 return true;
572}
573
574bool X86LowerAMXIntrinsics::visit() {
575 bool C = false;
577 for (BasicBlock *BB : depth_first(&Func)) {
578 for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) {
579 if (auto *Inst = dyn_cast<IntrinsicInst>(&*II++)) {
580 switch (Inst->getIntrinsicID()) {
581 case Intrinsic::x86_tdpbssd_internal:
582 case Intrinsic::x86_tdpbsud_internal:
583 case Intrinsic::x86_tdpbusd_internal:
584 case Intrinsic::x86_tdpbuud_internal:
585 case Intrinsic::x86_tileloadd64_internal:
586 case Intrinsic::x86_tilestored64_internal:
587 case Intrinsic::x86_tilezero_internal:
588 case Intrinsic::x86_tdpbf16ps_internal:
589 WorkList.push_back(Inst);
590 break;
591 default:
592 break;
593 }
594 }
595 }
596 }
597
598 for (auto *Inst : WorkList) {
599 switch (Inst->getIntrinsicID()) {
600 case Intrinsic::x86_tdpbssd_internal:
601 C = lowerTileDP<Intrinsic::x86_tdpbssd_internal>(Inst) || C;
602 break;
603 case Intrinsic::x86_tdpbsud_internal:
604 C = lowerTileDP<Intrinsic::x86_tdpbsud_internal>(Inst) || C;
605 break;
606 case Intrinsic::x86_tdpbusd_internal:
607 C = lowerTileDP<Intrinsic::x86_tdpbusd_internal>(Inst) || C;
608 break;
609 case Intrinsic::x86_tdpbuud_internal:
610 C = lowerTileDP<Intrinsic::x86_tdpbuud_internal>(Inst) || C;
611 break;
612 case Intrinsic::x86_tdpbf16ps_internal:
613 C = lowerTileDP<Intrinsic::x86_tdpbf16ps_internal>(Inst) || C;
614 break;
615 case Intrinsic::x86_tileloadd64_internal:
616 C = lowerTileLoadStore<true>(Inst) || C;
617 break;
618 case Intrinsic::x86_tilestored64_internal:
619 C = lowerTileLoadStore<false>(Inst) || C;
620 break;
621 case Intrinsic::x86_tilezero_internal:
622 C = lowerTileZero(Inst) || C;
623 break;
624 default:
625 llvm_unreachable("invalid amx intrinsics!");
626 }
627 }
628
629 return C;
630}
631
632namespace {
633bool shouldRunLowerAMXIntrinsics(const Function &F, const TargetMachine *TM) {
634 return X86ScalarizeAMX && (F.hasFnAttribute(Attribute::OptimizeNone) ||
635 TM->getOptLevel() == CodeGenOptLevel::None);
636}
637
638bool runLowerAMXIntrinsics(Function &F, DominatorTree *DT, LoopInfo *LI) {
639 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
640
641 X86LowerAMXIntrinsics LAT(F, DTU, LI);
642 return LAT.visit();
643}
644} // namespace
645
648 if (!shouldRunLowerAMXIntrinsics(F, TM))
649 return PreservedAnalyses::all();
650
651 DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F);
652 LoopInfo &LI = FAM.getResult<LoopAnalysis>(F);
653 bool Changed = runLowerAMXIntrinsics(F, &DT, &LI);
654 if (!Changed)
655 return PreservedAnalyses::all();
656
660 return PA;
661}
662
663namespace {
664class X86LowerAMXIntrinsicsLegacyPass : public FunctionPass {
665public:
666 static char ID;
667
668 X86LowerAMXIntrinsicsLegacyPass() : FunctionPass(ID) {}
669
670 bool runOnFunction(Function &F) override {
671 TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
672 if (!shouldRunLowerAMXIntrinsics(F, TM))
673 return false;
674
675 auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
676 auto *DT = DTWP ? &DTWP->getDomTree() : nullptr;
677 auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>();
678 auto *LI = LIWP ? &LIWP->getLoopInfo() : nullptr;
679 return runLowerAMXIntrinsics(F, DT, LI);
680 }
681 StringRef getPassName() const override { return "Lower AMX intrinsics"; }
682
683 void getAnalysisUsage(AnalysisUsage &AU) const override {
684 AU.addPreserved<DominatorTreeWrapperPass>();
685 AU.addPreserved<LoopInfoWrapperPass>();
686 AU.addRequired<TargetPassConfig>();
687 }
688};
689} // namespace
690
691static const char PassName[] = "Lower AMX intrinsics";
692char X86LowerAMXIntrinsicsLegacyPass::ID = 0;
693INITIALIZE_PASS_BEGIN(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName,
694 false, false)
696INITIALIZE_PASS_END(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName,
698
700 return new X86LowerAMXIntrinsicsLegacyPass();
701}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static bool runOnFunction(Function &F, bool PostInlining)
#define DEBUG_TYPE
This header defines various interfaces for pass management in LLVM.
#define F(x, y, z)
Definition MD5.cpp:55
#define I(x, y, z)
Definition MD5.cpp:58
uint64_t IntrinsicInst * II
FunctionAnalysisManager FAM
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition PassSupport.h:42
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition PassSupport.h:39
const SmallVectorImpl< MachineOperand > & Cond
void visit(MachineFunction &MF, MachineBasicBlock &Start, std::function< void(MachineBasicBlock *)> op)
Target-Independent Code Generator Pass Configuration Options pass.
This pass exposes codegen information to IR-level passes.
static cl::opt< bool > X86ScalarizeAMX("enable-x86-scalar-amx", cl::init(false), cl::Hidden, cl::desc("X86: enable AMX scalarizition."))
static bool isV256I32Ty(Type *Ty)
static const char PassName[]
Value * RHS
Value * LHS
static const uint32_t IV[8]
Definition blake3_impl.h:83
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
LLVM Basic Block Representation.
Definition BasicBlock.h:62
iterator begin()
Instruction iterator methods.
Definition BasicBlock.h:459
const Function * getParent() const
Return the enclosing method, or null if none.
Definition BasicBlock.h:213
LLVM_ABI InstListType::const_iterator getFirstNonPHIIt() const
Returns an iterator to the first instruction in this block that is not a PHINode instruction.
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
Definition BasicBlock.h:206
LLVM_ABI const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
LLVM_ABI const BasicBlock * getSingleSuccessor() const
Return the successor of this block if it has a single successor.
InstListType::iterator iterator
Instruction iterators...
Definition BasicBlock.h:170
LLVM_ABI LLVMContext & getContext() const
Get the context in which this basic block lives.
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition BasicBlock.h:233
static BranchInst * Create(BasicBlock *IfTrue, InsertPosition InsertBefore=nullptr)
BasicBlock * getSuccessor(unsigned i) const
void setSuccessor(unsigned idx, BasicBlock *NewSucc)
static LLVM_ABI Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
Analysis pass which computes a DominatorTree.
Definition Dominators.h:284
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition Dominators.h:165
static LLVM_ABI FixedVectorType * get(Type *ElementType, unsigned NumElts)
Definition Type.cpp:803
FunctionPass class - This class is used to implement most global optimizations.
Definition Pass.h:314
void applyUpdatesPermissive(ArrayRef< UpdateT > Updates)
Submit updates to all available trees.
Common base class shared among various IRBuilders.
Definition IRBuilder.h:114
LLVM_ABI InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
Analysis pass that exposes the LoopInfo for a function.
Definition LoopInfo.h:569
void addChildLoop(LoopT *NewChild)
Add the specified loop to be a child of this loop.
void addTopLevelLoop(LoopT *New)
This adds the specified loop to the collection of top-level loops.
LoopT * AllocateLoop(ArgsTy &&...Args)
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
Represents a single loop in the control flow graph.
Definition LoopInfo.h:40
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
static PHINode * Create(Type *Ty, unsigned NumReservedValues, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
Constructors - NumReservedValues is a hint for the number of incoming edges that this phi node will h...
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses none()
Convenience factory function for the empty preserved set.
Definition Analysis.h:115
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalyses & preserve()
Mark an analysis as preserved.
Definition Analysis.h:132
void push_back(const T &Elt)
StringRef - Represent a constant reference to a string, i.e.
Definition StringRef.h:55
Primary interface to the complete machine description for the target machine.
Target-Independent Code Generator Pass Configuration Options.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:45
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:256
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition Value.cpp:538
iterator_range< use_iterator > uses()
Definition Value.h:380
PreservedAnalyses run(Function &F, FunctionAnalysisManager &FAM)
const ParentTy * getParent() const
Definition ilist_node.h:34
Changed
Pass manager infrastructure for declaring and invalidating analyses.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
@ BasicBlock
Various leaf nodes.
Definition ISDOpcodes.h:81
bool match(Val *V, const Pattern &P)
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
CastOperator_match< OpTy, Instruction::BitCast > m_BitCast(const OpTy &Op)
Matches BitCast.
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
initializer< Ty > init(const Ty &Val)
friend class Instruction
Iterator for Instructions in a `BasicBlock.
Definition BasicBlock.h:73
This is an optimization pass for GlobalISel generic memory operations.
@ Offset
Definition DWP.cpp:477
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
FunctionPass * createX86LowerAMXIntrinsicsLegacyPass()
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition STLExtras.h:632
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
ArrayRef(const T &OneElt) -> ArrayRef< T >
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
LLVM_ABI BasicBlock * SplitBlock(BasicBlock *Old, BasicBlock::iterator SplitPt, DominatorTree *DT, LoopInfo *LI=nullptr, MemorySSAUpdater *MSSAU=nullptr, const Twine &BBName="", bool Before=false)
Split the specified block at the specified instruction.
iterator_range< df_iterator< T > > depth_first(const T &G)
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
#define N