LLVM 22.0.0git
SPIRVModuleAnalysis.cpp
Go to the documentation of this file.
1//===- SPIRVModuleAnalysis.cpp - analysis of global instrs & regs - C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// The analysis collects instructions that should be output at the module level
10// and performs the global register numbering.
11//
12// The results of this analysis are used in AsmPrinter to rename registers
13// globally and to output required instructions at the module level.
14//
15//===----------------------------------------------------------------------===//
16
17#include "SPIRVModuleAnalysis.h"
20#include "SPIRV.h"
21#include "SPIRVSubtarget.h"
22#include "SPIRVTargetMachine.h"
23#include "SPIRVUtils.h"
24#include "llvm/ADT/STLExtras.h"
27
28using namespace llvm;
29
30#define DEBUG_TYPE "spirv-module-analysis"
31
32static cl::opt<bool>
33 SPVDumpDeps("spv-dump-deps",
34 cl::desc("Dump MIR with SPIR-V dependencies info"),
35 cl::Optional, cl::init(false));
36
38 AvoidCapabilities("avoid-spirv-capabilities",
39 cl::desc("SPIR-V capabilities to avoid if there are "
40 "other options enabling a feature"),
42 cl::values(clEnumValN(SPIRV::Capability::Shader, "Shader",
43 "SPIR-V Shader capability")));
44// Use sets instead of cl::list to check "if contains" condition
49
51
52INITIALIZE_PASS(SPIRVModuleAnalysis, DEBUG_TYPE, "SPIRV module analysis", true,
53 true)
54
55// Retrieve an unsigned from an MDNode with a list of them as operands.
56static unsigned getMetadataUInt(MDNode *MdNode, unsigned OpIndex,
57 unsigned DefaultVal = 0) {
58 if (MdNode && OpIndex < MdNode->getNumOperands()) {
59 const auto &Op = MdNode->getOperand(OpIndex);
60 return mdconst::extract<ConstantInt>(Op)->getZExtValue();
61 }
62 return DefaultVal;
63}
64
66getSymbolicOperandRequirements(SPIRV::OperandCategory::OperandCategory Category,
67 unsigned i, const SPIRVSubtarget &ST,
69 // A set of capabilities to avoid if there is another option.
70 AvoidCapabilitiesSet AvoidCaps;
71 if (!ST.isShader())
72 AvoidCaps.S.insert(SPIRV::Capability::Shader);
73 else
74 AvoidCaps.S.insert(SPIRV::Capability::Kernel);
75
76 VersionTuple ReqMinVer = getSymbolicOperandMinVersion(Category, i);
77 VersionTuple ReqMaxVer = getSymbolicOperandMaxVersion(Category, i);
78 VersionTuple SPIRVVersion = ST.getSPIRVVersion();
79 bool MinVerOK = SPIRVVersion.empty() || SPIRVVersion >= ReqMinVer;
80 bool MaxVerOK =
81 ReqMaxVer.empty() || SPIRVVersion.empty() || SPIRVVersion <= ReqMaxVer;
83 ExtensionList ReqExts = getSymbolicOperandExtensions(Category, i);
84 if (ReqCaps.empty()) {
85 if (ReqExts.empty()) {
86 if (MinVerOK && MaxVerOK)
87 return {true, {}, {}, ReqMinVer, ReqMaxVer};
88 return {false, {}, {}, VersionTuple(), VersionTuple()};
89 }
90 } else if (MinVerOK && MaxVerOK) {
91 if (ReqCaps.size() == 1) {
92 auto Cap = ReqCaps[0];
93 if (Reqs.isCapabilityAvailable(Cap)) {
95 SPIRV::OperandCategory::CapabilityOperand, Cap));
96 return {true, {Cap}, std::move(ReqExts), ReqMinVer, ReqMaxVer};
97 }
98 } else {
99 // By SPIR-V specification: "If an instruction, enumerant, or other
100 // feature specifies multiple enabling capabilities, only one such
101 // capability needs to be declared to use the feature." However, one
102 // capability may be preferred over another. We use command line
103 // argument(s) and AvoidCapabilities to avoid selection of certain
104 // capabilities if there are other options.
105 CapabilityList UseCaps;
106 for (auto Cap : ReqCaps)
107 if (Reqs.isCapabilityAvailable(Cap))
108 UseCaps.push_back(Cap);
109 for (size_t i = 0, Sz = UseCaps.size(); i < Sz; ++i) {
110 auto Cap = UseCaps[i];
111 if (i == Sz - 1 || !AvoidCaps.S.contains(Cap)) {
113 SPIRV::OperandCategory::CapabilityOperand, Cap));
114 return {true, {Cap}, std::move(ReqExts), ReqMinVer, ReqMaxVer};
115 }
116 }
117 }
118 }
119 // If there are no capabilities, or we can't satisfy the version or
120 // capability requirements, use the list of extensions (if the subtarget
121 // can handle them all).
122 if (llvm::all_of(ReqExts, [&ST](const SPIRV::Extension::Extension &Ext) {
123 return ST.canUseExtension(Ext);
124 })) {
125 return {true,
126 {},
127 std::move(ReqExts),
128 VersionTuple(),
129 VersionTuple()}; // TODO: add versions to extensions.
130 }
131 return {false, {}, {}, VersionTuple(), VersionTuple()};
132}
133
134void SPIRVModuleAnalysis::setBaseInfo(const Module &M) {
135 MAI.MaxID = 0;
136 for (int i = 0; i < SPIRV::NUM_MODULE_SECTIONS; i++)
137 MAI.MS[i].clear();
138 MAI.RegisterAliasTable.clear();
139 MAI.InstrsToDelete.clear();
140 MAI.FuncMap.clear();
141 MAI.GlobalVarList.clear();
142 MAI.ExtInstSetMap.clear();
143 MAI.Reqs.clear();
144 MAI.Reqs.initAvailableCapabilities(*ST);
145
146 // TODO: determine memory model and source language from the configuratoin.
147 if (auto MemModel = M.getNamedMetadata("spirv.MemoryModel")) {
148 auto MemMD = MemModel->getOperand(0);
149 MAI.Addr = static_cast<SPIRV::AddressingModel::AddressingModel>(
150 getMetadataUInt(MemMD, 0));
151 MAI.Mem =
152 static_cast<SPIRV::MemoryModel::MemoryModel>(getMetadataUInt(MemMD, 1));
153 } else {
154 // TODO: Add support for VulkanMemoryModel.
155 MAI.Mem = ST->isShader() ? SPIRV::MemoryModel::GLSL450
156 : SPIRV::MemoryModel::OpenCL;
157 if (MAI.Mem == SPIRV::MemoryModel::OpenCL) {
158 unsigned PtrSize = ST->getPointerSize();
159 MAI.Addr = PtrSize == 32 ? SPIRV::AddressingModel::Physical32
160 : PtrSize == 64 ? SPIRV::AddressingModel::Physical64
161 : SPIRV::AddressingModel::Logical;
162 } else {
163 // TODO: Add support for PhysicalStorageBufferAddress.
164 MAI.Addr = SPIRV::AddressingModel::Logical;
165 }
166 }
167 // Get the OpenCL version number from metadata.
168 // TODO: support other source languages.
169 if (auto VerNode = M.getNamedMetadata("opencl.ocl.version")) {
170 MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_C;
171 // Construct version literal in accordance with SPIRV-LLVM-Translator.
172 // TODO: support multiple OCL version metadata.
173 assert(VerNode->getNumOperands() > 0 && "Invalid SPIR");
174 auto VersionMD = VerNode->getOperand(0);
175 unsigned MajorNum = getMetadataUInt(VersionMD, 0, 2);
176 unsigned MinorNum = getMetadataUInt(VersionMD, 1);
177 unsigned RevNum = getMetadataUInt(VersionMD, 2);
178 // Prevent Major part of OpenCL version to be 0
179 MAI.SrcLangVersion =
180 (std::max(1U, MajorNum) * 100 + MinorNum) * 1000 + RevNum;
181 } else {
182 // If there is no information about OpenCL version we are forced to generate
183 // OpenCL 1.0 by default for the OpenCL environment to avoid puzzling
184 // run-times with Unknown/0.0 version output. For a reference, LLVM-SPIRV
185 // Translator avoids potential issues with run-times in a similar manner.
186 if (!ST->isShader()) {
187 MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_CPP;
188 MAI.SrcLangVersion = 100000;
189 } else {
190 MAI.SrcLang = SPIRV::SourceLanguage::Unknown;
191 MAI.SrcLangVersion = 0;
192 }
193 }
194
195 if (auto ExtNode = M.getNamedMetadata("opencl.used.extensions")) {
196 for (unsigned I = 0, E = ExtNode->getNumOperands(); I != E; ++I) {
197 MDNode *MD = ExtNode->getOperand(I);
198 if (!MD || MD->getNumOperands() == 0)
199 continue;
200 for (unsigned J = 0, N = MD->getNumOperands(); J != N; ++J)
201 MAI.SrcExt.insert(cast<MDString>(MD->getOperand(J))->getString());
202 }
203 }
204
205 // Update required capabilities for this memory model, addressing model and
206 // source language.
207 MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand,
208 MAI.Mem, *ST);
209 MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::SourceLanguageOperand,
210 MAI.SrcLang, *ST);
211 MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand,
212 MAI.Addr, *ST);
213
214 if (!ST->isShader()) {
215 // TODO: check if it's required by default.
216 MAI.ExtInstSetMap[static_cast<unsigned>(
217 SPIRV::InstructionSet::OpenCL_std)] = MAI.getNextIDRegister();
218 }
219}
220
221// Appends the signature of the decoration instructions that decorate R to
222// Signature.
223static void appendDecorationsForReg(const MachineRegisterInfo &MRI, Register R,
224 InstrSignature &Signature) {
225 for (MachineInstr &UseMI : MRI.use_instructions(R)) {
226 // We don't handle OpDecorateId because getting the register alias for the
227 // ID can cause problems, and we do not need it for now.
228 if (UseMI.getOpcode() != SPIRV::OpDecorate &&
229 UseMI.getOpcode() != SPIRV::OpMemberDecorate)
230 continue;
231
232 for (unsigned I = 0; I < UseMI.getNumOperands(); ++I) {
233 const MachineOperand &MO = UseMI.getOperand(I);
234 if (MO.isReg())
235 continue;
236 Signature.push_back(hash_value(MO));
237 }
238 }
239}
240
241// Returns a representation of an instruction as a vector of MachineOperand
242// hash values, see llvm::hash_value(const MachineOperand &MO) for details.
243// This creates a signature of the instruction with the same content
244// that MachineOperand::isIdenticalTo uses for comparison.
245static InstrSignature instrToSignature(const MachineInstr &MI,
247 bool UseDefReg) {
248 Register DefReg;
249 InstrSignature Signature{MI.getOpcode()};
250 for (unsigned i = 0; i < MI.getNumOperands(); ++i) {
251 // The only decorations that can be applied more than once to a given <id>
252 // or structure member are UserSemantic(5635), CacheControlLoadINTEL (6442),
253 // and CacheControlStoreINTEL (6443). For all the rest of decorations, we
254 // will only add to the signature the Opcode, the id to which it applies,
255 // and the decoration id, disregarding any decoration flags. This will
256 // ensure that any subsequent decoration with the same id will be deemed as
257 // a duplicate. Then, at the call site, we will be able to handle duplicates
258 // in the best way.
259 unsigned Opcode = MI.getOpcode();
260 if ((Opcode == SPIRV::OpDecorate) && i >= 2) {
261 unsigned DecorationID = MI.getOperand(1).getImm();
262 if (DecorationID != SPIRV::Decoration::UserSemantic &&
263 DecorationID != SPIRV::Decoration::CacheControlLoadINTEL &&
264 DecorationID != SPIRV::Decoration::CacheControlStoreINTEL)
265 continue;
266 }
267 const MachineOperand &MO = MI.getOperand(i);
268 size_t h;
269 if (MO.isReg()) {
270 if (!UseDefReg && MO.isDef()) {
271 assert(!DefReg.isValid() && "Multiple def registers.");
272 DefReg = MO.getReg();
273 continue;
274 }
275 Register RegAlias = MAI.getRegisterAlias(MI.getMF(), MO.getReg());
276 if (!RegAlias.isValid()) {
277 LLVM_DEBUG({
278 dbgs() << "Unexpectedly, no global id found for the operand ";
279 MO.print(dbgs());
280 dbgs() << "\nInstruction: ";
281 MI.print(dbgs());
282 dbgs() << "\n";
283 });
284 report_fatal_error("All v-regs must have been mapped to global id's");
285 }
286 // mimic llvm::hash_value(const MachineOperand &MO)
287 h = hash_combine(MO.getType(), (unsigned)RegAlias, MO.getSubReg(),
288 MO.isDef());
289 } else {
290 h = hash_value(MO);
291 }
292 Signature.push_back(h);
293 }
294
295 if (DefReg.isValid()) {
296 // Decorations change the semantics of the current instruction. So two
297 // identical instruction with different decorations cannot be merged. That
298 // is why we add the decorations to the signature.
299 appendDecorationsForReg(MI.getMF()->getRegInfo(), DefReg, Signature);
300 }
301 return Signature;
302}
303
304bool SPIRVModuleAnalysis::isDeclSection(const MachineRegisterInfo &MRI,
305 const MachineInstr &MI) {
306 unsigned Opcode = MI.getOpcode();
307 switch (Opcode) {
308 case SPIRV::OpTypeForwardPointer:
309 // omit now, collect later
310 return false;
311 case SPIRV::OpVariable:
312 return static_cast<SPIRV::StorageClass::StorageClass>(
313 MI.getOperand(2).getImm()) != SPIRV::StorageClass::Function;
314 case SPIRV::OpFunction:
315 case SPIRV::OpFunctionParameter:
316 return true;
317 }
318 if (GR->hasConstFunPtr() && Opcode == SPIRV::OpUndef) {
319 Register DefReg = MI.getOperand(0).getReg();
320 for (MachineInstr &UseMI : MRI.use_instructions(DefReg)) {
321 if (UseMI.getOpcode() != SPIRV::OpConstantFunctionPointerINTEL)
322 continue;
323 // it's a dummy definition, FP constant refers to a function,
324 // and this is resolved in another way; let's skip this definition
325 assert(UseMI.getOperand(2).isReg() &&
326 UseMI.getOperand(2).getReg() == DefReg);
327 MAI.setSkipEmission(&MI);
328 return false;
329 }
330 }
331 return TII->isTypeDeclInstr(MI) || TII->isConstantInstr(MI) ||
332 TII->isInlineAsmDefInstr(MI);
333}
334
335// This is a special case of a function pointer refering to a possibly
336// forward function declaration. The operand is a dummy OpUndef that
337// requires a special treatment.
338void SPIRVModuleAnalysis::visitFunPtrUse(
339 Register OpReg, InstrGRegsMap &SignatureToGReg,
340 std::map<const Value *, unsigned> &GlobalToGReg, const MachineFunction *MF,
341 const MachineInstr &MI) {
342 const MachineOperand *OpFunDef =
343 GR->getFunctionDefinitionByUse(&MI.getOperand(2));
344 assert(OpFunDef && OpFunDef->isReg());
345 // find the actual function definition and number it globally in advance
346 const MachineInstr *OpDefMI = OpFunDef->getParent();
347 assert(OpDefMI && OpDefMI->getOpcode() == SPIRV::OpFunction);
348 const MachineFunction *FunDefMF = OpDefMI->getParent()->getParent();
349 const MachineRegisterInfo &FunDefMRI = FunDefMF->getRegInfo();
350 do {
351 visitDecl(FunDefMRI, SignatureToGReg, GlobalToGReg, FunDefMF, *OpDefMI);
352 OpDefMI = OpDefMI->getNextNode();
353 } while (OpDefMI && (OpDefMI->getOpcode() == SPIRV::OpFunction ||
354 OpDefMI->getOpcode() == SPIRV::OpFunctionParameter));
355 // associate the function pointer with the newly assigned global number
356 MCRegister GlobalFunDefReg =
357 MAI.getRegisterAlias(FunDefMF, OpFunDef->getReg());
358 assert(GlobalFunDefReg.isValid() &&
359 "Function definition must refer to a global register");
360 MAI.setRegisterAlias(MF, OpReg, GlobalFunDefReg);
361}
362
363// Depth first recursive traversal of dependencies. Repeated visits are guarded
364// by MAI.hasRegisterAlias().
365void SPIRVModuleAnalysis::visitDecl(
366 const MachineRegisterInfo &MRI, InstrGRegsMap &SignatureToGReg,
367 std::map<const Value *, unsigned> &GlobalToGReg, const MachineFunction *MF,
368 const MachineInstr &MI) {
369 unsigned Opcode = MI.getOpcode();
370
371 // Process each operand of the instruction to resolve dependencies
372 for (const MachineOperand &MO : MI.operands()) {
373 if (!MO.isReg() || MO.isDef())
374 continue;
375 Register OpReg = MO.getReg();
376 // Handle function pointers special case
377 if (Opcode == SPIRV::OpConstantFunctionPointerINTEL &&
378 MRI.getRegClass(OpReg) == &SPIRV::pIDRegClass) {
379 visitFunPtrUse(OpReg, SignatureToGReg, GlobalToGReg, MF, MI);
380 continue;
381 }
382 // Skip already processed instructions
383 if (MAI.hasRegisterAlias(MF, MO.getReg()))
384 continue;
385 // Recursively visit dependencies
386 if (const MachineInstr *OpDefMI = MRI.getUniqueVRegDef(OpReg)) {
387 if (isDeclSection(MRI, *OpDefMI))
388 visitDecl(MRI, SignatureToGReg, GlobalToGReg, MF, *OpDefMI);
389 continue;
390 }
391 // Handle the unexpected case of no unique definition for the SPIR-V
392 // instruction
393 LLVM_DEBUG({
394 dbgs() << "Unexpectedly, no unique definition for the operand ";
395 MO.print(dbgs());
396 dbgs() << "\nInstruction: ";
397 MI.print(dbgs());
398 dbgs() << "\n";
399 });
401 "No unique definition is found for the virtual register");
402 }
403
404 MCRegister GReg;
405 bool IsFunDef = false;
406 if (TII->isSpecConstantInstr(MI)) {
407 GReg = MAI.getNextIDRegister();
408 MAI.MS[SPIRV::MB_TypeConstVars].push_back(&MI);
409 } else if (Opcode == SPIRV::OpFunction ||
410 Opcode == SPIRV::OpFunctionParameter) {
411 GReg = handleFunctionOrParameter(MF, MI, GlobalToGReg, IsFunDef);
412 } else if (Opcode == SPIRV::OpTypeStruct ||
413 Opcode == SPIRV::OpConstantComposite) {
414 GReg = handleTypeDeclOrConstant(MI, SignatureToGReg);
415 const MachineInstr *NextInstr = MI.getNextNode();
416 while (NextInstr &&
417 ((Opcode == SPIRV::OpTypeStruct &&
418 NextInstr->getOpcode() == SPIRV::OpTypeStructContinuedINTEL) ||
419 (Opcode == SPIRV::OpConstantComposite &&
420 NextInstr->getOpcode() ==
421 SPIRV::OpConstantCompositeContinuedINTEL))) {
422 MCRegister Tmp = handleTypeDeclOrConstant(*NextInstr, SignatureToGReg);
423 MAI.setRegisterAlias(MF, NextInstr->getOperand(0).getReg(), Tmp);
424 MAI.setSkipEmission(NextInstr);
425 NextInstr = NextInstr->getNextNode();
426 }
427 } else if (TII->isTypeDeclInstr(MI) || TII->isConstantInstr(MI) ||
428 TII->isInlineAsmDefInstr(MI)) {
429 GReg = handleTypeDeclOrConstant(MI, SignatureToGReg);
430 } else if (Opcode == SPIRV::OpVariable) {
431 GReg = handleVariable(MF, MI, GlobalToGReg);
432 } else {
433 LLVM_DEBUG({
434 dbgs() << "\nInstruction: ";
435 MI.print(dbgs());
436 dbgs() << "\n";
437 });
438 llvm_unreachable("Unexpected instruction is visited");
439 }
440 MAI.setRegisterAlias(MF, MI.getOperand(0).getReg(), GReg);
441 if (!IsFunDef)
442 MAI.setSkipEmission(&MI);
443}
444
445MCRegister SPIRVModuleAnalysis::handleFunctionOrParameter(
446 const MachineFunction *MF, const MachineInstr &MI,
447 std::map<const Value *, unsigned> &GlobalToGReg, bool &IsFunDef) {
448 const Value *GObj = GR->getGlobalObject(MF, MI.getOperand(0).getReg());
449 assert(GObj && "Unregistered global definition");
450 const Function *F = dyn_cast<Function>(GObj);
451 if (!F)
452 F = dyn_cast<Argument>(GObj)->getParent();
453 assert(F && "Expected a reference to a function or an argument");
454 IsFunDef = !F->isDeclaration();
455 auto [It, Inserted] = GlobalToGReg.try_emplace(GObj);
456 if (!Inserted)
457 return It->second;
458 MCRegister GReg = MAI.getNextIDRegister();
459 It->second = GReg;
460 if (!IsFunDef)
461 MAI.MS[SPIRV::MB_ExtFuncDecls].push_back(&MI);
462 return GReg;
463}
464
466SPIRVModuleAnalysis::handleTypeDeclOrConstant(const MachineInstr &MI,
467 InstrGRegsMap &SignatureToGReg) {
468 InstrSignature MISign = instrToSignature(MI, MAI, false);
469 auto [It, Inserted] = SignatureToGReg.try_emplace(MISign);
470 if (!Inserted)
471 return It->second;
472 MCRegister GReg = MAI.getNextIDRegister();
473 It->second = GReg;
474 MAI.MS[SPIRV::MB_TypeConstVars].push_back(&MI);
475 return GReg;
476}
477
478MCRegister SPIRVModuleAnalysis::handleVariable(
479 const MachineFunction *MF, const MachineInstr &MI,
480 std::map<const Value *, unsigned> &GlobalToGReg) {
481 MAI.GlobalVarList.push_back(&MI);
482 const Value *GObj = GR->getGlobalObject(MF, MI.getOperand(0).getReg());
483 assert(GObj && "Unregistered global definition");
484 auto [It, Inserted] = GlobalToGReg.try_emplace(GObj);
485 if (!Inserted)
486 return It->second;
487 MCRegister GReg = MAI.getNextIDRegister();
488 It->second = GReg;
489 MAI.MS[SPIRV::MB_TypeConstVars].push_back(&MI);
490 return GReg;
491}
492
493void SPIRVModuleAnalysis::collectDeclarations(const Module &M) {
494 InstrGRegsMap SignatureToGReg;
495 std::map<const Value *, unsigned> GlobalToGReg;
496 for (auto F = M.begin(), E = M.end(); F != E; ++F) {
497 MachineFunction *MF = MMI->getMachineFunction(*F);
498 if (!MF)
499 continue;
500 const MachineRegisterInfo &MRI = MF->getRegInfo();
501 unsigned PastHeader = 0;
502 for (MachineBasicBlock &MBB : *MF) {
503 for (MachineInstr &MI : MBB) {
504 if (MI.getNumOperands() == 0)
505 continue;
506 unsigned Opcode = MI.getOpcode();
507 if (Opcode == SPIRV::OpFunction) {
508 if (PastHeader == 0) {
509 PastHeader = 1;
510 continue;
511 }
512 } else if (Opcode == SPIRV::OpFunctionParameter) {
513 if (PastHeader < 2)
514 continue;
515 } else if (PastHeader > 0) {
516 PastHeader = 2;
517 }
518
519 const MachineOperand &DefMO = MI.getOperand(0);
520 switch (Opcode) {
521 case SPIRV::OpExtension:
522 MAI.Reqs.addExtension(SPIRV::Extension::Extension(DefMO.getImm()));
523 MAI.setSkipEmission(&MI);
524 break;
525 case SPIRV::OpCapability:
526 MAI.Reqs.addCapability(SPIRV::Capability::Capability(DefMO.getImm()));
527 MAI.setSkipEmission(&MI);
528 if (PastHeader > 0)
529 PastHeader = 2;
530 break;
531 default:
532 if (DefMO.isReg() && isDeclSection(MRI, MI) &&
533 !MAI.hasRegisterAlias(MF, DefMO.getReg()))
534 visitDecl(MRI, SignatureToGReg, GlobalToGReg, MF, MI);
535 }
536 }
537 }
538 }
539}
540
541// Look for IDs declared with Import linkage, and map the corresponding function
542// to the register defining that variable (which will usually be the result of
543// an OpFunction). This lets us call externally imported functions using
544// the correct ID registers.
545void SPIRVModuleAnalysis::collectFuncNames(MachineInstr &MI,
546 const Function *F) {
547 if (MI.getOpcode() == SPIRV::OpDecorate) {
548 // If it's got Import linkage.
549 auto Dec = MI.getOperand(1).getImm();
550 if (Dec == SPIRV::Decoration::LinkageAttributes) {
551 auto Lnk = MI.getOperand(MI.getNumOperands() - 1).getImm();
552 if (Lnk == SPIRV::LinkageType::Import) {
553 // Map imported function name to function ID register.
554 const Function *ImportedFunc =
555 F->getParent()->getFunction(getStringImm(MI, 2));
556 Register Target = MI.getOperand(0).getReg();
557 MAI.FuncMap[ImportedFunc] = MAI.getRegisterAlias(MI.getMF(), Target);
558 }
559 }
560 } else if (MI.getOpcode() == SPIRV::OpFunction) {
561 // Record all internal OpFunction declarations.
562 Register Reg = MI.defs().begin()->getReg();
563 MCRegister GlobalReg = MAI.getRegisterAlias(MI.getMF(), Reg);
564 assert(GlobalReg.isValid());
565 MAI.FuncMap[F] = GlobalReg;
566 }
567}
568
569// Collect the given instruction in the specified MS. We assume global register
570// numbering has already occurred by this point. We can directly compare reg
571// arguments when detecting duplicates.
572static void collectOtherInstr(MachineInstr &MI, SPIRV::ModuleAnalysisInfo &MAI,
574 bool Append = true) {
575 MAI.setSkipEmission(&MI);
576 InstrSignature MISign = instrToSignature(MI, MAI, true);
577 auto FoundMI = IS.insert(std::move(MISign));
578 if (!FoundMI.second) {
579 if (MI.getOpcode() == SPIRV::OpDecorate) {
580 assert(MI.getNumOperands() >= 2 &&
581 "Decoration instructions must have at least 2 operands");
582 assert(MSType == SPIRV::MB_Annotations &&
583 "Only OpDecorate instructions can be duplicates");
584 // For FPFastMathMode decoration, we need to merge the flags of the
585 // duplicate decoration with the original one, so we need to find the
586 // original instruction that has the same signature. For the rest of
587 // instructions, we will simply skip the duplicate.
588 if (MI.getOperand(1).getImm() != SPIRV::Decoration::FPFastMathMode)
589 return; // Skip duplicates of other decorations.
590
591 const SPIRV::InstrList &Decorations = MAI.MS[MSType];
592 for (const MachineInstr *OrigMI : Decorations) {
593 if (instrToSignature(*OrigMI, MAI, true) == MISign) {
594 assert(OrigMI->getNumOperands() == MI.getNumOperands() &&
595 "Original instruction must have the same number of operands");
596 assert(
597 OrigMI->getNumOperands() == 3 &&
598 "FPFastMathMode decoration must have 3 operands for OpDecorate");
599 unsigned OrigFlags = OrigMI->getOperand(2).getImm();
600 unsigned NewFlags = MI.getOperand(2).getImm();
601 if (OrigFlags == NewFlags)
602 return; // No need to merge, the flags are the same.
603
604 // Emit warning about possible conflict between flags.
605 unsigned FinalFlags = OrigFlags | NewFlags;
606 llvm::errs()
607 << "Warning: Conflicting FPFastMathMode decoration flags "
608 "in instruction: "
609 << *OrigMI << "Original flags: " << OrigFlags
610 << ", new flags: " << NewFlags
611 << ". They will be merged on a best effort basis, but not "
612 "validated. Final flags: "
613 << FinalFlags << "\n";
614 MachineInstr *OrigMINonConst = const_cast<MachineInstr *>(OrigMI);
615 MachineOperand &OrigFlagsOp = OrigMINonConst->getOperand(2);
616 OrigFlagsOp = MachineOperand::CreateImm(FinalFlags);
617 return; // Merge done, so we found a duplicate; don't add it to MAI.MS
618 }
619 }
620 assert(false && "No original instruction found for the duplicate "
621 "OpDecorate, but we found one in IS.");
622 }
623 return; // insert failed, so we found a duplicate; don't add it to MAI.MS
624 }
625 // No duplicates, so add it.
626 if (Append)
627 MAI.MS[MSType].push_back(&MI);
628 else
629 MAI.MS[MSType].insert(MAI.MS[MSType].begin(), &MI);
630}
631
632// Some global instructions make reference to function-local ID regs, so cannot
633// be correctly collected until these registers are globally numbered.
634void SPIRVModuleAnalysis::processOtherInstrs(const Module &M) {
635 InstrTraces IS;
636 for (auto F = M.begin(), E = M.end(); F != E; ++F) {
637 if (F->isDeclaration())
638 continue;
639 MachineFunction *MF = MMI->getMachineFunction(*F);
640 assert(MF);
641
642 for (MachineBasicBlock &MBB : *MF)
643 for (MachineInstr &MI : MBB) {
644 if (MAI.getSkipEmission(&MI))
645 continue;
646 const unsigned OpCode = MI.getOpcode();
647 if (OpCode == SPIRV::OpString) {
648 collectOtherInstr(MI, MAI, SPIRV::MB_DebugStrings, IS);
649 } else if (OpCode == SPIRV::OpExtInst && MI.getOperand(2).isImm() &&
650 MI.getOperand(2).getImm() ==
651 SPIRV::InstructionSet::
652 NonSemantic_Shader_DebugInfo_100) {
653 MachineOperand Ins = MI.getOperand(3);
654 namespace NS = SPIRV::NonSemanticExtInst;
655 static constexpr int64_t GlobalNonSemanticDITy[] = {
656 NS::DebugSource, NS::DebugCompilationUnit, NS::DebugInfoNone,
657 NS::DebugTypeBasic, NS::DebugTypePointer};
658 bool IsGlobalDI = false;
659 for (unsigned Idx = 0; Idx < std::size(GlobalNonSemanticDITy); ++Idx)
660 IsGlobalDI |= Ins.getImm() == GlobalNonSemanticDITy[Idx];
661 if (IsGlobalDI)
662 collectOtherInstr(MI, MAI, SPIRV::MB_NonSemanticGlobalDI, IS);
663 } else if (OpCode == SPIRV::OpName || OpCode == SPIRV::OpMemberName) {
664 collectOtherInstr(MI, MAI, SPIRV::MB_DebugNames, IS);
665 } else if (OpCode == SPIRV::OpEntryPoint) {
666 collectOtherInstr(MI, MAI, SPIRV::MB_EntryPoints, IS);
667 } else if (TII->isAliasingInstr(MI)) {
668 collectOtherInstr(MI, MAI, SPIRV::MB_AliasingInsts, IS);
669 } else if (TII->isDecorationInstr(MI)) {
670 collectOtherInstr(MI, MAI, SPIRV::MB_Annotations, IS);
671 collectFuncNames(MI, &*F);
672 } else if (TII->isConstantInstr(MI)) {
673 // Now OpSpecConstant*s are not in DT,
674 // but they need to be collected anyway.
675 collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars, IS);
676 } else if (OpCode == SPIRV::OpFunction) {
677 collectFuncNames(MI, &*F);
678 } else if (OpCode == SPIRV::OpTypeForwardPointer) {
679 collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars, IS, false);
680 }
681 }
682 }
683}
684
685// Number registers in all functions globally from 0 onwards and store
686// the result in global register alias table. Some registers are already
687// numbered.
688void SPIRVModuleAnalysis::numberRegistersGlobally(const Module &M) {
689 for (auto F = M.begin(), E = M.end(); F != E; ++F) {
690 if ((*F).isDeclaration())
691 continue;
692 MachineFunction *MF = MMI->getMachineFunction(*F);
693 assert(MF);
694 for (MachineBasicBlock &MBB : *MF) {
695 for (MachineInstr &MI : MBB) {
696 for (MachineOperand &Op : MI.operands()) {
697 if (!Op.isReg())
698 continue;
699 Register Reg = Op.getReg();
700 if (MAI.hasRegisterAlias(MF, Reg))
701 continue;
702 MCRegister NewReg = MAI.getNextIDRegister();
703 MAI.setRegisterAlias(MF, Reg, NewReg);
704 }
705 if (MI.getOpcode() != SPIRV::OpExtInst)
706 continue;
707 auto Set = MI.getOperand(2).getImm();
708 auto [It, Inserted] = MAI.ExtInstSetMap.try_emplace(Set);
709 if (Inserted)
710 It->second = MAI.getNextIDRegister();
711 }
712 }
713 }
714}
715
716// RequirementHandler implementations.
718 SPIRV::OperandCategory::OperandCategory Category, uint32_t i,
719 const SPIRVSubtarget &ST) {
720 addRequirements(getSymbolicOperandRequirements(Category, i, ST, *this));
721}
722
723void SPIRV::RequirementHandler::recursiveAddCapabilities(
724 const CapabilityList &ToPrune) {
725 for (const auto &Cap : ToPrune) {
726 AllCaps.insert(Cap);
727 CapabilityList ImplicitDecls =
728 getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap);
729 recursiveAddCapabilities(ImplicitDecls);
730 }
731}
732
734 for (const auto &Cap : ToAdd) {
735 bool IsNewlyInserted = AllCaps.insert(Cap).second;
736 if (!IsNewlyInserted) // Don't re-add if it's already been declared.
737 continue;
738 CapabilityList ImplicitDecls =
739 getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap);
740 recursiveAddCapabilities(ImplicitDecls);
741 MinimalCaps.push_back(Cap);
742 }
743}
744
746 const SPIRV::Requirements &Req) {
747 if (!Req.IsSatisfiable)
748 report_fatal_error("Adding SPIR-V requirements this target can't satisfy.");
749
750 if (Req.Cap.has_value())
751 addCapabilities({Req.Cap.value()});
752
753 addExtensions(Req.Exts);
754
755 if (!Req.MinVer.empty()) {
756 if (!MaxVersion.empty() && Req.MinVer > MaxVersion) {
757 LLVM_DEBUG(dbgs() << "Conflicting version requirements: >= " << Req.MinVer
758 << " and <= " << MaxVersion << "\n");
759 report_fatal_error("Adding SPIR-V requirements that can't be satisfied.");
760 }
761
762 if (MinVersion.empty() || Req.MinVer > MinVersion)
763 MinVersion = Req.MinVer;
764 }
765
766 if (!Req.MaxVer.empty()) {
767 if (!MinVersion.empty() && Req.MaxVer < MinVersion) {
768 LLVM_DEBUG(dbgs() << "Conflicting version requirements: <= " << Req.MaxVer
769 << " and >= " << MinVersion << "\n");
770 report_fatal_error("Adding SPIR-V requirements that can't be satisfied.");
771 }
772
773 if (MaxVersion.empty() || Req.MaxVer < MaxVersion)
774 MaxVersion = Req.MaxVer;
775 }
776}
777
779 const SPIRVSubtarget &ST) const {
780 // Report as many errors as possible before aborting the compilation.
781 bool IsSatisfiable = true;
782 auto TargetVer = ST.getSPIRVVersion();
783
784 if (!MaxVersion.empty() && !TargetVer.empty() && MaxVersion < TargetVer) {
786 dbgs() << "Target SPIR-V version too high for required features\n"
787 << "Required max version: " << MaxVersion << " target version "
788 << TargetVer << "\n");
789 IsSatisfiable = false;
790 }
791
792 if (!MinVersion.empty() && !TargetVer.empty() && MinVersion > TargetVer) {
793 LLVM_DEBUG(dbgs() << "Target SPIR-V version too low for required features\n"
794 << "Required min version: " << MinVersion
795 << " target version " << TargetVer << "\n");
796 IsSatisfiable = false;
797 }
798
799 if (!MinVersion.empty() && !MaxVersion.empty() && MinVersion > MaxVersion) {
801 dbgs()
802 << "Version is too low for some features and too high for others.\n"
803 << "Required SPIR-V min version: " << MinVersion
804 << " required SPIR-V max version " << MaxVersion << "\n");
805 IsSatisfiable = false;
806 }
807
808 AvoidCapabilitiesSet AvoidCaps;
809 if (!ST.isShader())
810 AvoidCaps.S.insert(SPIRV::Capability::Shader);
811 else
812 AvoidCaps.S.insert(SPIRV::Capability::Kernel);
813
814 for (auto Cap : MinimalCaps) {
815 if (AvailableCaps.contains(Cap) && !AvoidCaps.S.contains(Cap))
816 continue;
817 LLVM_DEBUG(dbgs() << "Capability not supported: "
819 OperandCategory::CapabilityOperand, Cap)
820 << "\n");
821 IsSatisfiable = false;
822 }
823
824 for (auto Ext : AllExtensions) {
825 if (ST.canUseExtension(Ext))
826 continue;
827 LLVM_DEBUG(dbgs() << "Extension not supported: "
829 OperandCategory::ExtensionOperand, Ext)
830 << "\n");
831 IsSatisfiable = false;
832 }
833
834 if (!IsSatisfiable)
835 report_fatal_error("Unable to meet SPIR-V requirements for this target.");
836}
837
838// Add the given capabilities and all their implicitly defined capabilities too.
840 for (const auto Cap : ToAdd)
841 if (AvailableCaps.insert(Cap).second)
842 addAvailableCaps(getSymbolicOperandCapabilities(
843 SPIRV::OperandCategory::CapabilityOperand, Cap));
844}
845
847 const Capability::Capability ToRemove,
848 const Capability::Capability IfPresent) {
849 if (AllCaps.contains(IfPresent))
850 AllCaps.erase(ToRemove);
851}
852
853namespace llvm {
854namespace SPIRV {
855void RequirementHandler::initAvailableCapabilities(const SPIRVSubtarget &ST) {
856 // Provided by both all supported Vulkan versions and OpenCl.
857 addAvailableCaps({Capability::Shader, Capability::Linkage, Capability::Int8,
858 Capability::Int16});
859
860 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 3)))
861 addAvailableCaps({Capability::GroupNonUniform,
862 Capability::GroupNonUniformVote,
863 Capability::GroupNonUniformArithmetic,
864 Capability::GroupNonUniformBallot,
865 Capability::GroupNonUniformClustered,
866 Capability::GroupNonUniformShuffle,
867 Capability::GroupNonUniformShuffleRelative});
868
869 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 6)))
870 addAvailableCaps({Capability::DotProduct, Capability::DotProductInputAll,
871 Capability::DotProductInput4x8Bit,
872 Capability::DotProductInput4x8BitPacked,
873 Capability::DemoteToHelperInvocation});
874
875 // Add capabilities enabled by extensions.
876 for (auto Extension : ST.getAllAvailableExtensions()) {
877 CapabilityList EnabledCapabilities =
879 addAvailableCaps(EnabledCapabilities);
880 }
881
882 if (!ST.isShader()) {
883 initAvailableCapabilitiesForOpenCL(ST);
884 return;
885 }
886
887 if (ST.isShader()) {
888 initAvailableCapabilitiesForVulkan(ST);
889 return;
890 }
891
892 report_fatal_error("Unimplemented environment for SPIR-V generation.");
893}
894
895void RequirementHandler::initAvailableCapabilitiesForOpenCL(
896 const SPIRVSubtarget &ST) {
897 // Add the min requirements for different OpenCL and SPIR-V versions.
898 addAvailableCaps({Capability::Addresses, Capability::Float16Buffer,
899 Capability::Kernel, Capability::Vector16,
900 Capability::Groups, Capability::GenericPointer,
901 Capability::StorageImageWriteWithoutFormat,
902 Capability::StorageImageReadWithoutFormat});
903 if (ST.hasOpenCLFullProfile())
904 addAvailableCaps({Capability::Int64, Capability::Int64Atomics});
905 if (ST.hasOpenCLImageSupport()) {
906 addAvailableCaps({Capability::ImageBasic, Capability::LiteralSampler,
907 Capability::Image1D, Capability::SampledBuffer,
908 Capability::ImageBuffer});
909 if (ST.isAtLeastOpenCLVer(VersionTuple(2, 0)))
910 addAvailableCaps({Capability::ImageReadWrite});
911 }
912 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 1)) &&
913 ST.isAtLeastOpenCLVer(VersionTuple(2, 2)))
914 addAvailableCaps({Capability::SubgroupDispatch, Capability::PipeStorage});
915 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 4)))
916 addAvailableCaps({Capability::DenormPreserve, Capability::DenormFlushToZero,
917 Capability::SignedZeroInfNanPreserve,
918 Capability::RoundingModeRTE,
919 Capability::RoundingModeRTZ});
920 // TODO: verify if this needs some checks.
921 addAvailableCaps({Capability::Float16, Capability::Float64});
922
923 // TODO: add OpenCL extensions.
924}
925
926void RequirementHandler::initAvailableCapabilitiesForVulkan(
927 const SPIRVSubtarget &ST) {
928
929 // Core in Vulkan 1.1 and earlier.
930 addAvailableCaps({Capability::Int64, Capability::Float16, Capability::Float64,
931 Capability::GroupNonUniform, Capability::Image1D,
932 Capability::SampledBuffer, Capability::ImageBuffer,
933 Capability::UniformBufferArrayDynamicIndexing,
934 Capability::SampledImageArrayDynamicIndexing,
935 Capability::StorageBufferArrayDynamicIndexing,
936 Capability::StorageImageArrayDynamicIndexing});
937
938 // Became core in Vulkan 1.2
939 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 5))) {
941 {Capability::ShaderNonUniformEXT, Capability::RuntimeDescriptorArrayEXT,
942 Capability::InputAttachmentArrayDynamicIndexingEXT,
943 Capability::UniformTexelBufferArrayDynamicIndexingEXT,
944 Capability::StorageTexelBufferArrayDynamicIndexingEXT,
945 Capability::UniformBufferArrayNonUniformIndexingEXT,
946 Capability::SampledImageArrayNonUniformIndexingEXT,
947 Capability::StorageBufferArrayNonUniformIndexingEXT,
948 Capability::StorageImageArrayNonUniformIndexingEXT,
949 Capability::InputAttachmentArrayNonUniformIndexingEXT,
950 Capability::UniformTexelBufferArrayNonUniformIndexingEXT,
951 Capability::StorageTexelBufferArrayNonUniformIndexingEXT});
952 }
953
954 // Became core in Vulkan 1.3
955 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 6)))
956 addAvailableCaps({Capability::StorageImageWriteWithoutFormat,
957 Capability::StorageImageReadWithoutFormat});
958}
959
960} // namespace SPIRV
961} // namespace llvm
962
963// Add the required capabilities from a decoration instruction (including
964// BuiltIns).
965static void addOpDecorateReqs(const MachineInstr &MI, unsigned DecIndex,
967 const SPIRVSubtarget &ST) {
968 int64_t DecOp = MI.getOperand(DecIndex).getImm();
969 auto Dec = static_cast<SPIRV::Decoration::Decoration>(DecOp);
970 Reqs.addRequirements(getSymbolicOperandRequirements(
971 SPIRV::OperandCategory::DecorationOperand, Dec, ST, Reqs));
972
973 if (Dec == SPIRV::Decoration::BuiltIn) {
974 int64_t BuiltInOp = MI.getOperand(DecIndex + 1).getImm();
975 auto BuiltIn = static_cast<SPIRV::BuiltIn::BuiltIn>(BuiltInOp);
976 Reqs.addRequirements(getSymbolicOperandRequirements(
977 SPIRV::OperandCategory::BuiltInOperand, BuiltIn, ST, Reqs));
978 } else if (Dec == SPIRV::Decoration::LinkageAttributes) {
979 int64_t LinkageOp = MI.getOperand(MI.getNumOperands() - 1).getImm();
980 SPIRV::LinkageType::LinkageType LnkType =
981 static_cast<SPIRV::LinkageType::LinkageType>(LinkageOp);
982 if (LnkType == SPIRV::LinkageType::LinkOnceODR)
983 Reqs.addExtension(SPIRV::Extension::SPV_KHR_linkonce_odr);
984 } else if (Dec == SPIRV::Decoration::CacheControlLoadINTEL ||
985 Dec == SPIRV::Decoration::CacheControlStoreINTEL) {
986 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_cache_controls);
987 } else if (Dec == SPIRV::Decoration::HostAccessINTEL) {
988 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_global_variable_host_access);
989 } else if (Dec == SPIRV::Decoration::InitModeINTEL ||
990 Dec == SPIRV::Decoration::ImplementInRegisterMapINTEL) {
991 Reqs.addExtension(
992 SPIRV::Extension::SPV_INTEL_global_variable_fpga_decorations);
993 } else if (Dec == SPIRV::Decoration::NonUniformEXT) {
994 Reqs.addRequirements(SPIRV::Capability::ShaderNonUniformEXT);
995 } else if (Dec == SPIRV::Decoration::FPMaxErrorDecorationINTEL) {
996 Reqs.addRequirements(SPIRV::Capability::FPMaxErrorINTEL);
997 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_fp_max_error);
998 } else if (Dec == SPIRV::Decoration::FPFastMathMode) {
999 if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)) {
1000 Reqs.addRequirements(SPIRV::Capability::FloatControls2);
1001 Reqs.addExtension(SPIRV::Extension::SPV_KHR_float_controls2);
1002 }
1003 }
1004}
1005
1006// Add requirements for image handling.
1007static void addOpTypeImageReqs(const MachineInstr &MI,
1009 const SPIRVSubtarget &ST) {
1010 assert(MI.getNumOperands() >= 8 && "Insufficient operands for OpTypeImage");
1011 // The operand indices used here are based on the OpTypeImage layout, which
1012 // the MachineInstr follows as well.
1013 int64_t ImgFormatOp = MI.getOperand(7).getImm();
1014 auto ImgFormat = static_cast<SPIRV::ImageFormat::ImageFormat>(ImgFormatOp);
1015 Reqs.getAndAddRequirements(SPIRV::OperandCategory::ImageFormatOperand,
1016 ImgFormat, ST);
1017
1018 bool IsArrayed = MI.getOperand(4).getImm() == 1;
1019 bool IsMultisampled = MI.getOperand(5).getImm() == 1;
1020 bool NoSampler = MI.getOperand(6).getImm() == 2;
1021 // Add dimension requirements.
1022 assert(MI.getOperand(2).isImm());
1023 switch (MI.getOperand(2).getImm()) {
1024 case SPIRV::Dim::DIM_1D:
1025 Reqs.addRequirements(NoSampler ? SPIRV::Capability::Image1D
1026 : SPIRV::Capability::Sampled1D);
1027 break;
1028 case SPIRV::Dim::DIM_2D:
1029 if (IsMultisampled && NoSampler)
1030 Reqs.addRequirements(SPIRV::Capability::ImageMSArray);
1031 break;
1032 case SPIRV::Dim::DIM_Cube:
1033 Reqs.addRequirements(SPIRV::Capability::Shader);
1034 if (IsArrayed)
1035 Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageCubeArray
1036 : SPIRV::Capability::SampledCubeArray);
1037 break;
1038 case SPIRV::Dim::DIM_Rect:
1039 Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageRect
1040 : SPIRV::Capability::SampledRect);
1041 break;
1042 case SPIRV::Dim::DIM_Buffer:
1043 Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageBuffer
1044 : SPIRV::Capability::SampledBuffer);
1045 break;
1046 case SPIRV::Dim::DIM_SubpassData:
1047 Reqs.addRequirements(SPIRV::Capability::InputAttachment);
1048 break;
1049 }
1050
1051 // Has optional access qualifier.
1052 if (!ST.isShader()) {
1053 if (MI.getNumOperands() > 8 &&
1054 MI.getOperand(8).getImm() == SPIRV::AccessQualifier::ReadWrite)
1055 Reqs.addRequirements(SPIRV::Capability::ImageReadWrite);
1056 else
1057 Reqs.addRequirements(SPIRV::Capability::ImageBasic);
1058 }
1059}
1060
1061// Add requirements for handling atomic float instructions
1062#define ATOM_FLT_REQ_EXT_MSG(ExtName) \
1063 "The atomic float instruction requires the following SPIR-V " \
1064 "extension: SPV_EXT_shader_atomic_float" ExtName
1065static void AddAtomicFloatRequirements(const MachineInstr &MI,
1067 const SPIRVSubtarget &ST) {
1068 assert(MI.getOperand(1).isReg() &&
1069 "Expect register operand in atomic float instruction");
1070 Register TypeReg = MI.getOperand(1).getReg();
1071 SPIRVType *TypeDef = MI.getMF()->getRegInfo().getVRegDef(TypeReg);
1072 if (TypeDef->getOpcode() != SPIRV::OpTypeFloat)
1073 report_fatal_error("Result type of an atomic float instruction must be a "
1074 "floating-point type scalar");
1075
1076 unsigned BitWidth = TypeDef->getOperand(1).getImm();
1077 unsigned Op = MI.getOpcode();
1078 if (Op == SPIRV::OpAtomicFAddEXT) {
1079 if (!ST.canUseExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add))
1081 Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add);
1082 switch (BitWidth) {
1083 case 16:
1084 if (!ST.canUseExtension(
1085 SPIRV::Extension::SPV_EXT_shader_atomic_float16_add))
1086 report_fatal_error(ATOM_FLT_REQ_EXT_MSG("16_add"), false);
1087 Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float16_add);
1088 Reqs.addCapability(SPIRV::Capability::AtomicFloat16AddEXT);
1089 break;
1090 case 32:
1091 Reqs.addCapability(SPIRV::Capability::AtomicFloat32AddEXT);
1092 break;
1093 case 64:
1094 Reqs.addCapability(SPIRV::Capability::AtomicFloat64AddEXT);
1095 break;
1096 default:
1098 "Unexpected floating-point type width in atomic float instruction");
1099 }
1100 } else {
1101 if (!ST.canUseExtension(
1102 SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max))
1103 report_fatal_error(ATOM_FLT_REQ_EXT_MSG("_min_max"), false);
1104 Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max);
1105 switch (BitWidth) {
1106 case 16:
1107 Reqs.addCapability(SPIRV::Capability::AtomicFloat16MinMaxEXT);
1108 break;
1109 case 32:
1110 Reqs.addCapability(SPIRV::Capability::AtomicFloat32MinMaxEXT);
1111 break;
1112 case 64:
1113 Reqs.addCapability(SPIRV::Capability::AtomicFloat64MinMaxEXT);
1114 break;
1115 default:
1117 "Unexpected floating-point type width in atomic float instruction");
1118 }
1119 }
1120}
1121
1122bool isUniformTexelBuffer(MachineInstr *ImageInst) {
1123 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1124 return false;
1125 uint32_t Dim = ImageInst->getOperand(2).getImm();
1126 uint32_t Sampled = ImageInst->getOperand(6).getImm();
1127 return Dim == SPIRV::Dim::DIM_Buffer && Sampled == 1;
1128}
1129
1130bool isStorageTexelBuffer(MachineInstr *ImageInst) {
1131 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1132 return false;
1133 uint32_t Dim = ImageInst->getOperand(2).getImm();
1134 uint32_t Sampled = ImageInst->getOperand(6).getImm();
1135 return Dim == SPIRV::Dim::DIM_Buffer && Sampled == 2;
1136}
1137
1138bool isSampledImage(MachineInstr *ImageInst) {
1139 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1140 return false;
1141 uint32_t Dim = ImageInst->getOperand(2).getImm();
1142 uint32_t Sampled = ImageInst->getOperand(6).getImm();
1143 return Dim != SPIRV::Dim::DIM_Buffer && Sampled == 1;
1144}
1145
1146bool isInputAttachment(MachineInstr *ImageInst) {
1147 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1148 return false;
1149 uint32_t Dim = ImageInst->getOperand(2).getImm();
1150 uint32_t Sampled = ImageInst->getOperand(6).getImm();
1151 return Dim == SPIRV::Dim::DIM_SubpassData && Sampled == 2;
1152}
1153
1154bool isStorageImage(MachineInstr *ImageInst) {
1155 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1156 return false;
1157 uint32_t Dim = ImageInst->getOperand(2).getImm();
1158 uint32_t Sampled = ImageInst->getOperand(6).getImm();
1159 return Dim != SPIRV::Dim::DIM_Buffer && Sampled == 2;
1160}
1161
1162bool isCombinedImageSampler(MachineInstr *SampledImageInst) {
1163 if (SampledImageInst->getOpcode() != SPIRV::OpTypeSampledImage)
1164 return false;
1165
1166 const MachineRegisterInfo &MRI = SampledImageInst->getMF()->getRegInfo();
1167 Register ImageReg = SampledImageInst->getOperand(1).getReg();
1168 auto *ImageInst = MRI.getUniqueVRegDef(ImageReg);
1169 return isSampledImage(ImageInst);
1170}
1171
1172bool hasNonUniformDecoration(Register Reg, const MachineRegisterInfo &MRI) {
1173 for (const auto &MI : MRI.reg_instructions(Reg)) {
1174 if (MI.getOpcode() != SPIRV::OpDecorate)
1175 continue;
1176
1177 uint32_t Dec = MI.getOperand(1).getImm();
1178 if (Dec == SPIRV::Decoration::NonUniformEXT)
1179 return true;
1180 }
1181 return false;
1182}
1183
1184void addOpAccessChainReqs(const MachineInstr &Instr,
1186 const SPIRVSubtarget &Subtarget) {
1187 const MachineRegisterInfo &MRI = Instr.getMF()->getRegInfo();
1188 // Get the result type. If it is an image type, then the shader uses
1189 // descriptor indexing. The appropriate capabilities will be added based
1190 // on the specifics of the image.
1191 Register ResTypeReg = Instr.getOperand(1).getReg();
1192 MachineInstr *ResTypeInst = MRI.getUniqueVRegDef(ResTypeReg);
1193
1194 assert(ResTypeInst->getOpcode() == SPIRV::OpTypePointer);
1195 uint32_t StorageClass = ResTypeInst->getOperand(1).getImm();
1196 if (StorageClass != SPIRV::StorageClass::StorageClass::UniformConstant &&
1197 StorageClass != SPIRV::StorageClass::StorageClass::Uniform &&
1198 StorageClass != SPIRV::StorageClass::StorageClass::StorageBuffer) {
1199 return;
1200 }
1201
1202 bool IsNonUniform =
1203 hasNonUniformDecoration(Instr.getOperand(0).getReg(), MRI);
1204
1205 auto FirstIndexReg = Instr.getOperand(3).getReg();
1206 bool FirstIndexIsConstant =
1207 Subtarget.getInstrInfo()->isConstantInstr(*MRI.getVRegDef(FirstIndexReg));
1208
1209 if (StorageClass == SPIRV::StorageClass::StorageClass::StorageBuffer) {
1210 if (IsNonUniform)
1211 Handler.addRequirements(
1212 SPIRV::Capability::StorageBufferArrayNonUniformIndexingEXT);
1213 else if (!FirstIndexIsConstant)
1214 Handler.addRequirements(
1215 SPIRV::Capability::StorageBufferArrayDynamicIndexing);
1216 return;
1217 }
1218
1219 Register PointeeTypeReg = ResTypeInst->getOperand(2).getReg();
1220 MachineInstr *PointeeType = MRI.getUniqueVRegDef(PointeeTypeReg);
1221 if (PointeeType->getOpcode() != SPIRV::OpTypeImage &&
1222 PointeeType->getOpcode() != SPIRV::OpTypeSampledImage &&
1223 PointeeType->getOpcode() != SPIRV::OpTypeSampler) {
1224 return;
1225 }
1226
1227 if (isUniformTexelBuffer(PointeeType)) {
1228 if (IsNonUniform)
1229 Handler.addRequirements(
1230 SPIRV::Capability::UniformTexelBufferArrayNonUniformIndexingEXT);
1231 else if (!FirstIndexIsConstant)
1232 Handler.addRequirements(
1233 SPIRV::Capability::UniformTexelBufferArrayDynamicIndexingEXT);
1234 } else if (isInputAttachment(PointeeType)) {
1235 if (IsNonUniform)
1236 Handler.addRequirements(
1237 SPIRV::Capability::InputAttachmentArrayNonUniformIndexingEXT);
1238 else if (!FirstIndexIsConstant)
1239 Handler.addRequirements(
1240 SPIRV::Capability::InputAttachmentArrayDynamicIndexingEXT);
1241 } else if (isStorageTexelBuffer(PointeeType)) {
1242 if (IsNonUniform)
1243 Handler.addRequirements(
1244 SPIRV::Capability::StorageTexelBufferArrayNonUniformIndexingEXT);
1245 else if (!FirstIndexIsConstant)
1246 Handler.addRequirements(
1247 SPIRV::Capability::StorageTexelBufferArrayDynamicIndexingEXT);
1248 } else if (isSampledImage(PointeeType) ||
1249 isCombinedImageSampler(PointeeType) ||
1250 PointeeType->getOpcode() == SPIRV::OpTypeSampler) {
1251 if (IsNonUniform)
1252 Handler.addRequirements(
1253 SPIRV::Capability::SampledImageArrayNonUniformIndexingEXT);
1254 else if (!FirstIndexIsConstant)
1255 Handler.addRequirements(
1256 SPIRV::Capability::SampledImageArrayDynamicIndexing);
1257 } else if (isStorageImage(PointeeType)) {
1258 if (IsNonUniform)
1259 Handler.addRequirements(
1260 SPIRV::Capability::StorageImageArrayNonUniformIndexingEXT);
1261 else if (!FirstIndexIsConstant)
1262 Handler.addRequirements(
1263 SPIRV::Capability::StorageImageArrayDynamicIndexing);
1264 }
1265}
1266
1267static bool isImageTypeWithUnknownFormat(SPIRVType *TypeInst) {
1268 if (TypeInst->getOpcode() != SPIRV::OpTypeImage)
1269 return false;
1270 assert(TypeInst->getOperand(7).isImm() && "The image format must be an imm.");
1271 return TypeInst->getOperand(7).getImm() == 0;
1272}
1273
1274static void AddDotProductRequirements(const MachineInstr &MI,
1276 const SPIRVSubtarget &ST) {
1277 if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_integer_dot_product))
1278 Reqs.addExtension(SPIRV::Extension::SPV_KHR_integer_dot_product);
1279 Reqs.addCapability(SPIRV::Capability::DotProduct);
1280
1281 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1282 assert(MI.getOperand(2).isReg() && "Unexpected operand in dot");
1283 // We do not consider what the previous instruction is. This is just used
1284 // to get the input register and to check the type.
1285 const MachineInstr *Input = MRI.getVRegDef(MI.getOperand(2).getReg());
1286 assert(Input->getOperand(1).isReg() && "Unexpected operand in dot input");
1287 Register InputReg = Input->getOperand(1).getReg();
1288
1289 SPIRVType *TypeDef = MRI.getVRegDef(InputReg);
1290 if (TypeDef->getOpcode() == SPIRV::OpTypeInt) {
1291 assert(TypeDef->getOperand(1).getImm() == 32);
1292 Reqs.addCapability(SPIRV::Capability::DotProductInput4x8BitPacked);
1293 } else if (TypeDef->getOpcode() == SPIRV::OpTypeVector) {
1294 SPIRVType *ScalarTypeDef = MRI.getVRegDef(TypeDef->getOperand(1).getReg());
1295 assert(ScalarTypeDef->getOpcode() == SPIRV::OpTypeInt);
1296 if (ScalarTypeDef->getOperand(1).getImm() == 8) {
1297 assert(TypeDef->getOperand(2).getImm() == 4 &&
1298 "Dot operand of 8-bit integer type requires 4 components");
1299 Reqs.addCapability(SPIRV::Capability::DotProductInput4x8Bit);
1300 } else {
1301 Reqs.addCapability(SPIRV::Capability::DotProductInputAll);
1302 }
1303 }
1304}
1305
1306void addPrintfRequirements(const MachineInstr &MI,
1308 const SPIRVSubtarget &ST) {
1309 SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
1310 const SPIRVType *PtrType = GR->getSPIRVTypeForVReg(MI.getOperand(4).getReg());
1311 if (PtrType) {
1312 MachineOperand ASOp = PtrType->getOperand(1);
1313 if (ASOp.isImm()) {
1314 unsigned AddrSpace = ASOp.getImm();
1315 if (AddrSpace != SPIRV::StorageClass::UniformConstant) {
1316 if (!ST.canUseExtension(
1318 SPV_EXT_relaxed_printf_string_address_space)) {
1319 report_fatal_error("SPV_EXT_relaxed_printf_string_address_space is "
1320 "required because printf uses a format string not "
1321 "in constant address space.",
1322 false);
1323 }
1324 Reqs.addExtension(
1325 SPIRV::Extension::SPV_EXT_relaxed_printf_string_address_space);
1326 }
1327 }
1328 }
1329}
1330
1331static bool isBFloat16Type(const SPIRVType *TypeDef) {
1332 return TypeDef && TypeDef->getNumOperands() == 3 &&
1333 TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
1334 TypeDef->getOperand(1).getImm() == 16 &&
1335 TypeDef->getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR;
1336}
1337
1338void addInstrRequirements(const MachineInstr &MI,
1340 const SPIRVSubtarget &ST) {
1341 SPIRV::RequirementHandler &Reqs = MAI.Reqs;
1342 switch (MI.getOpcode()) {
1343 case SPIRV::OpMemoryModel: {
1344 int64_t Addr = MI.getOperand(0).getImm();
1345 Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand,
1346 Addr, ST);
1347 int64_t Mem = MI.getOperand(1).getImm();
1348 Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand, Mem,
1349 ST);
1350 break;
1351 }
1352 case SPIRV::OpEntryPoint: {
1353 int64_t Exe = MI.getOperand(0).getImm();
1354 Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModelOperand,
1355 Exe, ST);
1356 break;
1357 }
1358 case SPIRV::OpExecutionMode:
1359 case SPIRV::OpExecutionModeId: {
1360 int64_t Exe = MI.getOperand(1).getImm();
1361 Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModeOperand,
1362 Exe, ST);
1363 break;
1364 }
1365 case SPIRV::OpTypeMatrix:
1366 Reqs.addCapability(SPIRV::Capability::Matrix);
1367 break;
1368 case SPIRV::OpTypeInt: {
1369 unsigned BitWidth = MI.getOperand(1).getImm();
1370 if (BitWidth == 64)
1371 Reqs.addCapability(SPIRV::Capability::Int64);
1372 else if (BitWidth == 16)
1373 Reqs.addCapability(SPIRV::Capability::Int16);
1374 else if (BitWidth == 8)
1375 Reqs.addCapability(SPIRV::Capability::Int8);
1376 break;
1377 }
1378 case SPIRV::OpDot: {
1379 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1380 SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
1381 if (isBFloat16Type(TypeDef))
1382 Reqs.addCapability(SPIRV::Capability::BFloat16DotProductKHR);
1383 break;
1384 }
1385 case SPIRV::OpTypeFloat: {
1386 unsigned BitWidth = MI.getOperand(1).getImm();
1387 if (BitWidth == 64)
1388 Reqs.addCapability(SPIRV::Capability::Float64);
1389 else if (BitWidth == 16) {
1390 if (isBFloat16Type(&MI)) {
1391 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_bfloat16))
1392 report_fatal_error("OpTypeFloat type with bfloat requires the "
1393 "following SPIR-V extension: SPV_KHR_bfloat16",
1394 false);
1395 Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
1396 Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
1397 } else {
1398 Reqs.addCapability(SPIRV::Capability::Float16);
1399 }
1400 }
1401 break;
1402 }
1403 case SPIRV::OpTypeVector: {
1404 unsigned NumComponents = MI.getOperand(2).getImm();
1405 if (NumComponents == 8 || NumComponents == 16)
1406 Reqs.addCapability(SPIRV::Capability::Vector16);
1407 break;
1408 }
1409 case SPIRV::OpTypePointer: {
1410 auto SC = MI.getOperand(1).getImm();
1411 Reqs.getAndAddRequirements(SPIRV::OperandCategory::StorageClassOperand, SC,
1412 ST);
1413 // If it's a type of pointer to float16 targeting OpenCL, add Float16Buffer
1414 // capability.
1415 if (ST.isShader())
1416 break;
1417 assert(MI.getOperand(2).isReg());
1418 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1419 SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(2).getReg());
1420 if ((TypeDef->getNumOperands() == 2) &&
1421 (TypeDef->getOpcode() == SPIRV::OpTypeFloat) &&
1422 (TypeDef->getOperand(1).getImm() == 16))
1423 Reqs.addCapability(SPIRV::Capability::Float16Buffer);
1424 break;
1425 }
1426 case SPIRV::OpExtInst: {
1427 if (MI.getOperand(2).getImm() ==
1428 static_cast<int64_t>(
1429 SPIRV::InstructionSet::NonSemantic_Shader_DebugInfo_100)) {
1430 Reqs.addExtension(SPIRV::Extension::SPV_KHR_non_semantic_info);
1431 break;
1432 }
1433 if (MI.getOperand(3).getImm() ==
1434 static_cast<int64_t>(SPIRV::OpenCLExtInst::printf)) {
1435 addPrintfRequirements(MI, Reqs, ST);
1436 break;
1437 }
1438 // TODO: handle bfloat16 extended instructions when
1439 // SPV_INTEL_bfloat16_arithmetic is enabled.
1440 break;
1441 }
1442 case SPIRV::OpAliasDomainDeclINTEL:
1443 case SPIRV::OpAliasScopeDeclINTEL:
1444 case SPIRV::OpAliasScopeListDeclINTEL: {
1445 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_memory_access_aliasing);
1446 Reqs.addCapability(SPIRV::Capability::MemoryAccessAliasingINTEL);
1447 break;
1448 }
1449 case SPIRV::OpBitReverse:
1450 case SPIRV::OpBitFieldInsert:
1451 case SPIRV::OpBitFieldSExtract:
1452 case SPIRV::OpBitFieldUExtract:
1453 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions)) {
1454 Reqs.addCapability(SPIRV::Capability::Shader);
1455 break;
1456 }
1457 Reqs.addExtension(SPIRV::Extension::SPV_KHR_bit_instructions);
1458 Reqs.addCapability(SPIRV::Capability::BitInstructions);
1459 break;
1460 case SPIRV::OpTypeRuntimeArray:
1461 Reqs.addCapability(SPIRV::Capability::Shader);
1462 break;
1463 case SPIRV::OpTypeOpaque:
1464 case SPIRV::OpTypeEvent:
1465 Reqs.addCapability(SPIRV::Capability::Kernel);
1466 break;
1467 case SPIRV::OpTypePipe:
1468 case SPIRV::OpTypeReserveId:
1469 Reqs.addCapability(SPIRV::Capability::Pipes);
1470 break;
1471 case SPIRV::OpTypeDeviceEvent:
1472 case SPIRV::OpTypeQueue:
1473 case SPIRV::OpBuildNDRange:
1474 Reqs.addCapability(SPIRV::Capability::DeviceEnqueue);
1475 break;
1476 case SPIRV::OpDecorate:
1477 case SPIRV::OpDecorateId:
1478 case SPIRV::OpDecorateString:
1479 addOpDecorateReqs(MI, 1, Reqs, ST);
1480 break;
1481 case SPIRV::OpMemberDecorate:
1482 case SPIRV::OpMemberDecorateString:
1483 addOpDecorateReqs(MI, 2, Reqs, ST);
1484 break;
1485 case SPIRV::OpInBoundsPtrAccessChain:
1486 Reqs.addCapability(SPIRV::Capability::Addresses);
1487 break;
1488 case SPIRV::OpConstantSampler:
1489 Reqs.addCapability(SPIRV::Capability::LiteralSampler);
1490 break;
1491 case SPIRV::OpInBoundsAccessChain:
1492 case SPIRV::OpAccessChain:
1493 addOpAccessChainReqs(MI, Reqs, ST);
1494 break;
1495 case SPIRV::OpTypeImage:
1496 addOpTypeImageReqs(MI, Reqs, ST);
1497 break;
1498 case SPIRV::OpTypeSampler:
1499 if (!ST.isShader()) {
1500 Reqs.addCapability(SPIRV::Capability::ImageBasic);
1501 }
1502 break;
1503 case SPIRV::OpTypeForwardPointer:
1504 // TODO: check if it's OpenCL's kernel.
1505 Reqs.addCapability(SPIRV::Capability::Addresses);
1506 break;
1507 case SPIRV::OpAtomicFlagTestAndSet:
1508 case SPIRV::OpAtomicLoad:
1509 case SPIRV::OpAtomicStore:
1510 case SPIRV::OpAtomicExchange:
1511 case SPIRV::OpAtomicCompareExchange:
1512 case SPIRV::OpAtomicIIncrement:
1513 case SPIRV::OpAtomicIDecrement:
1514 case SPIRV::OpAtomicIAdd:
1515 case SPIRV::OpAtomicISub:
1516 case SPIRV::OpAtomicUMin:
1517 case SPIRV::OpAtomicUMax:
1518 case SPIRV::OpAtomicSMin:
1519 case SPIRV::OpAtomicSMax:
1520 case SPIRV::OpAtomicAnd:
1521 case SPIRV::OpAtomicOr:
1522 case SPIRV::OpAtomicXor: {
1523 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1524 const MachineInstr *InstrPtr = &MI;
1525 if (MI.getOpcode() == SPIRV::OpAtomicStore) {
1526 assert(MI.getOperand(3).isReg());
1527 InstrPtr = MRI.getVRegDef(MI.getOperand(3).getReg());
1528 assert(InstrPtr && "Unexpected type instruction for OpAtomicStore");
1529 }
1530 assert(InstrPtr->getOperand(1).isReg() && "Unexpected operand in atomic");
1531 Register TypeReg = InstrPtr->getOperand(1).getReg();
1532 SPIRVType *TypeDef = MRI.getVRegDef(TypeReg);
1533 if (TypeDef->getOpcode() == SPIRV::OpTypeInt) {
1534 unsigned BitWidth = TypeDef->getOperand(1).getImm();
1535 if (BitWidth == 64)
1536 Reqs.addCapability(SPIRV::Capability::Int64Atomics);
1537 }
1538 break;
1539 }
1540 case SPIRV::OpGroupNonUniformIAdd:
1541 case SPIRV::OpGroupNonUniformFAdd:
1542 case SPIRV::OpGroupNonUniformIMul:
1543 case SPIRV::OpGroupNonUniformFMul:
1544 case SPIRV::OpGroupNonUniformSMin:
1545 case SPIRV::OpGroupNonUniformUMin:
1546 case SPIRV::OpGroupNonUniformFMin:
1547 case SPIRV::OpGroupNonUniformSMax:
1548 case SPIRV::OpGroupNonUniformUMax:
1549 case SPIRV::OpGroupNonUniformFMax:
1550 case SPIRV::OpGroupNonUniformBitwiseAnd:
1551 case SPIRV::OpGroupNonUniformBitwiseOr:
1552 case SPIRV::OpGroupNonUniformBitwiseXor:
1553 case SPIRV::OpGroupNonUniformLogicalAnd:
1554 case SPIRV::OpGroupNonUniformLogicalOr:
1555 case SPIRV::OpGroupNonUniformLogicalXor: {
1556 assert(MI.getOperand(3).isImm());
1557 int64_t GroupOp = MI.getOperand(3).getImm();
1558 switch (GroupOp) {
1559 case SPIRV::GroupOperation::Reduce:
1560 case SPIRV::GroupOperation::InclusiveScan:
1561 case SPIRV::GroupOperation::ExclusiveScan:
1562 Reqs.addCapability(SPIRV::Capability::GroupNonUniformArithmetic);
1563 break;
1564 case SPIRV::GroupOperation::ClusteredReduce:
1565 Reqs.addCapability(SPIRV::Capability::GroupNonUniformClustered);
1566 break;
1567 case SPIRV::GroupOperation::PartitionedReduceNV:
1568 case SPIRV::GroupOperation::PartitionedInclusiveScanNV:
1569 case SPIRV::GroupOperation::PartitionedExclusiveScanNV:
1570 Reqs.addCapability(SPIRV::Capability::GroupNonUniformPartitionedNV);
1571 break;
1572 }
1573 break;
1574 }
1575 case SPIRV::OpGroupNonUniformShuffle:
1576 case SPIRV::OpGroupNonUniformShuffleXor:
1577 Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffle);
1578 break;
1579 case SPIRV::OpGroupNonUniformShuffleUp:
1580 case SPIRV::OpGroupNonUniformShuffleDown:
1581 Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffleRelative);
1582 break;
1583 case SPIRV::OpGroupAll:
1584 case SPIRV::OpGroupAny:
1585 case SPIRV::OpGroupBroadcast:
1586 case SPIRV::OpGroupIAdd:
1587 case SPIRV::OpGroupFAdd:
1588 case SPIRV::OpGroupFMin:
1589 case SPIRV::OpGroupUMin:
1590 case SPIRV::OpGroupSMin:
1591 case SPIRV::OpGroupFMax:
1592 case SPIRV::OpGroupUMax:
1593 case SPIRV::OpGroupSMax:
1594 Reqs.addCapability(SPIRV::Capability::Groups);
1595 break;
1596 case SPIRV::OpGroupNonUniformElect:
1597 Reqs.addCapability(SPIRV::Capability::GroupNonUniform);
1598 break;
1599 case SPIRV::OpGroupNonUniformAll:
1600 case SPIRV::OpGroupNonUniformAny:
1601 case SPIRV::OpGroupNonUniformAllEqual:
1602 Reqs.addCapability(SPIRV::Capability::GroupNonUniformVote);
1603 break;
1604 case SPIRV::OpGroupNonUniformBroadcast:
1605 case SPIRV::OpGroupNonUniformBroadcastFirst:
1606 case SPIRV::OpGroupNonUniformBallot:
1607 case SPIRV::OpGroupNonUniformInverseBallot:
1608 case SPIRV::OpGroupNonUniformBallotBitExtract:
1609 case SPIRV::OpGroupNonUniformBallotBitCount:
1610 case SPIRV::OpGroupNonUniformBallotFindLSB:
1611 case SPIRV::OpGroupNonUniformBallotFindMSB:
1612 Reqs.addCapability(SPIRV::Capability::GroupNonUniformBallot);
1613 break;
1614 case SPIRV::OpSubgroupShuffleINTEL:
1615 case SPIRV::OpSubgroupShuffleDownINTEL:
1616 case SPIRV::OpSubgroupShuffleUpINTEL:
1617 case SPIRV::OpSubgroupShuffleXorINTEL:
1618 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) {
1619 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups);
1620 Reqs.addCapability(SPIRV::Capability::SubgroupShuffleINTEL);
1621 }
1622 break;
1623 case SPIRV::OpSubgroupBlockReadINTEL:
1624 case SPIRV::OpSubgroupBlockWriteINTEL:
1625 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) {
1626 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups);
1627 Reqs.addCapability(SPIRV::Capability::SubgroupBufferBlockIOINTEL);
1628 }
1629 break;
1630 case SPIRV::OpSubgroupImageBlockReadINTEL:
1631 case SPIRV::OpSubgroupImageBlockWriteINTEL:
1632 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) {
1633 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups);
1634 Reqs.addCapability(SPIRV::Capability::SubgroupImageBlockIOINTEL);
1635 }
1636 break;
1637 case SPIRV::OpSubgroupImageMediaBlockReadINTEL:
1638 case SPIRV::OpSubgroupImageMediaBlockWriteINTEL:
1639 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_media_block_io)) {
1640 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_media_block_io);
1641 Reqs.addCapability(SPIRV::Capability::SubgroupImageMediaBlockIOINTEL);
1642 }
1643 break;
1644 case SPIRV::OpAssumeTrueKHR:
1645 case SPIRV::OpExpectKHR:
1646 if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_expect_assume)) {
1647 Reqs.addExtension(SPIRV::Extension::SPV_KHR_expect_assume);
1648 Reqs.addCapability(SPIRV::Capability::ExpectAssumeKHR);
1649 }
1650 break;
1651 case SPIRV::OpPtrCastToCrossWorkgroupINTEL:
1652 case SPIRV::OpCrossWorkgroupCastToPtrINTEL:
1653 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes)) {
1654 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes);
1655 Reqs.addCapability(SPIRV::Capability::USMStorageClassesINTEL);
1656 }
1657 break;
1658 case SPIRV::OpConstantFunctionPointerINTEL:
1659 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) {
1660 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
1661 Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL);
1662 }
1663 break;
1664 case SPIRV::OpGroupNonUniformRotateKHR:
1665 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_subgroup_rotate))
1666 report_fatal_error("OpGroupNonUniformRotateKHR instruction requires the "
1667 "following SPIR-V extension: SPV_KHR_subgroup_rotate",
1668 false);
1669 Reqs.addExtension(SPIRV::Extension::SPV_KHR_subgroup_rotate);
1670 Reqs.addCapability(SPIRV::Capability::GroupNonUniformRotateKHR);
1671 Reqs.addCapability(SPIRV::Capability::GroupNonUniform);
1672 break;
1673 case SPIRV::OpGroupIMulKHR:
1674 case SPIRV::OpGroupFMulKHR:
1675 case SPIRV::OpGroupBitwiseAndKHR:
1676 case SPIRV::OpGroupBitwiseOrKHR:
1677 case SPIRV::OpGroupBitwiseXorKHR:
1678 case SPIRV::OpGroupLogicalAndKHR:
1679 case SPIRV::OpGroupLogicalOrKHR:
1680 case SPIRV::OpGroupLogicalXorKHR:
1681 if (ST.canUseExtension(
1682 SPIRV::Extension::SPV_KHR_uniform_group_instructions)) {
1683 Reqs.addExtension(SPIRV::Extension::SPV_KHR_uniform_group_instructions);
1684 Reqs.addCapability(SPIRV::Capability::GroupUniformArithmeticKHR);
1685 }
1686 break;
1687 case SPIRV::OpReadClockKHR:
1688 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_shader_clock))
1689 report_fatal_error("OpReadClockKHR instruction requires the "
1690 "following SPIR-V extension: SPV_KHR_shader_clock",
1691 false);
1692 Reqs.addExtension(SPIRV::Extension::SPV_KHR_shader_clock);
1693 Reqs.addCapability(SPIRV::Capability::ShaderClockKHR);
1694 break;
1695 case SPIRV::OpFunctionPointerCallINTEL:
1696 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) {
1697 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
1698 Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL);
1699 }
1700 break;
1701 case SPIRV::OpAtomicFAddEXT:
1702 case SPIRV::OpAtomicFMinEXT:
1703 case SPIRV::OpAtomicFMaxEXT:
1704 AddAtomicFloatRequirements(MI, Reqs, ST);
1705 break;
1706 case SPIRV::OpConvertBF16ToFINTEL:
1707 case SPIRV::OpConvertFToBF16INTEL:
1708 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bfloat16_conversion)) {
1709 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_bfloat16_conversion);
1710 Reqs.addCapability(SPIRV::Capability::BFloat16ConversionINTEL);
1711 }
1712 break;
1713 case SPIRV::OpRoundFToTF32INTEL:
1714 if (ST.canUseExtension(
1715 SPIRV::Extension::SPV_INTEL_tensor_float32_conversion)) {
1716 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_tensor_float32_conversion);
1717 Reqs.addCapability(SPIRV::Capability::TensorFloat32RoundingINTEL);
1718 }
1719 break;
1720 case SPIRV::OpVariableLengthArrayINTEL:
1721 case SPIRV::OpSaveMemoryINTEL:
1722 case SPIRV::OpRestoreMemoryINTEL:
1723 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_variable_length_array)) {
1724 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_variable_length_array);
1725 Reqs.addCapability(SPIRV::Capability::VariableLengthArrayINTEL);
1726 }
1727 break;
1728 case SPIRV::OpAsmTargetINTEL:
1729 case SPIRV::OpAsmINTEL:
1730 case SPIRV::OpAsmCallINTEL:
1731 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_inline_assembly)) {
1732 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_inline_assembly);
1733 Reqs.addCapability(SPIRV::Capability::AsmINTEL);
1734 }
1735 break;
1736 case SPIRV::OpTypeCooperativeMatrixKHR: {
1737 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
1739 "OpTypeCooperativeMatrixKHR type requires the "
1740 "following SPIR-V extension: SPV_KHR_cooperative_matrix",
1741 false);
1742 Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
1743 Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
1744 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1745 SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
1746 if (isBFloat16Type(TypeDef))
1747 Reqs.addCapability(SPIRV::Capability::BFloat16CooperativeMatrixKHR);
1748 break;
1749 }
1750 case SPIRV::OpArithmeticFenceEXT:
1751 if (!ST.canUseExtension(SPIRV::Extension::SPV_EXT_arithmetic_fence))
1752 report_fatal_error("OpArithmeticFenceEXT requires the "
1753 "following SPIR-V extension: SPV_EXT_arithmetic_fence",
1754 false);
1755 Reqs.addExtension(SPIRV::Extension::SPV_EXT_arithmetic_fence);
1756 Reqs.addCapability(SPIRV::Capability::ArithmeticFenceEXT);
1757 break;
1758 case SPIRV::OpControlBarrierArriveINTEL:
1759 case SPIRV::OpControlBarrierWaitINTEL:
1760 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_split_barrier)) {
1761 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_split_barrier);
1762 Reqs.addCapability(SPIRV::Capability::SplitBarrierINTEL);
1763 }
1764 break;
1765 case SPIRV::OpCooperativeMatrixMulAddKHR: {
1766 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
1767 report_fatal_error("Cooperative matrix instructions require the "
1768 "following SPIR-V extension: "
1769 "SPV_KHR_cooperative_matrix",
1770 false);
1771 Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
1772 Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
1773 constexpr unsigned MulAddMaxSize = 6;
1774 if (MI.getNumOperands() != MulAddMaxSize)
1775 break;
1776 const int64_t CoopOperands = MI.getOperand(MulAddMaxSize - 1).getImm();
1777 if (CoopOperands &
1778 SPIRV::CooperativeMatrixOperands::MatrixAAndBTF32ComponentsINTEL) {
1779 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
1780 report_fatal_error("MatrixAAndBTF32ComponentsINTEL type interpretation "
1781 "require the following SPIR-V extension: "
1782 "SPV_INTEL_joint_matrix",
1783 false);
1784 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
1785 Reqs.addCapability(
1786 SPIRV::Capability::CooperativeMatrixTF32ComponentTypeINTEL);
1787 }
1788 if (CoopOperands & SPIRV::CooperativeMatrixOperands::
1789 MatrixAAndBBFloat16ComponentsINTEL ||
1790 CoopOperands &
1791 SPIRV::CooperativeMatrixOperands::MatrixCBFloat16ComponentsINTEL ||
1792 CoopOperands & SPIRV::CooperativeMatrixOperands::
1793 MatrixResultBFloat16ComponentsINTEL) {
1794 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
1795 report_fatal_error("***BF16ComponentsINTEL type interpretations "
1796 "require the following SPIR-V extension: "
1797 "SPV_INTEL_joint_matrix",
1798 false);
1799 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
1800 Reqs.addCapability(
1801 SPIRV::Capability::CooperativeMatrixBFloat16ComponentTypeINTEL);
1802 }
1803 break;
1804 }
1805 case SPIRV::OpCooperativeMatrixLoadKHR:
1806 case SPIRV::OpCooperativeMatrixStoreKHR:
1807 case SPIRV::OpCooperativeMatrixLoadCheckedINTEL:
1808 case SPIRV::OpCooperativeMatrixStoreCheckedINTEL:
1809 case SPIRV::OpCooperativeMatrixPrefetchINTEL: {
1810 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
1811 report_fatal_error("Cooperative matrix instructions require the "
1812 "following SPIR-V extension: "
1813 "SPV_KHR_cooperative_matrix",
1814 false);
1815 Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
1816 Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
1817
1818 // Check Layout operand in case if it's not a standard one and add the
1819 // appropriate capability.
1820 std::unordered_map<unsigned, unsigned> LayoutToInstMap = {
1821 {SPIRV::OpCooperativeMatrixLoadKHR, 3},
1822 {SPIRV::OpCooperativeMatrixStoreKHR, 2},
1823 {SPIRV::OpCooperativeMatrixLoadCheckedINTEL, 5},
1824 {SPIRV::OpCooperativeMatrixStoreCheckedINTEL, 4},
1825 {SPIRV::OpCooperativeMatrixPrefetchINTEL, 4}};
1826
1827 const auto OpCode = MI.getOpcode();
1828 const unsigned LayoutNum = LayoutToInstMap[OpCode];
1829 Register RegLayout = MI.getOperand(LayoutNum).getReg();
1830 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1831 MachineInstr *MILayout = MRI.getUniqueVRegDef(RegLayout);
1832 if (MILayout->getOpcode() == SPIRV::OpConstantI) {
1833 const unsigned LayoutVal = MILayout->getOperand(2).getImm();
1834 if (LayoutVal ==
1835 static_cast<unsigned>(SPIRV::CooperativeMatrixLayout::PackedINTEL)) {
1836 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
1837 report_fatal_error("PackedINTEL layout require the following SPIR-V "
1838 "extension: SPV_INTEL_joint_matrix",
1839 false);
1840 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
1841 Reqs.addCapability(SPIRV::Capability::PackedCooperativeMatrixINTEL);
1842 }
1843 }
1844
1845 // Nothing to do.
1846 if (OpCode == SPIRV::OpCooperativeMatrixLoadKHR ||
1847 OpCode == SPIRV::OpCooperativeMatrixStoreKHR)
1848 break;
1849
1850 std::string InstName;
1851 switch (OpCode) {
1852 case SPIRV::OpCooperativeMatrixPrefetchINTEL:
1853 InstName = "OpCooperativeMatrixPrefetchINTEL";
1854 break;
1855 case SPIRV::OpCooperativeMatrixLoadCheckedINTEL:
1856 InstName = "OpCooperativeMatrixLoadCheckedINTEL";
1857 break;
1858 case SPIRV::OpCooperativeMatrixStoreCheckedINTEL:
1859 InstName = "OpCooperativeMatrixStoreCheckedINTEL";
1860 break;
1861 }
1862
1863 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix)) {
1864 const std::string ErrorMsg =
1865 InstName + " instruction requires the "
1866 "following SPIR-V extension: SPV_INTEL_joint_matrix";
1867 report_fatal_error(ErrorMsg.c_str(), false);
1868 }
1869 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
1870 if (OpCode == SPIRV::OpCooperativeMatrixPrefetchINTEL) {
1871 Reqs.addCapability(SPIRV::Capability::CooperativeMatrixPrefetchINTEL);
1872 break;
1873 }
1874 Reqs.addCapability(
1875 SPIRV::Capability::CooperativeMatrixCheckedInstructionsINTEL);
1876 break;
1877 }
1878 case SPIRV::OpCooperativeMatrixConstructCheckedINTEL:
1879 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
1880 report_fatal_error("OpCooperativeMatrixConstructCheckedINTEL "
1881 "instructions require the following SPIR-V extension: "
1882 "SPV_INTEL_joint_matrix",
1883 false);
1884 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
1885 Reqs.addCapability(
1886 SPIRV::Capability::CooperativeMatrixCheckedInstructionsINTEL);
1887 break;
1888 case SPIRV::OpCooperativeMatrixGetElementCoordINTEL:
1889 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
1890 report_fatal_error("OpCooperativeMatrixGetElementCoordINTEL requires the "
1891 "following SPIR-V extension: SPV_INTEL_joint_matrix",
1892 false);
1893 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
1894 Reqs.addCapability(
1895 SPIRV::Capability::CooperativeMatrixInvocationInstructionsINTEL);
1896 break;
1897 case SPIRV::OpConvertHandleToImageINTEL:
1898 case SPIRV::OpConvertHandleToSamplerINTEL:
1899 case SPIRV::OpConvertHandleToSampledImageINTEL: {
1900 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bindless_images))
1901 report_fatal_error("OpConvertHandleTo[Image/Sampler/SampledImage]INTEL "
1902 "instructions require the following SPIR-V extension: "
1903 "SPV_INTEL_bindless_images",
1904 false);
1905 SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
1906 SPIRV::AddressingModel::AddressingModel AddrModel = MAI.Addr;
1907 SPIRVType *TyDef = GR->getSPIRVTypeForVReg(MI.getOperand(1).getReg());
1908 if (MI.getOpcode() == SPIRV::OpConvertHandleToImageINTEL &&
1909 TyDef->getOpcode() != SPIRV::OpTypeImage) {
1910 report_fatal_error("Incorrect return type for the instruction "
1911 "OpConvertHandleToImageINTEL",
1912 false);
1913 } else if (MI.getOpcode() == SPIRV::OpConvertHandleToSamplerINTEL &&
1914 TyDef->getOpcode() != SPIRV::OpTypeSampler) {
1915 report_fatal_error("Incorrect return type for the instruction "
1916 "OpConvertHandleToSamplerINTEL",
1917 false);
1918 } else if (MI.getOpcode() == SPIRV::OpConvertHandleToSampledImageINTEL &&
1919 TyDef->getOpcode() != SPIRV::OpTypeSampledImage) {
1920 report_fatal_error("Incorrect return type for the instruction "
1921 "OpConvertHandleToSampledImageINTEL",
1922 false);
1923 }
1924 SPIRVType *SpvTy = GR->getSPIRVTypeForVReg(MI.getOperand(2).getReg());
1925 unsigned Bitwidth = GR->getScalarOrVectorBitWidth(SpvTy);
1926 if (!(Bitwidth == 32 && AddrModel == SPIRV::AddressingModel::Physical32) &&
1927 !(Bitwidth == 64 && AddrModel == SPIRV::AddressingModel::Physical64)) {
1929 "Parameter value must be a 32-bit scalar in case of "
1930 "Physical32 addressing model or a 64-bit scalar in case of "
1931 "Physical64 addressing model",
1932 false);
1933 }
1934 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_bindless_images);
1935 Reqs.addCapability(SPIRV::Capability::BindlessImagesINTEL);
1936 break;
1937 }
1938 case SPIRV::OpSubgroup2DBlockLoadINTEL:
1939 case SPIRV::OpSubgroup2DBlockLoadTransposeINTEL:
1940 case SPIRV::OpSubgroup2DBlockLoadTransformINTEL:
1941 case SPIRV::OpSubgroup2DBlockPrefetchINTEL:
1942 case SPIRV::OpSubgroup2DBlockStoreINTEL: {
1943 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_2d_block_io))
1944 report_fatal_error("OpSubgroup2DBlock[Load/LoadTranspose/LoadTransform/"
1945 "Prefetch/Store]INTEL instructions require the "
1946 "following SPIR-V extension: SPV_INTEL_2d_block_io",
1947 false);
1948 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_2d_block_io);
1949 Reqs.addCapability(SPIRV::Capability::Subgroup2DBlockIOINTEL);
1950
1951 const auto OpCode = MI.getOpcode();
1952 if (OpCode == SPIRV::OpSubgroup2DBlockLoadTransposeINTEL) {
1953 Reqs.addCapability(SPIRV::Capability::Subgroup2DBlockTransposeINTEL);
1954 break;
1955 }
1956 if (OpCode == SPIRV::OpSubgroup2DBlockLoadTransformINTEL) {
1957 Reqs.addCapability(SPIRV::Capability::Subgroup2DBlockTransformINTEL);
1958 break;
1959 }
1960 break;
1961 }
1962 case SPIRV::OpKill: {
1963 Reqs.addCapability(SPIRV::Capability::Shader);
1964 } break;
1965 case SPIRV::OpDemoteToHelperInvocation:
1966 Reqs.addCapability(SPIRV::Capability::DemoteToHelperInvocation);
1967
1968 if (ST.canUseExtension(
1969 SPIRV::Extension::SPV_EXT_demote_to_helper_invocation)) {
1970 if (!ST.isAtLeastSPIRVVer(llvm::VersionTuple(1, 6)))
1971 Reqs.addExtension(
1972 SPIRV::Extension::SPV_EXT_demote_to_helper_invocation);
1973 }
1974 break;
1975 case SPIRV::OpSDot:
1976 case SPIRV::OpUDot:
1977 case SPIRV::OpSUDot:
1978 case SPIRV::OpSDotAccSat:
1979 case SPIRV::OpUDotAccSat:
1980 case SPIRV::OpSUDotAccSat:
1981 AddDotProductRequirements(MI, Reqs, ST);
1982 break;
1983 case SPIRV::OpImageRead: {
1984 Register ImageReg = MI.getOperand(2).getReg();
1985 SPIRVType *TypeDef = ST.getSPIRVGlobalRegistry()->getResultType(
1986 ImageReg, const_cast<MachineFunction *>(MI.getMF()));
1987 // OpImageRead and OpImageWrite can use Unknown Image Formats
1988 // when the Kernel capability is declared. In the OpenCL environment we are
1989 // not allowed to produce
1990 // StorageImageReadWithoutFormat/StorageImageWriteWithoutFormat, see
1991 // https://github.com/KhronosGroup/SPIRV-Headers/issues/487
1992
1993 if (isImageTypeWithUnknownFormat(TypeDef) && ST.isShader())
1994 Reqs.addCapability(SPIRV::Capability::StorageImageReadWithoutFormat);
1995 break;
1996 }
1997 case SPIRV::OpImageWrite: {
1998 Register ImageReg = MI.getOperand(0).getReg();
1999 SPIRVType *TypeDef = ST.getSPIRVGlobalRegistry()->getResultType(
2000 ImageReg, const_cast<MachineFunction *>(MI.getMF()));
2001 // OpImageRead and OpImageWrite can use Unknown Image Formats
2002 // when the Kernel capability is declared. In the OpenCL environment we are
2003 // not allowed to produce
2004 // StorageImageReadWithoutFormat/StorageImageWriteWithoutFormat, see
2005 // https://github.com/KhronosGroup/SPIRV-Headers/issues/487
2006
2007 if (isImageTypeWithUnknownFormat(TypeDef) && ST.isShader())
2008 Reqs.addCapability(SPIRV::Capability::StorageImageWriteWithoutFormat);
2009 break;
2010 }
2011 case SPIRV::OpTypeStructContinuedINTEL:
2012 case SPIRV::OpConstantCompositeContinuedINTEL:
2013 case SPIRV::OpSpecConstantCompositeContinuedINTEL:
2014 case SPIRV::OpCompositeConstructContinuedINTEL: {
2015 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_long_composites))
2017 "Continued instructions require the "
2018 "following SPIR-V extension: SPV_INTEL_long_composites",
2019 false);
2020 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_long_composites);
2021 Reqs.addCapability(SPIRV::Capability::LongCompositesINTEL);
2022 break;
2023 }
2024 case SPIRV::OpSubgroupMatrixMultiplyAccumulateINTEL: {
2025 if (!ST.canUseExtension(
2026 SPIRV::Extension::SPV_INTEL_subgroup_matrix_multiply_accumulate))
2028 "OpSubgroupMatrixMultiplyAccumulateINTEL instruction requires the "
2029 "following SPIR-V "
2030 "extension: SPV_INTEL_subgroup_matrix_multiply_accumulate",
2031 false);
2032 Reqs.addExtension(
2033 SPIRV::Extension::SPV_INTEL_subgroup_matrix_multiply_accumulate);
2034 Reqs.addCapability(
2035 SPIRV::Capability::SubgroupMatrixMultiplyAccumulateINTEL);
2036 break;
2037 }
2038 case SPIRV::OpBitwiseFunctionINTEL: {
2039 if (!ST.canUseExtension(
2040 SPIRV::Extension::SPV_INTEL_ternary_bitwise_function))
2042 "OpBitwiseFunctionINTEL instruction requires the following SPIR-V "
2043 "extension: SPV_INTEL_ternary_bitwise_function",
2044 false);
2045 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_ternary_bitwise_function);
2046 Reqs.addCapability(SPIRV::Capability::TernaryBitwiseFunctionINTEL);
2047 break;
2048 }
2049 case SPIRV::OpCopyMemorySized: {
2050 Reqs.addCapability(SPIRV::Capability::Addresses);
2051 // TODO: Add UntypedPointersKHR when implemented.
2052 break;
2053 }
2054 case SPIRV::OpPredicatedLoadINTEL:
2055 case SPIRV::OpPredicatedStoreINTEL: {
2056 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_predicated_io))
2058 "OpPredicated[Load/Store]INTEL instructions require "
2059 "the following SPIR-V extension: SPV_INTEL_predicated_io",
2060 false);
2061 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_predicated_io);
2062 Reqs.addCapability(SPIRV::Capability::PredicatedIOINTEL);
2063 break;
2064 }
2065 case SPIRV::OpFAddS:
2066 case SPIRV::OpFSubS:
2067 case SPIRV::OpFMulS:
2068 case SPIRV::OpFDivS:
2069 case SPIRV::OpFRemS:
2070 case SPIRV::OpFMod:
2071 case SPIRV::OpFNegate:
2072 case SPIRV::OpFAddV:
2073 case SPIRV::OpFSubV:
2074 case SPIRV::OpFMulV:
2075 case SPIRV::OpFDivV:
2076 case SPIRV::OpFRemV:
2077 case SPIRV::OpFNegateV: {
2078 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
2079 SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
2080 if (TypeDef->getOpcode() == SPIRV::OpTypeVector)
2081 TypeDef = MRI.getVRegDef(TypeDef->getOperand(1).getReg());
2082 if (isBFloat16Type(TypeDef)) {
2083 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic))
2085 "Arithmetic instructions with bfloat16 arguments require the "
2086 "following SPIR-V extension: SPV_INTEL_bfloat16_arithmetic",
2087 false);
2088 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic);
2089 Reqs.addCapability(SPIRV::Capability::BFloat16ArithmeticINTEL);
2090 }
2091 break;
2092 }
2093 case SPIRV::OpOrdered:
2094 case SPIRV::OpUnordered:
2095 case SPIRV::OpFOrdEqual:
2096 case SPIRV::OpFOrdNotEqual:
2097 case SPIRV::OpFOrdLessThan:
2098 case SPIRV::OpFOrdLessThanEqual:
2099 case SPIRV::OpFOrdGreaterThan:
2100 case SPIRV::OpFOrdGreaterThanEqual:
2101 case SPIRV::OpFUnordEqual:
2102 case SPIRV::OpFUnordNotEqual:
2103 case SPIRV::OpFUnordLessThan:
2104 case SPIRV::OpFUnordLessThanEqual:
2105 case SPIRV::OpFUnordGreaterThan:
2106 case SPIRV::OpFUnordGreaterThanEqual: {
2107 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
2108 MachineInstr *OperandDef = MRI.getVRegDef(MI.getOperand(2).getReg());
2109 SPIRVType *TypeDef = MRI.getVRegDef(OperandDef->getOperand(1).getReg());
2110 if (TypeDef->getOpcode() == SPIRV::OpTypeVector)
2111 TypeDef = MRI.getVRegDef(TypeDef->getOperand(1).getReg());
2112 if (isBFloat16Type(TypeDef)) {
2113 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic))
2115 "Relational instructions with bfloat16 arguments require the "
2116 "following SPIR-V extension: SPV_INTEL_bfloat16_arithmetic",
2117 false);
2118 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic);
2119 Reqs.addCapability(SPIRV::Capability::BFloat16ArithmeticINTEL);
2120 }
2121 break;
2122 }
2123 default:
2124 break;
2125 }
2126
2127 // If we require capability Shader, then we can remove the requirement for
2128 // the BitInstructions capability, since Shader is a superset capability
2129 // of BitInstructions.
2130 Reqs.removeCapabilityIf(SPIRV::Capability::BitInstructions,
2131 SPIRV::Capability::Shader);
2132}
2133
2134static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI,
2135 MachineModuleInfo *MMI, const SPIRVSubtarget &ST) {
2136 // Collect requirements for existing instructions.
2137 for (auto F = M.begin(), E = M.end(); F != E; ++F) {
2139 if (!MF)
2140 continue;
2141 for (const MachineBasicBlock &MBB : *MF)
2142 for (const MachineInstr &MI : MBB)
2143 addInstrRequirements(MI, MAI, ST);
2144 }
2145 // Collect requirements for OpExecutionMode instructions.
2146 auto Node = M.getNamedMetadata("spirv.ExecutionMode");
2147 if (Node) {
2148 bool RequireFloatControls = false, RequireIntelFloatControls2 = false,
2149 RequireKHRFloatControls2 = false,
2150 VerLower14 = !ST.isAtLeastSPIRVVer(VersionTuple(1, 4));
2151 bool HasIntelFloatControls2 =
2152 ST.canUseExtension(SPIRV::Extension::SPV_INTEL_float_controls2);
2153 bool HasKHRFloatControls2 =
2154 ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2);
2155 for (unsigned i = 0; i < Node->getNumOperands(); i++) {
2156 MDNode *MDN = cast<MDNode>(Node->getOperand(i));
2157 const MDOperand &MDOp = MDN->getOperand(1);
2158 if (auto *CMeta = dyn_cast<ConstantAsMetadata>(MDOp)) {
2159 Constant *C = CMeta->getValue();
2160 if (ConstantInt *Const = dyn_cast<ConstantInt>(C)) {
2161 auto EM = Const->getZExtValue();
2162 // SPV_KHR_float_controls is not available until v1.4:
2163 // add SPV_KHR_float_controls if the version is too low
2164 switch (EM) {
2165 case SPIRV::ExecutionMode::DenormPreserve:
2166 case SPIRV::ExecutionMode::DenormFlushToZero:
2167 case SPIRV::ExecutionMode::RoundingModeRTE:
2168 case SPIRV::ExecutionMode::RoundingModeRTZ:
2169 RequireFloatControls = VerLower14;
2171 SPIRV::OperandCategory::ExecutionModeOperand, EM, ST);
2172 break;
2173 case SPIRV::ExecutionMode::RoundingModeRTPINTEL:
2174 case SPIRV::ExecutionMode::RoundingModeRTNINTEL:
2175 case SPIRV::ExecutionMode::FloatingPointModeALTINTEL:
2176 case SPIRV::ExecutionMode::FloatingPointModeIEEEINTEL:
2177 if (HasIntelFloatControls2) {
2178 RequireIntelFloatControls2 = true;
2180 SPIRV::OperandCategory::ExecutionModeOperand, EM, ST);
2181 }
2182 break;
2183 case SPIRV::ExecutionMode::FPFastMathDefault: {
2184 if (HasKHRFloatControls2) {
2185 RequireKHRFloatControls2 = true;
2187 SPIRV::OperandCategory::ExecutionModeOperand, EM, ST);
2188 }
2189 break;
2190 }
2191 case SPIRV::ExecutionMode::ContractionOff:
2192 case SPIRV::ExecutionMode::SignedZeroInfNanPreserve:
2193 if (HasKHRFloatControls2) {
2194 RequireKHRFloatControls2 = true;
2196 SPIRV::OperandCategory::ExecutionModeOperand,
2197 SPIRV::ExecutionMode::FPFastMathDefault, ST);
2198 } else {
2200 SPIRV::OperandCategory::ExecutionModeOperand, EM, ST);
2201 }
2202 break;
2203 default:
2205 SPIRV::OperandCategory::ExecutionModeOperand, EM, ST);
2206 }
2207 }
2208 }
2209 }
2210 if (RequireFloatControls &&
2211 ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls))
2212 MAI.Reqs.addExtension(SPIRV::Extension::SPV_KHR_float_controls);
2213 if (RequireIntelFloatControls2)
2214 MAI.Reqs.addExtension(SPIRV::Extension::SPV_INTEL_float_controls2);
2215 if (RequireKHRFloatControls2)
2216 MAI.Reqs.addExtension(SPIRV::Extension::SPV_KHR_float_controls2);
2217 }
2218 for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) {
2219 const Function &F = *FI;
2220 if (F.isDeclaration())
2221 continue;
2222 if (F.getMetadata("reqd_work_group_size"))
2224 SPIRV::OperandCategory::ExecutionModeOperand,
2225 SPIRV::ExecutionMode::LocalSize, ST);
2226 if (F.getFnAttribute("hlsl.numthreads").isValid()) {
2228 SPIRV::OperandCategory::ExecutionModeOperand,
2229 SPIRV::ExecutionMode::LocalSize, ST);
2230 }
2231 if (F.getFnAttribute("enable-maximal-reconvergence").getValueAsBool()) {
2232 MAI.Reqs.addExtension(SPIRV::Extension::SPV_KHR_maximal_reconvergence);
2233 }
2234 if (F.getMetadata("work_group_size_hint"))
2236 SPIRV::OperandCategory::ExecutionModeOperand,
2237 SPIRV::ExecutionMode::LocalSizeHint, ST);
2238 if (F.getMetadata("intel_reqd_sub_group_size"))
2240 SPIRV::OperandCategory::ExecutionModeOperand,
2241 SPIRV::ExecutionMode::SubgroupSize, ST);
2242 if (F.getMetadata("max_work_group_size"))
2244 SPIRV::OperandCategory::ExecutionModeOperand,
2245 SPIRV::ExecutionMode::MaxWorkgroupSizeINTEL, ST);
2246 if (F.getMetadata("vec_type_hint"))
2248 SPIRV::OperandCategory::ExecutionModeOperand,
2249 SPIRV::ExecutionMode::VecTypeHint, ST);
2250
2251 if (F.hasOptNone()) {
2252 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_optnone)) {
2253 MAI.Reqs.addExtension(SPIRV::Extension::SPV_INTEL_optnone);
2254 MAI.Reqs.addCapability(SPIRV::Capability::OptNoneINTEL);
2255 } else if (ST.canUseExtension(SPIRV::Extension::SPV_EXT_optnone)) {
2256 MAI.Reqs.addExtension(SPIRV::Extension::SPV_EXT_optnone);
2257 MAI.Reqs.addCapability(SPIRV::Capability::OptNoneEXT);
2258 }
2259 }
2260 }
2261}
2262
2263static unsigned getFastMathFlags(const MachineInstr &I,
2264 const SPIRVSubtarget &ST) {
2265 unsigned Flags = SPIRV::FPFastMathMode::None;
2266 bool CanUseKHRFloatControls2 =
2267 ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2);
2268 if (I.getFlag(MachineInstr::MIFlag::FmNoNans))
2269 Flags |= SPIRV::FPFastMathMode::NotNaN;
2270 if (I.getFlag(MachineInstr::MIFlag::FmNoInfs))
2271 Flags |= SPIRV::FPFastMathMode::NotInf;
2272 if (I.getFlag(MachineInstr::MIFlag::FmNsz))
2273 Flags |= SPIRV::FPFastMathMode::NSZ;
2274 if (I.getFlag(MachineInstr::MIFlag::FmArcp))
2275 Flags |= SPIRV::FPFastMathMode::AllowRecip;
2276 if (I.getFlag(MachineInstr::MIFlag::FmContract) && CanUseKHRFloatControls2)
2277 Flags |= SPIRV::FPFastMathMode::AllowContract;
2278 if (I.getFlag(MachineInstr::MIFlag::FmReassoc)) {
2279 if (CanUseKHRFloatControls2)
2280 // LLVM reassoc maps to SPIRV transform, see
2281 // https://github.com/KhronosGroup/SPIRV-Registry/issues/326 for details.
2282 // Because we are enabling AllowTransform, we must enable AllowReassoc and
2283 // AllowContract too, as required by SPIRV spec. Also, we used to map
2284 // MIFlag::FmReassoc to FPFastMathMode::Fast, which now should instead by
2285 // replaced by turning all the other bits instead. Therefore, we're
2286 // enabling every bit here except None and Fast.
2287 Flags |= SPIRV::FPFastMathMode::NotNaN | SPIRV::FPFastMathMode::NotInf |
2288 SPIRV::FPFastMathMode::NSZ | SPIRV::FPFastMathMode::AllowRecip |
2289 SPIRV::FPFastMathMode::AllowTransform |
2290 SPIRV::FPFastMathMode::AllowReassoc |
2291 SPIRV::FPFastMathMode::AllowContract;
2292 else
2293 Flags |= SPIRV::FPFastMathMode::Fast;
2294 }
2295
2296 if (CanUseKHRFloatControls2) {
2297 // Error out if SPIRV::FPFastMathMode::Fast is enabled.
2298 assert(!(Flags & SPIRV::FPFastMathMode::Fast) &&
2299 "SPIRV::FPFastMathMode::Fast is deprecated and should not be used "
2300 "anymore.");
2301
2302 // Error out if AllowTransform is enabled without AllowReassoc and
2303 // AllowContract.
2304 assert((!(Flags & SPIRV::FPFastMathMode::AllowTransform) ||
2305 ((Flags & SPIRV::FPFastMathMode::AllowReassoc &&
2306 Flags & SPIRV::FPFastMathMode::AllowContract))) &&
2307 "SPIRV::FPFastMathMode::AllowTransform requires AllowReassoc and "
2308 "AllowContract flags to be enabled as well.");
2309 }
2310
2311 return Flags;
2312}
2313
2314static bool isFastMathModeAvailable(const SPIRVSubtarget &ST) {
2315 if (ST.isKernel())
2316 return true;
2317 if (ST.getSPIRVVersion() < VersionTuple(1, 2))
2318 return false;
2319 return ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2);
2320}
2321
2322static void handleMIFlagDecoration(
2323 MachineInstr &I, const SPIRVSubtarget &ST, const SPIRVInstrInfo &TII,
2325 SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec) {
2326 if (I.getFlag(MachineInstr::MIFlag::NoSWrap) && TII.canUseNSW(I) &&
2327 getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand,
2328 SPIRV::Decoration::NoSignedWrap, ST, Reqs)
2329 .IsSatisfiable) {
2330 buildOpDecorate(I.getOperand(0).getReg(), I, TII,
2331 SPIRV::Decoration::NoSignedWrap, {});
2332 }
2333 if (I.getFlag(MachineInstr::MIFlag::NoUWrap) && TII.canUseNUW(I) &&
2334 getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand,
2335 SPIRV::Decoration::NoUnsignedWrap, ST,
2336 Reqs)
2337 .IsSatisfiable) {
2338 buildOpDecorate(I.getOperand(0).getReg(), I, TII,
2339 SPIRV::Decoration::NoUnsignedWrap, {});
2340 }
2341 if (!TII.canUseFastMathFlags(
2342 I, ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)))
2343 return;
2344
2345 unsigned FMFlags = getFastMathFlags(I, ST);
2346 if (FMFlags == SPIRV::FPFastMathMode::None) {
2347 // We also need to check if any FPFastMathDefault info was set for the
2348 // types used in this instruction.
2349 if (FPFastMathDefaultInfoVec.empty())
2350 return;
2351
2352 // There are three types of instructions that can use fast math flags:
2353 // 1. Arithmetic instructions (FAdd, FMul, FSub, FDiv, FRem, etc.)
2354 // 2. Relational instructions (FCmp, FOrd, FUnord, etc.)
2355 // 3. Extended instructions (ExtInst)
2356 // For arithmetic instructions, the floating point type can be in the
2357 // result type or in the operands, but they all must be the same.
2358 // For the relational and logical instructions, the floating point type
2359 // can only be in the operands 1 and 2, not the result type. Also, the
2360 // operands must have the same type. For the extended instructions, the
2361 // floating point type can be in the result type or in the operands. It's
2362 // unclear if the operands and the result type must be the same. Let's
2363 // assume they must be. Therefore, for 1. and 2., we can check the first
2364 // operand type, and for 3. we can check the result type.
2365 assert(I.getNumOperands() >= 3 && "Expected at least 3 operands");
2366 Register ResReg = I.getOpcode() == SPIRV::OpExtInst
2367 ? I.getOperand(1).getReg()
2368 : I.getOperand(2).getReg();
2369 SPIRVType *ResType = GR->getSPIRVTypeForVReg(ResReg, I.getMF());
2370 const Type *Ty = GR->getTypeForSPIRVType(ResType);
2371 Ty = Ty->isVectorTy() ? cast<VectorType>(Ty)->getElementType() : Ty;
2372
2373 // Match instruction type with the FPFastMathDefaultInfoVec.
2374 bool Emit = false;
2375 for (SPIRV::FPFastMathDefaultInfo &Elem : FPFastMathDefaultInfoVec) {
2376 if (Ty == Elem.Ty) {
2377 FMFlags = Elem.FastMathFlags;
2378 Emit = Elem.ContractionOff || Elem.SignedZeroInfNanPreserve ||
2379 Elem.FPFastMathDefault;
2380 break;
2381 }
2382 }
2383
2384 if (FMFlags == SPIRV::FPFastMathMode::None && !Emit)
2385 return;
2386 }
2387 if (isFastMathModeAvailable(ST)) {
2388 Register DstReg = I.getOperand(0).getReg();
2389 buildOpDecorate(DstReg, I, TII, SPIRV::Decoration::FPFastMathMode,
2390 {FMFlags});
2391 }
2392}
2393
2394// Walk all functions and add decorations related to MI flags.
2395static void addDecorations(const Module &M, const SPIRVInstrInfo &TII,
2396 MachineModuleInfo *MMI, const SPIRVSubtarget &ST,
2398 const SPIRVGlobalRegistry *GR) {
2399 for (auto F = M.begin(), E = M.end(); F != E; ++F) {
2401 if (!MF)
2402 continue;
2403
2404 for (auto &MBB : *MF)
2405 for (auto &MI : MBB)
2406 handleMIFlagDecoration(MI, ST, TII, MAI.Reqs, GR,
2407 MAI.FPFastMathDefaultInfoMap[&(*F)]);
2408 }
2409}
2410
2411static void addMBBNames(const Module &M, const SPIRVInstrInfo &TII,
2412 MachineModuleInfo *MMI, const SPIRVSubtarget &ST,
2414 for (auto F = M.begin(), E = M.end(); F != E; ++F) {
2416 if (!MF)
2417 continue;
2419 for (auto &MBB : *MF) {
2420 if (!MBB.hasName() || MBB.empty())
2421 continue;
2422 // Emit basic block names.
2423 Register Reg = MRI.createGenericVirtualRegister(LLT::scalar(64));
2424 MRI.setRegClass(Reg, &SPIRV::IDRegClass);
2425 buildOpName(Reg, MBB.getName(), *std::prev(MBB.end()), TII);
2426 MCRegister GlobalReg = MAI.getOrCreateMBBRegister(MBB);
2427 MAI.setRegisterAlias(MF, Reg, GlobalReg);
2428 }
2429 }
2430}
2431
2432// patching Instruction::PHI to SPIRV::OpPhi
2433static void patchPhis(const Module &M, SPIRVGlobalRegistry *GR,
2434 const SPIRVInstrInfo &TII, MachineModuleInfo *MMI) {
2435 for (auto F = M.begin(), E = M.end(); F != E; ++F) {
2437 if (!MF)
2438 continue;
2439 for (auto &MBB : *MF) {
2440 for (MachineInstr &MI : MBB.phis()) {
2441 MI.setDesc(TII.get(SPIRV::OpPhi));
2442 Register ResTypeReg = GR->getSPIRVTypeID(
2443 GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg(), MF));
2444 MI.insert(MI.operands_begin() + 1,
2445 {MachineOperand::CreateReg(ResTypeReg, false)});
2446 }
2447 }
2448
2449 MF->getProperties().setNoPHIs();
2450 }
2451}
2452
2454 const Module &M, SPIRV::ModuleAnalysisInfo &MAI, const Function *F) {
2455 auto it = MAI.FPFastMathDefaultInfoMap.find(F);
2456 if (it != MAI.FPFastMathDefaultInfoMap.end())
2457 return it->second;
2458
2459 // If the map does not contain the entry, create a new one. Initialize it to
2460 // contain all 3 elements sorted by bit width of target type: {half, float,
2461 // double}.
2462 SPIRV::FPFastMathDefaultInfoVector FPFastMathDefaultInfoVec;
2463 FPFastMathDefaultInfoVec.emplace_back(Type::getHalfTy(M.getContext()),
2464 SPIRV::FPFastMathMode::None);
2465 FPFastMathDefaultInfoVec.emplace_back(Type::getFloatTy(M.getContext()),
2466 SPIRV::FPFastMathMode::None);
2467 FPFastMathDefaultInfoVec.emplace_back(Type::getDoubleTy(M.getContext()),
2468 SPIRV::FPFastMathMode::None);
2469 return MAI.FPFastMathDefaultInfoMap[F] = std::move(FPFastMathDefaultInfoVec);
2470}
2471
2473 SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec,
2474 const Type *Ty) {
2475 size_t BitWidth = Ty->getScalarSizeInBits();
2476 int Index =
2478 BitWidth);
2479 assert(Index >= 0 && Index < 3 &&
2480 "Expected FPFastMathDefaultInfo for half, float, or double");
2481 assert(FPFastMathDefaultInfoVec.size() == 3 &&
2482 "Expected FPFastMathDefaultInfoVec to have exactly 3 elements");
2483 return FPFastMathDefaultInfoVec[Index];
2484}
2485
2486static void collectFPFastMathDefaults(const Module &M,
2488 const SPIRVSubtarget &ST) {
2489 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2))
2490 return;
2491
2492 // Store the FPFastMathDefaultInfo in the FPFastMathDefaultInfoMap.
2493 // We need the entry point (function) as the key, and the target
2494 // type and flags as the value.
2495 // We also need to check ContractionOff and SignedZeroInfNanPreserve
2496 // execution modes, as they are now deprecated and must be replaced
2497 // with FPFastMathDefaultInfo.
2498 auto Node = M.getNamedMetadata("spirv.ExecutionMode");
2499 if (!Node)
2500 return;
2501
2502 for (unsigned i = 0; i < Node->getNumOperands(); i++) {
2503 MDNode *MDN = cast<MDNode>(Node->getOperand(i));
2504 assert(MDN->getNumOperands() >= 2 && "Expected at least 2 operands");
2505 const Function *F = cast<Function>(
2506 cast<ConstantAsMetadata>(MDN->getOperand(0))->getValue());
2507 const auto EM =
2509 cast<ConstantAsMetadata>(MDN->getOperand(1))->getValue())
2510 ->getZExtValue();
2511 if (EM == SPIRV::ExecutionMode::FPFastMathDefault) {
2512 assert(MDN->getNumOperands() == 4 &&
2513 "Expected 4 operands for FPFastMathDefault");
2514
2515 const Type *T = cast<ValueAsMetadata>(MDN->getOperand(2))->getType();
2516 unsigned Flags =
2518 cast<ConstantAsMetadata>(MDN->getOperand(3))->getValue())
2519 ->getZExtValue();
2520 SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec =
2523 getFPFastMathDefaultInfo(FPFastMathDefaultInfoVec, T);
2524 Info.FastMathFlags = Flags;
2525 Info.FPFastMathDefault = true;
2526 } else if (EM == SPIRV::ExecutionMode::ContractionOff) {
2527 assert(MDN->getNumOperands() == 2 &&
2528 "Expected no operands for ContractionOff");
2529
2530 // We need to save this info for every possible FP type, i.e. {half,
2531 // float, double, fp128}.
2532 SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec =
2534 for (SPIRV::FPFastMathDefaultInfo &Info : FPFastMathDefaultInfoVec) {
2535 Info.ContractionOff = true;
2536 }
2537 } else if (EM == SPIRV::ExecutionMode::SignedZeroInfNanPreserve) {
2538 assert(MDN->getNumOperands() == 3 &&
2539 "Expected 1 operand for SignedZeroInfNanPreserve");
2540 unsigned TargetWidth =
2542 cast<ConstantAsMetadata>(MDN->getOperand(2))->getValue())
2543 ->getZExtValue();
2544 // We need to save this info only for the FP type with TargetWidth.
2545 SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec =
2549 assert(Index >= 0 && Index < 3 &&
2550 "Expected FPFastMathDefaultInfo for half, float, or double");
2551 assert(FPFastMathDefaultInfoVec.size() == 3 &&
2552 "Expected FPFastMathDefaultInfoVec to have exactly 3 elements");
2553 FPFastMathDefaultInfoVec[Index].SignedZeroInfNanPreserve = true;
2554 }
2555 }
2556}
2557
2559
2561 AU.addRequired<TargetPassConfig>();
2562 AU.addRequired<MachineModuleInfoWrapperPass>();
2563}
2564
2566 SPIRVTargetMachine &TM =
2568 ST = TM.getSubtargetImpl();
2569 GR = ST->getSPIRVGlobalRegistry();
2570 TII = ST->getInstrInfo();
2571
2573
2574 setBaseInfo(M);
2575
2576 patchPhis(M, GR, *TII, MMI);
2577
2578 addMBBNames(M, *TII, MMI, *ST, MAI);
2579 collectFPFastMathDefaults(M, MAI, *ST);
2580 addDecorations(M, *TII, MMI, *ST, MAI, GR);
2581
2582 collectReqs(M, MAI, MMI, *ST);
2583
2584 // Process type/const/global var/func decl instructions, number their
2585 // destination registers from 0 to N, collect Extensions and Capabilities.
2586 collectReqs(M, MAI, MMI, *ST);
2587 collectDeclarations(M);
2588
2589 // Number rest of registers from N+1 onwards.
2590 numberRegistersGlobally(M);
2591
2592 // Collect OpName, OpEntryPoint, OpDecorate etc, process other instructions.
2593 processOtherInstrs(M);
2594
2595 // If there are no entry points, we need the Linkage capability.
2596 if (MAI.MS[SPIRV::MB_EntryPoints].empty())
2597 MAI.Reqs.addCapability(SPIRV::Capability::Linkage);
2598
2599 // Set maximum ID used.
2600 GR->setBound(MAI.MaxID);
2601
2602 return false;
2603}
unsigned const MachineRegisterInfo * MRI
MachineInstrBuilder & UseMI
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
aarch64 promote const
ReachingDefInfo InstSet & ToRemove
MachineBasicBlock & MBB
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
Analysis containing CSE Info
Definition CSEInfo.cpp:27
#define clEnumValN(ENUMVAL, FLAGNAME, DESC)
#define DEBUG_TYPE
const HexagonInstrInfo * TII
IRTranslator LLVM IR MI
#define F(x, y, z)
Definition MD5.cpp:55
#define I(x, y, z)
Definition MD5.cpp:58
Register Reg
Promote Memory to Register
Definition Mem2Reg.cpp:110
#define T
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
Definition PassSupport.h:56
static SPIRV::FPFastMathDefaultInfoVector & getOrCreateFPFastMathDefaultInfoVec(const Module &M, DenseMap< Function *, SPIRV::FPFastMathDefaultInfoVector > &FPFastMathDefaultInfoMap, Function *F)
static SPIRV::FPFastMathDefaultInfo & getFPFastMathDefaultInfo(SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec, const Type *Ty)
#define ATOM_FLT_REQ_EXT_MSG(ExtName)
static cl::opt< bool > SPVDumpDeps("spv-dump-deps", cl::desc("Dump MIR with SPIR-V dependencies info"), cl::Optional, cl::init(false))
unsigned unsigned DefaultVal
unsigned OpIndex
static cl::list< SPIRV::Capability::Capability > AvoidCapabilities("avoid-spirv-capabilities", cl::desc("SPIR-V capabilities to avoid if there are " "other options enabling a feature"), cl::ZeroOrMore, cl::Hidden, cl::values(clEnumValN(SPIRV::Capability::Shader, "Shader", "SPIR-V Shader capability")))
This file contains some templates that are useful if you are working with the STL at all.
#define LLVM_DEBUG(...)
Definition Debug.h:114
Target-Independent Code Generator Pass Configuration Options pass.
The Input class is used to parse a yaml document into in-memory structs and vectors.
This is the shared class of boolean and integer constants.
Definition Constants.h:87
This is an important base class in LLVM.
Definition Constant.h:43
static constexpr LLT scalar(unsigned SizeInBits)
Get a low-level scalar or aggregate "bag of bits".
Wrapper class representing physical registers. Should be passed by value.
Definition MCRegister.h:33
constexpr bool isValid() const
Definition MCRegister.h:76
Metadata node.
Definition Metadata.h:1078
const MDOperand & getOperand(unsigned I) const
Definition Metadata.h:1442
unsigned getNumOperands() const
Return number of MDNode operands.
Definition Metadata.h:1448
Tracking metadata reference owned by Metadata.
Definition Metadata.h:900
const MachineFunction * getParent() const
Return the MachineFunction containing this basic block.
MachineRegisterInfo & getRegInfo()
getRegInfo - Return information about the registers currently in use.
const MachineFunctionProperties & getProperties() const
Get the function properties.
Register getReg(unsigned Idx) const
Get the register for the operand index.
Representation of each machine instruction.
unsigned getOpcode() const
Returns the opcode of this MachineInstr.
const MachineBasicBlock * getParent() const
unsigned getNumOperands() const
Retuns the total number of operands.
LLVM_ABI const MachineFunction * getMF() const
Return the function that contains the basic block that this instruction belongs to.
const MachineOperand & getOperand(unsigned i) const
This class contains meta information specific to a module.
LLVM_ABI MachineFunction * getMachineFunction(const Function &F) const
Returns the MachineFunction associated to IR function F if there is one, otherwise nullptr.
MachineOperand class - Representation of each machine instruction operand.
unsigned getSubReg() const
int64_t getImm() const
bool isReg() const
isReg - Tests if this is a MO_Register operand.
bool isImm() const
isImm - Tests if this is a MO_Immediate operand.
LLVM_ABI void print(raw_ostream &os, const TargetRegisterInfo *TRI=nullptr) const
Print the MachineOperand to os.
MachineInstr * getParent()
getParent - Return the instruction that this operand belongs to.
static MachineOperand CreateImm(int64_t Val)
MachineOperandType getType() const
getType - Returns the MachineOperandType for this operand.
Register getReg() const
getReg - Returns the register number.
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:67
virtual void print(raw_ostream &OS, const Module *M) const
print - Print out the internal state of the pass.
Definition Pass.cpp:140
AnalysisType & getAnalysis() const
getAnalysis<AnalysisType>() - This function is used by subclasses to get to the analysis information ...
Wrapper class representing virtual and physical registers.
Definition Register.h:19
constexpr bool isValid() const
Definition Register.h:107
SPIRVType * getSPIRVTypeForVReg(Register VReg, const MachineFunction *MF=nullptr) const
const Type * getTypeForSPIRVType(const SPIRVType *Ty) const
Register getSPIRVTypeID(const SPIRVType *SpirvType) const
unsigned getScalarOrVectorBitWidth(const SPIRVType *Type) const
bool isConstantInstr(const MachineInstr &MI) const
const SPIRVInstrInfo * getInstrInfo() const override
SmallSet - This maintains a set of unique values, optimizing for the case when the set is small (less...
Definition SmallSet.h:133
bool contains(const T &V) const
Check if the SmallSet contains the given element.
Definition SmallSet.h:228
std::pair< const_iterator, bool > insert(const T &V)
insert - Insert an element into the set if it isn't already there.
Definition SmallSet.h:183
reference emplace_back(ArgTypes &&... Args)
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
iterator insert(iterator I, T &&Elt)
void push_back(const T &Elt)
Target-Independent Code Generator Pass Configuration Options.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:45
bool isVectorTy() const
True if this is an instance of VectorType.
Definition Type.h:273
static LLVM_ABI Type * getDoubleTy(LLVMContext &C)
Definition Type.cpp:286
static LLVM_ABI Type * getFloatTy(LLVMContext &C)
Definition Type.cpp:285
static LLVM_ABI Type * getHalfTy(LLVMContext &C)
Definition Type.cpp:283
Represents a version number in the form major[.minor[.subminor[.build]]].
bool empty() const
Determine whether this version information is empty (e.g., all version components are zero).
NodeTy * getNextNode()
Get the next node, or nullptr for the list tail.
Definition ilist_node.h:348
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
SmallVector< const MachineInstr * > InstrList
ValuesClass values(OptsTy... Options)
Helper to build a ValuesClass by forwarding a variable number of arguments as an initializer list to ...
initializer< Ty > init(const Ty &Val)
std::enable_if_t< detail::IsValidPointer< X, Y >::value, X * > extract(Y &&MD)
Extract a Value from Metadata.
Definition Metadata.h:667
NodeAddr< InstrNode * > Instr
Definition RDFGraph.h:389
This is an optimization pass for GlobalISel generic memory operations.
void buildOpName(Register Target, const StringRef &Name, MachineIRBuilder &MIRBuilder)
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
std::string getStringImm(const MachineInstr &MI, unsigned StartIndex)
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1725
hash_code hash_value(const FixedPointSemantics &Val)
ExtensionList getSymbolicOperandExtensions(SPIRV::OperandCategory::OperandCategory Category, uint32_t Value)
CapabilityList getSymbolicOperandCapabilities(SPIRV::OperandCategory::OperandCategory Category, uint32_t Value)
SmallVector< SPIRV::Extension::Extension, 8 > ExtensionList
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
SmallVector< size_t > InstrSignature
VersionTuple getSymbolicOperandMaxVersion(SPIRV::OperandCategory::OperandCategory Category, uint32_t Value)
void buildOpDecorate(Register Reg, MachineIRBuilder &MIRBuilder, SPIRV::Decoration::Decoration Dec, const std::vector< uint32_t > &DecArgs, StringRef StrImm)
MachineInstr * getImm(const MachineOperand &MO, const MachineRegisterInfo *MRI)
CapabilityList getCapabilitiesEnabledByExtension(SPIRV::Extension::Extension Extension)
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:207
LLVM_ABI void report_fatal_error(Error Err, bool gen_crash_diag=true)
Definition Error.cpp:167
const MachineInstr SPIRVType
std::string getSymbolicOperandMnemonic(SPIRV::OperandCategory::OperandCategory Category, int32_t Value)
LLVM_ABI raw_fd_ostream & errs()
This returns a reference to a raw_ostream for standard error.
DWARFExpression::Operation Op
VersionTuple getSymbolicOperandMinVersion(SPIRV::OperandCategory::OperandCategory Category, uint32_t Value)
constexpr unsigned BitWidth
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
SmallVector< SPIRV::Capability::Capability, 8 > CapabilityList
std::set< InstrSignature > InstrTraces
hash_code hash_combine(const Ts &...args)
Combine values into a single hash_code.
Definition Hashing.h:592
std::map< SmallVector< size_t >, unsigned > InstrGRegsMap
#define N
SmallSet< SPIRV::Capability::Capability, 4 > S
static struct SPIRV::ModuleAnalysisInfo MAI
bool runOnModule(Module &M) override
runOnModule - Virtual method overriden by subclasses to process the module being operated on.
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
static size_t computeFPFastMathDefaultInfoVecIndex(size_t BitWidth)
Definition SPIRVUtils.h:146
void setSkipEmission(const MachineInstr *MI)
MCRegister getRegisterAlias(const MachineFunction *MF, Register Reg)
MCRegister getOrCreateMBBRegister(const MachineBasicBlock &MBB)
InstrList MS[NUM_MODULE_SECTIONS]
AddressingModel::AddressingModel Addr
void setRegisterAlias(const MachineFunction *MF, Register Reg, MCRegister AliasReg)
DenseMap< const Function *, SPIRV::FPFastMathDefaultInfoVector > FPFastMathDefaultInfoMap
void addCapabilities(const CapabilityList &ToAdd)
bool isCapabilityAvailable(Capability::Capability Cap) const
void checkSatisfiable(const SPIRVSubtarget &ST) const
void getAndAddRequirements(SPIRV::OperandCategory::OperandCategory Category, uint32_t i, const SPIRVSubtarget &ST)
void addExtension(Extension::Extension ToAdd)
void initAvailableCapabilities(const SPIRVSubtarget &ST)
void removeCapabilityIf(const Capability::Capability ToRemove, const Capability::Capability IfPresent)
void addCapability(Capability::Capability ToAdd)
void addAvailableCaps(const CapabilityList &ToAdd)
void addRequirements(const Requirements &Req)
const std::optional< Capability::Capability > Cap