LLVM 23.0.0git
SPIRVModuleAnalysis.cpp
Go to the documentation of this file.
1//===- SPIRVModuleAnalysis.cpp - analysis of global instrs & regs - C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// The analysis collects instructions that should be output at the module level
10// and performs the global register numbering.
11//
12// The results of this analysis are used in AsmPrinter to rename registers
13// globally and to output required instructions at the module level.
14//
15//===----------------------------------------------------------------------===//
16
17// TODO: uses or report_fatal_error (which is also deprecated) /
18// ReportFatalUsageError in this file should be refactored, as per LLVM
19// best practices, to rely on the Diagnostic infrastructure.
20
21#include "SPIRVModuleAnalysis.h"
24#include "SPIRV.h"
25#include "SPIRVSubtarget.h"
26#include "SPIRVTargetMachine.h"
27#include "SPIRVUtils.h"
28#include "llvm/ADT/STLExtras.h"
31
32using namespace llvm;
33
34#define DEBUG_TYPE "spirv-module-analysis"
35
36static cl::opt<bool>
37 SPVDumpDeps("spv-dump-deps",
38 cl::desc("Dump MIR with SPIR-V dependencies info"),
39 cl::Optional, cl::init(false));
40
42 AvoidCapabilities("avoid-spirv-capabilities",
43 cl::desc("SPIR-V capabilities to avoid if there are "
44 "other options enabling a feature"),
46 cl::values(clEnumValN(SPIRV::Capability::Shader, "Shader",
47 "SPIR-V Shader capability")));
48// Use sets instead of cl::list to check "if contains" condition
53
55
56INITIALIZE_PASS(SPIRVModuleAnalysis, DEBUG_TYPE, "SPIRV module analysis", true,
57 true)
58
59// Retrieve an unsigned from an MDNode with a list of them as operands.
60static unsigned getMetadataUInt(MDNode *MdNode, unsigned OpIndex,
61 unsigned DefaultVal = 0) {
62 if (MdNode && OpIndex < MdNode->getNumOperands()) {
63 const auto &Op = MdNode->getOperand(OpIndex);
64 return mdconst::extract<ConstantInt>(Op)->getZExtValue();
65 }
66 return DefaultVal;
67}
68
70getSymbolicOperandRequirements(SPIRV::OperandCategory::OperandCategory Category,
71 unsigned i, const SPIRVSubtarget &ST,
73 // A set of capabilities to avoid if there is another option.
74 AvoidCapabilitiesSet AvoidCaps;
75 if (!ST.isShader())
76 AvoidCaps.S.insert(SPIRV::Capability::Shader);
77 else
78 AvoidCaps.S.insert(SPIRV::Capability::Kernel);
79
80 VersionTuple ReqMinVer = getSymbolicOperandMinVersion(Category, i);
81 VersionTuple ReqMaxVer = getSymbolicOperandMaxVersion(Category, i);
82 VersionTuple SPIRVVersion = ST.getSPIRVVersion();
83 bool MinVerOK = SPIRVVersion.empty() || SPIRVVersion >= ReqMinVer;
84 bool MaxVerOK =
85 ReqMaxVer.empty() || SPIRVVersion.empty() || SPIRVVersion <= ReqMaxVer;
87 ExtensionList ReqExts = getSymbolicOperandExtensions(Category, i);
88 if (ReqCaps.empty()) {
89 if (ReqExts.empty()) {
90 if (MinVerOK && MaxVerOK)
91 return {true, {}, {}, ReqMinVer, ReqMaxVer};
92 return {false, {}, {}, VersionTuple(), VersionTuple()};
93 }
94 } else if (MinVerOK && MaxVerOK) {
95 if (ReqCaps.size() == 1) {
96 auto Cap = ReqCaps[0];
97 if (Reqs.isCapabilityAvailable(Cap)) {
99 SPIRV::OperandCategory::CapabilityOperand, Cap));
100 return {true, {Cap}, std::move(ReqExts), ReqMinVer, ReqMaxVer};
101 }
102 } else {
103 // By SPIR-V specification: "If an instruction, enumerant, or other
104 // feature specifies multiple enabling capabilities, only one such
105 // capability needs to be declared to use the feature." However, one
106 // capability may be preferred over another. We use command line
107 // argument(s) and AvoidCapabilities to avoid selection of certain
108 // capabilities if there are other options.
109 CapabilityList UseCaps;
110 for (auto Cap : ReqCaps)
111 if (Reqs.isCapabilityAvailable(Cap))
112 UseCaps.push_back(Cap);
113 for (size_t i = 0, Sz = UseCaps.size(); i < Sz; ++i) {
114 auto Cap = UseCaps[i];
115 if (i == Sz - 1 || !AvoidCaps.S.contains(Cap)) {
117 SPIRV::OperandCategory::CapabilityOperand, Cap));
118 return {true, {Cap}, std::move(ReqExts), ReqMinVer, ReqMaxVer};
119 }
120 }
121 }
122 }
123 // If there are no capabilities, or we can't satisfy the version or
124 // capability requirements, use the list of extensions (if the subtarget
125 // can handle them all).
126 if (llvm::all_of(ReqExts, [&ST](const SPIRV::Extension::Extension &Ext) {
127 return ST.canUseExtension(Ext);
128 })) {
129 return {true,
130 {},
131 std::move(ReqExts),
132 VersionTuple(),
133 VersionTuple()}; // TODO: add versions to extensions.
134 }
135 return {false, {}, {}, VersionTuple(), VersionTuple()};
136}
137
138void SPIRVModuleAnalysis::setBaseInfo(const Module &M) {
139 MAI.MaxID = 0;
140 for (int i = 0; i < SPIRV::NUM_MODULE_SECTIONS; i++)
141 MAI.MS[i].clear();
142 MAI.RegisterAliasTable.clear();
143 MAI.InstrsToDelete.clear();
144 MAI.GlobalObjMap.clear();
145 MAI.GlobalVarList.clear();
146 MAI.ExtInstSetMap.clear();
147 MAI.Reqs.clear();
148 MAI.Reqs.initAvailableCapabilities(*ST);
149
150 // TODO: determine memory model and source language from the configuratoin.
151 if (auto MemModel = M.getNamedMetadata("spirv.MemoryModel")) {
152 auto MemMD = MemModel->getOperand(0);
153 MAI.Addr = static_cast<SPIRV::AddressingModel::AddressingModel>(
154 getMetadataUInt(MemMD, 0));
155 MAI.Mem =
156 static_cast<SPIRV::MemoryModel::MemoryModel>(getMetadataUInt(MemMD, 1));
157 } else {
158 // TODO: Add support for VulkanMemoryModel.
159 MAI.Mem = ST->isShader() ? SPIRV::MemoryModel::GLSL450
160 : SPIRV::MemoryModel::OpenCL;
161 if (MAI.Mem == SPIRV::MemoryModel::OpenCL) {
162 unsigned PtrSize = ST->getPointerSize();
163 MAI.Addr = PtrSize == 32 ? SPIRV::AddressingModel::Physical32
164 : PtrSize == 64 ? SPIRV::AddressingModel::Physical64
165 : SPIRV::AddressingModel::Logical;
166 } else {
167 // TODO: Add support for PhysicalStorageBufferAddress.
168 MAI.Addr = SPIRV::AddressingModel::Logical;
169 }
170 }
171 // Get the OpenCL version number from metadata.
172 // TODO: support other source languages.
173 if (auto VerNode = M.getNamedMetadata("opencl.ocl.version")) {
174 MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_C;
175 // Construct version literal in accordance with SPIRV-LLVM-Translator.
176 // TODO: support multiple OCL version metadata.
177 assert(VerNode->getNumOperands() > 0 && "Invalid SPIR");
178 auto VersionMD = VerNode->getOperand(0);
179 unsigned MajorNum = getMetadataUInt(VersionMD, 0, 2);
180 unsigned MinorNum = getMetadataUInt(VersionMD, 1);
181 unsigned RevNum = getMetadataUInt(VersionMD, 2);
182 // Prevent Major part of OpenCL version to be 0
183 MAI.SrcLangVersion =
184 (std::max(1U, MajorNum) * 100 + MinorNum) * 1000 + RevNum;
185 // When opencl.cxx.version is also present, validate compatibility
186 // and use C++ for OpenCL as source language with the C++ version.
187 if (auto *CxxVerNode = M.getNamedMetadata("opencl.cxx.version")) {
188 assert(CxxVerNode->getNumOperands() > 0 && "Invalid SPIR");
189 auto *CxxMD = CxxVerNode->getOperand(0);
190 unsigned CxxVer =
191 (getMetadataUInt(CxxMD, 0) * 100 + getMetadataUInt(CxxMD, 1)) * 1000 +
192 getMetadataUInt(CxxMD, 2);
193 if ((MAI.SrcLangVersion == 200000 && CxxVer == 100000) ||
194 (MAI.SrcLangVersion == 300000 && CxxVer == 202100000)) {
195 MAI.SrcLang = SPIRV::SourceLanguage::CPP_for_OpenCL;
196 MAI.SrcLangVersion = CxxVer;
197 } else {
199 "opencl cxx version is not compatible with opencl c version!");
200 }
201 }
202 } else {
203 // If there is no information about OpenCL version we are forced to generate
204 // OpenCL 1.0 by default for the OpenCL environment to avoid puzzling
205 // run-times with Unknown/0.0 version output. For a reference, LLVM-SPIRV
206 // Translator avoids potential issues with run-times in a similar manner.
207 if (!ST->isShader()) {
208 MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_CPP;
209 MAI.SrcLangVersion = 100000;
210 } else {
211 MAI.SrcLang = SPIRV::SourceLanguage::Unknown;
212 MAI.SrcLangVersion = 0;
213 }
214 }
215
216 if (auto ExtNode = M.getNamedMetadata("opencl.used.extensions")) {
217 for (unsigned I = 0, E = ExtNode->getNumOperands(); I != E; ++I) {
218 MDNode *MD = ExtNode->getOperand(I);
219 if (!MD || MD->getNumOperands() == 0)
220 continue;
221 for (unsigned J = 0, N = MD->getNumOperands(); J != N; ++J)
222 MAI.SrcExt.insert(cast<MDString>(MD->getOperand(J))->getString());
223 }
224 }
225
226 // Update required capabilities for this memory model, addressing model and
227 // source language.
228 MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand,
229 MAI.Mem, *ST);
230 MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::SourceLanguageOperand,
231 MAI.SrcLang, *ST);
232 MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand,
233 MAI.Addr, *ST);
234
235 if (MAI.Mem == SPIRV::MemoryModel::VulkanKHR)
236 MAI.Reqs.addExtension(SPIRV::Extension::SPV_KHR_vulkan_memory_model);
237
238 if (!ST->isShader()) {
239 // TODO: check if it's required by default.
240 MAI.ExtInstSetMap[static_cast<unsigned>(
241 SPIRV::InstructionSet::OpenCL_std)] = MAI.getNextIDRegister();
242 }
243}
244
245// Appends the signature of the decoration instructions that decorate R to
246// Signature.
247static void appendDecorationsForReg(const MachineRegisterInfo &MRI, Register R,
248 InstrSignature &Signature) {
249 for (MachineInstr &UseMI : MRI.use_instructions(R)) {
250 // We don't handle OpDecorateId because getting the register alias for the
251 // ID can cause problems, and we do not need it for now.
252 if (UseMI.getOpcode() != SPIRV::OpDecorate &&
253 UseMI.getOpcode() != SPIRV::OpMemberDecorate)
254 continue;
255
256 for (unsigned I = 0; I < UseMI.getNumOperands(); ++I) {
257 const MachineOperand &MO = UseMI.getOperand(I);
258 if (MO.isReg())
259 continue;
260 Signature.push_back(hash_value(MO));
261 }
262 }
263}
264
265// Returns a representation of an instruction as a vector of MachineOperand
266// hash values, see llvm::hash_value(const MachineOperand &MO) for details.
267// This creates a signature of the instruction with the same content
268// that MachineOperand::isIdenticalTo uses for comparison.
269static InstrSignature instrToSignature(const MachineInstr &MI,
271 bool UseDefReg) {
272 Register DefReg;
273 InstrSignature Signature{MI.getOpcode()};
274 for (unsigned i = 0; i < MI.getNumOperands(); ++i) {
275 // The only decorations that can be applied more than once to a given <id>
276 // or structure member are FuncParamAttr (38), UserSemantic (5635),
277 // CacheControlLoadINTEL (6442), and CacheControlStoreINTEL (6443). For all
278 // the rest of decorations, we will only add to the signature the Opcode,
279 // the id to which it applies, and the decoration id, disregarding any
280 // decoration flags. This will ensure that any subsequent decoration with
281 // the same id will be deemed as a duplicate. Then, at the call site, we
282 // will be able to handle duplicates in the best way.
283 unsigned Opcode = MI.getOpcode();
284 if ((Opcode == SPIRV::OpDecorate) && i >= 2) {
285 unsigned DecorationID = MI.getOperand(1).getImm();
286 if (DecorationID != SPIRV::Decoration::FuncParamAttr &&
287 DecorationID != SPIRV::Decoration::UserSemantic &&
288 DecorationID != SPIRV::Decoration::CacheControlLoadINTEL &&
289 DecorationID != SPIRV::Decoration::CacheControlStoreINTEL)
290 continue;
291 }
292 const MachineOperand &MO = MI.getOperand(i);
293 size_t h;
294 if (MO.isReg()) {
295 if (!UseDefReg && MO.isDef()) {
296 assert(!DefReg.isValid() && "Multiple def registers.");
297 DefReg = MO.getReg();
298 continue;
299 }
300 Register RegAlias = MAI.getRegisterAlias(MI.getMF(), MO.getReg());
301 if (!RegAlias.isValid()) {
302 LLVM_DEBUG({
303 dbgs() << "Unexpectedly, no global id found for the operand ";
304 MO.print(dbgs());
305 dbgs() << "\nInstruction: ";
306 MI.print(dbgs());
307 dbgs() << "\n";
308 });
309 report_fatal_error("All v-regs must have been mapped to global id's");
310 }
311 // mimic llvm::hash_value(const MachineOperand &MO)
312 h = hash_combine(MO.getType(), (unsigned)RegAlias, MO.getSubReg(),
313 MO.isDef());
314 } else {
315 h = hash_value(MO);
316 }
317 Signature.push_back(h);
318 }
319
320 if (DefReg.isValid()) {
321 // Decorations change the semantics of the current instruction. So two
322 // identical instruction with different decorations cannot be merged. That
323 // is why we add the decorations to the signature.
324 appendDecorationsForReg(MI.getMF()->getRegInfo(), DefReg, Signature);
325 }
326 return Signature;
327}
328
329bool SPIRVModuleAnalysis::isDeclSection(const MachineRegisterInfo &MRI,
330 const MachineInstr &MI) {
331 unsigned Opcode = MI.getOpcode();
332 switch (Opcode) {
333 case SPIRV::OpTypeForwardPointer:
334 // omit now, collect later
335 return false;
336 case SPIRV::OpVariable:
337 return static_cast<SPIRV::StorageClass::StorageClass>(
338 MI.getOperand(2).getImm()) != SPIRV::StorageClass::Function;
339 case SPIRV::OpFunction:
340 case SPIRV::OpFunctionParameter:
341 return true;
342 }
343 if (GR->hasConstFunPtr() && Opcode == SPIRV::OpUndef) {
344 Register DefReg = MI.getOperand(0).getReg();
345 for (MachineInstr &UseMI : MRI.use_instructions(DefReg)) {
346 if (UseMI.getOpcode() != SPIRV::OpConstantFunctionPointerINTEL)
347 continue;
348 // it's a dummy definition, FP constant refers to a function,
349 // and this is resolved in another way; let's skip this definition
350 assert(UseMI.getOperand(2).isReg() &&
351 UseMI.getOperand(2).getReg() == DefReg);
352 MAI.setSkipEmission(&MI);
353 return false;
354 }
355 }
356 return TII->isTypeDeclInstr(MI) || TII->isConstantInstr(MI) ||
357 TII->isInlineAsmDefInstr(MI);
358}
359
360// This is a special case of a function pointer refering to a possibly
361// forward function declaration. The operand is a dummy OpUndef that
362// requires a special treatment.
363void SPIRVModuleAnalysis::visitFunPtrUse(
364 Register OpReg, InstrGRegsMap &SignatureToGReg,
365 std::map<const Value *, unsigned> &GlobalToGReg, const MachineFunction *MF,
366 const MachineInstr &MI) {
367 const MachineOperand *OpFunDef =
368 GR->getFunctionDefinitionByUse(&MI.getOperand(2));
369 assert(OpFunDef && OpFunDef->isReg());
370 // find the actual function definition and number it globally in advance
371 const MachineInstr *OpDefMI = OpFunDef->getParent();
372 assert(OpDefMI && OpDefMI->getOpcode() == SPIRV::OpFunction);
373 const MachineFunction *FunDefMF = OpDefMI->getParent()->getParent();
374 const MachineRegisterInfo &FunDefMRI = FunDefMF->getRegInfo();
375 do {
376 visitDecl(FunDefMRI, SignatureToGReg, GlobalToGReg, FunDefMF, *OpDefMI);
377 OpDefMI = OpDefMI->getNextNode();
378 } while (OpDefMI && (OpDefMI->getOpcode() == SPIRV::OpFunction ||
379 OpDefMI->getOpcode() == SPIRV::OpFunctionParameter));
380 // associate the function pointer with the newly assigned global number
381 MCRegister GlobalFunDefReg =
382 MAI.getRegisterAlias(FunDefMF, OpFunDef->getReg());
383 assert(GlobalFunDefReg.isValid() &&
384 "Function definition must refer to a global register");
385 MAI.setRegisterAlias(MF, OpReg, GlobalFunDefReg);
386}
387
388// Depth first recursive traversal of dependencies. Repeated visits are guarded
389// by MAI.hasRegisterAlias().
390void SPIRVModuleAnalysis::visitDecl(
391 const MachineRegisterInfo &MRI, InstrGRegsMap &SignatureToGReg,
392 std::map<const Value *, unsigned> &GlobalToGReg, const MachineFunction *MF,
393 const MachineInstr &MI) {
394 unsigned Opcode = MI.getOpcode();
395
396 // Process each operand of the instruction to resolve dependencies
397 for (const MachineOperand &MO : MI.operands()) {
398 if (!MO.isReg() || MO.isDef())
399 continue;
400 Register OpReg = MO.getReg();
401 // Handle function pointers special case
402 if (Opcode == SPIRV::OpConstantFunctionPointerINTEL &&
403 MRI.getRegClass(OpReg) == &SPIRV::pIDRegClass) {
404 visitFunPtrUse(OpReg, SignatureToGReg, GlobalToGReg, MF, MI);
405 continue;
406 }
407 // Skip already processed instructions
408 if (MAI.hasRegisterAlias(MF, MO.getReg()))
409 continue;
410 // Recursively visit dependencies
411 if (const MachineInstr *OpDefMI = MRI.getUniqueVRegDef(OpReg)) {
412 if (isDeclSection(MRI, *OpDefMI))
413 visitDecl(MRI, SignatureToGReg, GlobalToGReg, MF, *OpDefMI);
414 continue;
415 }
416 // Handle the unexpected case of no unique definition for the SPIR-V
417 // instruction
418 LLVM_DEBUG({
419 dbgs() << "Unexpectedly, no unique definition for the operand ";
420 MO.print(dbgs());
421 dbgs() << "\nInstruction: ";
422 MI.print(dbgs());
423 dbgs() << "\n";
424 });
426 "No unique definition is found for the virtual register");
427 }
428
429 MCRegister GReg;
430 bool IsFunDef = false;
431 if (TII->isSpecConstantInstr(MI)) {
432 GReg = MAI.getNextIDRegister();
433 MAI.MS[SPIRV::MB_TypeConstVars].push_back(&MI);
434 } else if (Opcode == SPIRV::OpFunction ||
435 Opcode == SPIRV::OpFunctionParameter) {
436 GReg = handleFunctionOrParameter(MF, MI, GlobalToGReg, IsFunDef);
437 } else if (Opcode == SPIRV::OpTypeStruct ||
438 Opcode == SPIRV::OpConstantComposite) {
439 GReg = handleTypeDeclOrConstant(MI, SignatureToGReg);
440 const MachineInstr *NextInstr = MI.getNextNode();
441 while (NextInstr &&
442 ((Opcode == SPIRV::OpTypeStruct &&
443 NextInstr->getOpcode() == SPIRV::OpTypeStructContinuedINTEL) ||
444 (Opcode == SPIRV::OpConstantComposite &&
445 NextInstr->getOpcode() ==
446 SPIRV::OpConstantCompositeContinuedINTEL))) {
447 MCRegister Tmp = handleTypeDeclOrConstant(*NextInstr, SignatureToGReg);
448 MAI.setRegisterAlias(MF, NextInstr->getOperand(0).getReg(), Tmp);
449 MAI.setSkipEmission(NextInstr);
450 NextInstr = NextInstr->getNextNode();
451 }
452 } else if (TII->isTypeDeclInstr(MI) || TII->isConstantInstr(MI) ||
453 TII->isInlineAsmDefInstr(MI)) {
454 GReg = handleTypeDeclOrConstant(MI, SignatureToGReg);
455 } else if (Opcode == SPIRV::OpVariable) {
456 GReg = handleVariable(MF, MI, GlobalToGReg);
457 } else {
458 LLVM_DEBUG({
459 dbgs() << "\nInstruction: ";
460 MI.print(dbgs());
461 dbgs() << "\n";
462 });
463 llvm_unreachable("Unexpected instruction is visited");
464 }
465 MAI.setRegisterAlias(MF, MI.getOperand(0).getReg(), GReg);
466 if (!IsFunDef)
467 MAI.setSkipEmission(&MI);
468}
469
470MCRegister SPIRVModuleAnalysis::handleFunctionOrParameter(
471 const MachineFunction *MF, const MachineInstr &MI,
472 std::map<const Value *, unsigned> &GlobalToGReg, bool &IsFunDef) {
473 const Value *GObj = GR->getGlobalObject(MF, MI.getOperand(0).getReg());
474 assert(GObj && "Unregistered global definition");
475 const Function *F = dyn_cast<Function>(GObj);
476 if (!F)
477 F = dyn_cast<Argument>(GObj)->getParent();
478 assert(F && "Expected a reference to a function or an argument");
479 IsFunDef = !F->isDeclaration();
480 auto [It, Inserted] = GlobalToGReg.try_emplace(GObj);
481 if (!Inserted)
482 return It->second;
483 MCRegister GReg = MAI.getNextIDRegister();
484 It->second = GReg;
485 if (!IsFunDef)
486 MAI.MS[SPIRV::MB_ExtFuncDecls].push_back(&MI);
487 return GReg;
488}
489
491SPIRVModuleAnalysis::handleTypeDeclOrConstant(const MachineInstr &MI,
492 InstrGRegsMap &SignatureToGReg) {
493 InstrSignature MISign = instrToSignature(MI, MAI, false);
494 auto [It, Inserted] = SignatureToGReg.try_emplace(MISign);
495 if (!Inserted)
496 return It->second;
497 MCRegister GReg = MAI.getNextIDRegister();
498 It->second = GReg;
499 MAI.MS[SPIRV::MB_TypeConstVars].push_back(&MI);
500 return GReg;
501}
502
503MCRegister SPIRVModuleAnalysis::handleVariable(
504 const MachineFunction *MF, const MachineInstr &MI,
505 std::map<const Value *, unsigned> &GlobalToGReg) {
506 MAI.GlobalVarList.push_back(&MI);
507 const Value *GObj = GR->getGlobalObject(MF, MI.getOperand(0).getReg());
508 assert(GObj && "Unregistered global definition");
509 auto [It, Inserted] = GlobalToGReg.try_emplace(GObj);
510 if (!Inserted)
511 return It->second;
512 MCRegister GReg = MAI.getNextIDRegister();
513 It->second = GReg;
514 MAI.MS[SPIRV::MB_TypeConstVars].push_back(&MI);
515 if (const auto *GV = dyn_cast<GlobalVariable>(GObj))
516 MAI.GlobalObjMap[GV] = GReg;
517 return GReg;
518}
519
520void SPIRVModuleAnalysis::collectDeclarations(const Module &M) {
521 InstrGRegsMap SignatureToGReg;
522 std::map<const Value *, unsigned> GlobalToGReg;
523 for (const Function &F : M) {
524 MachineFunction *MF = MMI->getMachineFunction(F);
525 if (!MF)
526 continue;
527 const MachineRegisterInfo &MRI = MF->getRegInfo();
528 unsigned PastHeader = 0;
529 for (MachineBasicBlock &MBB : *MF) {
530 for (MachineInstr &MI : MBB) {
531 if (MI.getNumOperands() == 0)
532 continue;
533 unsigned Opcode = MI.getOpcode();
534 if (Opcode == SPIRV::OpFunction) {
535 if (PastHeader == 0) {
536 PastHeader = 1;
537 continue;
538 }
539 } else if (Opcode == SPIRV::OpFunctionParameter) {
540 if (PastHeader < 2)
541 continue;
542 } else if (PastHeader > 0) {
543 PastHeader = 2;
544 }
545
546 const MachineOperand &DefMO = MI.getOperand(0);
547 switch (Opcode) {
548 case SPIRV::OpExtension:
549 MAI.Reqs.addExtension(SPIRV::Extension::Extension(DefMO.getImm()));
550 MAI.setSkipEmission(&MI);
551 break;
552 case SPIRV::OpCapability:
553 MAI.Reqs.addCapability(SPIRV::Capability::Capability(DefMO.getImm()));
554 MAI.setSkipEmission(&MI);
555 if (PastHeader > 0)
556 PastHeader = 2;
557 break;
558 default:
559 if (DefMO.isReg() && isDeclSection(MRI, MI) &&
560 !MAI.hasRegisterAlias(MF, DefMO.getReg()))
561 visitDecl(MRI, SignatureToGReg, GlobalToGReg, MF, MI);
562 }
563 }
564 }
565 }
566}
567
568// Look for IDs declared with Import linkage, and map the corresponding function
569// to the register defining that variable (which will usually be the result of
570// an OpFunction). This lets us call externally imported functions using
571// the correct ID registers.
572void SPIRVModuleAnalysis::collectFuncNames(MachineInstr &MI,
573 const Function *F) {
574 if (MI.getOpcode() == SPIRV::OpDecorate) {
575 // If it's got Import linkage.
576 auto Dec = MI.getOperand(1).getImm();
577 if (Dec == SPIRV::Decoration::LinkageAttributes) {
578 auto Lnk = MI.getOperand(MI.getNumOperands() - 1).getImm();
579 if (Lnk == SPIRV::LinkageType::Import) {
580 // Map imported function name to function ID register.
581 const Function *ImportedFunc =
582 F->getParent()->getFunction(getStringImm(MI, 2));
583 Register Target = MI.getOperand(0).getReg();
584 MAI.GlobalObjMap[ImportedFunc] =
585 MAI.getRegisterAlias(MI.getMF(), Target);
586 }
587 }
588 } else if (MI.getOpcode() == SPIRV::OpFunction) {
589 // Record all internal OpFunction declarations.
590 Register Reg = MI.defs().begin()->getReg();
591 MCRegister GlobalReg = MAI.getRegisterAlias(MI.getMF(), Reg);
592 assert(GlobalReg.isValid());
593 MAI.GlobalObjMap[F] = GlobalReg;
594 }
595}
596
597// Collect the given instruction in the specified MS. We assume global register
598// numbering has already occurred by this point. We can directly compare reg
599// arguments when detecting duplicates.
600static void collectOtherInstr(MachineInstr &MI, SPIRV::ModuleAnalysisInfo &MAI,
602 bool Append = true) {
603 MAI.setSkipEmission(&MI);
604 InstrSignature MISign = instrToSignature(MI, MAI, true);
605 auto FoundMI = IS.insert(std::move(MISign));
606 if (!FoundMI.second) {
607 if (MI.getOpcode() == SPIRV::OpDecorate) {
608 assert(MI.getNumOperands() >= 2 &&
609 "Decoration instructions must have at least 2 operands");
610 assert(MSType == SPIRV::MB_Annotations &&
611 "Only OpDecorate instructions can be duplicates");
612 // For FPFastMathMode decoration, we need to merge the flags of the
613 // duplicate decoration with the original one, so we need to find the
614 // original instruction that has the same signature. For the rest of
615 // instructions, we will simply skip the duplicate.
616 if (MI.getOperand(1).getImm() != SPIRV::Decoration::FPFastMathMode)
617 return; // Skip duplicates of other decorations.
618
619 const SPIRV::InstrList &Decorations = MAI.MS[MSType];
620 for (const MachineInstr *OrigMI : Decorations) {
621 if (instrToSignature(*OrigMI, MAI, true) == MISign) {
622 assert(OrigMI->getNumOperands() == MI.getNumOperands() &&
623 "Original instruction must have the same number of operands");
624 assert(
625 OrigMI->getNumOperands() == 3 &&
626 "FPFastMathMode decoration must have 3 operands for OpDecorate");
627 unsigned OrigFlags = OrigMI->getOperand(2).getImm();
628 unsigned NewFlags = MI.getOperand(2).getImm();
629 if (OrigFlags == NewFlags)
630 return; // No need to merge, the flags are the same.
631
632 // Emit warning about possible conflict between flags.
633 unsigned FinalFlags = OrigFlags | NewFlags;
634 llvm::errs()
635 << "Warning: Conflicting FPFastMathMode decoration flags "
636 "in instruction: "
637 << *OrigMI << "Original flags: " << OrigFlags
638 << ", new flags: " << NewFlags
639 << ". They will be merged on a best effort basis, but not "
640 "validated. Final flags: "
641 << FinalFlags << "\n";
642 MachineInstr *OrigMINonConst = const_cast<MachineInstr *>(OrigMI);
643 MachineOperand &OrigFlagsOp = OrigMINonConst->getOperand(2);
644 OrigFlagsOp = MachineOperand::CreateImm(FinalFlags);
645 return; // Merge done, so we found a duplicate; don't add it to MAI.MS
646 }
647 }
648 assert(false && "No original instruction found for the duplicate "
649 "OpDecorate, but we found one in IS.");
650 }
651 return; // insert failed, so we found a duplicate; don't add it to MAI.MS
652 }
653 // No duplicates, so add it.
654 if (Append)
655 MAI.MS[MSType].push_back(&MI);
656 else
657 MAI.MS[MSType].insert(MAI.MS[MSType].begin(), &MI);
658}
659
660// Some global instructions make reference to function-local ID regs, so cannot
661// be correctly collected until these registers are globally numbered.
662void SPIRVModuleAnalysis::processOtherInstrs(const Module &M) {
663 InstrTraces IS;
664 for (const Function &F : M) {
665 if (F.isDeclaration())
666 continue;
667 MachineFunction *MF = MMI->getMachineFunction(F);
668 assert(MF);
669
670 for (MachineBasicBlock &MBB : *MF)
671 for (MachineInstr &MI : MBB) {
672 if (MAI.getSkipEmission(&MI))
673 continue;
674 const unsigned OpCode = MI.getOpcode();
675 if (OpCode == SPIRV::OpString) {
676 collectOtherInstr(MI, MAI, SPIRV::MB_DebugStrings, IS);
677 } else if (OpCode == SPIRV::OpExtInst && MI.getOperand(2).isImm() &&
678 MI.getOperand(2).getImm() ==
679 SPIRV::InstructionSet::
680 NonSemantic_Shader_DebugInfo_100) {
681 // TODO: This branch is dead. SPIRVNonSemanticDebugHandler emits NSDI
682 // instructions directly as MCInsts at print time; no
683 // MachineInstructions with the NSDI ext set are created anymore.
684 // Remove this block and
685 // MB_NonSemanticGlobalDI once per-function NSDI emission is confirmed
686 // not to need MIR routing.
687 MachineOperand Ins = MI.getOperand(3);
688 namespace NS = SPIRV::NonSemanticExtInst;
689 static constexpr int64_t GlobalNonSemanticDITy[] = {
690 NS::DebugSource, NS::DebugCompilationUnit, NS::DebugInfoNone,
691 NS::DebugTypeBasic, NS::DebugTypePointer};
692 bool IsGlobalDI = false;
693 for (unsigned Idx = 0; Idx < std::size(GlobalNonSemanticDITy); ++Idx)
694 IsGlobalDI |= Ins.getImm() == GlobalNonSemanticDITy[Idx];
695 if (IsGlobalDI)
696 collectOtherInstr(MI, MAI, SPIRV::MB_NonSemanticGlobalDI, IS);
697 } else if (OpCode == SPIRV::OpName || OpCode == SPIRV::OpMemberName) {
698 collectOtherInstr(MI, MAI, SPIRV::MB_DebugNames, IS);
699 } else if (OpCode == SPIRV::OpEntryPoint) {
700 collectOtherInstr(MI, MAI, SPIRV::MB_EntryPoints, IS);
701 } else if (TII->isAliasingInstr(MI)) {
702 collectOtherInstr(MI, MAI, SPIRV::MB_AliasingInsts, IS);
703 } else if (TII->isDecorationInstr(MI)) {
704 collectOtherInstr(MI, MAI, SPIRV::MB_Annotations, IS);
705 collectFuncNames(MI, &F);
706 } else if (TII->isConstantInstr(MI)) {
707 // Now OpSpecConstant*s are not in DT,
708 // but they need to be collected anyway.
709 collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars, IS);
710 } else if (OpCode == SPIRV::OpFunction) {
711 collectFuncNames(MI, &F);
712 } else if (OpCode == SPIRV::OpTypeForwardPointer) {
713 collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars, IS, false);
714 }
715 }
716 }
717}
718
719// Number registers in all functions globally from 0 onwards and store
720// the result in global register alias table. Some registers are already
721// numbered.
722void SPIRVModuleAnalysis::numberRegistersGlobally(const Module &M) {
723 for (const Function &F : M) {
724 if (F.isDeclaration())
725 continue;
726 MachineFunction *MF = MMI->getMachineFunction(F);
727 assert(MF);
728 for (MachineBasicBlock &MBB : *MF) {
729 for (MachineInstr &MI : MBB) {
730 for (MachineOperand &Op : MI.operands()) {
731 if (!Op.isReg())
732 continue;
733 Register Reg = Op.getReg();
734 if (MAI.hasRegisterAlias(MF, Reg))
735 continue;
736 MCRegister NewReg = MAI.getNextIDRegister();
737 MAI.setRegisterAlias(MF, Reg, NewReg);
738 }
739 if (MI.getOpcode() != SPIRV::OpExtInst)
740 continue;
741 auto Set = MI.getOperand(2).getImm();
742 auto [It, Inserted] = MAI.ExtInstSetMap.try_emplace(Set);
743 if (Inserted)
744 It->second = MAI.getNextIDRegister();
745 }
746 }
747 }
748}
749
750// RequirementHandler implementations.
752 SPIRV::OperandCategory::OperandCategory Category, uint32_t i,
753 const SPIRVSubtarget &ST) {
754 addRequirements(getSymbolicOperandRequirements(Category, i, ST, *this));
755}
756
757void SPIRV::RequirementHandler::recursiveAddCapabilities(
758 const CapabilityList &ToPrune) {
759 for (const auto &Cap : ToPrune) {
760 AllCaps.insert(Cap);
761 CapabilityList ImplicitDecls =
762 getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap);
763 recursiveAddCapabilities(ImplicitDecls);
764 }
765}
766
768 for (const auto &Cap : ToAdd) {
769 bool IsNewlyInserted = AllCaps.insert(Cap).second;
770 if (!IsNewlyInserted) // Don't re-add if it's already been declared.
771 continue;
772 CapabilityList ImplicitDecls =
773 getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap);
774 recursiveAddCapabilities(ImplicitDecls);
775 MinimalCaps.push_back(Cap);
776 }
777}
778
780 const SPIRV::Requirements &Req) {
781 if (!Req.IsSatisfiable)
782 report_fatal_error("Adding SPIR-V requirements this target can't satisfy.");
783
784 if (Req.Cap.has_value())
785 addCapabilities({Req.Cap.value()});
786
787 addExtensions(Req.Exts);
788
789 if (!Req.MinVer.empty()) {
790 if (!MaxVersion.empty() && Req.MinVer > MaxVersion) {
791 LLVM_DEBUG(dbgs() << "Conflicting version requirements: >= " << Req.MinVer
792 << " and <= " << MaxVersion << "\n");
793 report_fatal_error("Adding SPIR-V requirements that can't be satisfied.");
794 }
795
796 if (MinVersion.empty() || Req.MinVer > MinVersion)
797 MinVersion = Req.MinVer;
798 }
799
800 if (!Req.MaxVer.empty()) {
801 if (!MinVersion.empty() && Req.MaxVer < MinVersion) {
802 LLVM_DEBUG(dbgs() << "Conflicting version requirements: <= " << Req.MaxVer
803 << " and >= " << MinVersion << "\n");
804 report_fatal_error("Adding SPIR-V requirements that can't be satisfied.");
805 }
806
807 if (MaxVersion.empty() || Req.MaxVer < MaxVersion)
808 MaxVersion = Req.MaxVer;
809 }
810}
811
813 const SPIRVSubtarget &ST) const {
814 // Report as many errors as possible before aborting the compilation.
815 bool IsSatisfiable = true;
816 auto TargetVer = ST.getSPIRVVersion();
817
818 if (!MaxVersion.empty() && !TargetVer.empty() && MaxVersion < TargetVer) {
820 dbgs() << "Target SPIR-V version too high for required features\n"
821 << "Required max version: " << MaxVersion << " target version "
822 << TargetVer << "\n");
823 IsSatisfiable = false;
824 }
825
826 if (!MinVersion.empty() && !TargetVer.empty() && MinVersion > TargetVer) {
827 LLVM_DEBUG(dbgs() << "Target SPIR-V version too low for required features\n"
828 << "Required min version: " << MinVersion
829 << " target version " << TargetVer << "\n");
830 IsSatisfiable = false;
831 }
832
833 if (!MinVersion.empty() && !MaxVersion.empty() && MinVersion > MaxVersion) {
835 dbgs()
836 << "Version is too low for some features and too high for others.\n"
837 << "Required SPIR-V min version: " << MinVersion
838 << " required SPIR-V max version " << MaxVersion << "\n");
839 IsSatisfiable = false;
840 }
841
842 AvoidCapabilitiesSet AvoidCaps;
843 if (!ST.isShader())
844 AvoidCaps.S.insert(SPIRV::Capability::Shader);
845 else
846 AvoidCaps.S.insert(SPIRV::Capability::Kernel);
847
848 for (auto Cap : MinimalCaps) {
849 if (AvailableCaps.contains(Cap) && !AvoidCaps.S.contains(Cap))
850 continue;
851 LLVM_DEBUG(dbgs() << "Capability not supported: "
853 OperandCategory::CapabilityOperand, Cap)
854 << "\n");
855 IsSatisfiable = false;
856 }
857
858 for (auto Ext : AllExtensions) {
859 if (ST.canUseExtension(Ext))
860 continue;
861 LLVM_DEBUG(dbgs() << "Extension not supported: "
863 OperandCategory::ExtensionOperand, Ext)
864 << "\n");
865 IsSatisfiable = false;
866 }
867
868 if (!IsSatisfiable)
869 report_fatal_error("Unable to meet SPIR-V requirements for this target.");
870}
871
872// Add the given capabilities and all their implicitly defined capabilities too.
874 for (const auto Cap : ToAdd)
875 if (AvailableCaps.insert(Cap).second)
876 addAvailableCaps(getSymbolicOperandCapabilities(
877 SPIRV::OperandCategory::CapabilityOperand, Cap));
878}
879
881 const Capability::Capability ToRemove,
882 const Capability::Capability IfPresent) {
883 if (AllCaps.contains(IfPresent))
884 AllCaps.erase(ToRemove);
885}
886
887namespace llvm {
888namespace SPIRV {
889void RequirementHandler::initAvailableCapabilities(const SPIRVSubtarget &ST) {
890 // Provided by both all supported Vulkan versions and OpenCl.
891 addAvailableCaps({Capability::Shader, Capability::Linkage, Capability::Int8,
892 Capability::Int16});
893
894 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 3)))
895 addAvailableCaps({Capability::GroupNonUniform,
896 Capability::GroupNonUniformVote,
897 Capability::GroupNonUniformArithmetic,
898 Capability::GroupNonUniformBallot,
899 Capability::GroupNonUniformClustered,
900 Capability::GroupNonUniformShuffle,
901 Capability::GroupNonUniformShuffleRelative,
902 Capability::GroupNonUniformQuad});
903
904 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 6)))
905 addAvailableCaps({Capability::DotProduct, Capability::DotProductInputAll,
906 Capability::DotProductInput4x8Bit,
907 Capability::DotProductInput4x8BitPacked,
908 Capability::DemoteToHelperInvocation});
909
910 // Add capabilities enabled by extensions.
911 for (auto Extension : ST.getAllAvailableExtensions()) {
912 CapabilityList EnabledCapabilities =
914 addAvailableCaps(EnabledCapabilities);
915 }
916
917 if (!ST.isShader()) {
918 initAvailableCapabilitiesForOpenCL(ST);
919 return;
920 }
921
922 if (ST.isShader()) {
923 initAvailableCapabilitiesForVulkan(ST);
924 return;
925 }
926
927 report_fatal_error("Unimplemented environment for SPIR-V generation.");
928}
929
930void RequirementHandler::initAvailableCapabilitiesForOpenCL(
931 const SPIRVSubtarget &ST) {
932 // Add the min requirements for different OpenCL and SPIR-V versions.
933 addAvailableCaps({Capability::Addresses, Capability::Float16Buffer,
934 Capability::Kernel, Capability::Vector16,
935 Capability::Groups, Capability::GenericPointer,
936 Capability::StorageImageWriteWithoutFormat,
937 Capability::StorageImageReadWithoutFormat});
938 if (ST.hasOpenCLFullProfile())
939 addAvailableCaps({Capability::Int64, Capability::Int64Atomics});
940 if (ST.hasOpenCLImageSupport()) {
941 addAvailableCaps({Capability::ImageBasic, Capability::LiteralSampler,
942 Capability::Image1D, Capability::SampledBuffer,
943 Capability::ImageBuffer});
944 if (ST.isAtLeastOpenCLVer(VersionTuple(2, 0)))
945 addAvailableCaps({Capability::ImageReadWrite});
946 }
947 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 1)) &&
948 ST.isAtLeastOpenCLVer(VersionTuple(2, 2)))
949 addAvailableCaps({Capability::SubgroupDispatch, Capability::PipeStorage});
950 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 4)))
951 addAvailableCaps({Capability::DenormPreserve, Capability::DenormFlushToZero,
952 Capability::SignedZeroInfNanPreserve,
953 Capability::RoundingModeRTE,
954 Capability::RoundingModeRTZ});
955 // TODO: verify if this needs some checks.
956 addAvailableCaps({Capability::Float16, Capability::Float64});
957
958 // TODO: add OpenCL extensions.
959}
960
961void RequirementHandler::initAvailableCapabilitiesForVulkan(
962 const SPIRVSubtarget &ST) {
963
964 // Core in Vulkan 1.1 and earlier.
965 addAvailableCaps({Capability::Int64,
966 Capability::Float16,
967 Capability::Float64,
968 Capability::GroupNonUniform,
969 Capability::Image1D,
970 Capability::SampledBuffer,
971 Capability::ImageBuffer,
972 Capability::UniformBufferArrayDynamicIndexing,
973 Capability::SampledImageArrayDynamicIndexing,
974 Capability::StorageBufferArrayDynamicIndexing,
975 Capability::StorageImageArrayDynamicIndexing,
976 Capability::DerivativeControl,
977 Capability::MinLod,
978 Capability::ImageQuery,
979 Capability::ImageGatherExtended,
980 Capability::Addresses,
981 Capability::VulkanMemoryModelKHR,
982 Capability::StorageImageExtendedFormats,
983 Capability::StorageImageMultisample,
984 Capability::ImageMSArray});
985
986 // Became core in Vulkan 1.2
987 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 5))) {
989 {Capability::ShaderNonUniformEXT, Capability::RuntimeDescriptorArrayEXT,
990 Capability::InputAttachmentArrayDynamicIndexingEXT,
991 Capability::UniformTexelBufferArrayDynamicIndexingEXT,
992 Capability::StorageTexelBufferArrayDynamicIndexingEXT,
993 Capability::UniformBufferArrayNonUniformIndexingEXT,
994 Capability::SampledImageArrayNonUniformIndexingEXT,
995 Capability::StorageBufferArrayNonUniformIndexingEXT,
996 Capability::StorageImageArrayNonUniformIndexingEXT,
997 Capability::InputAttachmentArrayNonUniformIndexingEXT,
998 Capability::UniformTexelBufferArrayNonUniformIndexingEXT,
999 Capability::StorageTexelBufferArrayNonUniformIndexingEXT});
1000 }
1001
1002 // Became core in Vulkan 1.3
1003 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 6)))
1004 addAvailableCaps({Capability::StorageImageWriteWithoutFormat,
1005 Capability::StorageImageReadWithoutFormat});
1006}
1007
1008} // namespace SPIRV
1009} // namespace llvm
1010
1011// Add the required capabilities from a decoration instruction (including
1012// BuiltIns).
1013static void addOpDecorateReqs(const MachineInstr &MI, unsigned DecIndex,
1015 const SPIRVSubtarget &ST) {
1016 int64_t DecOp = MI.getOperand(DecIndex).getImm();
1017 auto Dec = static_cast<SPIRV::Decoration::Decoration>(DecOp);
1018 Reqs.addRequirements(getSymbolicOperandRequirements(
1019 SPIRV::OperandCategory::DecorationOperand, Dec, ST, Reqs));
1020
1021 if (Dec == SPIRV::Decoration::BuiltIn) {
1022 int64_t BuiltInOp = MI.getOperand(DecIndex + 1).getImm();
1023 auto BuiltIn = static_cast<SPIRV::BuiltIn::BuiltIn>(BuiltInOp);
1024 Reqs.addRequirements(getSymbolicOperandRequirements(
1025 SPIRV::OperandCategory::BuiltInOperand, BuiltIn, ST, Reqs));
1026 } else if (Dec == SPIRV::Decoration::LinkageAttributes) {
1027 int64_t LinkageOp = MI.getOperand(MI.getNumOperands() - 1).getImm();
1028 SPIRV::LinkageType::LinkageType LnkType =
1029 static_cast<SPIRV::LinkageType::LinkageType>(LinkageOp);
1030 if (LnkType == SPIRV::LinkageType::LinkOnceODR)
1031 Reqs.addExtension(SPIRV::Extension::SPV_KHR_linkonce_odr);
1032 else if (LnkType == SPIRV::LinkageType::Weak)
1033 Reqs.addExtension(SPIRV::Extension::SPV_AMD_weak_linkage);
1034 } else if (Dec == SPIRV::Decoration::CacheControlLoadINTEL ||
1035 Dec == SPIRV::Decoration::CacheControlStoreINTEL) {
1036 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_cache_controls);
1037 } else if (Dec == SPIRV::Decoration::HostAccessINTEL) {
1038 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_global_variable_host_access);
1039 } else if (Dec == SPIRV::Decoration::InitModeINTEL ||
1040 Dec == SPIRV::Decoration::ImplementInRegisterMapINTEL) {
1041 Reqs.addExtension(
1042 SPIRV::Extension::SPV_INTEL_global_variable_fpga_decorations);
1043 } else if (Dec == SPIRV::Decoration::NonUniformEXT) {
1044 Reqs.addRequirements(SPIRV::Capability::ShaderNonUniformEXT);
1045 } else if (Dec == SPIRV::Decoration::FPMaxErrorDecorationINTEL) {
1046 Reqs.addRequirements(SPIRV::Capability::FPMaxErrorINTEL);
1047 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_fp_max_error);
1048 } else if (Dec == SPIRV::Decoration::FPFastMathMode) {
1049 if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)) {
1050 Reqs.addRequirements(SPIRV::Capability::FloatControls2);
1051 Reqs.addExtension(SPIRV::Extension::SPV_KHR_float_controls2);
1052 }
1053 }
1054}
1055
1056// Add requirements for image handling.
1057static void addOpTypeImageReqs(const MachineInstr &MI,
1059 const SPIRVSubtarget &ST) {
1060 assert(MI.getNumOperands() >= 8 && "Insufficient operands for OpTypeImage");
1061 // The operand indices used here are based on the OpTypeImage layout, which
1062 // the MachineInstr follows as well.
1063 int64_t ImgFormatOp = MI.getOperand(7).getImm();
1064 auto ImgFormat = static_cast<SPIRV::ImageFormat::ImageFormat>(ImgFormatOp);
1065 Reqs.getAndAddRequirements(SPIRV::OperandCategory::ImageFormatOperand,
1066 ImgFormat, ST);
1067
1068 bool IsArrayed = MI.getOperand(4).getImm() == 1;
1069 bool IsMultisampled = MI.getOperand(5).getImm() == 1;
1070 bool NoSampler = MI.getOperand(6).getImm() == 2;
1071 // Add dimension requirements.
1072 assert(MI.getOperand(2).isImm());
1073 switch (MI.getOperand(2).getImm()) {
1074 case SPIRV::Dim::DIM_1D:
1075 Reqs.addRequirements(NoSampler ? SPIRV::Capability::Image1D
1076 : SPIRV::Capability::Sampled1D);
1077 break;
1078 case SPIRV::Dim::DIM_2D:
1079 if (IsMultisampled && NoSampler)
1080 Reqs.addRequirements(SPIRV::Capability::StorageImageMultisample);
1081 if (IsMultisampled && IsArrayed)
1082 Reqs.addRequirements(SPIRV::Capability::ImageMSArray);
1083 break;
1084 case SPIRV::Dim::DIM_3D:
1085 break;
1086 case SPIRV::Dim::DIM_Cube:
1087 Reqs.addRequirements(SPIRV::Capability::Shader);
1088 if (IsArrayed)
1089 Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageCubeArray
1090 : SPIRV::Capability::SampledCubeArray);
1091 break;
1092 case SPIRV::Dim::DIM_Rect:
1093 Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageRect
1094 : SPIRV::Capability::SampledRect);
1095 break;
1096 case SPIRV::Dim::DIM_Buffer:
1097 Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageBuffer
1098 : SPIRV::Capability::SampledBuffer);
1099 break;
1100 case SPIRV::Dim::DIM_SubpassData:
1101 Reqs.addRequirements(SPIRV::Capability::InputAttachment);
1102 break;
1103 }
1104
1105 // Check if the sampled type is a 64-bit integer, which requires
1106 // Int64ImageEXT capability.
1107 assert(MI.getOperand(1).isReg());
1108 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1109 SPIRVTypeInst SampledTypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
1110 if (SampledTypeDef.isTypeIntN(64)) {
1111 Reqs.addCapability(SPIRV::Capability::Int64ImageEXT);
1112 Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_image_int64);
1113 }
1114
1115 // Has optional access qualifier.
1116 if (!ST.isShader()) {
1117 if (MI.getNumOperands() > 8 &&
1118 MI.getOperand(8).getImm() == SPIRV::AccessQualifier::ReadWrite)
1119 Reqs.addRequirements(SPIRV::Capability::ImageReadWrite);
1120 else
1121 Reqs.addRequirements(SPIRV::Capability::ImageBasic);
1122 }
1123}
1124
1125static bool isBFloat16Type(SPIRVTypeInst TypeDef) {
1126 return TypeDef && TypeDef->getNumOperands() == 3 &&
1127 TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
1128 TypeDef->getOperand(1).getImm() == 16 &&
1129 TypeDef->getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR;
1130}
1131
1132// Add requirements for handling atomic float instructions
1133#define ATOM_FLT_REQ_EXT_MSG(ExtName) \
1134 "The atomic float instruction requires the following SPIR-V " \
1135 "extension: SPV_EXT_shader_atomic_float" ExtName
1136static void AddAtomicVectorFloatRequirements(const MachineInstr &MI,
1138 const SPIRVSubtarget &ST) {
1139 SPIRVTypeInst VecTypeDef =
1140 MI.getMF()->getRegInfo().getVRegDef(MI.getOperand(1).getReg());
1141
1142 const unsigned Rank = VecTypeDef->getOperand(2).getImm();
1143 if (Rank != 2 && Rank != 4)
1144 reportFatalUsageError("Result type of an atomic vector float instruction "
1145 "must be a 2-component or 4 component vector");
1146
1147 SPIRVTypeInst EltTypeDef =
1148 MI.getMF()->getRegInfo().getVRegDef(VecTypeDef->getOperand(1).getReg());
1149
1150 if (EltTypeDef->getOpcode() != SPIRV::OpTypeFloat ||
1151 EltTypeDef->getOperand(1).getImm() != 16)
1153 "The element type for the result type of an atomic vector float "
1154 "instruction must be a 16-bit floating-point scalar");
1155
1156 if (isBFloat16Type(EltTypeDef))
1158 "The element type for the result type of an atomic vector float "
1159 "instruction cannot be a bfloat16 scalar");
1160 if (!ST.canUseExtension(SPIRV::Extension::SPV_NV_shader_atomic_fp16_vector))
1162 "The atomic float16 vector instruction requires the following SPIR-V "
1163 "extension: SPV_NV_shader_atomic_fp16_vector");
1164
1165 Reqs.addExtension(SPIRV::Extension::SPV_NV_shader_atomic_fp16_vector);
1166 Reqs.addCapability(SPIRV::Capability::AtomicFloat16VectorNV);
1167}
1168
1169static void AddAtomicFloatRequirements(const MachineInstr &MI,
1171 const SPIRVSubtarget &ST) {
1172 assert(MI.getOperand(1).isReg() &&
1173 "Expect register operand in atomic float instruction");
1174 Register TypeReg = MI.getOperand(1).getReg();
1175 SPIRVTypeInst TypeDef = MI.getMF()->getRegInfo().getVRegDef(TypeReg);
1176
1177 if (TypeDef->getOpcode() == SPIRV::OpTypeVector)
1178 return AddAtomicVectorFloatRequirements(MI, Reqs, ST);
1179
1180 if (TypeDef->getOpcode() != SPIRV::OpTypeFloat)
1181 report_fatal_error("Result type of an atomic float instruction must be a "
1182 "floating-point type scalar");
1183
1184 unsigned BitWidth = TypeDef->getOperand(1).getImm();
1185 unsigned Op = MI.getOpcode();
1186 if (Op == SPIRV::OpAtomicFAddEXT) {
1187 if (!ST.canUseExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add))
1189 Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add);
1190 switch (BitWidth) {
1191 case 16:
1192 if (isBFloat16Type(TypeDef)) {
1193 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics))
1195 "The atomic bfloat16 instruction requires the following SPIR-V "
1196 "extension: SPV_INTEL_16bit_atomics",
1197 false);
1198 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics);
1199 Reqs.addCapability(SPIRV::Capability::AtomicBFloat16AddINTEL);
1200 } else {
1201 if (!ST.canUseExtension(
1202 SPIRV::Extension::SPV_EXT_shader_atomic_float16_add))
1203 report_fatal_error(ATOM_FLT_REQ_EXT_MSG("16_add"), false);
1204 Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float16_add);
1205 Reqs.addCapability(SPIRV::Capability::AtomicFloat16AddEXT);
1206 }
1207 break;
1208 case 32:
1209 Reqs.addCapability(SPIRV::Capability::AtomicFloat32AddEXT);
1210 break;
1211 case 64:
1212 Reqs.addCapability(SPIRV::Capability::AtomicFloat64AddEXT);
1213 break;
1214 default:
1216 "Unexpected floating-point type width in atomic float instruction");
1217 }
1218 } else {
1219 if (!ST.canUseExtension(
1220 SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max))
1221 report_fatal_error(ATOM_FLT_REQ_EXT_MSG("_min_max"), false);
1222 Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max);
1223 switch (BitWidth) {
1224 case 16:
1225 if (isBFloat16Type(TypeDef)) {
1226 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics))
1228 "The atomic bfloat16 instruction requires the following SPIR-V "
1229 "extension: SPV_INTEL_16bit_atomics",
1230 false);
1231 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics);
1232 Reqs.addCapability(SPIRV::Capability::AtomicBFloat16MinMaxINTEL);
1233 } else {
1234 Reqs.addCapability(SPIRV::Capability::AtomicFloat16MinMaxEXT);
1235 }
1236 break;
1237 case 32:
1238 Reqs.addCapability(SPIRV::Capability::AtomicFloat32MinMaxEXT);
1239 break;
1240 case 64:
1241 Reqs.addCapability(SPIRV::Capability::AtomicFloat64MinMaxEXT);
1242 break;
1243 default:
1245 "Unexpected floating-point type width in atomic float instruction");
1246 }
1247 }
1248}
1249
1250bool isUniformTexelBuffer(MachineInstr *ImageInst) {
1251 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1252 return false;
1253 uint32_t Dim = ImageInst->getOperand(2).getImm();
1254 uint32_t Sampled = ImageInst->getOperand(6).getImm();
1255 return Dim == SPIRV::Dim::DIM_Buffer && Sampled == 1;
1256}
1257
1258bool isStorageTexelBuffer(MachineInstr *ImageInst) {
1259 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1260 return false;
1261 uint32_t Dim = ImageInst->getOperand(2).getImm();
1262 uint32_t Sampled = ImageInst->getOperand(6).getImm();
1263 return Dim == SPIRV::Dim::DIM_Buffer && Sampled == 2;
1264}
1265
1266bool isSampledImage(MachineInstr *ImageInst) {
1267 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1268 return false;
1269 uint32_t Dim = ImageInst->getOperand(2).getImm();
1270 uint32_t Sampled = ImageInst->getOperand(6).getImm();
1271 return Dim != SPIRV::Dim::DIM_Buffer && Sampled == 1;
1272}
1273
1274bool isInputAttachment(MachineInstr *ImageInst) {
1275 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1276 return false;
1277 uint32_t Dim = ImageInst->getOperand(2).getImm();
1278 uint32_t Sampled = ImageInst->getOperand(6).getImm();
1279 return Dim == SPIRV::Dim::DIM_SubpassData && Sampled == 2;
1280}
1281
1282bool isStorageImage(MachineInstr *ImageInst) {
1283 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1284 return false;
1285 uint32_t Dim = ImageInst->getOperand(2).getImm();
1286 uint32_t Sampled = ImageInst->getOperand(6).getImm();
1287 return Dim != SPIRV::Dim::DIM_Buffer && Sampled == 2;
1288}
1289
1290bool isCombinedImageSampler(MachineInstr *SampledImageInst) {
1291 if (SampledImageInst->getOpcode() != SPIRV::OpTypeSampledImage)
1292 return false;
1293
1294 const MachineRegisterInfo &MRI = SampledImageInst->getMF()->getRegInfo();
1295 Register ImageReg = SampledImageInst->getOperand(1).getReg();
1296 auto *ImageInst = MRI.getUniqueVRegDef(ImageReg);
1297 return isSampledImage(ImageInst);
1298}
1299
1300bool hasNonUniformDecoration(Register Reg, const MachineRegisterInfo &MRI) {
1301 for (const auto &MI : MRI.reg_instructions(Reg)) {
1302 if (MI.getOpcode() != SPIRV::OpDecorate)
1303 continue;
1304
1305 uint32_t Dec = MI.getOperand(1).getImm();
1306 if (Dec == SPIRV::Decoration::NonUniformEXT)
1307 return true;
1308 }
1309 return false;
1310}
1311
1312void addOpAccessChainReqs(const MachineInstr &Instr,
1314 const SPIRVSubtarget &Subtarget) {
1315 const MachineRegisterInfo &MRI = Instr.getMF()->getRegInfo();
1316 // Get the result type. If it is an image type, then the shader uses
1317 // descriptor indexing. The appropriate capabilities will be added based
1318 // on the specifics of the image.
1319 Register ResTypeReg = Instr.getOperand(1).getReg();
1320 MachineInstr *ResTypeInst = MRI.getUniqueVRegDef(ResTypeReg);
1321
1322 assert(ResTypeInst->getOpcode() == SPIRV::OpTypePointer);
1323 uint32_t StorageClass = ResTypeInst->getOperand(1).getImm();
1324 if (StorageClass != SPIRV::StorageClass::StorageClass::UniformConstant &&
1325 StorageClass != SPIRV::StorageClass::StorageClass::Uniform &&
1326 StorageClass != SPIRV::StorageClass::StorageClass::StorageBuffer) {
1327 return;
1328 }
1329
1330 bool IsNonUniform =
1331 hasNonUniformDecoration(Instr.getOperand(0).getReg(), MRI);
1332
1333 auto FirstIndexReg = Instr.getOperand(3).getReg();
1334 bool FirstIndexIsConstant =
1335 Subtarget.getInstrInfo()->isConstantInstr(*MRI.getVRegDef(FirstIndexReg));
1336
1337 if (StorageClass == SPIRV::StorageClass::StorageClass::StorageBuffer) {
1338 if (IsNonUniform)
1339 Handler.addRequirements(
1340 SPIRV::Capability::StorageBufferArrayNonUniformIndexingEXT);
1341 else if (!FirstIndexIsConstant)
1342 Handler.addRequirements(
1343 SPIRV::Capability::StorageBufferArrayDynamicIndexing);
1344 return;
1345 }
1346
1347 Register PointeeTypeReg = ResTypeInst->getOperand(2).getReg();
1348 MachineInstr *PointeeType = MRI.getUniqueVRegDef(PointeeTypeReg);
1349 if (PointeeType->getOpcode() != SPIRV::OpTypeImage &&
1350 PointeeType->getOpcode() != SPIRV::OpTypeSampledImage &&
1351 PointeeType->getOpcode() != SPIRV::OpTypeSampler) {
1352 return;
1353 }
1354
1355 if (isUniformTexelBuffer(PointeeType)) {
1356 if (IsNonUniform)
1357 Handler.addRequirements(
1358 SPIRV::Capability::UniformTexelBufferArrayNonUniformIndexingEXT);
1359 else if (!FirstIndexIsConstant)
1360 Handler.addRequirements(
1361 SPIRV::Capability::UniformTexelBufferArrayDynamicIndexingEXT);
1362 } else if (isInputAttachment(PointeeType)) {
1363 if (IsNonUniform)
1364 Handler.addRequirements(
1365 SPIRV::Capability::InputAttachmentArrayNonUniformIndexingEXT);
1366 else if (!FirstIndexIsConstant)
1367 Handler.addRequirements(
1368 SPIRV::Capability::InputAttachmentArrayDynamicIndexingEXT);
1369 } else if (isStorageTexelBuffer(PointeeType)) {
1370 if (IsNonUniform)
1371 Handler.addRequirements(
1372 SPIRV::Capability::StorageTexelBufferArrayNonUniformIndexingEXT);
1373 else if (!FirstIndexIsConstant)
1374 Handler.addRequirements(
1375 SPIRV::Capability::StorageTexelBufferArrayDynamicIndexingEXT);
1376 } else if (isSampledImage(PointeeType) ||
1377 isCombinedImageSampler(PointeeType) ||
1378 PointeeType->getOpcode() == SPIRV::OpTypeSampler) {
1379 if (IsNonUniform)
1380 Handler.addRequirements(
1381 SPIRV::Capability::SampledImageArrayNonUniformIndexingEXT);
1382 else if (!FirstIndexIsConstant)
1383 Handler.addRequirements(
1384 SPIRV::Capability::SampledImageArrayDynamicIndexing);
1385 } else if (isStorageImage(PointeeType)) {
1386 if (IsNonUniform)
1387 Handler.addRequirements(
1388 SPIRV::Capability::StorageImageArrayNonUniformIndexingEXT);
1389 else if (!FirstIndexIsConstant)
1390 Handler.addRequirements(
1391 SPIRV::Capability::StorageImageArrayDynamicIndexing);
1392 }
1393}
1394
1395static bool isImageTypeWithUnknownFormat(SPIRVTypeInst TypeInst) {
1396 if (TypeInst->getOpcode() != SPIRV::OpTypeImage)
1397 return false;
1398 assert(TypeInst->getOperand(7).isImm() && "The image format must be an imm.");
1399 return TypeInst->getOperand(7).getImm() == 0;
1400}
1401
1402static void AddDotProductRequirements(const MachineInstr &MI,
1404 const SPIRVSubtarget &ST) {
1405 if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_integer_dot_product))
1406 Reqs.addExtension(SPIRV::Extension::SPV_KHR_integer_dot_product);
1407 Reqs.addCapability(SPIRV::Capability::DotProduct);
1408
1409 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1410 assert(MI.getOperand(2).isReg() && "Unexpected operand in dot");
1411 // We do not consider what the previous instruction is. This is just used
1412 // to get the input register and to check the type.
1413 const MachineInstr *Input = MRI.getVRegDef(MI.getOperand(2).getReg());
1414 assert(Input->getOperand(1).isReg() && "Unexpected operand in dot input");
1415 Register InputReg = Input->getOperand(1).getReg();
1416
1417 SPIRVTypeInst TypeDef = MRI.getVRegDef(InputReg);
1418 if (TypeDef->getOpcode() == SPIRV::OpTypeInt) {
1419 assert(TypeDef->getOperand(1).getImm() == 32);
1420 Reqs.addCapability(SPIRV::Capability::DotProductInput4x8BitPacked);
1421 } else if (TypeDef->getOpcode() == SPIRV::OpTypeVector) {
1422 SPIRVTypeInst ScalarTypeDef =
1423 MRI.getVRegDef(TypeDef->getOperand(1).getReg());
1424 assert(ScalarTypeDef->getOpcode() == SPIRV::OpTypeInt);
1425 if (ScalarTypeDef->getOperand(1).getImm() == 8) {
1426 assert(TypeDef->getOperand(2).getImm() == 4 &&
1427 "Dot operand of 8-bit integer type requires 4 components");
1428 Reqs.addCapability(SPIRV::Capability::DotProductInput4x8Bit);
1429 } else {
1430 Reqs.addCapability(SPIRV::Capability::DotProductInputAll);
1431 }
1432 }
1433}
1434
1435void addPrintfRequirements(const MachineInstr &MI,
1437 const SPIRVSubtarget &ST) {
1438 SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
1439 SPIRVTypeInst PtrType = GR->getSPIRVTypeForVReg(MI.getOperand(4).getReg());
1440 if (PtrType) {
1441 MachineOperand ASOp = PtrType->getOperand(1);
1442 if (ASOp.isImm()) {
1443 unsigned AddrSpace = ASOp.getImm();
1444 if (AddrSpace != SPIRV::StorageClass::UniformConstant) {
1445 if (!ST.canUseExtension(
1447 SPV_EXT_relaxed_printf_string_address_space)) {
1448 report_fatal_error("SPV_EXT_relaxed_printf_string_address_space is "
1449 "required because printf uses a format string not "
1450 "in constant address space.",
1451 false);
1452 }
1453 Reqs.addExtension(
1454 SPIRV::Extension::SPV_EXT_relaxed_printf_string_address_space);
1455 }
1456 }
1457 }
1458}
1459
1460static void addImageOperandReqs(const MachineInstr &MI,
1462 const SPIRVSubtarget &ST, unsigned OpIdx) {
1463 if (MI.getNumOperands() <= OpIdx)
1464 return;
1465 uint32_t Mask = MI.getOperand(OpIdx).getImm();
1466 for (uint32_t I = 0; I < 32; ++I)
1467 if (Mask & (1U << I))
1468 Reqs.getAndAddRequirements(SPIRV::OperandCategory::ImageOperandOperand,
1469 1U << I, ST);
1470}
1471
1472void addInstrRequirements(const MachineInstr &MI,
1474 const SPIRVSubtarget &ST) {
1475 SPIRV::RequirementHandler &Reqs = MAI.Reqs;
1476 unsigned Op = MI.getOpcode();
1477 switch (Op) {
1478 case SPIRV::OpMemoryModel: {
1479 int64_t Addr = MI.getOperand(0).getImm();
1480 Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand,
1481 Addr, ST);
1482 int64_t Mem = MI.getOperand(1).getImm();
1483 Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand, Mem,
1484 ST);
1485 break;
1486 }
1487 case SPIRV::OpEntryPoint: {
1488 int64_t Exe = MI.getOperand(0).getImm();
1489 Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModelOperand,
1490 Exe, ST);
1491 break;
1492 }
1493 case SPIRV::OpExecutionMode:
1494 case SPIRV::OpExecutionModeId: {
1495 int64_t Exe = MI.getOperand(1).getImm();
1496 Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModeOperand,
1497 Exe, ST);
1498 break;
1499 }
1500 case SPIRV::OpTypeMatrix:
1501 Reqs.addCapability(SPIRV::Capability::Matrix);
1502 break;
1503 case SPIRV::OpTypeInt: {
1504 unsigned BitWidth = MI.getOperand(1).getImm();
1505 if (BitWidth == 64)
1506 Reqs.addCapability(SPIRV::Capability::Int64);
1507 else if (BitWidth == 16)
1508 Reqs.addCapability(SPIRV::Capability::Int16);
1509 else if (BitWidth == 8)
1510 Reqs.addCapability(SPIRV::Capability::Int8);
1511 else if (BitWidth == 4 &&
1512 ST.canUseExtension(SPIRV::Extension::SPV_INTEL_int4)) {
1513 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_int4);
1514 Reqs.addCapability(SPIRV::Capability::Int4TypeINTEL);
1515 } else if (BitWidth != 32) {
1516 if (!ST.canUseExtension(
1517 SPIRV::Extension::SPV_ALTERA_arbitrary_precision_integers))
1519 "OpTypeInt type with a width other than 8, 16, 32 or 64 bits "
1520 "requires the following SPIR-V extension: "
1521 "SPV_ALTERA_arbitrary_precision_integers");
1522 Reqs.addExtension(
1523 SPIRV::Extension::SPV_ALTERA_arbitrary_precision_integers);
1524 Reqs.addCapability(SPIRV::Capability::ArbitraryPrecisionIntegersALTERA);
1525 }
1526 break;
1527 }
1528 case SPIRV::OpDot: {
1529 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1530 SPIRVTypeInst TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
1531 if (isBFloat16Type(TypeDef))
1532 Reqs.addCapability(SPIRV::Capability::BFloat16DotProductKHR);
1533 break;
1534 }
1535 case SPIRV::OpTypeFloat: {
1536 unsigned BitWidth = MI.getOperand(1).getImm();
1537 if (BitWidth == 64)
1538 Reqs.addCapability(SPIRV::Capability::Float64);
1539 else if (BitWidth == 16) {
1540 if (isBFloat16Type(&MI)) {
1541 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_bfloat16))
1542 report_fatal_error("OpTypeFloat type with bfloat requires the "
1543 "following SPIR-V extension: SPV_KHR_bfloat16",
1544 false);
1545 Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
1546 Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
1547 } else {
1548 Reqs.addCapability(SPIRV::Capability::Float16);
1549 }
1550 }
1551 break;
1552 }
1553 case SPIRV::OpTypeVector: {
1554 unsigned NumComponents = MI.getOperand(2).getImm();
1555 if (NumComponents == 8 || NumComponents == 16)
1556 Reqs.addCapability(SPIRV::Capability::Vector16);
1557
1558 assert(MI.getOperand(1).isReg());
1559 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1560 SPIRVTypeInst ElemTypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
1561 if (ElemTypeDef->getOpcode() == SPIRV::OpTypePointer &&
1562 ST.canUseExtension(SPIRV::Extension::SPV_INTEL_masked_gather_scatter)) {
1563 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_masked_gather_scatter);
1564 Reqs.addCapability(SPIRV::Capability::MaskedGatherScatterINTEL);
1565 }
1566 break;
1567 }
1568 case SPIRV::OpTypePointer: {
1569 auto SC = MI.getOperand(1).getImm();
1570 Reqs.getAndAddRequirements(SPIRV::OperandCategory::StorageClassOperand, SC,
1571 ST);
1572 // If it's a type of pointer to float16 targeting OpenCL, add Float16Buffer
1573 // capability.
1574 if (ST.isShader())
1575 break;
1576 assert(MI.getOperand(2).isReg());
1577 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1578 SPIRVTypeInst TypeDef = MRI.getVRegDef(MI.getOperand(2).getReg());
1579 if ((TypeDef->getNumOperands() == 2) &&
1580 (TypeDef->getOpcode() == SPIRV::OpTypeFloat) &&
1581 (TypeDef->getOperand(1).getImm() == 16))
1582 Reqs.addCapability(SPIRV::Capability::Float16Buffer);
1583 break;
1584 }
1585 case SPIRV::OpExtInst: {
1586 if (MI.getOperand(2).getImm() ==
1587 static_cast<int64_t>(
1588 SPIRV::InstructionSet::NonSemantic_Shader_DebugInfo_100)) {
1589 Reqs.addExtension(SPIRV::Extension::SPV_KHR_non_semantic_info);
1590 break;
1591 }
1592 if (MI.getOperand(3).getImm() ==
1593 static_cast<int64_t>(SPIRV::OpenCLExtInst::printf)) {
1594 addPrintfRequirements(MI, Reqs, ST);
1595 break;
1596 }
1597 // TODO: handle bfloat16 extended instructions when
1598 // SPV_INTEL_bfloat16_arithmetic is enabled.
1599 break;
1600 }
1601 case SPIRV::OpAliasDomainDeclINTEL:
1602 case SPIRV::OpAliasScopeDeclINTEL:
1603 case SPIRV::OpAliasScopeListDeclINTEL: {
1604 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_memory_access_aliasing);
1605 Reqs.addCapability(SPIRV::Capability::MemoryAccessAliasingINTEL);
1606 break;
1607 }
1608 case SPIRV::OpBitReverse:
1609 case SPIRV::OpBitFieldInsert:
1610 case SPIRV::OpBitFieldSExtract:
1611 case SPIRV::OpBitFieldUExtract:
1612 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions)) {
1613 Reqs.addCapability(SPIRV::Capability::Shader);
1614 break;
1615 }
1616 Reqs.addExtension(SPIRV::Extension::SPV_KHR_bit_instructions);
1617 Reqs.addCapability(SPIRV::Capability::BitInstructions);
1618 break;
1619 case SPIRV::OpTypeRuntimeArray:
1620 Reqs.addCapability(SPIRV::Capability::Shader);
1621 break;
1622 case SPIRV::OpTypeOpaque:
1623 case SPIRV::OpTypeEvent:
1624 Reqs.addCapability(SPIRV::Capability::Kernel);
1625 break;
1626 case SPIRV::OpTypePipe:
1627 case SPIRV::OpTypeReserveId:
1628 Reqs.addCapability(SPIRV::Capability::Pipes);
1629 break;
1630 case SPIRV::OpTypeDeviceEvent:
1631 case SPIRV::OpTypeQueue:
1632 case SPIRV::OpBuildNDRange:
1633 Reqs.addCapability(SPIRV::Capability::DeviceEnqueue);
1634 break;
1635 case SPIRV::OpDecorate:
1636 case SPIRV::OpDecorateId:
1637 case SPIRV::OpDecorateString:
1638 addOpDecorateReqs(MI, 1, Reqs, ST);
1639 break;
1640 case SPIRV::OpMemberDecorate:
1641 case SPIRV::OpMemberDecorateString:
1642 addOpDecorateReqs(MI, 2, Reqs, ST);
1643 break;
1644 case SPIRV::OpInBoundsPtrAccessChain:
1645 Reqs.addCapability(SPIRV::Capability::Addresses);
1646 break;
1647 case SPIRV::OpConstantSampler:
1648 Reqs.addCapability(SPIRV::Capability::LiteralSampler);
1649 break;
1650 case SPIRV::OpInBoundsAccessChain:
1651 case SPIRV::OpAccessChain:
1652 addOpAccessChainReqs(MI, Reqs, ST);
1653 break;
1654 case SPIRV::OpTypeImage:
1655 addOpTypeImageReqs(MI, Reqs, ST);
1656 break;
1657 case SPIRV::OpTypeSampler:
1658 if (!ST.isShader()) {
1659 Reqs.addCapability(SPIRV::Capability::ImageBasic);
1660 }
1661 break;
1662 case SPIRV::OpTypeForwardPointer:
1663 // TODO: check if it's OpenCL's kernel.
1664 Reqs.addCapability(SPIRV::Capability::Addresses);
1665 break;
1666 case SPIRV::OpAtomicFlagTestAndSet:
1667 case SPIRV::OpAtomicLoad:
1668 case SPIRV::OpAtomicStore:
1669 case SPIRV::OpAtomicExchange:
1670 case SPIRV::OpAtomicCompareExchange:
1671 case SPIRV::OpAtomicIIncrement:
1672 case SPIRV::OpAtomicIDecrement:
1673 case SPIRV::OpAtomicIAdd:
1674 case SPIRV::OpAtomicISub:
1675 case SPIRV::OpAtomicUMin:
1676 case SPIRV::OpAtomicUMax:
1677 case SPIRV::OpAtomicSMin:
1678 case SPIRV::OpAtomicSMax:
1679 case SPIRV::OpAtomicAnd:
1680 case SPIRV::OpAtomicOr:
1681 case SPIRV::OpAtomicXor: {
1682 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1683 const MachineInstr *InstrPtr = &MI;
1684 if (Op == SPIRV::OpAtomicStore) {
1685 assert(MI.getOperand(3).isReg());
1686 InstrPtr = MRI.getVRegDef(MI.getOperand(3).getReg());
1687 assert(InstrPtr && "Unexpected type instruction for OpAtomicStore");
1688 }
1689 assert(InstrPtr->getOperand(1).isReg() && "Unexpected operand in atomic");
1690 Register TypeReg = InstrPtr->getOperand(1).getReg();
1691 SPIRVTypeInst TypeDef = MRI.getVRegDef(TypeReg);
1692
1693 if (TypeDef->getOpcode() == SPIRV::OpTypeInt) {
1694 unsigned BitWidth = TypeDef->getOperand(1).getImm();
1695 if (BitWidth == 64)
1696 Reqs.addCapability(SPIRV::Capability::Int64Atomics);
1697 else if (BitWidth == 16) {
1698 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics))
1700 "16-bit integer atomic operations require the following SPIR-V "
1701 "extension: SPV_INTEL_16bit_atomics",
1702 false);
1703 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics);
1704 switch (Op) {
1705 case SPIRV::OpAtomicLoad:
1706 case SPIRV::OpAtomicStore:
1707 case SPIRV::OpAtomicExchange:
1708 case SPIRV::OpAtomicCompareExchange:
1709 case SPIRV::OpAtomicCompareExchangeWeak:
1710 Reqs.addCapability(
1711 SPIRV::Capability::AtomicInt16CompareExchangeINTEL);
1712 break;
1713 default:
1714 Reqs.addCapability(SPIRV::Capability::Int16AtomicsINTEL);
1715 break;
1716 }
1717 }
1718 } else if (isBFloat16Type(TypeDef)) {
1719 if (is_contained({SPIRV::OpAtomicLoad, SPIRV::OpAtomicStore,
1720 SPIRV::OpAtomicExchange},
1721 Op)) {
1722 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics))
1724 "The atomic bfloat16 instruction requires the following SPIR-V "
1725 "extension: SPV_INTEL_16bit_atomics",
1726 false);
1727 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics);
1728 Reqs.addCapability(SPIRV::Capability::AtomicBFloat16LoadStoreINTEL);
1729 }
1730 }
1731 break;
1732 }
1733 case SPIRV::OpGroupNonUniformIAdd:
1734 case SPIRV::OpGroupNonUniformFAdd:
1735 case SPIRV::OpGroupNonUniformIMul:
1736 case SPIRV::OpGroupNonUniformFMul:
1737 case SPIRV::OpGroupNonUniformSMin:
1738 case SPIRV::OpGroupNonUniformUMin:
1739 case SPIRV::OpGroupNonUniformFMin:
1740 case SPIRV::OpGroupNonUniformSMax:
1741 case SPIRV::OpGroupNonUniformUMax:
1742 case SPIRV::OpGroupNonUniformFMax:
1743 case SPIRV::OpGroupNonUniformBitwiseAnd:
1744 case SPIRV::OpGroupNonUniformBitwiseOr:
1745 case SPIRV::OpGroupNonUniformBitwiseXor:
1746 case SPIRV::OpGroupNonUniformLogicalAnd:
1747 case SPIRV::OpGroupNonUniformLogicalOr:
1748 case SPIRV::OpGroupNonUniformLogicalXor: {
1749 assert(MI.getOperand(3).isImm());
1750 int64_t GroupOp = MI.getOperand(3).getImm();
1751 switch (GroupOp) {
1752 case SPIRV::GroupOperation::Reduce:
1753 case SPIRV::GroupOperation::InclusiveScan:
1754 case SPIRV::GroupOperation::ExclusiveScan:
1755 Reqs.addCapability(SPIRV::Capability::GroupNonUniformArithmetic);
1756 break;
1757 case SPIRV::GroupOperation::ClusteredReduce:
1758 Reqs.addCapability(SPIRV::Capability::GroupNonUniformClustered);
1759 break;
1760 case SPIRV::GroupOperation::PartitionedReduceNV:
1761 case SPIRV::GroupOperation::PartitionedInclusiveScanNV:
1762 case SPIRV::GroupOperation::PartitionedExclusiveScanNV:
1763 Reqs.addCapability(SPIRV::Capability::GroupNonUniformPartitionedNV);
1764 break;
1765 }
1766 break;
1767 }
1768 case SPIRV::OpGroupNonUniformQuadSwap:
1769 Reqs.addCapability(SPIRV::Capability::GroupNonUniformQuad);
1770 break;
1771 case SPIRV::OpImageQueryLod:
1772 Reqs.addCapability(SPIRV::Capability::ImageQuery);
1773 break;
1774 case SPIRV::OpImageQuerySize:
1775 case SPIRV::OpImageQuerySizeLod:
1776 case SPIRV::OpImageQueryLevels:
1777 case SPIRV::OpImageQuerySamples:
1778 if (ST.isShader())
1779 Reqs.addCapability(SPIRV::Capability::ImageQuery);
1780 break;
1781 case SPIRV::OpImageQueryFormat: {
1782 Register ResultReg = MI.getOperand(0).getReg();
1783 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1784 static const unsigned CompareOps[] = {
1785 SPIRV::OpIEqual, SPIRV::OpINotEqual,
1786 SPIRV::OpUGreaterThan, SPIRV::OpUGreaterThanEqual,
1787 SPIRV::OpULessThan, SPIRV::OpULessThanEqual,
1788 SPIRV::OpSGreaterThan, SPIRV::OpSGreaterThanEqual,
1789 SPIRV::OpSLessThan, SPIRV::OpSLessThanEqual};
1790
1791 auto CheckAndAddExtension = [&](int64_t ImmVal) {
1792 if (ImmVal == 4323 || ImmVal == 4324) {
1793 if (ST.canUseExtension(SPIRV::Extension::SPV_EXT_image_raw10_raw12))
1794 Reqs.addExtension(SPIRV::Extension::SPV_EXT_image_raw10_raw12);
1795 else
1796 report_fatal_error("This requires the "
1797 "SPV_EXT_image_raw10_raw12 extension");
1798 }
1799 };
1800
1801 for (MachineInstr &UseInst : MRI.use_instructions(ResultReg)) {
1802 unsigned Opc = UseInst.getOpcode();
1803
1804 if (Opc == SPIRV::OpSwitch) {
1805 for (const MachineOperand &Op : UseInst.operands())
1806 if (Op.isImm())
1807 CheckAndAddExtension(Op.getImm());
1808 } else if (llvm::is_contained(CompareOps, Opc)) {
1809 for (unsigned i = 1; i < UseInst.getNumOperands(); ++i) {
1810 Register UseReg = UseInst.getOperand(i).getReg();
1811 MachineInstr *ConstInst = MRI.getVRegDef(UseReg);
1812 if (ConstInst && ConstInst->getOpcode() == SPIRV::OpConstantI) {
1813 int64_t ImmVal = ConstInst->getOperand(2).getImm();
1814 if (ImmVal)
1815 CheckAndAddExtension(ImmVal);
1816 }
1817 }
1818 }
1819 }
1820 break;
1821 }
1822
1823 case SPIRV::OpGroupNonUniformShuffle:
1824 case SPIRV::OpGroupNonUniformShuffleXor:
1825 Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffle);
1826 break;
1827 case SPIRV::OpGroupNonUniformShuffleUp:
1828 case SPIRV::OpGroupNonUniformShuffleDown:
1829 Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffleRelative);
1830 break;
1831 case SPIRV::OpGroupAll:
1832 case SPIRV::OpGroupAny:
1833 case SPIRV::OpGroupBroadcast:
1834 case SPIRV::OpGroupIAdd:
1835 case SPIRV::OpGroupFAdd:
1836 case SPIRV::OpGroupFMin:
1837 case SPIRV::OpGroupUMin:
1838 case SPIRV::OpGroupSMin:
1839 case SPIRV::OpGroupFMax:
1840 case SPIRV::OpGroupUMax:
1841 case SPIRV::OpGroupSMax:
1842 Reqs.addCapability(SPIRV::Capability::Groups);
1843 break;
1844 case SPIRV::OpGroupNonUniformElect:
1845 Reqs.addCapability(SPIRV::Capability::GroupNonUniform);
1846 break;
1847 case SPIRV::OpGroupNonUniformAll:
1848 case SPIRV::OpGroupNonUniformAny:
1849 case SPIRV::OpGroupNonUniformAllEqual:
1850 Reqs.addCapability(SPIRV::Capability::GroupNonUniformVote);
1851 break;
1852 case SPIRV::OpGroupNonUniformBroadcast:
1853 case SPIRV::OpGroupNonUniformBroadcastFirst:
1854 case SPIRV::OpGroupNonUniformBallot:
1855 case SPIRV::OpGroupNonUniformInverseBallot:
1856 case SPIRV::OpGroupNonUniformBallotBitExtract:
1857 case SPIRV::OpGroupNonUniformBallotBitCount:
1858 case SPIRV::OpGroupNonUniformBallotFindLSB:
1859 case SPIRV::OpGroupNonUniformBallotFindMSB:
1860 Reqs.addCapability(SPIRV::Capability::GroupNonUniformBallot);
1861 break;
1862 case SPIRV::OpSubgroupShuffleINTEL:
1863 case SPIRV::OpSubgroupShuffleDownINTEL:
1864 case SPIRV::OpSubgroupShuffleUpINTEL:
1865 case SPIRV::OpSubgroupShuffleXorINTEL:
1866 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) {
1867 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups);
1868 Reqs.addCapability(SPIRV::Capability::SubgroupShuffleINTEL);
1869 }
1870 break;
1871 case SPIRV::OpSubgroupBlockReadINTEL:
1872 case SPIRV::OpSubgroupBlockWriteINTEL:
1873 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) {
1874 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups);
1875 Reqs.addCapability(SPIRV::Capability::SubgroupBufferBlockIOINTEL);
1876 }
1877 break;
1878 case SPIRV::OpSubgroupImageBlockReadINTEL:
1879 case SPIRV::OpSubgroupImageBlockWriteINTEL:
1880 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) {
1881 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups);
1882 Reqs.addCapability(SPIRV::Capability::SubgroupImageBlockIOINTEL);
1883 }
1884 break;
1885 case SPIRV::OpSubgroupImageMediaBlockReadINTEL:
1886 case SPIRV::OpSubgroupImageMediaBlockWriteINTEL:
1887 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_media_block_io)) {
1888 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_media_block_io);
1889 Reqs.addCapability(SPIRV::Capability::SubgroupImageMediaBlockIOINTEL);
1890 }
1891 break;
1892 case SPIRV::OpAssumeTrueKHR:
1893 case SPIRV::OpExpectKHR:
1894 if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_expect_assume)) {
1895 Reqs.addExtension(SPIRV::Extension::SPV_KHR_expect_assume);
1896 Reqs.addCapability(SPIRV::Capability::ExpectAssumeKHR);
1897 }
1898 break;
1899 case SPIRV::OpFmaKHR:
1900 if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_fma)) {
1901 Reqs.addExtension(SPIRV::Extension::SPV_KHR_fma);
1902 Reqs.addCapability(SPIRV::Capability::FmaKHR);
1903 }
1904 break;
1905 case SPIRV::OpPtrCastToCrossWorkgroupINTEL:
1906 case SPIRV::OpCrossWorkgroupCastToPtrINTEL:
1907 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes)) {
1908 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes);
1909 Reqs.addCapability(SPIRV::Capability::USMStorageClassesINTEL);
1910 }
1911 break;
1912 case SPIRV::OpConstantFunctionPointerINTEL:
1913 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) {
1914 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
1915 Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL);
1916 }
1917 break;
1918 case SPIRV::OpGroupNonUniformRotateKHR:
1919 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_subgroup_rotate))
1920 report_fatal_error("OpGroupNonUniformRotateKHR instruction requires the "
1921 "following SPIR-V extension: SPV_KHR_subgroup_rotate",
1922 false);
1923 Reqs.addExtension(SPIRV::Extension::SPV_KHR_subgroup_rotate);
1924 Reqs.addCapability(SPIRV::Capability::GroupNonUniformRotateKHR);
1925 Reqs.addCapability(SPIRV::Capability::GroupNonUniform);
1926 break;
1927 case SPIRV::OpFixedCosALTERA:
1928 case SPIRV::OpFixedSinALTERA:
1929 case SPIRV::OpFixedCosPiALTERA:
1930 case SPIRV::OpFixedSinPiALTERA:
1931 case SPIRV::OpFixedExpALTERA:
1932 case SPIRV::OpFixedLogALTERA:
1933 case SPIRV::OpFixedRecipALTERA:
1934 case SPIRV::OpFixedSqrtALTERA:
1935 case SPIRV::OpFixedSinCosALTERA:
1936 case SPIRV::OpFixedSinCosPiALTERA:
1937 case SPIRV::OpFixedRsqrtALTERA:
1938 if (!ST.canUseExtension(
1939 SPIRV::Extension::SPV_ALTERA_arbitrary_precision_fixed_point))
1940 report_fatal_error("This instruction requires the "
1941 "following SPIR-V extension: "
1942 "SPV_ALTERA_arbitrary_precision_fixed_point",
1943 false);
1944 Reqs.addExtension(
1945 SPIRV::Extension::SPV_ALTERA_arbitrary_precision_fixed_point);
1946 Reqs.addCapability(SPIRV::Capability::ArbitraryPrecisionFixedPointALTERA);
1947 break;
1948 case SPIRV::OpGroupIMulKHR:
1949 case SPIRV::OpGroupFMulKHR:
1950 case SPIRV::OpGroupBitwiseAndKHR:
1951 case SPIRV::OpGroupBitwiseOrKHR:
1952 case SPIRV::OpGroupBitwiseXorKHR:
1953 case SPIRV::OpGroupLogicalAndKHR:
1954 case SPIRV::OpGroupLogicalOrKHR:
1955 case SPIRV::OpGroupLogicalXorKHR:
1956 if (ST.canUseExtension(
1957 SPIRV::Extension::SPV_KHR_uniform_group_instructions)) {
1958 Reqs.addExtension(SPIRV::Extension::SPV_KHR_uniform_group_instructions);
1959 Reqs.addCapability(SPIRV::Capability::GroupUniformArithmeticKHR);
1960 }
1961 break;
1962 case SPIRV::OpReadClockKHR:
1963 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_shader_clock))
1964 report_fatal_error("OpReadClockKHR instruction requires the "
1965 "following SPIR-V extension: SPV_KHR_shader_clock",
1966 false);
1967 Reqs.addExtension(SPIRV::Extension::SPV_KHR_shader_clock);
1968 Reqs.addCapability(SPIRV::Capability::ShaderClockKHR);
1969 break;
1970 case SPIRV::OpFunctionPointerCallINTEL:
1971 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) {
1972 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
1973 Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL);
1974 }
1975 break;
1976 case SPIRV::OpAtomicFAddEXT:
1977 case SPIRV::OpAtomicFMinEXT:
1978 case SPIRV::OpAtomicFMaxEXT:
1979 AddAtomicFloatRequirements(MI, Reqs, ST);
1980 break;
1981 case SPIRV::OpConvertBF16ToFINTEL:
1982 case SPIRV::OpConvertFToBF16INTEL:
1983 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bfloat16_conversion)) {
1984 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_bfloat16_conversion);
1985 Reqs.addCapability(SPIRV::Capability::BFloat16ConversionINTEL);
1986 }
1987 break;
1988 case SPIRV::OpRoundFToTF32INTEL:
1989 if (ST.canUseExtension(
1990 SPIRV::Extension::SPV_INTEL_tensor_float32_conversion)) {
1991 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_tensor_float32_conversion);
1992 Reqs.addCapability(SPIRV::Capability::TensorFloat32RoundingINTEL);
1993 }
1994 break;
1995 case SPIRV::OpVariableLengthArrayINTEL:
1996 case SPIRV::OpSaveMemoryINTEL:
1997 case SPIRV::OpRestoreMemoryINTEL:
1998 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_variable_length_array)) {
1999 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_variable_length_array);
2000 Reqs.addCapability(SPIRV::Capability::VariableLengthArrayINTEL);
2001 }
2002 break;
2003 case SPIRV::OpAsmTargetINTEL:
2004 case SPIRV::OpAsmINTEL:
2005 case SPIRV::OpAsmCallINTEL:
2006 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_inline_assembly)) {
2007 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_inline_assembly);
2008 Reqs.addCapability(SPIRV::Capability::AsmINTEL);
2009 }
2010 break;
2011 case SPIRV::OpTypeCooperativeMatrixKHR: {
2012 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
2014 "OpTypeCooperativeMatrixKHR type requires the "
2015 "following SPIR-V extension: SPV_KHR_cooperative_matrix",
2016 false);
2017 Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
2018 Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
2019 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
2020 SPIRVTypeInst TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
2021 if (isBFloat16Type(TypeDef))
2022 Reqs.addCapability(SPIRV::Capability::BFloat16CooperativeMatrixKHR);
2023 break;
2024 }
2025 case SPIRV::OpArithmeticFenceEXT:
2026 if (!ST.canUseExtension(SPIRV::Extension::SPV_EXT_arithmetic_fence))
2027 report_fatal_error("OpArithmeticFenceEXT requires the "
2028 "following SPIR-V extension: SPV_EXT_arithmetic_fence",
2029 false);
2030 Reqs.addExtension(SPIRV::Extension::SPV_EXT_arithmetic_fence);
2031 Reqs.addCapability(SPIRV::Capability::ArithmeticFenceEXT);
2032 break;
2033 case SPIRV::OpControlBarrierArriveINTEL:
2034 case SPIRV::OpControlBarrierWaitINTEL:
2035 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_split_barrier)) {
2036 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_split_barrier);
2037 Reqs.addCapability(SPIRV::Capability::SplitBarrierINTEL);
2038 }
2039 break;
2040 case SPIRV::OpCooperativeMatrixMulAddKHR: {
2041 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
2042 report_fatal_error("Cooperative matrix instructions require the "
2043 "following SPIR-V extension: "
2044 "SPV_KHR_cooperative_matrix",
2045 false);
2046 Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
2047 Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
2048 constexpr unsigned MulAddMaxSize = 6;
2049 if (MI.getNumOperands() != MulAddMaxSize)
2050 break;
2051 const int64_t CoopOperands = MI.getOperand(MulAddMaxSize - 1).getImm();
2052 if (CoopOperands &
2053 SPIRV::CooperativeMatrixOperands::MatrixAAndBTF32ComponentsINTEL) {
2054 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
2055 report_fatal_error("MatrixAAndBTF32ComponentsINTEL type interpretation "
2056 "require the following SPIR-V extension: "
2057 "SPV_INTEL_joint_matrix",
2058 false);
2059 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
2060 Reqs.addCapability(
2061 SPIRV::Capability::CooperativeMatrixTF32ComponentTypeINTEL);
2062 }
2063 if (CoopOperands & SPIRV::CooperativeMatrixOperands::
2064 MatrixAAndBBFloat16ComponentsINTEL ||
2065 CoopOperands &
2066 SPIRV::CooperativeMatrixOperands::MatrixCBFloat16ComponentsINTEL ||
2067 CoopOperands & SPIRV::CooperativeMatrixOperands::
2068 MatrixResultBFloat16ComponentsINTEL) {
2069 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
2070 report_fatal_error("***BF16ComponentsINTEL type interpretations "
2071 "require the following SPIR-V extension: "
2072 "SPV_INTEL_joint_matrix",
2073 false);
2074 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
2075 Reqs.addCapability(
2076 SPIRV::Capability::CooperativeMatrixBFloat16ComponentTypeINTEL);
2077 }
2078 break;
2079 }
2080 case SPIRV::OpCooperativeMatrixLoadKHR:
2081 case SPIRV::OpCooperativeMatrixStoreKHR:
2082 case SPIRV::OpCooperativeMatrixLoadCheckedINTEL:
2083 case SPIRV::OpCooperativeMatrixStoreCheckedINTEL:
2084 case SPIRV::OpCooperativeMatrixPrefetchINTEL: {
2085 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
2086 report_fatal_error("Cooperative matrix instructions require the "
2087 "following SPIR-V extension: "
2088 "SPV_KHR_cooperative_matrix",
2089 false);
2090 Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
2091 Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
2092
2093 // Check Layout operand in case if it's not a standard one and add the
2094 // appropriate capability.
2095 std::unordered_map<unsigned, unsigned> LayoutToInstMap = {
2096 {SPIRV::OpCooperativeMatrixLoadKHR, 3},
2097 {SPIRV::OpCooperativeMatrixStoreKHR, 2},
2098 {SPIRV::OpCooperativeMatrixLoadCheckedINTEL, 5},
2099 {SPIRV::OpCooperativeMatrixStoreCheckedINTEL, 4},
2100 {SPIRV::OpCooperativeMatrixPrefetchINTEL, 4}};
2101
2102 const unsigned LayoutNum = LayoutToInstMap[Op];
2103 Register RegLayout = MI.getOperand(LayoutNum).getReg();
2104 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
2105 MachineInstr *MILayout = MRI.getUniqueVRegDef(RegLayout);
2106 if (MILayout->getOpcode() == SPIRV::OpConstantI) {
2107 const unsigned LayoutVal = MILayout->getOperand(2).getImm();
2108 if (LayoutVal ==
2109 static_cast<unsigned>(SPIRV::CooperativeMatrixLayout::PackedINTEL)) {
2110 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
2111 report_fatal_error("PackedINTEL layout require the following SPIR-V "
2112 "extension: SPV_INTEL_joint_matrix",
2113 false);
2114 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
2115 Reqs.addCapability(SPIRV::Capability::PackedCooperativeMatrixINTEL);
2116 }
2117 }
2118
2119 // Nothing to do.
2120 if (Op == SPIRV::OpCooperativeMatrixLoadKHR ||
2121 Op == SPIRV::OpCooperativeMatrixStoreKHR)
2122 break;
2123
2124 std::string InstName;
2125 switch (Op) {
2126 case SPIRV::OpCooperativeMatrixPrefetchINTEL:
2127 InstName = "OpCooperativeMatrixPrefetchINTEL";
2128 break;
2129 case SPIRV::OpCooperativeMatrixLoadCheckedINTEL:
2130 InstName = "OpCooperativeMatrixLoadCheckedINTEL";
2131 break;
2132 case SPIRV::OpCooperativeMatrixStoreCheckedINTEL:
2133 InstName = "OpCooperativeMatrixStoreCheckedINTEL";
2134 break;
2135 }
2136
2137 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix)) {
2138 const std::string ErrorMsg =
2139 InstName + " instruction requires the "
2140 "following SPIR-V extension: SPV_INTEL_joint_matrix";
2141 report_fatal_error(ErrorMsg.c_str(), false);
2142 }
2143 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
2144 if (Op == SPIRV::OpCooperativeMatrixPrefetchINTEL) {
2145 Reqs.addCapability(SPIRV::Capability::CooperativeMatrixPrefetchINTEL);
2146 break;
2147 }
2148 Reqs.addCapability(
2149 SPIRV::Capability::CooperativeMatrixCheckedInstructionsINTEL);
2150 break;
2151 }
2152 case SPIRV::OpCooperativeMatrixConstructCheckedINTEL:
2153 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
2154 report_fatal_error("OpCooperativeMatrixConstructCheckedINTEL "
2155 "instructions require the following SPIR-V extension: "
2156 "SPV_INTEL_joint_matrix",
2157 false);
2158 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
2159 Reqs.addCapability(
2160 SPIRV::Capability::CooperativeMatrixCheckedInstructionsINTEL);
2161 break;
2162 case SPIRV::OpReadPipeBlockingALTERA:
2163 case SPIRV::OpWritePipeBlockingALTERA:
2164 if (ST.canUseExtension(SPIRV::Extension::SPV_ALTERA_blocking_pipes)) {
2165 Reqs.addExtension(SPIRV::Extension::SPV_ALTERA_blocking_pipes);
2166 Reqs.addCapability(SPIRV::Capability::BlockingPipesALTERA);
2167 }
2168 break;
2169 case SPIRV::OpCooperativeMatrixGetElementCoordINTEL:
2170 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
2171 report_fatal_error("OpCooperativeMatrixGetElementCoordINTEL requires the "
2172 "following SPIR-V extension: SPV_INTEL_joint_matrix",
2173 false);
2174 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
2175 Reqs.addCapability(
2176 SPIRV::Capability::CooperativeMatrixInvocationInstructionsINTEL);
2177 break;
2178 case SPIRV::OpConvertHandleToImageINTEL:
2179 case SPIRV::OpConvertHandleToSamplerINTEL:
2180 case SPIRV::OpConvertHandleToSampledImageINTEL: {
2181 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bindless_images))
2182 report_fatal_error("OpConvertHandleTo[Image/Sampler/SampledImage]INTEL "
2183 "instructions require the following SPIR-V extension: "
2184 "SPV_INTEL_bindless_images",
2185 false);
2186 SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
2187 SPIRV::AddressingModel::AddressingModel AddrModel = MAI.Addr;
2188 SPIRVTypeInst TyDef = GR->getSPIRVTypeForVReg(MI.getOperand(1).getReg());
2189 if (Op == SPIRV::OpConvertHandleToImageINTEL &&
2190 TyDef->getOpcode() != SPIRV::OpTypeImage) {
2191 report_fatal_error("Incorrect return type for the instruction "
2192 "OpConvertHandleToImageINTEL",
2193 false);
2194 } else if (Op == SPIRV::OpConvertHandleToSamplerINTEL &&
2195 TyDef->getOpcode() != SPIRV::OpTypeSampler) {
2196 report_fatal_error("Incorrect return type for the instruction "
2197 "OpConvertHandleToSamplerINTEL",
2198 false);
2199 } else if (Op == SPIRV::OpConvertHandleToSampledImageINTEL &&
2200 TyDef->getOpcode() != SPIRV::OpTypeSampledImage) {
2201 report_fatal_error("Incorrect return type for the instruction "
2202 "OpConvertHandleToSampledImageINTEL",
2203 false);
2204 }
2205 SPIRVTypeInst SpvTy = GR->getSPIRVTypeForVReg(MI.getOperand(2).getReg());
2206 unsigned Bitwidth = GR->getScalarOrVectorBitWidth(SpvTy);
2207 if (!(Bitwidth == 32 && AddrModel == SPIRV::AddressingModel::Physical32) &&
2208 !(Bitwidth == 64 && AddrModel == SPIRV::AddressingModel::Physical64)) {
2210 "Parameter value must be a 32-bit scalar in case of "
2211 "Physical32 addressing model or a 64-bit scalar in case of "
2212 "Physical64 addressing model",
2213 false);
2214 }
2215 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_bindless_images);
2216 Reqs.addCapability(SPIRV::Capability::BindlessImagesINTEL);
2217 break;
2218 }
2219 case SPIRV::OpSubgroup2DBlockLoadINTEL:
2220 case SPIRV::OpSubgroup2DBlockLoadTransposeINTEL:
2221 case SPIRV::OpSubgroup2DBlockLoadTransformINTEL:
2222 case SPIRV::OpSubgroup2DBlockPrefetchINTEL:
2223 case SPIRV::OpSubgroup2DBlockStoreINTEL: {
2224 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_2d_block_io))
2225 report_fatal_error("OpSubgroup2DBlock[Load/LoadTranspose/LoadTransform/"
2226 "Prefetch/Store]INTEL instructions require the "
2227 "following SPIR-V extension: SPV_INTEL_2d_block_io",
2228 false);
2229 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_2d_block_io);
2230 Reqs.addCapability(SPIRV::Capability::Subgroup2DBlockIOINTEL);
2231
2232 if (Op == SPIRV::OpSubgroup2DBlockLoadTransposeINTEL) {
2233 Reqs.addCapability(SPIRV::Capability::Subgroup2DBlockTransposeINTEL);
2234 break;
2235 }
2236 if (Op == SPIRV::OpSubgroup2DBlockLoadTransformINTEL) {
2237 Reqs.addCapability(SPIRV::Capability::Subgroup2DBlockTransformINTEL);
2238 break;
2239 }
2240 break;
2241 }
2242 case SPIRV::OpKill: {
2243 Reqs.addCapability(SPIRV::Capability::Shader);
2244 } break;
2245 case SPIRV::OpDemoteToHelperInvocation:
2246 Reqs.addCapability(SPIRV::Capability::DemoteToHelperInvocation);
2247
2248 if (ST.canUseExtension(
2249 SPIRV::Extension::SPV_EXT_demote_to_helper_invocation)) {
2250 if (!ST.isAtLeastSPIRVVer(llvm::VersionTuple(1, 6)))
2251 Reqs.addExtension(
2252 SPIRV::Extension::SPV_EXT_demote_to_helper_invocation);
2253 }
2254 break;
2255 case SPIRV::OpSDot:
2256 case SPIRV::OpUDot:
2257 case SPIRV::OpSUDot:
2258 case SPIRV::OpSDotAccSat:
2259 case SPIRV::OpUDotAccSat:
2260 case SPIRV::OpSUDotAccSat:
2261 AddDotProductRequirements(MI, Reqs, ST);
2262 break;
2263 case SPIRV::OpImageSampleImplicitLod:
2264 Reqs.addCapability(SPIRV::Capability::Shader);
2265 addImageOperandReqs(MI, Reqs, ST, 4);
2266 break;
2267 case SPIRV::OpImageSampleExplicitLod:
2268 addImageOperandReqs(MI, Reqs, ST, 4);
2269 break;
2270 case SPIRV::OpImageSampleDrefImplicitLod:
2271 Reqs.addCapability(SPIRV::Capability::Shader);
2272 addImageOperandReqs(MI, Reqs, ST, 5);
2273 break;
2274 case SPIRV::OpImageSampleDrefExplicitLod:
2275 Reqs.addCapability(SPIRV::Capability::Shader);
2276 addImageOperandReqs(MI, Reqs, ST, 5);
2277 break;
2278 case SPIRV::OpImageFetch:
2279 Reqs.addCapability(SPIRV::Capability::Shader);
2280 addImageOperandReqs(MI, Reqs, ST, 4);
2281 break;
2282 case SPIRV::OpImageDrefGather:
2283 case SPIRV::OpImageGather:
2284 Reqs.addCapability(SPIRV::Capability::Shader);
2285 addImageOperandReqs(MI, Reqs, ST, 5);
2286 break;
2287 case SPIRV::OpImageRead: {
2288 Register ImageReg = MI.getOperand(2).getReg();
2289 SPIRVTypeInst TypeDef = ST.getSPIRVGlobalRegistry()->getResultType(
2290 ImageReg, const_cast<MachineFunction *>(MI.getMF()));
2291 // OpImageRead and OpImageWrite can use Unknown Image Formats
2292 // when the Kernel capability is declared. In the OpenCL environment we are
2293 // not allowed to produce
2294 // StorageImageReadWithoutFormat/StorageImageWriteWithoutFormat, see
2295 // https://github.com/KhronosGroup/SPIRV-Headers/issues/487
2296
2297 if (isImageTypeWithUnknownFormat(TypeDef) && ST.isShader())
2298 Reqs.addCapability(SPIRV::Capability::StorageImageReadWithoutFormat);
2299 break;
2300 }
2301 case SPIRV::OpImageWrite: {
2302 Register ImageReg = MI.getOperand(0).getReg();
2303 SPIRVTypeInst TypeDef = ST.getSPIRVGlobalRegistry()->getResultType(
2304 ImageReg, const_cast<MachineFunction *>(MI.getMF()));
2305 // OpImageRead and OpImageWrite can use Unknown Image Formats
2306 // when the Kernel capability is declared. In the OpenCL environment we are
2307 // not allowed to produce
2308 // StorageImageReadWithoutFormat/StorageImageWriteWithoutFormat, see
2309 // https://github.com/KhronosGroup/SPIRV-Headers/issues/487
2310
2311 if (isImageTypeWithUnknownFormat(TypeDef) && ST.isShader())
2312 Reqs.addCapability(SPIRV::Capability::StorageImageWriteWithoutFormat);
2313 break;
2314 }
2315 case SPIRV::OpTypeStructContinuedINTEL:
2316 case SPIRV::OpConstantCompositeContinuedINTEL:
2317 case SPIRV::OpSpecConstantCompositeContinuedINTEL:
2318 case SPIRV::OpCompositeConstructContinuedINTEL: {
2319 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_long_composites))
2321 "Continued instructions require the "
2322 "following SPIR-V extension: SPV_INTEL_long_composites",
2323 false);
2324 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_long_composites);
2325 Reqs.addCapability(SPIRV::Capability::LongCompositesINTEL);
2326 break;
2327 }
2328 case SPIRV::OpArbitraryFloatEQALTERA:
2329 case SPIRV::OpArbitraryFloatGEALTERA:
2330 case SPIRV::OpArbitraryFloatGTALTERA:
2331 case SPIRV::OpArbitraryFloatLEALTERA:
2332 case SPIRV::OpArbitraryFloatLTALTERA:
2333 case SPIRV::OpArbitraryFloatCbrtALTERA:
2334 case SPIRV::OpArbitraryFloatCosALTERA:
2335 case SPIRV::OpArbitraryFloatCosPiALTERA:
2336 case SPIRV::OpArbitraryFloatExp10ALTERA:
2337 case SPIRV::OpArbitraryFloatExp2ALTERA:
2338 case SPIRV::OpArbitraryFloatExpALTERA:
2339 case SPIRV::OpArbitraryFloatExpm1ALTERA:
2340 case SPIRV::OpArbitraryFloatHypotALTERA:
2341 case SPIRV::OpArbitraryFloatLog10ALTERA:
2342 case SPIRV::OpArbitraryFloatLog1pALTERA:
2343 case SPIRV::OpArbitraryFloatLog2ALTERA:
2344 case SPIRV::OpArbitraryFloatLogALTERA:
2345 case SPIRV::OpArbitraryFloatRecipALTERA:
2346 case SPIRV::OpArbitraryFloatSinCosALTERA:
2347 case SPIRV::OpArbitraryFloatSinCosPiALTERA:
2348 case SPIRV::OpArbitraryFloatSinALTERA:
2349 case SPIRV::OpArbitraryFloatSinPiALTERA:
2350 case SPIRV::OpArbitraryFloatSqrtALTERA:
2351 case SPIRV::OpArbitraryFloatACosALTERA:
2352 case SPIRV::OpArbitraryFloatACosPiALTERA:
2353 case SPIRV::OpArbitraryFloatAddALTERA:
2354 case SPIRV::OpArbitraryFloatASinALTERA:
2355 case SPIRV::OpArbitraryFloatASinPiALTERA:
2356 case SPIRV::OpArbitraryFloatATan2ALTERA:
2357 case SPIRV::OpArbitraryFloatATanALTERA:
2358 case SPIRV::OpArbitraryFloatATanPiALTERA:
2359 case SPIRV::OpArbitraryFloatCastFromIntALTERA:
2360 case SPIRV::OpArbitraryFloatCastALTERA:
2361 case SPIRV::OpArbitraryFloatCastToIntALTERA:
2362 case SPIRV::OpArbitraryFloatDivALTERA:
2363 case SPIRV::OpArbitraryFloatMulALTERA:
2364 case SPIRV::OpArbitraryFloatPowALTERA:
2365 case SPIRV::OpArbitraryFloatPowNALTERA:
2366 case SPIRV::OpArbitraryFloatPowRALTERA:
2367 case SPIRV::OpArbitraryFloatRSqrtALTERA:
2368 case SPIRV::OpArbitraryFloatSubALTERA: {
2369 if (!ST.canUseExtension(
2370 SPIRV::Extension::SPV_ALTERA_arbitrary_precision_floating_point))
2372 "Floating point instructions can't be translated correctly without "
2373 "enabled SPV_ALTERA_arbitrary_precision_floating_point extension!",
2374 false);
2375 Reqs.addExtension(
2376 SPIRV::Extension::SPV_ALTERA_arbitrary_precision_floating_point);
2377 Reqs.addCapability(
2378 SPIRV::Capability::ArbitraryPrecisionFloatingPointALTERA);
2379 break;
2380 }
2381 case SPIRV::OpSubgroupMatrixMultiplyAccumulateINTEL: {
2382 if (!ST.canUseExtension(
2383 SPIRV::Extension::SPV_INTEL_subgroup_matrix_multiply_accumulate))
2385 "OpSubgroupMatrixMultiplyAccumulateINTEL instruction requires the "
2386 "following SPIR-V "
2387 "extension: SPV_INTEL_subgroup_matrix_multiply_accumulate",
2388 false);
2389 Reqs.addExtension(
2390 SPIRV::Extension::SPV_INTEL_subgroup_matrix_multiply_accumulate);
2391 Reqs.addCapability(
2392 SPIRV::Capability::SubgroupMatrixMultiplyAccumulateINTEL);
2393 break;
2394 }
2395 case SPIRV::OpBitwiseFunctionINTEL: {
2396 if (!ST.canUseExtension(
2397 SPIRV::Extension::SPV_INTEL_ternary_bitwise_function))
2399 "OpBitwiseFunctionINTEL instruction requires the following SPIR-V "
2400 "extension: SPV_INTEL_ternary_bitwise_function",
2401 false);
2402 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_ternary_bitwise_function);
2403 Reqs.addCapability(SPIRV::Capability::TernaryBitwiseFunctionINTEL);
2404 break;
2405 }
2406 case SPIRV::OpCopyMemorySized: {
2407 Reqs.addCapability(SPIRV::Capability::Addresses);
2408 // TODO: Add UntypedPointersKHR when implemented.
2409 break;
2410 }
2411 case SPIRV::OpPredicatedLoadINTEL:
2412 case SPIRV::OpPredicatedStoreINTEL: {
2413 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_predicated_io))
2415 "OpPredicated[Load/Store]INTEL instructions require "
2416 "the following SPIR-V extension: SPV_INTEL_predicated_io",
2417 false);
2418 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_predicated_io);
2419 Reqs.addCapability(SPIRV::Capability::PredicatedIOINTEL);
2420 break;
2421 }
2422 case SPIRV::OpFAddS:
2423 case SPIRV::OpFSubS:
2424 case SPIRV::OpFMulS:
2425 case SPIRV::OpFDivS:
2426 case SPIRV::OpFRemS:
2427 case SPIRV::OpFMod:
2428 case SPIRV::OpFNegate:
2429 case SPIRV::OpFAddV:
2430 case SPIRV::OpFSubV:
2431 case SPIRV::OpFMulV:
2432 case SPIRV::OpFDivV:
2433 case SPIRV::OpFRemV:
2434 case SPIRV::OpFNegateV: {
2435 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
2436 SPIRVTypeInst TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
2437 if (TypeDef->getOpcode() == SPIRV::OpTypeVector)
2438 TypeDef = MRI.getVRegDef(TypeDef->getOperand(1).getReg());
2439 if (isBFloat16Type(TypeDef)) {
2440 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic))
2442 "Arithmetic instructions with bfloat16 arguments require the "
2443 "following SPIR-V extension: SPV_INTEL_bfloat16_arithmetic",
2444 false);
2445 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic);
2446 Reqs.addCapability(SPIRV::Capability::BFloat16ArithmeticINTEL);
2447 }
2448 break;
2449 }
2450 case SPIRV::OpOrdered:
2451 case SPIRV::OpUnordered:
2452 case SPIRV::OpFOrdEqual:
2453 case SPIRV::OpFOrdNotEqual:
2454 case SPIRV::OpFOrdLessThan:
2455 case SPIRV::OpFOrdLessThanEqual:
2456 case SPIRV::OpFOrdGreaterThan:
2457 case SPIRV::OpFOrdGreaterThanEqual:
2458 case SPIRV::OpFUnordEqual:
2459 case SPIRV::OpFUnordNotEqual:
2460 case SPIRV::OpFUnordLessThan:
2461 case SPIRV::OpFUnordLessThanEqual:
2462 case SPIRV::OpFUnordGreaterThan:
2463 case SPIRV::OpFUnordGreaterThanEqual: {
2464 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
2465 MachineInstr *OperandDef = MRI.getVRegDef(MI.getOperand(2).getReg());
2466 SPIRVTypeInst TypeDef = MRI.getVRegDef(OperandDef->getOperand(1).getReg());
2467 if (TypeDef->getOpcode() == SPIRV::OpTypeVector)
2468 TypeDef = MRI.getVRegDef(TypeDef->getOperand(1).getReg());
2469 if (isBFloat16Type(TypeDef)) {
2470 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic))
2472 "Relational instructions with bfloat16 arguments require the "
2473 "following SPIR-V extension: SPV_INTEL_bfloat16_arithmetic",
2474 false);
2475 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic);
2476 Reqs.addCapability(SPIRV::Capability::BFloat16ArithmeticINTEL);
2477 }
2478 break;
2479 }
2480 case SPIRV::OpDPdxCoarse:
2481 case SPIRV::OpDPdyCoarse:
2482 case SPIRV::OpDPdxFine:
2483 case SPIRV::OpDPdyFine: {
2484 Reqs.addCapability(SPIRV::Capability::DerivativeControl);
2485 break;
2486 }
2487 case SPIRV::OpLoopControlINTEL: {
2488 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_unstructured_loop_controls);
2489 Reqs.addCapability(SPIRV::Capability::UnstructuredLoopControlsINTEL);
2490 break;
2491 }
2492
2493 default:
2494 break;
2495 }
2496
2497 // If we require capability Shader, then we can remove the requirement for
2498 // the BitInstructions capability, since Shader is a superset capability
2499 // of BitInstructions.
2500 Reqs.removeCapabilityIf(SPIRV::Capability::BitInstructions,
2501 SPIRV::Capability::Shader);
2502}
2503
2504static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI,
2505 MachineModuleInfo *MMI, const SPIRVSubtarget &ST) {
2506 // Collect requirements for existing instructions.
2507 for (const Function &F : M) {
2509 if (!MF)
2510 continue;
2511 for (const MachineBasicBlock &MBB : *MF)
2512 for (const MachineInstr &MI : MBB)
2513 addInstrRequirements(MI, MAI, ST);
2514 }
2515 // Collect requirements for OpExecutionMode instructions.
2516 auto Node = M.getNamedMetadata("spirv.ExecutionMode");
2517 if (Node) {
2518 bool RequireFloatControls = false, RequireIntelFloatControls2 = false,
2519 RequireKHRFloatControls2 = false,
2520 VerLower14 = !ST.isAtLeastSPIRVVer(VersionTuple(1, 4));
2521 bool HasIntelFloatControls2 =
2522 ST.canUseExtension(SPIRV::Extension::SPV_INTEL_float_controls2);
2523 bool HasKHRFloatControls2 =
2524 ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2);
2525 for (unsigned i = 0; i < Node->getNumOperands(); i++) {
2526 MDNode *MDN = cast<MDNode>(Node->getOperand(i));
2527 const MDOperand &MDOp = MDN->getOperand(1);
2528 if (auto *CMeta = dyn_cast<ConstantAsMetadata>(MDOp)) {
2529 Constant *C = CMeta->getValue();
2530 if (ConstantInt *Const = dyn_cast<ConstantInt>(C)) {
2531 auto EM = Const->getZExtValue();
2532 // SPV_KHR_float_controls is not available until v1.4:
2533 // add SPV_KHR_float_controls if the version is too low
2534 switch (EM) {
2535 case SPIRV::ExecutionMode::DenormPreserve:
2536 case SPIRV::ExecutionMode::DenormFlushToZero:
2537 case SPIRV::ExecutionMode::RoundingModeRTE:
2538 case SPIRV::ExecutionMode::RoundingModeRTZ:
2539 RequireFloatControls = VerLower14;
2541 SPIRV::OperandCategory::ExecutionModeOperand, EM, ST);
2542 break;
2543 case SPIRV::ExecutionMode::RoundingModeRTPINTEL:
2544 case SPIRV::ExecutionMode::RoundingModeRTNINTEL:
2545 case SPIRV::ExecutionMode::FloatingPointModeALTINTEL:
2546 case SPIRV::ExecutionMode::FloatingPointModeIEEEINTEL:
2547 if (HasIntelFloatControls2) {
2548 RequireIntelFloatControls2 = true;
2550 SPIRV::OperandCategory::ExecutionModeOperand, EM, ST);
2551 }
2552 break;
2553 case SPIRV::ExecutionMode::FPFastMathDefault: {
2554 if (HasKHRFloatControls2) {
2555 RequireKHRFloatControls2 = true;
2557 SPIRV::OperandCategory::ExecutionModeOperand, EM, ST);
2558 }
2559 break;
2560 }
2561 case SPIRV::ExecutionMode::ContractionOff:
2562 case SPIRV::ExecutionMode::SignedZeroInfNanPreserve:
2563 if (HasKHRFloatControls2) {
2564 RequireKHRFloatControls2 = true;
2566 SPIRV::OperandCategory::ExecutionModeOperand,
2567 SPIRV::ExecutionMode::FPFastMathDefault, ST);
2568 } else {
2570 SPIRV::OperandCategory::ExecutionModeOperand, EM, ST);
2571 }
2572 break;
2573 default:
2575 SPIRV::OperandCategory::ExecutionModeOperand, EM, ST);
2576 }
2577 }
2578 }
2579 }
2580 if (RequireFloatControls &&
2581 ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls))
2582 MAI.Reqs.addExtension(SPIRV::Extension::SPV_KHR_float_controls);
2583 if (RequireIntelFloatControls2)
2584 MAI.Reqs.addExtension(SPIRV::Extension::SPV_INTEL_float_controls2);
2585 if (RequireKHRFloatControls2)
2586 MAI.Reqs.addExtension(SPIRV::Extension::SPV_KHR_float_controls2);
2587 }
2588 for (const Function &F : M) {
2589 if (F.isDeclaration())
2590 continue;
2591 if (F.getMetadata("reqd_work_group_size"))
2593 SPIRV::OperandCategory::ExecutionModeOperand,
2594 SPIRV::ExecutionMode::LocalSize, ST);
2595 if (F.getFnAttribute("hlsl.numthreads").isValid()) {
2597 SPIRV::OperandCategory::ExecutionModeOperand,
2598 SPIRV::ExecutionMode::LocalSize, ST);
2599 }
2600 if (F.getFnAttribute("enable-maximal-reconvergence").getValueAsBool()) {
2601 MAI.Reqs.addExtension(SPIRV::Extension::SPV_KHR_maximal_reconvergence);
2602 }
2603 if (F.getMetadata("work_group_size_hint"))
2605 SPIRV::OperandCategory::ExecutionModeOperand,
2606 SPIRV::ExecutionMode::LocalSizeHint, ST);
2607 if (F.getMetadata("intel_reqd_sub_group_size"))
2609 SPIRV::OperandCategory::ExecutionModeOperand,
2610 SPIRV::ExecutionMode::SubgroupSize, ST);
2611 if (F.getMetadata("max_work_group_size"))
2613 SPIRV::OperandCategory::ExecutionModeOperand,
2614 SPIRV::ExecutionMode::MaxWorkgroupSizeINTEL, ST);
2615 if (F.getMetadata("vec_type_hint"))
2617 SPIRV::OperandCategory::ExecutionModeOperand,
2618 SPIRV::ExecutionMode::VecTypeHint, ST);
2619
2620 if (F.hasOptNone()) {
2621 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_optnone)) {
2622 MAI.Reqs.addExtension(SPIRV::Extension::SPV_INTEL_optnone);
2623 MAI.Reqs.addCapability(SPIRV::Capability::OptNoneINTEL);
2624 } else if (ST.canUseExtension(SPIRV::Extension::SPV_EXT_optnone)) {
2625 MAI.Reqs.addExtension(SPIRV::Extension::SPV_EXT_optnone);
2626 MAI.Reqs.addCapability(SPIRV::Capability::OptNoneEXT);
2627 }
2628 }
2629 }
2630}
2631
2632static unsigned getFastMathFlags(const MachineInstr &I,
2633 const SPIRVSubtarget &ST) {
2634 unsigned Flags = SPIRV::FPFastMathMode::None;
2635 bool CanUseKHRFloatControls2 =
2636 ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2);
2637 if (I.getFlag(MachineInstr::MIFlag::FmNoNans))
2638 Flags |= SPIRV::FPFastMathMode::NotNaN;
2639 if (I.getFlag(MachineInstr::MIFlag::FmNoInfs))
2640 Flags |= SPIRV::FPFastMathMode::NotInf;
2641 if (I.getFlag(MachineInstr::MIFlag::FmNsz))
2642 Flags |= SPIRV::FPFastMathMode::NSZ;
2643 if (I.getFlag(MachineInstr::MIFlag::FmArcp))
2644 Flags |= SPIRV::FPFastMathMode::AllowRecip;
2645 if (I.getFlag(MachineInstr::MIFlag::FmContract) && CanUseKHRFloatControls2)
2646 Flags |= SPIRV::FPFastMathMode::AllowContract;
2647 if (I.getFlag(MachineInstr::MIFlag::FmReassoc)) {
2648 if (CanUseKHRFloatControls2)
2649 // LLVM reassoc maps to SPIRV transform, see
2650 // https://github.com/KhronosGroup/SPIRV-Registry/issues/326 for details.
2651 // Because we are enabling AllowTransform, we must enable AllowReassoc and
2652 // AllowContract too, as required by SPIRV spec. Also, we used to map
2653 // MIFlag::FmReassoc to FPFastMathMode::Fast, which now should instead by
2654 // replaced by turning all the other bits instead. Therefore, we're
2655 // enabling every bit here except None and Fast.
2656 Flags |= SPIRV::FPFastMathMode::NotNaN | SPIRV::FPFastMathMode::NotInf |
2657 SPIRV::FPFastMathMode::NSZ | SPIRV::FPFastMathMode::AllowRecip |
2658 SPIRV::FPFastMathMode::AllowTransform |
2659 SPIRV::FPFastMathMode::AllowReassoc |
2660 SPIRV::FPFastMathMode::AllowContract;
2661 else
2662 Flags |= SPIRV::FPFastMathMode::Fast;
2663 }
2664
2665 if (CanUseKHRFloatControls2) {
2666 // Error out if SPIRV::FPFastMathMode::Fast is enabled.
2667 assert(!(Flags & SPIRV::FPFastMathMode::Fast) &&
2668 "SPIRV::FPFastMathMode::Fast is deprecated and should not be used "
2669 "anymore.");
2670
2671 // Error out if AllowTransform is enabled without AllowReassoc and
2672 // AllowContract.
2673 assert((!(Flags & SPIRV::FPFastMathMode::AllowTransform) ||
2674 ((Flags & SPIRV::FPFastMathMode::AllowReassoc &&
2675 Flags & SPIRV::FPFastMathMode::AllowContract))) &&
2676 "SPIRV::FPFastMathMode::AllowTransform requires AllowReassoc and "
2677 "AllowContract flags to be enabled as well.");
2678 }
2679
2680 return Flags;
2681}
2682
2683static bool isFastMathModeAvailable(const SPIRVSubtarget &ST) {
2684 if (ST.isKernel())
2685 return true;
2686 if (ST.getSPIRVVersion() < VersionTuple(1, 2))
2687 return false;
2688 return ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2);
2689}
2690
2691static void handleMIFlagDecoration(
2692 MachineInstr &I, const SPIRVSubtarget &ST, const SPIRVInstrInfo &TII,
2694 SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec) {
2695 if (I.getFlag(MachineInstr::MIFlag::NoSWrap) && TII.canUseNSW(I) &&
2696 getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand,
2697 SPIRV::Decoration::NoSignedWrap, ST, Reqs)
2698 .IsSatisfiable) {
2699 buildOpDecorate(I.getOperand(0).getReg(), I, TII,
2700 SPIRV::Decoration::NoSignedWrap, {});
2701 }
2702 if (I.getFlag(MachineInstr::MIFlag::NoUWrap) && TII.canUseNUW(I) &&
2703 getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand,
2704 SPIRV::Decoration::NoUnsignedWrap, ST,
2705 Reqs)
2706 .IsSatisfiable) {
2707 buildOpDecorate(I.getOperand(0).getReg(), I, TII,
2708 SPIRV::Decoration::NoUnsignedWrap, {});
2709 }
2710 // In Kernel environments, FPFastMathMode on OpExtInst is valid per core
2711 // spec. For other instruction types, SPV_KHR_float_controls2 is required.
2712 bool CanUseFM =
2713 TII.canUseFastMathFlags(
2714 I, ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)) ||
2715 (ST.isKernel() && I.getOpcode() == SPIRV::OpExtInst);
2716 if (!CanUseFM)
2717 return;
2718
2719 unsigned FMFlags = getFastMathFlags(I, ST);
2720 if (FMFlags == SPIRV::FPFastMathMode::None) {
2721 // We also need to check if any FPFastMathDefault info was set for the
2722 // types used in this instruction.
2723 if (FPFastMathDefaultInfoVec.empty())
2724 return;
2725
2726 // There are three types of instructions that can use fast math flags:
2727 // 1. Arithmetic instructions (FAdd, FMul, FSub, FDiv, FRem, etc.)
2728 // 2. Relational instructions (FCmp, FOrd, FUnord, etc.)
2729 // 3. Extended instructions (ExtInst)
2730 // For arithmetic instructions, the floating point type can be in the
2731 // result type or in the operands, but they all must be the same.
2732 // For the relational and logical instructions, the floating point type
2733 // can only be in the operands 1 and 2, not the result type. Also, the
2734 // operands must have the same type. For the extended instructions, the
2735 // floating point type can be in the result type or in the operands. It's
2736 // unclear if the operands and the result type must be the same. Let's
2737 // assume they must be. Therefore, for 1. and 2., we can check the first
2738 // operand type, and for 3. we can check the result type.
2739 assert(I.getNumOperands() >= 3 && "Expected at least 3 operands");
2740 Register ResReg = I.getOpcode() == SPIRV::OpExtInst
2741 ? I.getOperand(1).getReg()
2742 : I.getOperand(2).getReg();
2743 SPIRVTypeInst ResType = GR->getSPIRVTypeForVReg(ResReg, I.getMF());
2744 const Type *Ty = GR->getTypeForSPIRVType(ResType);
2745 Ty = Ty->isVectorTy() ? cast<VectorType>(Ty)->getElementType() : Ty;
2746
2747 // Match instruction type with the FPFastMathDefaultInfoVec.
2748 bool Emit = false;
2749 for (SPIRV::FPFastMathDefaultInfo &Elem : FPFastMathDefaultInfoVec) {
2750 if (Ty == Elem.Ty) {
2751 FMFlags = Elem.FastMathFlags;
2752 Emit = Elem.ContractionOff || Elem.SignedZeroInfNanPreserve ||
2753 Elem.FPFastMathDefault;
2754 break;
2755 }
2756 }
2757
2758 if (FMFlags == SPIRV::FPFastMathMode::None && !Emit)
2759 return;
2760 }
2761 if (isFastMathModeAvailable(ST)) {
2762 Register DstReg = I.getOperand(0).getReg();
2763 buildOpDecorate(DstReg, I, TII, SPIRV::Decoration::FPFastMathMode,
2764 {FMFlags});
2765 }
2766}
2767
2768// Walk all functions and add decorations related to MI flags.
2769static void addDecorations(const Module &M, const SPIRVInstrInfo &TII,
2770 MachineModuleInfo *MMI, const SPIRVSubtarget &ST,
2772 const SPIRVGlobalRegistry *GR) {
2773 for (const Function &F : M) {
2775 if (!MF)
2776 continue;
2777
2778 for (auto &MBB : *MF)
2779 for (auto &MI : MBB)
2780 handleMIFlagDecoration(MI, ST, TII, MAI.Reqs, GR,
2782 }
2783}
2784
2785static void addMBBNames(const Module &M, const SPIRVInstrInfo &TII,
2786 MachineModuleInfo *MMI, const SPIRVSubtarget &ST,
2788 for (const Function &F : M) {
2790 if (!MF)
2791 continue;
2792 if (MF->getFunction()
2794 .isValid())
2795 continue;
2796 MachineRegisterInfo &MRI = MF->getRegInfo();
2797 for (auto &MBB : *MF) {
2798 if (!MBB.hasName() || MBB.empty())
2799 continue;
2800 // Emit basic block names.
2802 MRI.setRegClass(Reg, &SPIRV::IDRegClass);
2803 buildOpName(Reg, MBB.getName(), *std::prev(MBB.end()), TII);
2804 MCRegister GlobalReg = MAI.getOrCreateMBBRegister(MBB);
2805 MAI.setRegisterAlias(MF, Reg, GlobalReg);
2806 }
2807 }
2808}
2809
2810// patching Instruction::PHI to SPIRV::OpPhi
2811static void patchPhis(const Module &M, SPIRVGlobalRegistry *GR,
2812 const SPIRVInstrInfo &TII, MachineModuleInfo *MMI) {
2813 for (const Function &F : M) {
2815 if (!MF)
2816 continue;
2817 for (auto &MBB : *MF) {
2818 for (MachineInstr &MI : MBB.phis()) {
2819 MI.setDesc(TII.get(SPIRV::OpPhi));
2820 Register ResTypeReg = GR->getSPIRVTypeID(
2821 GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg(), MF));
2822 MI.insert(MI.operands_begin() + 1,
2823 {MachineOperand::CreateReg(ResTypeReg, false)});
2824 }
2825 }
2826
2827 MF->getProperties().setNoPHIs();
2828 }
2829}
2830
2832 const Module &M, SPIRV::ModuleAnalysisInfo &MAI, const Function *F) {
2833 auto it = MAI.FPFastMathDefaultInfoMap.find(F);
2834 if (it != MAI.FPFastMathDefaultInfoMap.end())
2835 return it->second;
2836
2837 // If the map does not contain the entry, create a new one. Initialize it to
2838 // contain all 3 elements sorted by bit width of target type: {half, float,
2839 // double}.
2840 SPIRV::FPFastMathDefaultInfoVector FPFastMathDefaultInfoVec;
2841 FPFastMathDefaultInfoVec.emplace_back(Type::getHalfTy(M.getContext()),
2842 SPIRV::FPFastMathMode::None);
2843 FPFastMathDefaultInfoVec.emplace_back(Type::getFloatTy(M.getContext()),
2844 SPIRV::FPFastMathMode::None);
2845 FPFastMathDefaultInfoVec.emplace_back(Type::getDoubleTy(M.getContext()),
2846 SPIRV::FPFastMathMode::None);
2847 return MAI.FPFastMathDefaultInfoMap[F] = std::move(FPFastMathDefaultInfoVec);
2848}
2849
2851 SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec,
2852 const Type *Ty) {
2853 size_t BitWidth = Ty->getScalarSizeInBits();
2854 int Index =
2856 BitWidth);
2857 assert(Index >= 0 && Index < 3 &&
2858 "Expected FPFastMathDefaultInfo for half, float, or double");
2859 assert(FPFastMathDefaultInfoVec.size() == 3 &&
2860 "Expected FPFastMathDefaultInfoVec to have exactly 3 elements");
2861 return FPFastMathDefaultInfoVec[Index];
2862}
2863
2864static void collectFPFastMathDefaults(const Module &M,
2866 const SPIRVSubtarget &ST) {
2867 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2))
2868 return;
2869
2870 // Store the FPFastMathDefaultInfo in the FPFastMathDefaultInfoMap.
2871 // We need the entry point (function) as the key, and the target
2872 // type and flags as the value.
2873 // We also need to check ContractionOff and SignedZeroInfNanPreserve
2874 // execution modes, as they are now deprecated and must be replaced
2875 // with FPFastMathDefaultInfo.
2876 auto Node = M.getNamedMetadata("spirv.ExecutionMode");
2877 if (!Node)
2878 return;
2879
2880 for (unsigned i = 0; i < Node->getNumOperands(); i++) {
2881 MDNode *MDN = cast<MDNode>(Node->getOperand(i));
2882 assert(MDN->getNumOperands() >= 2 && "Expected at least 2 operands");
2883 const Function *F = cast<Function>(
2884 cast<ConstantAsMetadata>(MDN->getOperand(0))->getValue());
2885 const auto EM =
2887 cast<ConstantAsMetadata>(MDN->getOperand(1))->getValue())
2888 ->getZExtValue();
2889 if (EM == SPIRV::ExecutionMode::FPFastMathDefault) {
2890 assert(MDN->getNumOperands() == 4 &&
2891 "Expected 4 operands for FPFastMathDefault");
2892
2893 const Type *T = cast<ValueAsMetadata>(MDN->getOperand(2))->getType();
2894 unsigned Flags =
2896 cast<ConstantAsMetadata>(MDN->getOperand(3))->getValue())
2897 ->getZExtValue();
2898 SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec =
2901 getFPFastMathDefaultInfo(FPFastMathDefaultInfoVec, T);
2902 Info.FastMathFlags = Flags;
2903 Info.FPFastMathDefault = true;
2904 } else if (EM == SPIRV::ExecutionMode::ContractionOff) {
2905 assert(MDN->getNumOperands() == 2 &&
2906 "Expected no operands for ContractionOff");
2907
2908 // We need to save this info for every possible FP type, i.e. {half,
2909 // float, double, fp128}.
2910 SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec =
2912 for (SPIRV::FPFastMathDefaultInfo &Info : FPFastMathDefaultInfoVec) {
2913 Info.ContractionOff = true;
2914 }
2915 } else if (EM == SPIRV::ExecutionMode::SignedZeroInfNanPreserve) {
2916 assert(MDN->getNumOperands() == 3 &&
2917 "Expected 1 operand for SignedZeroInfNanPreserve");
2918 unsigned TargetWidth =
2920 cast<ConstantAsMetadata>(MDN->getOperand(2))->getValue())
2921 ->getZExtValue();
2922 // We need to save this info only for the FP type with TargetWidth.
2923 SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec =
2927 assert(Index >= 0 && Index < 3 &&
2928 "Expected FPFastMathDefaultInfo for half, float, or double");
2929 assert(FPFastMathDefaultInfoVec.size() == 3 &&
2930 "Expected FPFastMathDefaultInfoVec to have exactly 3 elements");
2931 FPFastMathDefaultInfoVec[Index].SignedZeroInfNanPreserve = true;
2932 }
2933 }
2934}
2935
2937 AU.addRequired<TargetPassConfig>();
2938 AU.addRequired<MachineModuleInfoWrapperPass>();
2939}
2940
2942 SPIRVTargetMachine &TM =
2943 getAnalysis<TargetPassConfig>().getTM<SPIRVTargetMachine>();
2944 ST = TM.getSubtargetImpl();
2945 GR = ST->getSPIRVGlobalRegistry();
2946 TII = ST->getInstrInfo();
2947
2949
2950 setBaseInfo(M);
2951
2952 patchPhis(M, GR, *TII, MMI);
2953
2954 addMBBNames(M, *TII, MMI, *ST, MAI);
2955 collectFPFastMathDefaults(M, MAI, *ST);
2956 addDecorations(M, *TII, MMI, *ST, MAI, GR);
2957
2958 collectReqs(M, MAI, MMI, *ST);
2959
2960 // Process type/const/global var/func decl instructions, number their
2961 // destination registers from 0 to N, collect Extensions and Capabilities.
2962 collectReqs(M, MAI, MMI, *ST);
2963 collectDeclarations(M);
2964
2965 // Number rest of registers from N+1 onwards.
2966 numberRegistersGlobally(M);
2967
2968 // Collect OpName, OpEntryPoint, OpDecorate etc, process other instructions.
2969 processOtherInstrs(M);
2970
2971 // If there are no entry points, we need the Linkage capability.
2972 if (MAI.MS[SPIRV::MB_EntryPoints].empty())
2973 MAI.Reqs.addCapability(SPIRV::Capability::Linkage);
2974
2975 // Set maximum ID used.
2976 GR->setBound(MAI.MaxID);
2977
2978 return false;
2979}
MachineInstrBuilder & UseMI
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
ReachingDefInfo InstSet & ToRemove
MachineBasicBlock & MBB
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
#define clEnumValN(ENUMVAL, FLAGNAME, DESC)
#define DEBUG_TYPE
static Register UseReg(const MachineOperand &MO)
const HexagonInstrInfo * TII
IRTranslator LLVM IR MI
#define F(x, y, z)
Definition MD5.cpp:54
#define I(x, y, z)
Definition MD5.cpp:57
Register Reg
Promote Memory to Register
Definition Mem2Reg.cpp:110
#define T
MachineInstr unsigned OpIdx
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
Definition PassSupport.h:56
static SPIRV::FPFastMathDefaultInfoVector & getOrCreateFPFastMathDefaultInfoVec(const Module &M, DenseMap< Function *, SPIRV::FPFastMathDefaultInfoVector > &FPFastMathDefaultInfoMap, Function *F)
static SPIRV::FPFastMathDefaultInfo & getFPFastMathDefaultInfo(SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec, const Type *Ty)
#define ATOM_FLT_REQ_EXT_MSG(ExtName)
static cl::opt< bool > SPVDumpDeps("spv-dump-deps", cl::desc("Dump MIR with SPIR-V dependencies info"), cl::Optional, cl::init(false))
static cl::list< SPIRV::Capability::Capability > AvoidCapabilities("avoid-spirv-capabilities", cl::desc("SPIR-V capabilities to avoid if there are " "other options enabling a feature"), cl::Hidden, cl::values(clEnumValN(SPIRV::Capability::Shader, "Shader", "SPIR-V Shader capability")))
unsigned OpIndex
#define SPIRV_BACKEND_SERVICE_FUN_NAME
Definition SPIRVUtils.h:536
This file contains some templates that are useful if you are working with the STL at all.
#define LLVM_DEBUG(...)
Definition Debug.h:114
Target-Independent Code Generator Pass Configuration Options pass.
The Input class is used to parse a yaml document into in-memory structs and vectors.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
bool isValid() const
Return true if the attribute is any kind of attribute.
Definition Attributes.h:261
This is the shared class of boolean and integer constants.
Definition Constants.h:87
This is an important base class in LLVM.
Definition Constant.h:43
Attribute getFnAttribute(Attribute::AttrKind Kind) const
Return the attribute for the given attribute kind.
Definition Function.cpp:763
static constexpr LLT scalar(unsigned SizeInBits)
Get a low-level scalar or aggregate "bag of bits".
Wrapper class representing physical registers. Should be passed by value.
Definition MCRegister.h:41
constexpr bool isValid() const
Definition MCRegister.h:84
Metadata node.
Definition Metadata.h:1080
const MDOperand & getOperand(unsigned I) const
Definition Metadata.h:1444
unsigned getNumOperands() const
Return number of MDNode operands.
Definition Metadata.h:1450
Tracking metadata reference owned by Metadata.
Definition Metadata.h:902
const MachineFunction * getParent() const
Return the MachineFunction containing this basic block.
MachineRegisterInfo & getRegInfo()
getRegInfo - Return information about the registers currently in use.
Function & getFunction()
Return the LLVM function that this machine code represents.
const MachineFunctionProperties & getProperties() const
Get the function properties.
Register getReg(unsigned Idx) const
Get the register for the operand index.
Representation of each machine instruction.
unsigned getOpcode() const
Returns the opcode of this MachineInstr.
const MachineBasicBlock * getParent() const
unsigned getNumOperands() const
Retuns the total number of operands.
LLVM_ABI const MachineFunction * getMF() const
Return the function that contains the basic block that this instruction belongs to.
const MachineOperand & getOperand(unsigned i) const
This class contains meta information specific to a module.
LLVM_ABI MachineFunction * getMachineFunction(const Function &F) const
Returns the MachineFunction associated to IR function F if there is one, otherwise nullptr.
MachineOperand class - Representation of each machine instruction operand.
unsigned getSubReg() const
int64_t getImm() const
bool isReg() const
isReg - Tests if this is a MO_Register operand.
bool isImm() const
isImm - Tests if this is a MO_Immediate operand.
LLVM_ABI void print(raw_ostream &os, const TargetRegisterInfo *TRI=nullptr) const
Print the MachineOperand to os.
MachineInstr * getParent()
getParent - Return the instruction that this operand belongs to.
static MachineOperand CreateImm(int64_t Val)
MachineOperandType getType() const
getType - Returns the MachineOperandType for this operand.
Register getReg() const
getReg - Returns the register number.
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
const TargetRegisterClass * getRegClass(Register Reg) const
Return the register class of the specified virtual register.
LLVM_ABI MachineInstr * getVRegDef(Register Reg) const
getVRegDef - Return the machine instr that defines the specified virtual register or null if none is ...
LLVM_ABI void setRegClass(Register Reg, const TargetRegisterClass *RC)
setRegClass - Set the register class of the specified virtual register.
LLVM_ABI Register createGenericVirtualRegister(LLT Ty, StringRef Name="")
Create and return a new generic virtual register with low-level type Ty.
iterator_range< reg_instr_iterator > reg_instructions(Register Reg) const
iterator_range< use_instr_iterator > use_instructions(Register Reg) const
LLVM_ABI MachineInstr * getUniqueVRegDef(Register Reg) const
getUniqueVRegDef - Return the unique machine instr that defines the specified virtual register or nul...
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:67
virtual void print(raw_ostream &OS, const Module *M) const
print - Print out the internal state of the pass.
Definition Pass.cpp:140
AnalysisType & getAnalysis() const
getAnalysis<AnalysisType>() - This function is used by subclasses to get to the analysis information ...
Wrapper class representing virtual and physical registers.
Definition Register.h:20
constexpr bool isValid() const
Definition Register.h:112
unsigned getScalarOrVectorBitWidth(SPIRVTypeInst Type) const
const Type * getTypeForSPIRVType(SPIRVTypeInst Ty) const
Register getSPIRVTypeID(SPIRVTypeInst SpirvType) const
SPIRVTypeInst getSPIRVTypeForVReg(Register VReg, const MachineFunction *MF=nullptr) const
bool isConstantInstr(const MachineInstr &MI) const
const SPIRVInstrInfo * getInstrInfo() const override
SPIRVGlobalRegistry * getSPIRVGlobalRegistry() const
const SPIRVSubtarget * getSubtargetImpl() const
bool isTypeIntN(unsigned N=0) const
SmallSet - This maintains a set of unique values, optimizing for the case when the set is small (less...
Definition SmallSet.h:134
bool contains(const T &V) const
Check if the SmallSet contains the given element.
Definition SmallSet.h:229
std::pair< const_iterator, bool > insert(const T &V)
insert - Insert an element into the set if it isn't already there.
Definition SmallSet.h:184
reference emplace_back(ArgTypes &&... Args)
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
iterator insert(iterator I, T &&Elt)
void push_back(const T &Elt)
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:46
bool isVectorTy() const
True if this is an instance of VectorType.
Definition Type.h:290
static LLVM_ABI Type * getDoubleTy(LLVMContext &C)
Definition Type.cpp:291
static LLVM_ABI Type * getFloatTy(LLVMContext &C)
Definition Type.cpp:290
static LLVM_ABI Type * getHalfTy(LLVMContext &C)
Definition Type.cpp:288
Represents a version number in the form major[.minor[.subminor[.build]]].
bool empty() const
Determine whether this version information is empty (e.g., all version components are zero).
NodeTy * getNextNode()
Get the next node, or nullptr for the list tail.
Definition ilist_node.h:348
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
SmallVector< const MachineInstr * > InstrList
ValuesClass values(OptsTy... Options)
Helper to build a ValuesClass by forwarding a variable number of arguments as an initializer list to ...
initializer< Ty > init(const Ty &Val)
std::enable_if_t< detail::IsValidPointer< X, Y >::value, X * > extract(Y &&MD)
Extract a Value from Metadata.
Definition Metadata.h:668
NodeAddr< InstrNode * > Instr
Definition RDFGraph.h:389
This is an optimization pass for GlobalISel generic memory operations.
void buildOpName(Register Target, const StringRef &Name, MachineIRBuilder &MIRBuilder)
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
std::string getStringImm(const MachineInstr &MI, unsigned StartIndex)
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1739
hash_code hash_value(const FixedPointSemantics &Val)
ExtensionList getSymbolicOperandExtensions(SPIRV::OperandCategory::OperandCategory Category, uint32_t Value)
CapabilityList getSymbolicOperandCapabilities(SPIRV::OperandCategory::OperandCategory Category, uint32_t Value)
SmallVector< SPIRV::Extension::Extension, 8 > ExtensionList
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
SmallVector< size_t > InstrSignature
VersionTuple getSymbolicOperandMaxVersion(SPIRV::OperandCategory::OperandCategory Category, uint32_t Value)
void buildOpDecorate(Register Reg, MachineIRBuilder &MIRBuilder, SPIRV::Decoration::Decoration Dec, const std::vector< uint32_t > &DecArgs, StringRef StrImm)
MachineInstr * getImm(const MachineOperand &MO, const MachineRegisterInfo *MRI)
CapabilityList getCapabilitiesEnabledByExtension(SPIRV::Extension::Extension Extension)
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:207
LLVM_ABI void report_fatal_error(Error Err, bool gen_crash_diag=true)
Definition Error.cpp:163
std::string getSymbolicOperandMnemonic(SPIRV::OperandCategory::OperandCategory Category, int32_t Value)
LLVM_ABI raw_fd_ostream & errs()
This returns a reference to a raw_ostream for standard error.
DWARFExpression::Operation Op
VersionTuple getSymbolicOperandMinVersion(SPIRV::OperandCategory::OperandCategory Category, uint32_t Value)
constexpr unsigned BitWidth
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
Definition STLExtras.h:1947
SmallVector< SPIRV::Capability::Capability, 8 > CapabilityList
std::set< InstrSignature > InstrTraces
hash_code hash_combine(const Ts &...args)
Combine values into a single hash_code.
Definition Hashing.h:592
std::map< SmallVector< size_t >, unsigned > InstrGRegsMap
LLVM_ABI void reportFatalUsageError(Error Err)
Report a fatal error that does not indicate a bug in LLVM.
Definition Error.cpp:177
#define N
SmallSet< SPIRV::Capability::Capability, 4 > S
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
SPIRV::ModuleAnalysisInfo MAI
bool runOnModule(Module &M) override
runOnModule - Virtual method overriden by subclasses to process the module being operated on.
static size_t computeFPFastMathDefaultInfoVecIndex(size_t BitWidth)
Definition SPIRVUtils.h:150
void setSkipEmission(const MachineInstr *MI)
MCRegister getRegisterAlias(const MachineFunction *MF, Register Reg)
MCRegister getOrCreateMBBRegister(const MachineBasicBlock &MBB)
InstrList MS[NUM_MODULE_SECTIONS]
AddressingModel::AddressingModel Addr
void setRegisterAlias(const MachineFunction *MF, Register Reg, MCRegister AliasReg)
DenseMap< const Function *, SPIRV::FPFastMathDefaultInfoVector > FPFastMathDefaultInfoMap
void addCapabilities(const CapabilityList &ToAdd)
bool isCapabilityAvailable(Capability::Capability Cap) const
void checkSatisfiable(const SPIRVSubtarget &ST) const
void getAndAddRequirements(SPIRV::OperandCategory::OperandCategory Category, uint32_t i, const SPIRVSubtarget &ST)
void addExtension(Extension::Extension ToAdd)
void initAvailableCapabilities(const SPIRVSubtarget &ST)
void removeCapabilityIf(const Capability::Capability ToRemove, const Capability::Capability IfPresent)
void addCapability(Capability::Capability ToAdd)
void addAvailableCaps(const CapabilityList &ToAdd)
void addRequirements(const Requirements &Req)
const std::optional< Capability::Capability > Cap