LLVM  13.0.0git
X86LowerAMXType.cpp
Go to the documentation of this file.
1 //===- llvm/CodeGen/TileShapeInfo.h - ---------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Pass to transform <256 x i32> load/store
10 /// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only
11 /// provides simple operation on x86_amx. The basic elementwise operation
12 /// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32>
13 /// and only AMX intrinsics can operate on the type, we need transform
14 /// load/store <256 x i32> instruction to AMX load/store. If the bitcast can
15 /// not be combined with load/store, we transform the bitcast to amx load/store
16 /// and <256 x i32> store/load.
17 //
18 //===----------------------------------------------------------------------===//
19 //
20 #include "X86.h"
22 #include "llvm/ADT/SmallSet.h"
25 #include "llvm/CodeGen/Passes.h"
28 #include "llvm/IR/DataLayout.h"
29 #include "llvm/IR/Function.h"
30 #include "llvm/IR/IRBuilder.h"
31 #include "llvm/IR/Instructions.h"
32 #include "llvm/IR/IntrinsicInst.h"
33 #include "llvm/IR/IntrinsicsX86.h"
34 #include "llvm/IR/PatternMatch.h"
35 #include "llvm/InitializePasses.h"
36 #include "llvm/Pass.h"
38 
39 using namespace llvm;
40 using namespace PatternMatch;
41 
42 #define DEBUG_TYPE "lower-amx-type"
43 
45  Function &F = *BB->getParent();
46  Module *M = BB->getModule();
47  const DataLayout &DL = M->getDataLayout();
48 
49  Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false);
50  LLVMContext &Ctx = Builder.getContext();
51  auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx));
52  unsigned AllocaAS = DL.getAllocaAddrSpace();
53  AllocaInst *AllocaRes =
54  new AllocaInst(V256I32Ty, AllocaAS, "", &F.getEntryBlock().front());
55  AllocaRes->setAlignment(AllocaAlignment);
56  return AllocaRes;
57 }
58 
59 static std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) {
60  Value *Row = nullptr, *Col = nullptr;
61  switch (II->getIntrinsicID()) {
62  default:
63  llvm_unreachable("Expect amx intrinsics");
64  case Intrinsic::x86_tileloadd64_internal:
65  case Intrinsic::x86_tilestored64_internal: {
66  Row = II->getArgOperand(0);
67  Col = II->getArgOperand(1);
68  break;
69  }
70  // a * b + c
71  // The shape depends on which operand.
72  case Intrinsic::x86_tdpbssd_internal:
73  case Intrinsic::x86_tdpbsud_internal:
74  case Intrinsic::x86_tdpbusd_internal:
75  case Intrinsic::x86_tdpbuud_internal:
76  case Intrinsic::x86_tdpbf16ps_internal: {
77  switch (OpNo) {
78  case 3:
79  Row = II->getArgOperand(0);
80  Col = II->getArgOperand(1);
81  break;
82  case 4:
83  Row = II->getArgOperand(0);
84  Col = II->getArgOperand(2);
85  break;
86  case 5:
87  Row = II->getArgOperand(2);
88  Col = II->getArgOperand(1);
89  break;
90  }
91  break;
92  }
93  }
94 
95  return std::make_pair(Row, Col);
96 }
97 
98 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
99 // %2 = bitcast <256 x i32> %src to x86_amx
100 // -->
101 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
102 // i8* %addr, i64 %stride64)
104  Value *Row = nullptr, *Col = nullptr;
105  Use &U = *(Bitcast->use_begin());
106  unsigned OpNo = U.getOperandNo();
107  auto *II = cast<IntrinsicInst>(U.getUser());
108  std::tie(Row, Col) = getShape(II, OpNo);
110  // Use the maximun column as stride.
111  Value *Stride = Builder.getInt64(64);
112  Value *I8Ptr =
113  Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy());
114  std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
115 
116  Value *NewInst =
117  Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args);
118  Bitcast->replaceAllUsesWith(NewInst);
119 }
120 
121 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
122 // %stride);
123 // %13 = bitcast x86_amx %src to <256 x i32>
124 // store <256 x i32> %13, <256 x i32>* %addr, align 64
125 // -->
126 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
127 // %stride64, %13)
129 
130  Value *Tile = Bitcast->getOperand(0);
131  auto *II = cast<IntrinsicInst>(Tile);
132  // Tile is output from AMX intrinsic. The first operand of the
133  // intrinsic is row, the second operand of the intrinsic is column.
134  Value *Row = II->getOperand(0);
135  Value *Col = II->getOperand(1);
137  // Use the maximum column as stride. It must be the same with load
138  // stride.
139  Value *Stride = Builder.getInt64(64);
140  Value *I8Ptr =
141  Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy());
142  std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
143  Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
144  if (Bitcast->hasOneUse())
145  return;
146  // %13 = bitcast x86_amx %src to <256 x i32>
147  // store <256 x i32> %13, <256 x i32>* %addr, align 64
148  // %add = <256 x i32> %13, <256 x i32> %src2
149  // -->
150  // %13 = bitcast x86_amx %src to <256 x i32>
151  // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
152  // %stride64, %13)
153  // %14 = load <256 x i32>, %addr
154  // %add = <256 x i32> %14, <256 x i32> %src2
155  Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1));
156  Bitcast->replaceAllUsesWith(Vec);
157 }
158 
159 // transform bitcast to <store, load> instructions.
162  AllocaInst *AllocaAddr;
163  Value *I8Ptr, *Stride;
164  auto *Src = Bitcast->getOperand(0);
165 
166  auto Prepare = [&]() {
167  AllocaAddr = CreateAllocaInst(Builder, Bitcast->getParent());
168  I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy());
169  Stride = Builder.getInt64(64);
170  };
171 
172  if (Bitcast->getType()->isX86_AMXTy()) {
173  // %2 = bitcast <256 x i32> %src to x86_amx
174  // -->
175  // %addr = alloca <256 x i32>, align 64
176  // store <256 x i32> %src, <256 x i32>* %addr, align 64
177  // %addr2 = bitcast <256 x i32>* to i8*
178  // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
179  // i8* %addr2,
180  // i64 64)
181  Use &U = *(Bitcast->use_begin());
182  unsigned OpNo = U.getOperandNo();
183  auto *II = dyn_cast<IntrinsicInst>(U.getUser());
184  if (!II)
185  return false; // May be bitcast from x86amx to <256 x i32>.
186  Prepare();
187  Builder.CreateStore(Src, AllocaAddr);
188  // TODO we can pick an constant operand for the shape.
189  Value *Row = nullptr, *Col = nullptr;
190  std::tie(Row, Col) = getShape(II, OpNo);
191  std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
192  Value *NewInst = Builder.CreateIntrinsic(
193  Intrinsic::x86_tileloadd64_internal, None, Args);
194  Bitcast->replaceAllUsesWith(NewInst);
195  } else {
196  // %2 = bitcast x86_amx %src to <256 x i32>
197  // -->
198  // %addr = alloca <256 x i32>, align 64
199  // %addr2 = bitcast <256 x i32>* to i8*
200  // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
201  // i8* %addr2, i64 %stride)
202  // %2 = load <256 x i32>, <256 x i32>* %addr, align 64
203  auto *II = dyn_cast<IntrinsicInst>(Src);
204  if (!II)
205  return false; // May be bitcast from <256 x i32> to x86amx.
206  Prepare();
207  Value *Row = II->getOperand(0);
208  Value *Col = II->getOperand(1);
209  std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src};
210  Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
211  Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr);
212  Bitcast->replaceAllUsesWith(NewInst);
213  }
214 
215  return true;
216 }
217 
218 namespace {
219 class X86LowerAMXType {
220  Function &Func;
221 
222 public:
223  X86LowerAMXType(Function &F) : Func(F) {}
224  bool visit();
225 };
226 
227 bool X86LowerAMXType::visit() {
229 
230  for (BasicBlock *BB : post_order(&Func)) {
231  for (BasicBlock::reverse_iterator II = BB->rbegin(), IE = BB->rend();
232  II != IE;) {
233  Instruction &Inst = *II++;
234  auto *Bitcast = dyn_cast<BitCastInst>(&Inst);
235  if (!Bitcast)
236  continue;
237 
238  Value *Src = Bitcast->getOperand(0);
239  if (Bitcast->getType()->isX86_AMXTy()) {
240  if (Bitcast->user_empty()) {
241  DeadInsts.push_back(Bitcast);
242  continue;
243  }
244  LoadInst *LD = dyn_cast<LoadInst>(Src);
245  if (!LD) {
247  DeadInsts.push_back(Bitcast);
248  continue;
249  }
250  // If load has mutli-user, duplicate a vector load.
251  // %src = load <256 x i32>, <256 x i32>* %addr, align 64
252  // %2 = bitcast <256 x i32> %src to x86_amx
253  // %add = add <256 x i32> %src, <256 x i32> %src2
254  // -->
255  // %src = load <256 x i32>, <256 x i32>* %addr, align 64
256  // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
257  // i8* %addr, i64 %stride64)
258  // %add = add <256 x i32> %src, <256 x i32> %src2
259 
260  // If load has one user, the load will be eliminated in DAG ISel.
261  // %src = load <256 x i32>, <256 x i32>* %addr, align 64
262  // %2 = bitcast <256 x i32> %src to x86_amx
263  // -->
264  // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
265  // i8* %addr, i64 %stride64)
267  DeadInsts.push_back(Bitcast);
268  if (LD->hasOneUse())
269  DeadInsts.push_back(LD);
270  } else if (Src->getType()->isX86_AMXTy()) {
271  if (Bitcast->user_empty()) {
272  DeadInsts.push_back(Bitcast);
273  continue;
274  }
275  StoreInst *ST = nullptr;
276  for (auto UI = Bitcast->use_begin(), UE = Bitcast->use_end();
277  UI != UE;) {
278  Value *I = (UI++)->getUser();
279  ST = dyn_cast<StoreInst>(I);
280  if (ST)
281  break;
282  }
283  if (!ST) {
285  DeadInsts.push_back(Bitcast);
286  continue;
287  }
288  // If bitcast (%13) has one use, combine bitcast and store to amx store.
289  // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
290  // %stride);
291  // %13 = bitcast x86_amx %src to <256 x i32>
292  // store <256 x i32> %13, <256 x i32>* %addr, align 64
293  // -->
294  // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
295  // %stride64, %13)
296  //
297  // If bitcast (%13) has multi-use, transform as below.
298  // %13 = bitcast x86_amx %src to <256 x i32>
299  // store <256 x i32> %13, <256 x i32>* %addr, align 64
300  // %add = <256 x i32> %13, <256 x i32> %src2
301  // -->
302  // %13 = bitcast x86_amx %src to <256 x i32>
303  // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
304  // %stride64, %13)
305  // %14 = load <256 x i32>, %addr
306  // %add = <256 x i32> %14, <256 x i32> %src2
307  //
309  // Delete user first.
310  DeadInsts.push_back(ST);
311  DeadInsts.push_back(Bitcast);
312  }
313  }
314  }
315 
316  bool C = !DeadInsts.empty();
317 
318  for (auto *Inst : DeadInsts)
319  Inst->eraseFromParent();
320 
321  return C;
322 }
323 } // anonymous namespace
324 
325 namespace {
326 
327 class X86LowerAMXTypeLegacyPass : public FunctionPass {
328 public:
329  static char ID;
330 
331  X86LowerAMXTypeLegacyPass() : FunctionPass(ID) {
333  }
334 
335  bool runOnFunction(Function &F) override {
336  TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
337  if (F.hasFnAttribute(Attribute::OptimizeNone) ||
338  TM->getOptLevel() == CodeGenOpt::None)
339  return false;
340  X86LowerAMXType LAT(F);
341  bool C = LAT.visit();
342  return C;
343  }
344 
345  void getAnalysisUsage(AnalysisUsage &AU) const override {
346  AU.setPreservesCFG();
348  }
349 };
350 
351 } // anonymous namespace
352 
353 static const char PassName[] = "Lower AMX type for load/store";
355 INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
356  false)
358 INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
359  false)
360 
362  return new X86LowerAMXTypeLegacyPass();
363 }
llvm::createX86LowerAMXTypePass
FunctionPass * createX86LowerAMXTypePass()
The pass transforms load/store <256 x i32> to AMX load/store intrinsics or split the data to two <128...
Definition: X86LowerAMXType.cpp:361
ValueTypes.h
INITIALIZE_PASS_BEGIN
INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false, false) INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass
llvm
Definition: AllocatorList.h:23
M
We currently emits eax Perhaps this is what we really should generate is Is imull three or four cycles eax eax The current instruction priority is based on pattern complexity The former is more complex because it folds a load so the latter will not be emitted Perhaps we should use AddedComplexity to give LEA32r a higher priority We should always try to match LEA first since the LEA matching code does some estimate to determine whether the match is profitable if we care more about code then imull is better It s two bytes shorter than movl leal On a Pentium M
Definition: README.txt:252
llvm::SystemZISD::TM
@ TM
Definition: SystemZISelLowering.h:65
llvm::DataLayout
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:112
IntrinsicInst.h
llvm::Function
Definition: Function.h:61
Pass.h
llvm::IntrinsicInst::getIntrinsicID
Intrinsic::ID getIntrinsicID() const
Return the intrinsic ID of this intrinsic.
Definition: IntrinsicInst.h:52
llvm::ARM_MB::LD
@ LD
Definition: ARMBaseInfo.h:72
llvm::BitCastInst
This class represents a no-op cast from one type to another.
Definition: Instructions.h:5136
llvm::SmallVector< Instruction *, 8 >
combineLoadBitcast
static void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast)
Definition: X86LowerAMXType.cpp:103
llvm::IRBuilder<>
OptimizationRemarkEmitter.h
llvm::Type
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:46
transformBitcast
static bool transformBitcast(BitCastInst *Bitcast)
Definition: X86LowerAMXType.cpp:160
llvm::Use::getOperandNo
unsigned getOperandNo() const
Return the operand # of this use in its User.
Definition: Use.cpp:33
INITIALIZE_PASS_END
INITIALIZE_PASS_END(RegBankSelect, DEBUG_TYPE, "Assign register bank of generic virtual registers", false, false) RegBankSelect
Definition: RegBankSelect.cpp:69
F
#define F(x, y, z)
Definition: MD5.cpp:56
DEBUG_TYPE
#define DEBUG_TYPE
Definition: X86LowerAMXType.cpp:42
llvm::BasicBlock
LLVM Basic Block Representation.
Definition: BasicBlock.h:58
X86.h
llvm::Type::getX86_AMXTy
static Type * getX86_AMXTy(LLVMContext &C)
Definition: Type.cpp:199
TargetMachine.h
llvm::PassRegistry::getPassRegistry
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
Definition: PassRegistry.cpp:31
C
(vector float) vec_cmpeq(*A, *B) C
Definition: README_ALTIVEC.txt:86
combineBitcastStore
static void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST)
Definition: X86LowerAMXType.cpp:128
llvm::AnalysisUsage
Represent the analysis usage information of a pass.
Definition: PassAnalysisSupport.h:47
false
Definition: StackSlotColoring.cpp:142
CreateAllocaInst
static AllocaInst * CreateAllocaInst(IRBuilder<> &Builder, BasicBlock *BB)
Definition: X86LowerAMXType.cpp:44
llvm::Instruction
Definition: Instruction.h:45
llvm::Use::getUser
User * getUser() const
Returns the User that contains this Use.
Definition: Use.h:73
PatternMatch.h
llvm::None
const NoneType None
Definition: None.h:23
llvm::ARM_PROC::IE
@ IE
Definition: ARMBaseInfo.h:27
Passes.h
llvm::TargetPassConfig
Target-Independent Code Generator Pass Configuration Options.
Definition: TargetPassConfig.h:84
llvm::StoreInst
An instruction for storing to memory.
Definition: Instructions.h:303
llvm::Instruction::eraseFromParent
SymbolTableList< Instruction >::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
Definition: Instruction.cpp:78
llvm::ARM_MB::ST
@ ST
Definition: ARMBaseInfo.h:73
INITIALIZE_PASS_DEPENDENCY
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
llvm::LegalizeActions::Bitcast
@ Bitcast
Perform the operation on a different, but equivalently sized type.
Definition: LegalizerInfo.h:72
llvm::LLVMContext
This is an important class for using LLVM in a threaded context.
Definition: LLVMContext.h:68
I
#define I(x, y, z)
Definition: MD5.cpp:59
TargetPassConfig.h
llvm::initializeX86LowerAMXTypeLegacyPassPass
void initializeX86LowerAMXTypeLegacyPassPass(PassRegistry &)
IRBuilder.h
llvm::TargetMachine
Primary interface to the complete machine description for the target machine.
Definition: TargetMachine.h:77
llvm::elfabi::ELFSymbolType::Func
@ Func
llvm::Module
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:67
PassName
static const char PassName[]
Definition: X86LowerAMXType.cpp:353
Builder
assume Assume Builder
Definition: AssumeBundleBuilder.cpp:649
llvm::CodeGenOpt::None
@ None
Definition: CodeGen.h:53
DataLayout.h
llvm::AnalysisUsage::setPreservesCFG
void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition: Pass.cpp:253
llvm_unreachable
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
Definition: ErrorHandling.h:136
llvm::Value::getType
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:256
DL
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Definition: AArch64SLSHardening.cpp:76
llvm::LoadInst
An instruction for reading from memory.
Definition: Instructions.h:174
runOnFunction
static bool runOnFunction(Function &F, bool PostInlining)
Definition: EntryExitInstrumenter.cpp:69
getShape
static std::pair< Value *, Value * > getShape(IntrinsicInst *II, unsigned OpNo)
Definition: X86LowerAMXType.cpp:59
llvm::post_order
iterator_range< po_iterator< T > > post_order(const T &G)
Definition: PostOrderIterator.h:188
Function.h
llvm::IntrinsicInst
A wrapper class for inspecting calls to intrinsic functions.
Definition: IntrinsicInst.h:45
llvm::AllocaInst::setAlignment
void setAlignment(Align Align)
Definition: Instructions.h:123
llvm::BasicBlock::reverse_iterator
InstListType::reverse_iterator reverse_iterator
Definition: BasicBlock.h:92
Instructions.h
PostOrderIterator.h
llvm::CallBase::getArgOperand
Value * getArgOperand(unsigned i) const
Definition: InstrTypes.h:1341
TargetTransformInfo.h
llvm::FunctionPass
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:298
BB
Common register allocation spilling lr str ldr sxth r3 ldr mla r4 can lr mov lr str ldr sxth r3 mla r4 and then merge mul and lr str ldr sxth r3 mla r4 It also increase the likelihood the store may become dead bb27 Successors according to LLVM BB
Definition: README.txt:39
llvm::AnalysisUsage::addRequired
AnalysisUsage & addRequired()
Definition: PassAnalysisSupport.h:75
llvm::AMDGPU::HSAMD::Kernel::Key::Args
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
Definition: AMDGPUMetadata.h:389
llvm::AllocaInst
an instruction to allocate memory on the stack
Definition: Instructions.h:61
InitializePasses.h
llvm::Value
LLVM Value Representation.
Definition: Value.h:75
llvm::VectorType::get
static VectorType * get(Type *ElementType, ElementCount EC)
This static method is the primary way to construct an VectorType.
Definition: Type.cpp:634
llvm::Type::isX86_AMXTy
bool isX86_AMXTy() const
Return true if this is X86 AMX.
Definition: Type.h:187
llvm::Use
A Use represents the edge between a Value definition and its users.
Definition: Use.h:44
SmallSet.h
llvm::Intrinsic::ID
unsigned ID
Definition: TargetTransformInfo.h:38