LLVM 22.0.0git
SPIRVModuleAnalysis.cpp
Go to the documentation of this file.
1//===- SPIRVModuleAnalysis.cpp - analysis of global instrs & regs - C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// The analysis collects instructions that should be output at the module level
10// and performs the global register numbering.
11//
12// The results of this analysis are used in AsmPrinter to rename registers
13// globally and to output required instructions at the module level.
14//
15//===----------------------------------------------------------------------===//
16
17#include "SPIRVModuleAnalysis.h"
20#include "SPIRV.h"
21#include "SPIRVSubtarget.h"
22#include "SPIRVTargetMachine.h"
23#include "SPIRVUtils.h"
24#include "llvm/ADT/STLExtras.h"
27
28using namespace llvm;
29
30#define DEBUG_TYPE "spirv-module-analysis"
31
32static cl::opt<bool>
33 SPVDumpDeps("spv-dump-deps",
34 cl::desc("Dump MIR with SPIR-V dependencies info"),
35 cl::Optional, cl::init(false));
36
38 AvoidCapabilities("avoid-spirv-capabilities",
39 cl::desc("SPIR-V capabilities to avoid if there are "
40 "other options enabling a feature"),
42 cl::values(clEnumValN(SPIRV::Capability::Shader, "Shader",
43 "SPIR-V Shader capability")));
44// Use sets instead of cl::list to check "if contains" condition
49
51
52INITIALIZE_PASS(SPIRVModuleAnalysis, DEBUG_TYPE, "SPIRV module analysis", true,
53 true)
54
55// Retrieve an unsigned from an MDNode with a list of them as operands.
56static unsigned getMetadataUInt(MDNode *MdNode, unsigned OpIndex,
57 unsigned DefaultVal = 0) {
58 if (MdNode && OpIndex < MdNode->getNumOperands()) {
59 const auto &Op = MdNode->getOperand(OpIndex);
60 return mdconst::extract<ConstantInt>(Op)->getZExtValue();
61 }
62 return DefaultVal;
63}
64
66getSymbolicOperandRequirements(SPIRV::OperandCategory::OperandCategory Category,
67 unsigned i, const SPIRVSubtarget &ST,
69 // A set of capabilities to avoid if there is another option.
70 AvoidCapabilitiesSet AvoidCaps;
71 if (!ST.isShader())
72 AvoidCaps.S.insert(SPIRV::Capability::Shader);
73 else
74 AvoidCaps.S.insert(SPIRV::Capability::Kernel);
75
76 VersionTuple ReqMinVer = getSymbolicOperandMinVersion(Category, i);
77 VersionTuple ReqMaxVer = getSymbolicOperandMaxVersion(Category, i);
78 VersionTuple SPIRVVersion = ST.getSPIRVVersion();
79 bool MinVerOK = SPIRVVersion.empty() || SPIRVVersion >= ReqMinVer;
80 bool MaxVerOK =
81 ReqMaxVer.empty() || SPIRVVersion.empty() || SPIRVVersion <= ReqMaxVer;
83 ExtensionList ReqExts = getSymbolicOperandExtensions(Category, i);
84 if (ReqCaps.empty()) {
85 if (ReqExts.empty()) {
86 if (MinVerOK && MaxVerOK)
87 return {true, {}, {}, ReqMinVer, ReqMaxVer};
88 return {false, {}, {}, VersionTuple(), VersionTuple()};
89 }
90 } else if (MinVerOK && MaxVerOK) {
91 if (ReqCaps.size() == 1) {
92 auto Cap = ReqCaps[0];
93 if (Reqs.isCapabilityAvailable(Cap)) {
95 SPIRV::OperandCategory::CapabilityOperand, Cap));
96 return {true, {Cap}, std::move(ReqExts), ReqMinVer, ReqMaxVer};
97 }
98 } else {
99 // By SPIR-V specification: "If an instruction, enumerant, or other
100 // feature specifies multiple enabling capabilities, only one such
101 // capability needs to be declared to use the feature." However, one
102 // capability may be preferred over another. We use command line
103 // argument(s) and AvoidCapabilities to avoid selection of certain
104 // capabilities if there are other options.
105 CapabilityList UseCaps;
106 for (auto Cap : ReqCaps)
107 if (Reqs.isCapabilityAvailable(Cap))
108 UseCaps.push_back(Cap);
109 for (size_t i = 0, Sz = UseCaps.size(); i < Sz; ++i) {
110 auto Cap = UseCaps[i];
111 if (i == Sz - 1 || !AvoidCaps.S.contains(Cap)) {
113 SPIRV::OperandCategory::CapabilityOperand, Cap));
114 return {true, {Cap}, std::move(ReqExts), ReqMinVer, ReqMaxVer};
115 }
116 }
117 }
118 }
119 // If there are no capabilities, or we can't satisfy the version or
120 // capability requirements, use the list of extensions (if the subtarget
121 // can handle them all).
122 if (llvm::all_of(ReqExts, [&ST](const SPIRV::Extension::Extension &Ext) {
123 return ST.canUseExtension(Ext);
124 })) {
125 return {true,
126 {},
127 std::move(ReqExts),
128 VersionTuple(),
129 VersionTuple()}; // TODO: add versions to extensions.
130 }
131 return {false, {}, {}, VersionTuple(), VersionTuple()};
132}
133
134void SPIRVModuleAnalysis::setBaseInfo(const Module &M) {
135 MAI.MaxID = 0;
136 for (int i = 0; i < SPIRV::NUM_MODULE_SECTIONS; i++)
137 MAI.MS[i].clear();
138 MAI.RegisterAliasTable.clear();
139 MAI.InstrsToDelete.clear();
140 MAI.FuncMap.clear();
141 MAI.GlobalVarList.clear();
142 MAI.ExtInstSetMap.clear();
143 MAI.Reqs.clear();
144 MAI.Reqs.initAvailableCapabilities(*ST);
145
146 // TODO: determine memory model and source language from the configuratoin.
147 if (auto MemModel = M.getNamedMetadata("spirv.MemoryModel")) {
148 auto MemMD = MemModel->getOperand(0);
149 MAI.Addr = static_cast<SPIRV::AddressingModel::AddressingModel>(
150 getMetadataUInt(MemMD, 0));
151 MAI.Mem =
152 static_cast<SPIRV::MemoryModel::MemoryModel>(getMetadataUInt(MemMD, 1));
153 } else {
154 // TODO: Add support for VulkanMemoryModel.
155 MAI.Mem = ST->isShader() ? SPIRV::MemoryModel::GLSL450
156 : SPIRV::MemoryModel::OpenCL;
157 if (MAI.Mem == SPIRV::MemoryModel::OpenCL) {
158 unsigned PtrSize = ST->getPointerSize();
159 MAI.Addr = PtrSize == 32 ? SPIRV::AddressingModel::Physical32
160 : PtrSize == 64 ? SPIRV::AddressingModel::Physical64
161 : SPIRV::AddressingModel::Logical;
162 } else {
163 // TODO: Add support for PhysicalStorageBufferAddress.
164 MAI.Addr = SPIRV::AddressingModel::Logical;
165 }
166 }
167 // Get the OpenCL version number from metadata.
168 // TODO: support other source languages.
169 if (auto VerNode = M.getNamedMetadata("opencl.ocl.version")) {
170 MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_C;
171 // Construct version literal in accordance with SPIRV-LLVM-Translator.
172 // TODO: support multiple OCL version metadata.
173 assert(VerNode->getNumOperands() > 0 && "Invalid SPIR");
174 auto VersionMD = VerNode->getOperand(0);
175 unsigned MajorNum = getMetadataUInt(VersionMD, 0, 2);
176 unsigned MinorNum = getMetadataUInt(VersionMD, 1);
177 unsigned RevNum = getMetadataUInt(VersionMD, 2);
178 // Prevent Major part of OpenCL version to be 0
179 MAI.SrcLangVersion =
180 (std::max(1U, MajorNum) * 100 + MinorNum) * 1000 + RevNum;
181 } else {
182 // If there is no information about OpenCL version we are forced to generate
183 // OpenCL 1.0 by default for the OpenCL environment to avoid puzzling
184 // run-times with Unknown/0.0 version output. For a reference, LLVM-SPIRV
185 // Translator avoids potential issues with run-times in a similar manner.
186 if (!ST->isShader()) {
187 MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_CPP;
188 MAI.SrcLangVersion = 100000;
189 } else {
190 MAI.SrcLang = SPIRV::SourceLanguage::Unknown;
191 MAI.SrcLangVersion = 0;
192 }
193 }
194
195 if (auto ExtNode = M.getNamedMetadata("opencl.used.extensions")) {
196 for (unsigned I = 0, E = ExtNode->getNumOperands(); I != E; ++I) {
197 MDNode *MD = ExtNode->getOperand(I);
198 if (!MD || MD->getNumOperands() == 0)
199 continue;
200 for (unsigned J = 0, N = MD->getNumOperands(); J != N; ++J)
201 MAI.SrcExt.insert(cast<MDString>(MD->getOperand(J))->getString());
202 }
203 }
204
205 // Update required capabilities for this memory model, addressing model and
206 // source language.
207 MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand,
208 MAI.Mem, *ST);
209 MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::SourceLanguageOperand,
210 MAI.SrcLang, *ST);
211 MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand,
212 MAI.Addr, *ST);
213
214 if (!ST->isShader()) {
215 // TODO: check if it's required by default.
216 MAI.ExtInstSetMap[static_cast<unsigned>(
217 SPIRV::InstructionSet::OpenCL_std)] = MAI.getNextIDRegister();
218 }
219}
220
221// Appends the signature of the decoration instructions that decorate R to
222// Signature.
223static void appendDecorationsForReg(const MachineRegisterInfo &MRI, Register R,
224 InstrSignature &Signature) {
225 for (MachineInstr &UseMI : MRI.use_instructions(R)) {
226 // We don't handle OpDecorateId because getting the register alias for the
227 // ID can cause problems, and we do not need it for now.
228 if (UseMI.getOpcode() != SPIRV::OpDecorate &&
229 UseMI.getOpcode() != SPIRV::OpMemberDecorate)
230 continue;
231
232 for (unsigned I = 0; I < UseMI.getNumOperands(); ++I) {
233 const MachineOperand &MO = UseMI.getOperand(I);
234 if (MO.isReg())
235 continue;
236 Signature.push_back(hash_value(MO));
237 }
238 }
239}
240
241// Returns a representation of an instruction as a vector of MachineOperand
242// hash values, see llvm::hash_value(const MachineOperand &MO) for details.
243// This creates a signature of the instruction with the same content
244// that MachineOperand::isIdenticalTo uses for comparison.
245static InstrSignature instrToSignature(const MachineInstr &MI,
247 bool UseDefReg) {
248 Register DefReg;
249 InstrSignature Signature{MI.getOpcode()};
250 for (unsigned i = 0; i < MI.getNumOperands(); ++i) {
251 // The only decorations that can be applied more than once to a given <id>
252 // or structure member are UserSemantic(5635), CacheControlLoadINTEL (6442),
253 // and CacheControlStoreINTEL (6443). For all the rest of decorations, we
254 // will only add to the signature the Opcode, the id to which it applies,
255 // and the decoration id, disregarding any decoration flags. This will
256 // ensure that any subsequent decoration with the same id will be deemed as
257 // a duplicate. Then, at the call site, we will be able to handle duplicates
258 // in the best way.
259 unsigned Opcode = MI.getOpcode();
260 if ((Opcode == SPIRV::OpDecorate) && i >= 2) {
261 unsigned DecorationID = MI.getOperand(1).getImm();
262 if (DecorationID != SPIRV::Decoration::UserSemantic &&
263 DecorationID != SPIRV::Decoration::CacheControlLoadINTEL &&
264 DecorationID != SPIRV::Decoration::CacheControlStoreINTEL)
265 continue;
266 }
267 const MachineOperand &MO = MI.getOperand(i);
268 size_t h;
269 if (MO.isReg()) {
270 if (!UseDefReg && MO.isDef()) {
271 assert(!DefReg.isValid() && "Multiple def registers.");
272 DefReg = MO.getReg();
273 continue;
274 }
275 Register RegAlias = MAI.getRegisterAlias(MI.getMF(), MO.getReg());
276 if (!RegAlias.isValid()) {
277 LLVM_DEBUG({
278 dbgs() << "Unexpectedly, no global id found for the operand ";
279 MO.print(dbgs());
280 dbgs() << "\nInstruction: ";
281 MI.print(dbgs());
282 dbgs() << "\n";
283 });
284 report_fatal_error("All v-regs must have been mapped to global id's");
285 }
286 // mimic llvm::hash_value(const MachineOperand &MO)
287 h = hash_combine(MO.getType(), (unsigned)RegAlias, MO.getSubReg(),
288 MO.isDef());
289 } else {
290 h = hash_value(MO);
291 }
292 Signature.push_back(h);
293 }
294
295 if (DefReg.isValid()) {
296 // Decorations change the semantics of the current instruction. So two
297 // identical instruction with different decorations cannot be merged. That
298 // is why we add the decorations to the signature.
299 appendDecorationsForReg(MI.getMF()->getRegInfo(), DefReg, Signature);
300 }
301 return Signature;
302}
303
304bool SPIRVModuleAnalysis::isDeclSection(const MachineRegisterInfo &MRI,
305 const MachineInstr &MI) {
306 unsigned Opcode = MI.getOpcode();
307 switch (Opcode) {
308 case SPIRV::OpTypeForwardPointer:
309 // omit now, collect later
310 return false;
311 case SPIRV::OpVariable:
312 return static_cast<SPIRV::StorageClass::StorageClass>(
313 MI.getOperand(2).getImm()) != SPIRV::StorageClass::Function;
314 case SPIRV::OpFunction:
315 case SPIRV::OpFunctionParameter:
316 return true;
317 }
318 if (GR->hasConstFunPtr() && Opcode == SPIRV::OpUndef) {
319 Register DefReg = MI.getOperand(0).getReg();
320 for (MachineInstr &UseMI : MRI.use_instructions(DefReg)) {
321 if (UseMI.getOpcode() != SPIRV::OpConstantFunctionPointerINTEL)
322 continue;
323 // it's a dummy definition, FP constant refers to a function,
324 // and this is resolved in another way; let's skip this definition
325 assert(UseMI.getOperand(2).isReg() &&
326 UseMI.getOperand(2).getReg() == DefReg);
327 MAI.setSkipEmission(&MI);
328 return false;
329 }
330 }
331 return TII->isTypeDeclInstr(MI) || TII->isConstantInstr(MI) ||
332 TII->isInlineAsmDefInstr(MI);
333}
334
335// This is a special case of a function pointer refering to a possibly
336// forward function declaration. The operand is a dummy OpUndef that
337// requires a special treatment.
338void SPIRVModuleAnalysis::visitFunPtrUse(
339 Register OpReg, InstrGRegsMap &SignatureToGReg,
340 std::map<const Value *, unsigned> &GlobalToGReg, const MachineFunction *MF,
341 const MachineInstr &MI) {
342 const MachineOperand *OpFunDef =
343 GR->getFunctionDefinitionByUse(&MI.getOperand(2));
344 assert(OpFunDef && OpFunDef->isReg());
345 // find the actual function definition and number it globally in advance
346 const MachineInstr *OpDefMI = OpFunDef->getParent();
347 assert(OpDefMI && OpDefMI->getOpcode() == SPIRV::OpFunction);
348 const MachineFunction *FunDefMF = OpDefMI->getParent()->getParent();
349 const MachineRegisterInfo &FunDefMRI = FunDefMF->getRegInfo();
350 do {
351 visitDecl(FunDefMRI, SignatureToGReg, GlobalToGReg, FunDefMF, *OpDefMI);
352 OpDefMI = OpDefMI->getNextNode();
353 } while (OpDefMI && (OpDefMI->getOpcode() == SPIRV::OpFunction ||
354 OpDefMI->getOpcode() == SPIRV::OpFunctionParameter));
355 // associate the function pointer with the newly assigned global number
356 MCRegister GlobalFunDefReg =
357 MAI.getRegisterAlias(FunDefMF, OpFunDef->getReg());
358 assert(GlobalFunDefReg.isValid() &&
359 "Function definition must refer to a global register");
360 MAI.setRegisterAlias(MF, OpReg, GlobalFunDefReg);
361}
362
363// Depth first recursive traversal of dependencies. Repeated visits are guarded
364// by MAI.hasRegisterAlias().
365void SPIRVModuleAnalysis::visitDecl(
366 const MachineRegisterInfo &MRI, InstrGRegsMap &SignatureToGReg,
367 std::map<const Value *, unsigned> &GlobalToGReg, const MachineFunction *MF,
368 const MachineInstr &MI) {
369 unsigned Opcode = MI.getOpcode();
370
371 // Process each operand of the instruction to resolve dependencies
372 for (const MachineOperand &MO : MI.operands()) {
373 if (!MO.isReg() || MO.isDef())
374 continue;
375 Register OpReg = MO.getReg();
376 // Handle function pointers special case
377 if (Opcode == SPIRV::OpConstantFunctionPointerINTEL &&
378 MRI.getRegClass(OpReg) == &SPIRV::pIDRegClass) {
379 visitFunPtrUse(OpReg, SignatureToGReg, GlobalToGReg, MF, MI);
380 continue;
381 }
382 // Skip already processed instructions
383 if (MAI.hasRegisterAlias(MF, MO.getReg()))
384 continue;
385 // Recursively visit dependencies
386 if (const MachineInstr *OpDefMI = MRI.getUniqueVRegDef(OpReg)) {
387 if (isDeclSection(MRI, *OpDefMI))
388 visitDecl(MRI, SignatureToGReg, GlobalToGReg, MF, *OpDefMI);
389 continue;
390 }
391 // Handle the unexpected case of no unique definition for the SPIR-V
392 // instruction
393 LLVM_DEBUG({
394 dbgs() << "Unexpectedly, no unique definition for the operand ";
395 MO.print(dbgs());
396 dbgs() << "\nInstruction: ";
397 MI.print(dbgs());
398 dbgs() << "\n";
399 });
401 "No unique definition is found for the virtual register");
402 }
403
404 MCRegister GReg;
405 bool IsFunDef = false;
406 if (TII->isSpecConstantInstr(MI)) {
407 GReg = MAI.getNextIDRegister();
408 MAI.MS[SPIRV::MB_TypeConstVars].push_back(&MI);
409 } else if (Opcode == SPIRV::OpFunction ||
410 Opcode == SPIRV::OpFunctionParameter) {
411 GReg = handleFunctionOrParameter(MF, MI, GlobalToGReg, IsFunDef);
412 } else if (Opcode == SPIRV::OpTypeStruct ||
413 Opcode == SPIRV::OpConstantComposite) {
414 GReg = handleTypeDeclOrConstant(MI, SignatureToGReg);
415 const MachineInstr *NextInstr = MI.getNextNode();
416 while (NextInstr &&
417 ((Opcode == SPIRV::OpTypeStruct &&
418 NextInstr->getOpcode() == SPIRV::OpTypeStructContinuedINTEL) ||
419 (Opcode == SPIRV::OpConstantComposite &&
420 NextInstr->getOpcode() ==
421 SPIRV::OpConstantCompositeContinuedINTEL))) {
422 MCRegister Tmp = handleTypeDeclOrConstant(*NextInstr, SignatureToGReg);
423 MAI.setRegisterAlias(MF, NextInstr->getOperand(0).getReg(), Tmp);
424 MAI.setSkipEmission(NextInstr);
425 NextInstr = NextInstr->getNextNode();
426 }
427 } else if (TII->isTypeDeclInstr(MI) || TII->isConstantInstr(MI) ||
428 TII->isInlineAsmDefInstr(MI)) {
429 GReg = handleTypeDeclOrConstant(MI, SignatureToGReg);
430 } else if (Opcode == SPIRV::OpVariable) {
431 GReg = handleVariable(MF, MI, GlobalToGReg);
432 } else {
433 LLVM_DEBUG({
434 dbgs() << "\nInstruction: ";
435 MI.print(dbgs());
436 dbgs() << "\n";
437 });
438 llvm_unreachable("Unexpected instruction is visited");
439 }
440 MAI.setRegisterAlias(MF, MI.getOperand(0).getReg(), GReg);
441 if (!IsFunDef)
442 MAI.setSkipEmission(&MI);
443}
444
445MCRegister SPIRVModuleAnalysis::handleFunctionOrParameter(
446 const MachineFunction *MF, const MachineInstr &MI,
447 std::map<const Value *, unsigned> &GlobalToGReg, bool &IsFunDef) {
448 const Value *GObj = GR->getGlobalObject(MF, MI.getOperand(0).getReg());
449 assert(GObj && "Unregistered global definition");
450 const Function *F = dyn_cast<Function>(GObj);
451 if (!F)
452 F = dyn_cast<Argument>(GObj)->getParent();
453 assert(F && "Expected a reference to a function or an argument");
454 IsFunDef = !F->isDeclaration();
455 auto [It, Inserted] = GlobalToGReg.try_emplace(GObj);
456 if (!Inserted)
457 return It->second;
458 MCRegister GReg = MAI.getNextIDRegister();
459 It->second = GReg;
460 if (!IsFunDef)
461 MAI.MS[SPIRV::MB_ExtFuncDecls].push_back(&MI);
462 return GReg;
463}
464
466SPIRVModuleAnalysis::handleTypeDeclOrConstant(const MachineInstr &MI,
467 InstrGRegsMap &SignatureToGReg) {
468 InstrSignature MISign = instrToSignature(MI, MAI, false);
469 auto [It, Inserted] = SignatureToGReg.try_emplace(MISign);
470 if (!Inserted)
471 return It->second;
472 MCRegister GReg = MAI.getNextIDRegister();
473 It->second = GReg;
474 MAI.MS[SPIRV::MB_TypeConstVars].push_back(&MI);
475 return GReg;
476}
477
478MCRegister SPIRVModuleAnalysis::handleVariable(
479 const MachineFunction *MF, const MachineInstr &MI,
480 std::map<const Value *, unsigned> &GlobalToGReg) {
481 MAI.GlobalVarList.push_back(&MI);
482 const Value *GObj = GR->getGlobalObject(MF, MI.getOperand(0).getReg());
483 assert(GObj && "Unregistered global definition");
484 auto [It, Inserted] = GlobalToGReg.try_emplace(GObj);
485 if (!Inserted)
486 return It->second;
487 MCRegister GReg = MAI.getNextIDRegister();
488 It->second = GReg;
489 MAI.MS[SPIRV::MB_TypeConstVars].push_back(&MI);
490 return GReg;
491}
492
493void SPIRVModuleAnalysis::collectDeclarations(const Module &M) {
494 InstrGRegsMap SignatureToGReg;
495 std::map<const Value *, unsigned> GlobalToGReg;
496 for (auto F = M.begin(), E = M.end(); F != E; ++F) {
497 MachineFunction *MF = MMI->getMachineFunction(*F);
498 if (!MF)
499 continue;
500 const MachineRegisterInfo &MRI = MF->getRegInfo();
501 unsigned PastHeader = 0;
502 for (MachineBasicBlock &MBB : *MF) {
503 for (MachineInstr &MI : MBB) {
504 if (MI.getNumOperands() == 0)
505 continue;
506 unsigned Opcode = MI.getOpcode();
507 if (Opcode == SPIRV::OpFunction) {
508 if (PastHeader == 0) {
509 PastHeader = 1;
510 continue;
511 }
512 } else if (Opcode == SPIRV::OpFunctionParameter) {
513 if (PastHeader < 2)
514 continue;
515 } else if (PastHeader > 0) {
516 PastHeader = 2;
517 }
518
519 const MachineOperand &DefMO = MI.getOperand(0);
520 switch (Opcode) {
521 case SPIRV::OpExtension:
522 MAI.Reqs.addExtension(SPIRV::Extension::Extension(DefMO.getImm()));
523 MAI.setSkipEmission(&MI);
524 break;
525 case SPIRV::OpCapability:
526 MAI.Reqs.addCapability(SPIRV::Capability::Capability(DefMO.getImm()));
527 MAI.setSkipEmission(&MI);
528 if (PastHeader > 0)
529 PastHeader = 2;
530 break;
531 default:
532 if (DefMO.isReg() && isDeclSection(MRI, MI) &&
533 !MAI.hasRegisterAlias(MF, DefMO.getReg()))
534 visitDecl(MRI, SignatureToGReg, GlobalToGReg, MF, MI);
535 }
536 }
537 }
538 }
539}
540
541// Look for IDs declared with Import linkage, and map the corresponding function
542// to the register defining that variable (which will usually be the result of
543// an OpFunction). This lets us call externally imported functions using
544// the correct ID registers.
545void SPIRVModuleAnalysis::collectFuncNames(MachineInstr &MI,
546 const Function *F) {
547 if (MI.getOpcode() == SPIRV::OpDecorate) {
548 // If it's got Import linkage.
549 auto Dec = MI.getOperand(1).getImm();
550 if (Dec == SPIRV::Decoration::LinkageAttributes) {
551 auto Lnk = MI.getOperand(MI.getNumOperands() - 1).getImm();
552 if (Lnk == SPIRV::LinkageType::Import) {
553 // Map imported function name to function ID register.
554 const Function *ImportedFunc =
555 F->getParent()->getFunction(getStringImm(MI, 2));
556 Register Target = MI.getOperand(0).getReg();
557 MAI.FuncMap[ImportedFunc] = MAI.getRegisterAlias(MI.getMF(), Target);
558 }
559 }
560 } else if (MI.getOpcode() == SPIRV::OpFunction) {
561 // Record all internal OpFunction declarations.
562 Register Reg = MI.defs().begin()->getReg();
563 MCRegister GlobalReg = MAI.getRegisterAlias(MI.getMF(), Reg);
564 assert(GlobalReg.isValid());
565 MAI.FuncMap[F] = GlobalReg;
566 }
567}
568
569// Collect the given instruction in the specified MS. We assume global register
570// numbering has already occurred by this point. We can directly compare reg
571// arguments when detecting duplicates.
572static void collectOtherInstr(MachineInstr &MI, SPIRV::ModuleAnalysisInfo &MAI,
574 bool Append = true) {
575 MAI.setSkipEmission(&MI);
576 InstrSignature MISign = instrToSignature(MI, MAI, true);
577 auto FoundMI = IS.insert(std::move(MISign));
578 if (!FoundMI.second) {
579 if (MI.getOpcode() == SPIRV::OpDecorate) {
580 assert(MI.getNumOperands() >= 2 &&
581 "Decoration instructions must have at least 2 operands");
582 assert(MSType == SPIRV::MB_Annotations &&
583 "Only OpDecorate instructions can be duplicates");
584 // For FPFastMathMode decoration, we need to merge the flags of the
585 // duplicate decoration with the original one, so we need to find the
586 // original instruction that has the same signature. For the rest of
587 // instructions, we will simply skip the duplicate.
588 if (MI.getOperand(1).getImm() != SPIRV::Decoration::FPFastMathMode)
589 return; // Skip duplicates of other decorations.
590
591 const SPIRV::InstrList &Decorations = MAI.MS[MSType];
592 for (const MachineInstr *OrigMI : Decorations) {
593 if (instrToSignature(*OrigMI, MAI, true) == MISign) {
594 assert(OrigMI->getNumOperands() == MI.getNumOperands() &&
595 "Original instruction must have the same number of operands");
596 assert(
597 OrigMI->getNumOperands() == 3 &&
598 "FPFastMathMode decoration must have 3 operands for OpDecorate");
599 unsigned OrigFlags = OrigMI->getOperand(2).getImm();
600 unsigned NewFlags = MI.getOperand(2).getImm();
601 if (OrigFlags == NewFlags)
602 return; // No need to merge, the flags are the same.
603
604 // Emit warning about possible conflict between flags.
605 unsigned FinalFlags = OrigFlags | NewFlags;
606 llvm::errs()
607 << "Warning: Conflicting FPFastMathMode decoration flags "
608 "in instruction: "
609 << *OrigMI << "Original flags: " << OrigFlags
610 << ", new flags: " << NewFlags
611 << ". They will be merged on a best effort basis, but not "
612 "validated. Final flags: "
613 << FinalFlags << "\n";
614 MachineInstr *OrigMINonConst = const_cast<MachineInstr *>(OrigMI);
615 MachineOperand &OrigFlagsOp = OrigMINonConst->getOperand(2);
616 OrigFlagsOp = MachineOperand::CreateImm(FinalFlags);
617 return; // Merge done, so we found a duplicate; don't add it to MAI.MS
618 }
619 }
620 assert(false && "No original instruction found for the duplicate "
621 "OpDecorate, but we found one in IS.");
622 }
623 return; // insert failed, so we found a duplicate; don't add it to MAI.MS
624 }
625 // No duplicates, so add it.
626 if (Append)
627 MAI.MS[MSType].push_back(&MI);
628 else
629 MAI.MS[MSType].insert(MAI.MS[MSType].begin(), &MI);
630}
631
632// Some global instructions make reference to function-local ID regs, so cannot
633// be correctly collected until these registers are globally numbered.
634void SPIRVModuleAnalysis::processOtherInstrs(const Module &M) {
635 InstrTraces IS;
636 for (auto F = M.begin(), E = M.end(); F != E; ++F) {
637 if (F->isDeclaration())
638 continue;
639 MachineFunction *MF = MMI->getMachineFunction(*F);
640 assert(MF);
641
642 for (MachineBasicBlock &MBB : *MF)
643 for (MachineInstr &MI : MBB) {
644 if (MAI.getSkipEmission(&MI))
645 continue;
646 const unsigned OpCode = MI.getOpcode();
647 if (OpCode == SPIRV::OpString) {
648 collectOtherInstr(MI, MAI, SPIRV::MB_DebugStrings, IS);
649 } else if (OpCode == SPIRV::OpExtInst && MI.getOperand(2).isImm() &&
650 MI.getOperand(2).getImm() ==
651 SPIRV::InstructionSet::
652 NonSemantic_Shader_DebugInfo_100) {
653 MachineOperand Ins = MI.getOperand(3);
654 namespace NS = SPIRV::NonSemanticExtInst;
655 static constexpr int64_t GlobalNonSemanticDITy[] = {
656 NS::DebugSource, NS::DebugCompilationUnit, NS::DebugInfoNone,
657 NS::DebugTypeBasic, NS::DebugTypePointer};
658 bool IsGlobalDI = false;
659 for (unsigned Idx = 0; Idx < std::size(GlobalNonSemanticDITy); ++Idx)
660 IsGlobalDI |= Ins.getImm() == GlobalNonSemanticDITy[Idx];
661 if (IsGlobalDI)
662 collectOtherInstr(MI, MAI, SPIRV::MB_NonSemanticGlobalDI, IS);
663 } else if (OpCode == SPIRV::OpName || OpCode == SPIRV::OpMemberName) {
664 collectOtherInstr(MI, MAI, SPIRV::MB_DebugNames, IS);
665 } else if (OpCode == SPIRV::OpEntryPoint) {
666 collectOtherInstr(MI, MAI, SPIRV::MB_EntryPoints, IS);
667 } else if (TII->isAliasingInstr(MI)) {
668 collectOtherInstr(MI, MAI, SPIRV::MB_AliasingInsts, IS);
669 } else if (TII->isDecorationInstr(MI)) {
670 collectOtherInstr(MI, MAI, SPIRV::MB_Annotations, IS);
671 collectFuncNames(MI, &*F);
672 } else if (TII->isConstantInstr(MI)) {
673 // Now OpSpecConstant*s are not in DT,
674 // but they need to be collected anyway.
675 collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars, IS);
676 } else if (OpCode == SPIRV::OpFunction) {
677 collectFuncNames(MI, &*F);
678 } else if (OpCode == SPIRV::OpTypeForwardPointer) {
679 collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars, IS, false);
680 }
681 }
682 }
683}
684
685// Number registers in all functions globally from 0 onwards and store
686// the result in global register alias table. Some registers are already
687// numbered.
688void SPIRVModuleAnalysis::numberRegistersGlobally(const Module &M) {
689 for (auto F = M.begin(), E = M.end(); F != E; ++F) {
690 if ((*F).isDeclaration())
691 continue;
692 MachineFunction *MF = MMI->getMachineFunction(*F);
693 assert(MF);
694 for (MachineBasicBlock &MBB : *MF) {
695 for (MachineInstr &MI : MBB) {
696 for (MachineOperand &Op : MI.operands()) {
697 if (!Op.isReg())
698 continue;
699 Register Reg = Op.getReg();
700 if (MAI.hasRegisterAlias(MF, Reg))
701 continue;
702 MCRegister NewReg = MAI.getNextIDRegister();
703 MAI.setRegisterAlias(MF, Reg, NewReg);
704 }
705 if (MI.getOpcode() != SPIRV::OpExtInst)
706 continue;
707 auto Set = MI.getOperand(2).getImm();
708 auto [It, Inserted] = MAI.ExtInstSetMap.try_emplace(Set);
709 if (Inserted)
710 It->second = MAI.getNextIDRegister();
711 }
712 }
713 }
714}
715
716// RequirementHandler implementations.
718 SPIRV::OperandCategory::OperandCategory Category, uint32_t i,
719 const SPIRVSubtarget &ST) {
720 addRequirements(getSymbolicOperandRequirements(Category, i, ST, *this));
721}
722
723void SPIRV::RequirementHandler::recursiveAddCapabilities(
724 const CapabilityList &ToPrune) {
725 for (const auto &Cap : ToPrune) {
726 AllCaps.insert(Cap);
727 CapabilityList ImplicitDecls =
728 getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap);
729 recursiveAddCapabilities(ImplicitDecls);
730 }
731}
732
734 for (const auto &Cap : ToAdd) {
735 bool IsNewlyInserted = AllCaps.insert(Cap).second;
736 if (!IsNewlyInserted) // Don't re-add if it's already been declared.
737 continue;
738 CapabilityList ImplicitDecls =
739 getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap);
740 recursiveAddCapabilities(ImplicitDecls);
741 MinimalCaps.push_back(Cap);
742 }
743}
744
746 const SPIRV::Requirements &Req) {
747 if (!Req.IsSatisfiable)
748 report_fatal_error("Adding SPIR-V requirements this target can't satisfy.");
749
750 if (Req.Cap.has_value())
751 addCapabilities({Req.Cap.value()});
752
753 addExtensions(Req.Exts);
754
755 if (!Req.MinVer.empty()) {
756 if (!MaxVersion.empty() && Req.MinVer > MaxVersion) {
757 LLVM_DEBUG(dbgs() << "Conflicting version requirements: >= " << Req.MinVer
758 << " and <= " << MaxVersion << "\n");
759 report_fatal_error("Adding SPIR-V requirements that can't be satisfied.");
760 }
761
762 if (MinVersion.empty() || Req.MinVer > MinVersion)
763 MinVersion = Req.MinVer;
764 }
765
766 if (!Req.MaxVer.empty()) {
767 if (!MinVersion.empty() && Req.MaxVer < MinVersion) {
768 LLVM_DEBUG(dbgs() << "Conflicting version requirements: <= " << Req.MaxVer
769 << " and >= " << MinVersion << "\n");
770 report_fatal_error("Adding SPIR-V requirements that can't be satisfied.");
771 }
772
773 if (MaxVersion.empty() || Req.MaxVer < MaxVersion)
774 MaxVersion = Req.MaxVer;
775 }
776}
777
779 const SPIRVSubtarget &ST) const {
780 // Report as many errors as possible before aborting the compilation.
781 bool IsSatisfiable = true;
782 auto TargetVer = ST.getSPIRVVersion();
783
784 if (!MaxVersion.empty() && !TargetVer.empty() && MaxVersion < TargetVer) {
786 dbgs() << "Target SPIR-V version too high for required features\n"
787 << "Required max version: " << MaxVersion << " target version "
788 << TargetVer << "\n");
789 IsSatisfiable = false;
790 }
791
792 if (!MinVersion.empty() && !TargetVer.empty() && MinVersion > TargetVer) {
793 LLVM_DEBUG(dbgs() << "Target SPIR-V version too low for required features\n"
794 << "Required min version: " << MinVersion
795 << " target version " << TargetVer << "\n");
796 IsSatisfiable = false;
797 }
798
799 if (!MinVersion.empty() && !MaxVersion.empty() && MinVersion > MaxVersion) {
801 dbgs()
802 << "Version is too low for some features and too high for others.\n"
803 << "Required SPIR-V min version: " << MinVersion
804 << " required SPIR-V max version " << MaxVersion << "\n");
805 IsSatisfiable = false;
806 }
807
808 AvoidCapabilitiesSet AvoidCaps;
809 if (!ST.isShader())
810 AvoidCaps.S.insert(SPIRV::Capability::Shader);
811 else
812 AvoidCaps.S.insert(SPIRV::Capability::Kernel);
813
814 for (auto Cap : MinimalCaps) {
815 if (AvailableCaps.contains(Cap) && !AvoidCaps.S.contains(Cap))
816 continue;
817 LLVM_DEBUG(dbgs() << "Capability not supported: "
819 OperandCategory::CapabilityOperand, Cap)
820 << "\n");
821 IsSatisfiable = false;
822 }
823
824 for (auto Ext : AllExtensions) {
825 if (ST.canUseExtension(Ext))
826 continue;
827 LLVM_DEBUG(dbgs() << "Extension not supported: "
829 OperandCategory::ExtensionOperand, Ext)
830 << "\n");
831 IsSatisfiable = false;
832 }
833
834 if (!IsSatisfiable)
835 report_fatal_error("Unable to meet SPIR-V requirements for this target.");
836}
837
838// Add the given capabilities and all their implicitly defined capabilities too.
840 for (const auto Cap : ToAdd)
841 if (AvailableCaps.insert(Cap).second)
842 addAvailableCaps(getSymbolicOperandCapabilities(
843 SPIRV::OperandCategory::CapabilityOperand, Cap));
844}
845
847 const Capability::Capability ToRemove,
848 const Capability::Capability IfPresent) {
849 if (AllCaps.contains(IfPresent))
850 AllCaps.erase(ToRemove);
851}
852
853namespace llvm {
854namespace SPIRV {
855void RequirementHandler::initAvailableCapabilities(const SPIRVSubtarget &ST) {
856 // Provided by both all supported Vulkan versions and OpenCl.
857 addAvailableCaps({Capability::Shader, Capability::Linkage, Capability::Int8,
858 Capability::Int16});
859
860 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 3)))
861 addAvailableCaps({Capability::GroupNonUniform,
862 Capability::GroupNonUniformVote,
863 Capability::GroupNonUniformArithmetic,
864 Capability::GroupNonUniformBallot,
865 Capability::GroupNonUniformClustered,
866 Capability::GroupNonUniformShuffle,
867 Capability::GroupNonUniformShuffleRelative});
868
869 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 6)))
870 addAvailableCaps({Capability::DotProduct, Capability::DotProductInputAll,
871 Capability::DotProductInput4x8Bit,
872 Capability::DotProductInput4x8BitPacked,
873 Capability::DemoteToHelperInvocation});
874
875 // Add capabilities enabled by extensions.
876 for (auto Extension : ST.getAllAvailableExtensions()) {
877 CapabilityList EnabledCapabilities =
879 addAvailableCaps(EnabledCapabilities);
880 }
881
882 if (!ST.isShader()) {
883 initAvailableCapabilitiesForOpenCL(ST);
884 return;
885 }
886
887 if (ST.isShader()) {
888 initAvailableCapabilitiesForVulkan(ST);
889 return;
890 }
891
892 report_fatal_error("Unimplemented environment for SPIR-V generation.");
893}
894
895void RequirementHandler::initAvailableCapabilitiesForOpenCL(
896 const SPIRVSubtarget &ST) {
897 // Add the min requirements for different OpenCL and SPIR-V versions.
898 addAvailableCaps({Capability::Addresses, Capability::Float16Buffer,
899 Capability::Kernel, Capability::Vector16,
900 Capability::Groups, Capability::GenericPointer,
901 Capability::StorageImageWriteWithoutFormat,
902 Capability::StorageImageReadWithoutFormat});
903 if (ST.hasOpenCLFullProfile())
904 addAvailableCaps({Capability::Int64, Capability::Int64Atomics});
905 if (ST.hasOpenCLImageSupport()) {
906 addAvailableCaps({Capability::ImageBasic, Capability::LiteralSampler,
907 Capability::Image1D, Capability::SampledBuffer,
908 Capability::ImageBuffer});
909 if (ST.isAtLeastOpenCLVer(VersionTuple(2, 0)))
910 addAvailableCaps({Capability::ImageReadWrite});
911 }
912 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 1)) &&
913 ST.isAtLeastOpenCLVer(VersionTuple(2, 2)))
914 addAvailableCaps({Capability::SubgroupDispatch, Capability::PipeStorage});
915 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 4)))
916 addAvailableCaps({Capability::DenormPreserve, Capability::DenormFlushToZero,
917 Capability::SignedZeroInfNanPreserve,
918 Capability::RoundingModeRTE,
919 Capability::RoundingModeRTZ});
920 // TODO: verify if this needs some checks.
921 addAvailableCaps({Capability::Float16, Capability::Float64});
922
923 // TODO: add OpenCL extensions.
924}
925
926void RequirementHandler::initAvailableCapabilitiesForVulkan(
927 const SPIRVSubtarget &ST) {
928
929 // Core in Vulkan 1.1 and earlier.
930 addAvailableCaps({Capability::Int64, Capability::Float16, Capability::Float64,
931 Capability::GroupNonUniform, Capability::Image1D,
932 Capability::SampledBuffer, Capability::ImageBuffer,
933 Capability::UniformBufferArrayDynamicIndexing,
934 Capability::SampledImageArrayDynamicIndexing,
935 Capability::StorageBufferArrayDynamicIndexing,
936 Capability::StorageImageArrayDynamicIndexing});
937
938 // Became core in Vulkan 1.2
939 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 5))) {
941 {Capability::ShaderNonUniformEXT, Capability::RuntimeDescriptorArrayEXT,
942 Capability::InputAttachmentArrayDynamicIndexingEXT,
943 Capability::UniformTexelBufferArrayDynamicIndexingEXT,
944 Capability::StorageTexelBufferArrayDynamicIndexingEXT,
945 Capability::UniformBufferArrayNonUniformIndexingEXT,
946 Capability::SampledImageArrayNonUniformIndexingEXT,
947 Capability::StorageBufferArrayNonUniformIndexingEXT,
948 Capability::StorageImageArrayNonUniformIndexingEXT,
949 Capability::InputAttachmentArrayNonUniformIndexingEXT,
950 Capability::UniformTexelBufferArrayNonUniformIndexingEXT,
951 Capability::StorageTexelBufferArrayNonUniformIndexingEXT});
952 }
953
954 // Became core in Vulkan 1.3
955 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 6)))
956 addAvailableCaps({Capability::StorageImageWriteWithoutFormat,
957 Capability::StorageImageReadWithoutFormat});
958}
959
960} // namespace SPIRV
961} // namespace llvm
962
963// Add the required capabilities from a decoration instruction (including
964// BuiltIns).
965static void addOpDecorateReqs(const MachineInstr &MI, unsigned DecIndex,
967 const SPIRVSubtarget &ST) {
968 int64_t DecOp = MI.getOperand(DecIndex).getImm();
969 auto Dec = static_cast<SPIRV::Decoration::Decoration>(DecOp);
970 Reqs.addRequirements(getSymbolicOperandRequirements(
971 SPIRV::OperandCategory::DecorationOperand, Dec, ST, Reqs));
972
973 if (Dec == SPIRV::Decoration::BuiltIn) {
974 int64_t BuiltInOp = MI.getOperand(DecIndex + 1).getImm();
975 auto BuiltIn = static_cast<SPIRV::BuiltIn::BuiltIn>(BuiltInOp);
976 Reqs.addRequirements(getSymbolicOperandRequirements(
977 SPIRV::OperandCategory::BuiltInOperand, BuiltIn, ST, Reqs));
978 } else if (Dec == SPIRV::Decoration::LinkageAttributes) {
979 int64_t LinkageOp = MI.getOperand(MI.getNumOperands() - 1).getImm();
980 SPIRV::LinkageType::LinkageType LnkType =
981 static_cast<SPIRV::LinkageType::LinkageType>(LinkageOp);
982 if (LnkType == SPIRV::LinkageType::LinkOnceODR)
983 Reqs.addExtension(SPIRV::Extension::SPV_KHR_linkonce_odr);
984 } else if (Dec == SPIRV::Decoration::CacheControlLoadINTEL ||
985 Dec == SPIRV::Decoration::CacheControlStoreINTEL) {
986 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_cache_controls);
987 } else if (Dec == SPIRV::Decoration::HostAccessINTEL) {
988 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_global_variable_host_access);
989 } else if (Dec == SPIRV::Decoration::InitModeINTEL ||
990 Dec == SPIRV::Decoration::ImplementInRegisterMapINTEL) {
991 Reqs.addExtension(
992 SPIRV::Extension::SPV_INTEL_global_variable_fpga_decorations);
993 } else if (Dec == SPIRV::Decoration::NonUniformEXT) {
994 Reqs.addRequirements(SPIRV::Capability::ShaderNonUniformEXT);
995 } else if (Dec == SPIRV::Decoration::FPMaxErrorDecorationINTEL) {
996 Reqs.addRequirements(SPIRV::Capability::FPMaxErrorINTEL);
997 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_fp_max_error);
998 } else if (Dec == SPIRV::Decoration::FPFastMathMode) {
999 if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)) {
1000 Reqs.addRequirements(SPIRV::Capability::FloatControls2);
1001 Reqs.addExtension(SPIRV::Extension::SPV_KHR_float_controls2);
1002 }
1003 }
1004}
1005
1006// Add requirements for image handling.
1007static void addOpTypeImageReqs(const MachineInstr &MI,
1009 const SPIRVSubtarget &ST) {
1010 assert(MI.getNumOperands() >= 8 && "Insufficient operands for OpTypeImage");
1011 // The operand indices used here are based on the OpTypeImage layout, which
1012 // the MachineInstr follows as well.
1013 int64_t ImgFormatOp = MI.getOperand(7).getImm();
1014 auto ImgFormat = static_cast<SPIRV::ImageFormat::ImageFormat>(ImgFormatOp);
1015 Reqs.getAndAddRequirements(SPIRV::OperandCategory::ImageFormatOperand,
1016 ImgFormat, ST);
1017
1018 bool IsArrayed = MI.getOperand(4).getImm() == 1;
1019 bool IsMultisampled = MI.getOperand(5).getImm() == 1;
1020 bool NoSampler = MI.getOperand(6).getImm() == 2;
1021 // Add dimension requirements.
1022 assert(MI.getOperand(2).isImm());
1023 switch (MI.getOperand(2).getImm()) {
1024 case SPIRV::Dim::DIM_1D:
1025 Reqs.addRequirements(NoSampler ? SPIRV::Capability::Image1D
1026 : SPIRV::Capability::Sampled1D);
1027 break;
1028 case SPIRV::Dim::DIM_2D:
1029 if (IsMultisampled && NoSampler)
1030 Reqs.addRequirements(SPIRV::Capability::ImageMSArray);
1031 break;
1032 case SPIRV::Dim::DIM_Cube:
1033 Reqs.addRequirements(SPIRV::Capability::Shader);
1034 if (IsArrayed)
1035 Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageCubeArray
1036 : SPIRV::Capability::SampledCubeArray);
1037 break;
1038 case SPIRV::Dim::DIM_Rect:
1039 Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageRect
1040 : SPIRV::Capability::SampledRect);
1041 break;
1042 case SPIRV::Dim::DIM_Buffer:
1043 Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageBuffer
1044 : SPIRV::Capability::SampledBuffer);
1045 break;
1046 case SPIRV::Dim::DIM_SubpassData:
1047 Reqs.addRequirements(SPIRV::Capability::InputAttachment);
1048 break;
1049 }
1050
1051 // Has optional access qualifier.
1052 if (!ST.isShader()) {
1053 if (MI.getNumOperands() > 8 &&
1054 MI.getOperand(8).getImm() == SPIRV::AccessQualifier::ReadWrite)
1055 Reqs.addRequirements(SPIRV::Capability::ImageReadWrite);
1056 else
1057 Reqs.addRequirements(SPIRV::Capability::ImageBasic);
1058 }
1059}
1060
1061// Add requirements for handling atomic float instructions
1062#define ATOM_FLT_REQ_EXT_MSG(ExtName) \
1063 "The atomic float instruction requires the following SPIR-V " \
1064 "extension: SPV_EXT_shader_atomic_float" ExtName
1065static void AddAtomicFloatRequirements(const MachineInstr &MI,
1067 const SPIRVSubtarget &ST) {
1068 assert(MI.getOperand(1).isReg() &&
1069 "Expect register operand in atomic float instruction");
1070 Register TypeReg = MI.getOperand(1).getReg();
1071 SPIRVType *TypeDef = MI.getMF()->getRegInfo().getVRegDef(TypeReg);
1072 if (TypeDef->getOpcode() != SPIRV::OpTypeFloat)
1073 report_fatal_error("Result type of an atomic float instruction must be a "
1074 "floating-point type scalar");
1075
1076 unsigned BitWidth = TypeDef->getOperand(1).getImm();
1077 unsigned Op = MI.getOpcode();
1078 if (Op == SPIRV::OpAtomicFAddEXT) {
1079 if (!ST.canUseExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add))
1081 Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add);
1082 switch (BitWidth) {
1083 case 16:
1084 if (!ST.canUseExtension(
1085 SPIRV::Extension::SPV_EXT_shader_atomic_float16_add))
1086 report_fatal_error(ATOM_FLT_REQ_EXT_MSG("16_add"), false);
1087 Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float16_add);
1088 Reqs.addCapability(SPIRV::Capability::AtomicFloat16AddEXT);
1089 break;
1090 case 32:
1091 Reqs.addCapability(SPIRV::Capability::AtomicFloat32AddEXT);
1092 break;
1093 case 64:
1094 Reqs.addCapability(SPIRV::Capability::AtomicFloat64AddEXT);
1095 break;
1096 default:
1098 "Unexpected floating-point type width in atomic float instruction");
1099 }
1100 } else {
1101 if (!ST.canUseExtension(
1102 SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max))
1103 report_fatal_error(ATOM_FLT_REQ_EXT_MSG("_min_max"), false);
1104 Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max);
1105 switch (BitWidth) {
1106 case 16:
1107 Reqs.addCapability(SPIRV::Capability::AtomicFloat16MinMaxEXT);
1108 break;
1109 case 32:
1110 Reqs.addCapability(SPIRV::Capability::AtomicFloat32MinMaxEXT);
1111 break;
1112 case 64:
1113 Reqs.addCapability(SPIRV::Capability::AtomicFloat64MinMaxEXT);
1114 break;
1115 default:
1117 "Unexpected floating-point type width in atomic float instruction");
1118 }
1119 }
1120}
1121
1122bool isUniformTexelBuffer(MachineInstr *ImageInst) {
1123 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1124 return false;
1125 uint32_t Dim = ImageInst->getOperand(2).getImm();
1126 uint32_t Sampled = ImageInst->getOperand(6).getImm();
1127 return Dim == SPIRV::Dim::DIM_Buffer && Sampled == 1;
1128}
1129
1130bool isStorageTexelBuffer(MachineInstr *ImageInst) {
1131 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1132 return false;
1133 uint32_t Dim = ImageInst->getOperand(2).getImm();
1134 uint32_t Sampled = ImageInst->getOperand(6).getImm();
1135 return Dim == SPIRV::Dim::DIM_Buffer && Sampled == 2;
1136}
1137
1138bool isSampledImage(MachineInstr *ImageInst) {
1139 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1140 return false;
1141 uint32_t Dim = ImageInst->getOperand(2).getImm();
1142 uint32_t Sampled = ImageInst->getOperand(6).getImm();
1143 return Dim != SPIRV::Dim::DIM_Buffer && Sampled == 1;
1144}
1145
1146bool isInputAttachment(MachineInstr *ImageInst) {
1147 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1148 return false;
1149 uint32_t Dim = ImageInst->getOperand(2).getImm();
1150 uint32_t Sampled = ImageInst->getOperand(6).getImm();
1151 return Dim == SPIRV::Dim::DIM_SubpassData && Sampled == 2;
1152}
1153
1154bool isStorageImage(MachineInstr *ImageInst) {
1155 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1156 return false;
1157 uint32_t Dim = ImageInst->getOperand(2).getImm();
1158 uint32_t Sampled = ImageInst->getOperand(6).getImm();
1159 return Dim != SPIRV::Dim::DIM_Buffer && Sampled == 2;
1160}
1161
1162bool isCombinedImageSampler(MachineInstr *SampledImageInst) {
1163 if (SampledImageInst->getOpcode() != SPIRV::OpTypeSampledImage)
1164 return false;
1165
1166 const MachineRegisterInfo &MRI = SampledImageInst->getMF()->getRegInfo();
1167 Register ImageReg = SampledImageInst->getOperand(1).getReg();
1168 auto *ImageInst = MRI.getUniqueVRegDef(ImageReg);
1169 return isSampledImage(ImageInst);
1170}
1171
1172bool hasNonUniformDecoration(Register Reg, const MachineRegisterInfo &MRI) {
1173 for (const auto &MI : MRI.reg_instructions(Reg)) {
1174 if (MI.getOpcode() != SPIRV::OpDecorate)
1175 continue;
1176
1177 uint32_t Dec = MI.getOperand(1).getImm();
1178 if (Dec == SPIRV::Decoration::NonUniformEXT)
1179 return true;
1180 }
1181 return false;
1182}
1183
1184void addOpAccessChainReqs(const MachineInstr &Instr,
1186 const SPIRVSubtarget &Subtarget) {
1187 const MachineRegisterInfo &MRI = Instr.getMF()->getRegInfo();
1188 // Get the result type. If it is an image type, then the shader uses
1189 // descriptor indexing. The appropriate capabilities will be added based
1190 // on the specifics of the image.
1191 Register ResTypeReg = Instr.getOperand(1).getReg();
1192 MachineInstr *ResTypeInst = MRI.getUniqueVRegDef(ResTypeReg);
1193
1194 assert(ResTypeInst->getOpcode() == SPIRV::OpTypePointer);
1195 uint32_t StorageClass = ResTypeInst->getOperand(1).getImm();
1196 if (StorageClass != SPIRV::StorageClass::StorageClass::UniformConstant &&
1197 StorageClass != SPIRV::StorageClass::StorageClass::Uniform &&
1198 StorageClass != SPIRV::StorageClass::StorageClass::StorageBuffer) {
1199 return;
1200 }
1201
1202 bool IsNonUniform =
1203 hasNonUniformDecoration(Instr.getOperand(0).getReg(), MRI);
1204
1205 auto FirstIndexReg = Instr.getOperand(3).getReg();
1206 bool FirstIndexIsConstant =
1207 Subtarget.getInstrInfo()->isConstantInstr(*MRI.getVRegDef(FirstIndexReg));
1208
1209 if (StorageClass == SPIRV::StorageClass::StorageClass::StorageBuffer) {
1210 if (IsNonUniform)
1211 Handler.addRequirements(
1212 SPIRV::Capability::StorageBufferArrayNonUniformIndexingEXT);
1213 else if (!FirstIndexIsConstant)
1214 Handler.addRequirements(
1215 SPIRV::Capability::StorageBufferArrayDynamicIndexing);
1216 return;
1217 }
1218
1219 Register PointeeTypeReg = ResTypeInst->getOperand(2).getReg();
1220 MachineInstr *PointeeType = MRI.getUniqueVRegDef(PointeeTypeReg);
1221 if (PointeeType->getOpcode() != SPIRV::OpTypeImage &&
1222 PointeeType->getOpcode() != SPIRV::OpTypeSampledImage &&
1223 PointeeType->getOpcode() != SPIRV::OpTypeSampler) {
1224 return;
1225 }
1226
1227 if (isUniformTexelBuffer(PointeeType)) {
1228 if (IsNonUniform)
1229 Handler.addRequirements(
1230 SPIRV::Capability::UniformTexelBufferArrayNonUniformIndexingEXT);
1231 else if (!FirstIndexIsConstant)
1232 Handler.addRequirements(
1233 SPIRV::Capability::UniformTexelBufferArrayDynamicIndexingEXT);
1234 } else if (isInputAttachment(PointeeType)) {
1235 if (IsNonUniform)
1236 Handler.addRequirements(
1237 SPIRV::Capability::InputAttachmentArrayNonUniformIndexingEXT);
1238 else if (!FirstIndexIsConstant)
1239 Handler.addRequirements(
1240 SPIRV::Capability::InputAttachmentArrayDynamicIndexingEXT);
1241 } else if (isStorageTexelBuffer(PointeeType)) {
1242 if (IsNonUniform)
1243 Handler.addRequirements(
1244 SPIRV::Capability::StorageTexelBufferArrayNonUniformIndexingEXT);
1245 else if (!FirstIndexIsConstant)
1246 Handler.addRequirements(
1247 SPIRV::Capability::StorageTexelBufferArrayDynamicIndexingEXT);
1248 } else if (isSampledImage(PointeeType) ||
1249 isCombinedImageSampler(PointeeType) ||
1250 PointeeType->getOpcode() == SPIRV::OpTypeSampler) {
1251 if (IsNonUniform)
1252 Handler.addRequirements(
1253 SPIRV::Capability::SampledImageArrayNonUniformIndexingEXT);
1254 else if (!FirstIndexIsConstant)
1255 Handler.addRequirements(
1256 SPIRV::Capability::SampledImageArrayDynamicIndexing);
1257 } else if (isStorageImage(PointeeType)) {
1258 if (IsNonUniform)
1259 Handler.addRequirements(
1260 SPIRV::Capability::StorageImageArrayNonUniformIndexingEXT);
1261 else if (!FirstIndexIsConstant)
1262 Handler.addRequirements(
1263 SPIRV::Capability::StorageImageArrayDynamicIndexing);
1264 }
1265}
1266
1267static bool isImageTypeWithUnknownFormat(SPIRVType *TypeInst) {
1268 if (TypeInst->getOpcode() != SPIRV::OpTypeImage)
1269 return false;
1270 assert(TypeInst->getOperand(7).isImm() && "The image format must be an imm.");
1271 return TypeInst->getOperand(7).getImm() == 0;
1272}
1273
1274static void AddDotProductRequirements(const MachineInstr &MI,
1276 const SPIRVSubtarget &ST) {
1277 if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_integer_dot_product))
1278 Reqs.addExtension(SPIRV::Extension::SPV_KHR_integer_dot_product);
1279 Reqs.addCapability(SPIRV::Capability::DotProduct);
1280
1281 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1282 assert(MI.getOperand(2).isReg() && "Unexpected operand in dot");
1283 // We do not consider what the previous instruction is. This is just used
1284 // to get the input register and to check the type.
1285 const MachineInstr *Input = MRI.getVRegDef(MI.getOperand(2).getReg());
1286 assert(Input->getOperand(1).isReg() && "Unexpected operand in dot input");
1287 Register InputReg = Input->getOperand(1).getReg();
1288
1289 SPIRVType *TypeDef = MRI.getVRegDef(InputReg);
1290 if (TypeDef->getOpcode() == SPIRV::OpTypeInt) {
1291 assert(TypeDef->getOperand(1).getImm() == 32);
1292 Reqs.addCapability(SPIRV::Capability::DotProductInput4x8BitPacked);
1293 } else if (TypeDef->getOpcode() == SPIRV::OpTypeVector) {
1294 SPIRVType *ScalarTypeDef = MRI.getVRegDef(TypeDef->getOperand(1).getReg());
1295 assert(ScalarTypeDef->getOpcode() == SPIRV::OpTypeInt);
1296 if (ScalarTypeDef->getOperand(1).getImm() == 8) {
1297 assert(TypeDef->getOperand(2).getImm() == 4 &&
1298 "Dot operand of 8-bit integer type requires 4 components");
1299 Reqs.addCapability(SPIRV::Capability::DotProductInput4x8Bit);
1300 } else {
1301 Reqs.addCapability(SPIRV::Capability::DotProductInputAll);
1302 }
1303 }
1304}
1305
1306void addPrintfRequirements(const MachineInstr &MI,
1308 const SPIRVSubtarget &ST) {
1309 SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
1310 const SPIRVType *PtrType = GR->getSPIRVTypeForVReg(MI.getOperand(4).getReg());
1311 if (PtrType) {
1312 MachineOperand ASOp = PtrType->getOperand(1);
1313 if (ASOp.isImm()) {
1314 unsigned AddrSpace = ASOp.getImm();
1315 if (AddrSpace != SPIRV::StorageClass::UniformConstant) {
1316 if (!ST.canUseExtension(
1318 SPV_EXT_relaxed_printf_string_address_space)) {
1319 report_fatal_error("SPV_EXT_relaxed_printf_string_address_space is "
1320 "required because printf uses a format string not "
1321 "in constant address space.",
1322 false);
1323 }
1324 Reqs.addExtension(
1325 SPIRV::Extension::SPV_EXT_relaxed_printf_string_address_space);
1326 }
1327 }
1328 }
1329}
1330
1331static bool isBFloat16Type(const SPIRVType *TypeDef) {
1332 return TypeDef && TypeDef->getNumOperands() == 3 &&
1333 TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
1334 TypeDef->getOperand(1).getImm() == 16 &&
1335 TypeDef->getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR;
1336}
1337
1338void addInstrRequirements(const MachineInstr &MI,
1340 const SPIRVSubtarget &ST) {
1341 SPIRV::RequirementHandler &Reqs = MAI.Reqs;
1342 switch (MI.getOpcode()) {
1343 case SPIRV::OpMemoryModel: {
1344 int64_t Addr = MI.getOperand(0).getImm();
1345 Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand,
1346 Addr, ST);
1347 int64_t Mem = MI.getOperand(1).getImm();
1348 Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand, Mem,
1349 ST);
1350 break;
1351 }
1352 case SPIRV::OpEntryPoint: {
1353 int64_t Exe = MI.getOperand(0).getImm();
1354 Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModelOperand,
1355 Exe, ST);
1356 break;
1357 }
1358 case SPIRV::OpExecutionMode:
1359 case SPIRV::OpExecutionModeId: {
1360 int64_t Exe = MI.getOperand(1).getImm();
1361 Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModeOperand,
1362 Exe, ST);
1363 break;
1364 }
1365 case SPIRV::OpTypeMatrix:
1366 Reqs.addCapability(SPIRV::Capability::Matrix);
1367 break;
1368 case SPIRV::OpTypeInt: {
1369 unsigned BitWidth = MI.getOperand(1).getImm();
1370 if (BitWidth == 64)
1371 Reqs.addCapability(SPIRV::Capability::Int64);
1372 else if (BitWidth == 16)
1373 Reqs.addCapability(SPIRV::Capability::Int16);
1374 else if (BitWidth == 8)
1375 Reqs.addCapability(SPIRV::Capability::Int8);
1376 break;
1377 }
1378 case SPIRV::OpDot: {
1379 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1380 SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
1381 if (isBFloat16Type(TypeDef))
1382 Reqs.addCapability(SPIRV::Capability::BFloat16DotProductKHR);
1383 break;
1384 }
1385 case SPIRV::OpTypeFloat: {
1386 unsigned BitWidth = MI.getOperand(1).getImm();
1387 if (BitWidth == 64)
1388 Reqs.addCapability(SPIRV::Capability::Float64);
1389 else if (BitWidth == 16) {
1390 if (isBFloat16Type(&MI)) {
1391 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_bfloat16))
1392 report_fatal_error("OpTypeFloat type with bfloat requires the "
1393 "following SPIR-V extension: SPV_KHR_bfloat16",
1394 false);
1395 Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
1396 Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
1397 } else {
1398 Reqs.addCapability(SPIRV::Capability::Float16);
1399 }
1400 }
1401 break;
1402 }
1403 case SPIRV::OpTypeVector: {
1404 unsigned NumComponents = MI.getOperand(2).getImm();
1405 if (NumComponents == 8 || NumComponents == 16)
1406 Reqs.addCapability(SPIRV::Capability::Vector16);
1407 break;
1408 }
1409 case SPIRV::OpTypePointer: {
1410 auto SC = MI.getOperand(1).getImm();
1411 Reqs.getAndAddRequirements(SPIRV::OperandCategory::StorageClassOperand, SC,
1412 ST);
1413 // If it's a type of pointer to float16 targeting OpenCL, add Float16Buffer
1414 // capability.
1415 if (ST.isShader())
1416 break;
1417 assert(MI.getOperand(2).isReg());
1418 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1419 SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(2).getReg());
1420 if ((TypeDef->getNumOperands() == 2) &&
1421 (TypeDef->getOpcode() == SPIRV::OpTypeFloat) &&
1422 (TypeDef->getOperand(1).getImm() == 16))
1423 Reqs.addCapability(SPIRV::Capability::Float16Buffer);
1424 break;
1425 }
1426 case SPIRV::OpExtInst: {
1427 if (MI.getOperand(2).getImm() ==
1428 static_cast<int64_t>(
1429 SPIRV::InstructionSet::NonSemantic_Shader_DebugInfo_100)) {
1430 Reqs.addExtension(SPIRV::Extension::SPV_KHR_non_semantic_info);
1431 break;
1432 }
1433 if (MI.getOperand(3).getImm() ==
1434 static_cast<int64_t>(SPIRV::OpenCLExtInst::printf)) {
1435 addPrintfRequirements(MI, Reqs, ST);
1436 break;
1437 }
1438 break;
1439 }
1440 case SPIRV::OpAliasDomainDeclINTEL:
1441 case SPIRV::OpAliasScopeDeclINTEL:
1442 case SPIRV::OpAliasScopeListDeclINTEL: {
1443 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_memory_access_aliasing);
1444 Reqs.addCapability(SPIRV::Capability::MemoryAccessAliasingINTEL);
1445 break;
1446 }
1447 case SPIRV::OpBitReverse:
1448 case SPIRV::OpBitFieldInsert:
1449 case SPIRV::OpBitFieldSExtract:
1450 case SPIRV::OpBitFieldUExtract:
1451 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions)) {
1452 Reqs.addCapability(SPIRV::Capability::Shader);
1453 break;
1454 }
1455 Reqs.addExtension(SPIRV::Extension::SPV_KHR_bit_instructions);
1456 Reqs.addCapability(SPIRV::Capability::BitInstructions);
1457 break;
1458 case SPIRV::OpTypeRuntimeArray:
1459 Reqs.addCapability(SPIRV::Capability::Shader);
1460 break;
1461 case SPIRV::OpTypeOpaque:
1462 case SPIRV::OpTypeEvent:
1463 Reqs.addCapability(SPIRV::Capability::Kernel);
1464 break;
1465 case SPIRV::OpTypePipe:
1466 case SPIRV::OpTypeReserveId:
1467 Reqs.addCapability(SPIRV::Capability::Pipes);
1468 break;
1469 case SPIRV::OpTypeDeviceEvent:
1470 case SPIRV::OpTypeQueue:
1471 case SPIRV::OpBuildNDRange:
1472 Reqs.addCapability(SPIRV::Capability::DeviceEnqueue);
1473 break;
1474 case SPIRV::OpDecorate:
1475 case SPIRV::OpDecorateId:
1476 case SPIRV::OpDecorateString:
1477 addOpDecorateReqs(MI, 1, Reqs, ST);
1478 break;
1479 case SPIRV::OpMemberDecorate:
1480 case SPIRV::OpMemberDecorateString:
1481 addOpDecorateReqs(MI, 2, Reqs, ST);
1482 break;
1483 case SPIRV::OpInBoundsPtrAccessChain:
1484 Reqs.addCapability(SPIRV::Capability::Addresses);
1485 break;
1486 case SPIRV::OpConstantSampler:
1487 Reqs.addCapability(SPIRV::Capability::LiteralSampler);
1488 break;
1489 case SPIRV::OpInBoundsAccessChain:
1490 case SPIRV::OpAccessChain:
1491 addOpAccessChainReqs(MI, Reqs, ST);
1492 break;
1493 case SPIRV::OpTypeImage:
1494 addOpTypeImageReqs(MI, Reqs, ST);
1495 break;
1496 case SPIRV::OpTypeSampler:
1497 if (!ST.isShader()) {
1498 Reqs.addCapability(SPIRV::Capability::ImageBasic);
1499 }
1500 break;
1501 case SPIRV::OpTypeForwardPointer:
1502 // TODO: check if it's OpenCL's kernel.
1503 Reqs.addCapability(SPIRV::Capability::Addresses);
1504 break;
1505 case SPIRV::OpAtomicFlagTestAndSet:
1506 case SPIRV::OpAtomicLoad:
1507 case SPIRV::OpAtomicStore:
1508 case SPIRV::OpAtomicExchange:
1509 case SPIRV::OpAtomicCompareExchange:
1510 case SPIRV::OpAtomicIIncrement:
1511 case SPIRV::OpAtomicIDecrement:
1512 case SPIRV::OpAtomicIAdd:
1513 case SPIRV::OpAtomicISub:
1514 case SPIRV::OpAtomicUMin:
1515 case SPIRV::OpAtomicUMax:
1516 case SPIRV::OpAtomicSMin:
1517 case SPIRV::OpAtomicSMax:
1518 case SPIRV::OpAtomicAnd:
1519 case SPIRV::OpAtomicOr:
1520 case SPIRV::OpAtomicXor: {
1521 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1522 const MachineInstr *InstrPtr = &MI;
1523 if (MI.getOpcode() == SPIRV::OpAtomicStore) {
1524 assert(MI.getOperand(3).isReg());
1525 InstrPtr = MRI.getVRegDef(MI.getOperand(3).getReg());
1526 assert(InstrPtr && "Unexpected type instruction for OpAtomicStore");
1527 }
1528 assert(InstrPtr->getOperand(1).isReg() && "Unexpected operand in atomic");
1529 Register TypeReg = InstrPtr->getOperand(1).getReg();
1530 SPIRVType *TypeDef = MRI.getVRegDef(TypeReg);
1531 if (TypeDef->getOpcode() == SPIRV::OpTypeInt) {
1532 unsigned BitWidth = TypeDef->getOperand(1).getImm();
1533 if (BitWidth == 64)
1534 Reqs.addCapability(SPIRV::Capability::Int64Atomics);
1535 }
1536 break;
1537 }
1538 case SPIRV::OpGroupNonUniformIAdd:
1539 case SPIRV::OpGroupNonUniformFAdd:
1540 case SPIRV::OpGroupNonUniformIMul:
1541 case SPIRV::OpGroupNonUniformFMul:
1542 case SPIRV::OpGroupNonUniformSMin:
1543 case SPIRV::OpGroupNonUniformUMin:
1544 case SPIRV::OpGroupNonUniformFMin:
1545 case SPIRV::OpGroupNonUniformSMax:
1546 case SPIRV::OpGroupNonUniformUMax:
1547 case SPIRV::OpGroupNonUniformFMax:
1548 case SPIRV::OpGroupNonUniformBitwiseAnd:
1549 case SPIRV::OpGroupNonUniformBitwiseOr:
1550 case SPIRV::OpGroupNonUniformBitwiseXor:
1551 case SPIRV::OpGroupNonUniformLogicalAnd:
1552 case SPIRV::OpGroupNonUniformLogicalOr:
1553 case SPIRV::OpGroupNonUniformLogicalXor: {
1554 assert(MI.getOperand(3).isImm());
1555 int64_t GroupOp = MI.getOperand(3).getImm();
1556 switch (GroupOp) {
1557 case SPIRV::GroupOperation::Reduce:
1558 case SPIRV::GroupOperation::InclusiveScan:
1559 case SPIRV::GroupOperation::ExclusiveScan:
1560 Reqs.addCapability(SPIRV::Capability::GroupNonUniformArithmetic);
1561 break;
1562 case SPIRV::GroupOperation::ClusteredReduce:
1563 Reqs.addCapability(SPIRV::Capability::GroupNonUniformClustered);
1564 break;
1565 case SPIRV::GroupOperation::PartitionedReduceNV:
1566 case SPIRV::GroupOperation::PartitionedInclusiveScanNV:
1567 case SPIRV::GroupOperation::PartitionedExclusiveScanNV:
1568 Reqs.addCapability(SPIRV::Capability::GroupNonUniformPartitionedNV);
1569 break;
1570 }
1571 break;
1572 }
1573 case SPIRV::OpGroupNonUniformShuffle:
1574 case SPIRV::OpGroupNonUniformShuffleXor:
1575 Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffle);
1576 break;
1577 case SPIRV::OpGroupNonUniformShuffleUp:
1578 case SPIRV::OpGroupNonUniformShuffleDown:
1579 Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffleRelative);
1580 break;
1581 case SPIRV::OpGroupAll:
1582 case SPIRV::OpGroupAny:
1583 case SPIRV::OpGroupBroadcast:
1584 case SPIRV::OpGroupIAdd:
1585 case SPIRV::OpGroupFAdd:
1586 case SPIRV::OpGroupFMin:
1587 case SPIRV::OpGroupUMin:
1588 case SPIRV::OpGroupSMin:
1589 case SPIRV::OpGroupFMax:
1590 case SPIRV::OpGroupUMax:
1591 case SPIRV::OpGroupSMax:
1592 Reqs.addCapability(SPIRV::Capability::Groups);
1593 break;
1594 case SPIRV::OpGroupNonUniformElect:
1595 Reqs.addCapability(SPIRV::Capability::GroupNonUniform);
1596 break;
1597 case SPIRV::OpGroupNonUniformAll:
1598 case SPIRV::OpGroupNonUniformAny:
1599 case SPIRV::OpGroupNonUniformAllEqual:
1600 Reqs.addCapability(SPIRV::Capability::GroupNonUniformVote);
1601 break;
1602 case SPIRV::OpGroupNonUniformBroadcast:
1603 case SPIRV::OpGroupNonUniformBroadcastFirst:
1604 case SPIRV::OpGroupNonUniformBallot:
1605 case SPIRV::OpGroupNonUniformInverseBallot:
1606 case SPIRV::OpGroupNonUniformBallotBitExtract:
1607 case SPIRV::OpGroupNonUniformBallotBitCount:
1608 case SPIRV::OpGroupNonUniformBallotFindLSB:
1609 case SPIRV::OpGroupNonUniformBallotFindMSB:
1610 Reqs.addCapability(SPIRV::Capability::GroupNonUniformBallot);
1611 break;
1612 case SPIRV::OpSubgroupShuffleINTEL:
1613 case SPIRV::OpSubgroupShuffleDownINTEL:
1614 case SPIRV::OpSubgroupShuffleUpINTEL:
1615 case SPIRV::OpSubgroupShuffleXorINTEL:
1616 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) {
1617 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups);
1618 Reqs.addCapability(SPIRV::Capability::SubgroupShuffleINTEL);
1619 }
1620 break;
1621 case SPIRV::OpSubgroupBlockReadINTEL:
1622 case SPIRV::OpSubgroupBlockWriteINTEL:
1623 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) {
1624 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups);
1625 Reqs.addCapability(SPIRV::Capability::SubgroupBufferBlockIOINTEL);
1626 }
1627 break;
1628 case SPIRV::OpSubgroupImageBlockReadINTEL:
1629 case SPIRV::OpSubgroupImageBlockWriteINTEL:
1630 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) {
1631 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups);
1632 Reqs.addCapability(SPIRV::Capability::SubgroupImageBlockIOINTEL);
1633 }
1634 break;
1635 case SPIRV::OpSubgroupImageMediaBlockReadINTEL:
1636 case SPIRV::OpSubgroupImageMediaBlockWriteINTEL:
1637 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_media_block_io)) {
1638 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_media_block_io);
1639 Reqs.addCapability(SPIRV::Capability::SubgroupImageMediaBlockIOINTEL);
1640 }
1641 break;
1642 case SPIRV::OpAssumeTrueKHR:
1643 case SPIRV::OpExpectKHR:
1644 if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_expect_assume)) {
1645 Reqs.addExtension(SPIRV::Extension::SPV_KHR_expect_assume);
1646 Reqs.addCapability(SPIRV::Capability::ExpectAssumeKHR);
1647 }
1648 break;
1649 case SPIRV::OpPtrCastToCrossWorkgroupINTEL:
1650 case SPIRV::OpCrossWorkgroupCastToPtrINTEL:
1651 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes)) {
1652 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes);
1653 Reqs.addCapability(SPIRV::Capability::USMStorageClassesINTEL);
1654 }
1655 break;
1656 case SPIRV::OpConstantFunctionPointerINTEL:
1657 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) {
1658 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
1659 Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL);
1660 }
1661 break;
1662 case SPIRV::OpGroupNonUniformRotateKHR:
1663 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_subgroup_rotate))
1664 report_fatal_error("OpGroupNonUniformRotateKHR instruction requires the "
1665 "following SPIR-V extension: SPV_KHR_subgroup_rotate",
1666 false);
1667 Reqs.addExtension(SPIRV::Extension::SPV_KHR_subgroup_rotate);
1668 Reqs.addCapability(SPIRV::Capability::GroupNonUniformRotateKHR);
1669 Reqs.addCapability(SPIRV::Capability::GroupNonUniform);
1670 break;
1671 case SPIRV::OpGroupIMulKHR:
1672 case SPIRV::OpGroupFMulKHR:
1673 case SPIRV::OpGroupBitwiseAndKHR:
1674 case SPIRV::OpGroupBitwiseOrKHR:
1675 case SPIRV::OpGroupBitwiseXorKHR:
1676 case SPIRV::OpGroupLogicalAndKHR:
1677 case SPIRV::OpGroupLogicalOrKHR:
1678 case SPIRV::OpGroupLogicalXorKHR:
1679 if (ST.canUseExtension(
1680 SPIRV::Extension::SPV_KHR_uniform_group_instructions)) {
1681 Reqs.addExtension(SPIRV::Extension::SPV_KHR_uniform_group_instructions);
1682 Reqs.addCapability(SPIRV::Capability::GroupUniformArithmeticKHR);
1683 }
1684 break;
1685 case SPIRV::OpReadClockKHR:
1686 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_shader_clock))
1687 report_fatal_error("OpReadClockKHR instruction requires the "
1688 "following SPIR-V extension: SPV_KHR_shader_clock",
1689 false);
1690 Reqs.addExtension(SPIRV::Extension::SPV_KHR_shader_clock);
1691 Reqs.addCapability(SPIRV::Capability::ShaderClockKHR);
1692 break;
1693 case SPIRV::OpFunctionPointerCallINTEL:
1694 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) {
1695 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
1696 Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL);
1697 }
1698 break;
1699 case SPIRV::OpAtomicFAddEXT:
1700 case SPIRV::OpAtomicFMinEXT:
1701 case SPIRV::OpAtomicFMaxEXT:
1702 AddAtomicFloatRequirements(MI, Reqs, ST);
1703 break;
1704 case SPIRV::OpConvertBF16ToFINTEL:
1705 case SPIRV::OpConvertFToBF16INTEL:
1706 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bfloat16_conversion)) {
1707 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_bfloat16_conversion);
1708 Reqs.addCapability(SPIRV::Capability::BFloat16ConversionINTEL);
1709 }
1710 break;
1711 case SPIRV::OpRoundFToTF32INTEL:
1712 if (ST.canUseExtension(
1713 SPIRV::Extension::SPV_INTEL_tensor_float32_conversion)) {
1714 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_tensor_float32_conversion);
1715 Reqs.addCapability(SPIRV::Capability::TensorFloat32RoundingINTEL);
1716 }
1717 break;
1718 case SPIRV::OpVariableLengthArrayINTEL:
1719 case SPIRV::OpSaveMemoryINTEL:
1720 case SPIRV::OpRestoreMemoryINTEL:
1721 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_variable_length_array)) {
1722 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_variable_length_array);
1723 Reqs.addCapability(SPIRV::Capability::VariableLengthArrayINTEL);
1724 }
1725 break;
1726 case SPIRV::OpAsmTargetINTEL:
1727 case SPIRV::OpAsmINTEL:
1728 case SPIRV::OpAsmCallINTEL:
1729 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_inline_assembly)) {
1730 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_inline_assembly);
1731 Reqs.addCapability(SPIRV::Capability::AsmINTEL);
1732 }
1733 break;
1734 case SPIRV::OpTypeCooperativeMatrixKHR: {
1735 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
1737 "OpTypeCooperativeMatrixKHR type requires the "
1738 "following SPIR-V extension: SPV_KHR_cooperative_matrix",
1739 false);
1740 Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
1741 Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
1742 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1743 SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
1744 if (isBFloat16Type(TypeDef))
1745 Reqs.addCapability(SPIRV::Capability::BFloat16CooperativeMatrixKHR);
1746 break;
1747 }
1748 case SPIRV::OpArithmeticFenceEXT:
1749 if (!ST.canUseExtension(SPIRV::Extension::SPV_EXT_arithmetic_fence))
1750 report_fatal_error("OpArithmeticFenceEXT requires the "
1751 "following SPIR-V extension: SPV_EXT_arithmetic_fence",
1752 false);
1753 Reqs.addExtension(SPIRV::Extension::SPV_EXT_arithmetic_fence);
1754 Reqs.addCapability(SPIRV::Capability::ArithmeticFenceEXT);
1755 break;
1756 case SPIRV::OpControlBarrierArriveINTEL:
1757 case SPIRV::OpControlBarrierWaitINTEL:
1758 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_split_barrier)) {
1759 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_split_barrier);
1760 Reqs.addCapability(SPIRV::Capability::SplitBarrierINTEL);
1761 }
1762 break;
1763 case SPIRV::OpCooperativeMatrixMulAddKHR: {
1764 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
1765 report_fatal_error("Cooperative matrix instructions require the "
1766 "following SPIR-V extension: "
1767 "SPV_KHR_cooperative_matrix",
1768 false);
1769 Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
1770 Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
1771 constexpr unsigned MulAddMaxSize = 6;
1772 if (MI.getNumOperands() != MulAddMaxSize)
1773 break;
1774 const int64_t CoopOperands = MI.getOperand(MulAddMaxSize - 1).getImm();
1775 if (CoopOperands &
1776 SPIRV::CooperativeMatrixOperands::MatrixAAndBTF32ComponentsINTEL) {
1777 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
1778 report_fatal_error("MatrixAAndBTF32ComponentsINTEL type interpretation "
1779 "require the following SPIR-V extension: "
1780 "SPV_INTEL_joint_matrix",
1781 false);
1782 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
1783 Reqs.addCapability(
1784 SPIRV::Capability::CooperativeMatrixTF32ComponentTypeINTEL);
1785 }
1786 if (CoopOperands & SPIRV::CooperativeMatrixOperands::
1787 MatrixAAndBBFloat16ComponentsINTEL ||
1788 CoopOperands &
1789 SPIRV::CooperativeMatrixOperands::MatrixCBFloat16ComponentsINTEL ||
1790 CoopOperands & SPIRV::CooperativeMatrixOperands::
1791 MatrixResultBFloat16ComponentsINTEL) {
1792 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
1793 report_fatal_error("***BF16ComponentsINTEL type interpretations "
1794 "require the following SPIR-V extension: "
1795 "SPV_INTEL_joint_matrix",
1796 false);
1797 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
1798 Reqs.addCapability(
1799 SPIRV::Capability::CooperativeMatrixBFloat16ComponentTypeINTEL);
1800 }
1801 break;
1802 }
1803 case SPIRV::OpCooperativeMatrixLoadKHR:
1804 case SPIRV::OpCooperativeMatrixStoreKHR:
1805 case SPIRV::OpCooperativeMatrixLoadCheckedINTEL:
1806 case SPIRV::OpCooperativeMatrixStoreCheckedINTEL:
1807 case SPIRV::OpCooperativeMatrixPrefetchINTEL: {
1808 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
1809 report_fatal_error("Cooperative matrix instructions require the "
1810 "following SPIR-V extension: "
1811 "SPV_KHR_cooperative_matrix",
1812 false);
1813 Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
1814 Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
1815
1816 // Check Layout operand in case if it's not a standard one and add the
1817 // appropriate capability.
1818 std::unordered_map<unsigned, unsigned> LayoutToInstMap = {
1819 {SPIRV::OpCooperativeMatrixLoadKHR, 3},
1820 {SPIRV::OpCooperativeMatrixStoreKHR, 2},
1821 {SPIRV::OpCooperativeMatrixLoadCheckedINTEL, 5},
1822 {SPIRV::OpCooperativeMatrixStoreCheckedINTEL, 4},
1823 {SPIRV::OpCooperativeMatrixPrefetchINTEL, 4}};
1824
1825 const auto OpCode = MI.getOpcode();
1826 const unsigned LayoutNum = LayoutToInstMap[OpCode];
1827 Register RegLayout = MI.getOperand(LayoutNum).getReg();
1828 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1829 MachineInstr *MILayout = MRI.getUniqueVRegDef(RegLayout);
1830 if (MILayout->getOpcode() == SPIRV::OpConstantI) {
1831 const unsigned LayoutVal = MILayout->getOperand(2).getImm();
1832 if (LayoutVal ==
1833 static_cast<unsigned>(SPIRV::CooperativeMatrixLayout::PackedINTEL)) {
1834 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
1835 report_fatal_error("PackedINTEL layout require the following SPIR-V "
1836 "extension: SPV_INTEL_joint_matrix",
1837 false);
1838 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
1839 Reqs.addCapability(SPIRV::Capability::PackedCooperativeMatrixINTEL);
1840 }
1841 }
1842
1843 // Nothing to do.
1844 if (OpCode == SPIRV::OpCooperativeMatrixLoadKHR ||
1845 OpCode == SPIRV::OpCooperativeMatrixStoreKHR)
1846 break;
1847
1848 std::string InstName;
1849 switch (OpCode) {
1850 case SPIRV::OpCooperativeMatrixPrefetchINTEL:
1851 InstName = "OpCooperativeMatrixPrefetchINTEL";
1852 break;
1853 case SPIRV::OpCooperativeMatrixLoadCheckedINTEL:
1854 InstName = "OpCooperativeMatrixLoadCheckedINTEL";
1855 break;
1856 case SPIRV::OpCooperativeMatrixStoreCheckedINTEL:
1857 InstName = "OpCooperativeMatrixStoreCheckedINTEL";
1858 break;
1859 }
1860
1861 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix)) {
1862 const std::string ErrorMsg =
1863 InstName + " instruction requires the "
1864 "following SPIR-V extension: SPV_INTEL_joint_matrix";
1865 report_fatal_error(ErrorMsg.c_str(), false);
1866 }
1867 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
1868 if (OpCode == SPIRV::OpCooperativeMatrixPrefetchINTEL) {
1869 Reqs.addCapability(SPIRV::Capability::CooperativeMatrixPrefetchINTEL);
1870 break;
1871 }
1872 Reqs.addCapability(
1873 SPIRV::Capability::CooperativeMatrixCheckedInstructionsINTEL);
1874 break;
1875 }
1876 case SPIRV::OpCooperativeMatrixConstructCheckedINTEL:
1877 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
1878 report_fatal_error("OpCooperativeMatrixConstructCheckedINTEL "
1879 "instructions require the following SPIR-V extension: "
1880 "SPV_INTEL_joint_matrix",
1881 false);
1882 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
1883 Reqs.addCapability(
1884 SPIRV::Capability::CooperativeMatrixCheckedInstructionsINTEL);
1885 break;
1886 case SPIRV::OpCooperativeMatrixGetElementCoordINTEL:
1887 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
1888 report_fatal_error("OpCooperativeMatrixGetElementCoordINTEL requires the "
1889 "following SPIR-V extension: SPV_INTEL_joint_matrix",
1890 false);
1891 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
1892 Reqs.addCapability(
1893 SPIRV::Capability::CooperativeMatrixInvocationInstructionsINTEL);
1894 break;
1895 case SPIRV::OpConvertHandleToImageINTEL:
1896 case SPIRV::OpConvertHandleToSamplerINTEL:
1897 case SPIRV::OpConvertHandleToSampledImageINTEL: {
1898 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bindless_images))
1899 report_fatal_error("OpConvertHandleTo[Image/Sampler/SampledImage]INTEL "
1900 "instructions require the following SPIR-V extension: "
1901 "SPV_INTEL_bindless_images",
1902 false);
1903 SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
1904 SPIRV::AddressingModel::AddressingModel AddrModel = MAI.Addr;
1905 SPIRVType *TyDef = GR->getSPIRVTypeForVReg(MI.getOperand(1).getReg());
1906 if (MI.getOpcode() == SPIRV::OpConvertHandleToImageINTEL &&
1907 TyDef->getOpcode() != SPIRV::OpTypeImage) {
1908 report_fatal_error("Incorrect return type for the instruction "
1909 "OpConvertHandleToImageINTEL",
1910 false);
1911 } else if (MI.getOpcode() == SPIRV::OpConvertHandleToSamplerINTEL &&
1912 TyDef->getOpcode() != SPIRV::OpTypeSampler) {
1913 report_fatal_error("Incorrect return type for the instruction "
1914 "OpConvertHandleToSamplerINTEL",
1915 false);
1916 } else if (MI.getOpcode() == SPIRV::OpConvertHandleToSampledImageINTEL &&
1917 TyDef->getOpcode() != SPIRV::OpTypeSampledImage) {
1918 report_fatal_error("Incorrect return type for the instruction "
1919 "OpConvertHandleToSampledImageINTEL",
1920 false);
1921 }
1922 SPIRVType *SpvTy = GR->getSPIRVTypeForVReg(MI.getOperand(2).getReg());
1923 unsigned Bitwidth = GR->getScalarOrVectorBitWidth(SpvTy);
1924 if (!(Bitwidth == 32 && AddrModel == SPIRV::AddressingModel::Physical32) &&
1925 !(Bitwidth == 64 && AddrModel == SPIRV::AddressingModel::Physical64)) {
1927 "Parameter value must be a 32-bit scalar in case of "
1928 "Physical32 addressing model or a 64-bit scalar in case of "
1929 "Physical64 addressing model",
1930 false);
1931 }
1932 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_bindless_images);
1933 Reqs.addCapability(SPIRV::Capability::BindlessImagesINTEL);
1934 break;
1935 }
1936 case SPIRV::OpSubgroup2DBlockLoadINTEL:
1937 case SPIRV::OpSubgroup2DBlockLoadTransposeINTEL:
1938 case SPIRV::OpSubgroup2DBlockLoadTransformINTEL:
1939 case SPIRV::OpSubgroup2DBlockPrefetchINTEL:
1940 case SPIRV::OpSubgroup2DBlockStoreINTEL: {
1941 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_2d_block_io))
1942 report_fatal_error("OpSubgroup2DBlock[Load/LoadTranspose/LoadTransform/"
1943 "Prefetch/Store]INTEL instructions require the "
1944 "following SPIR-V extension: SPV_INTEL_2d_block_io",
1945 false);
1946 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_2d_block_io);
1947 Reqs.addCapability(SPIRV::Capability::Subgroup2DBlockIOINTEL);
1948
1949 const auto OpCode = MI.getOpcode();
1950 if (OpCode == SPIRV::OpSubgroup2DBlockLoadTransposeINTEL) {
1951 Reqs.addCapability(SPIRV::Capability::Subgroup2DBlockTransposeINTEL);
1952 break;
1953 }
1954 if (OpCode == SPIRV::OpSubgroup2DBlockLoadTransformINTEL) {
1955 Reqs.addCapability(SPIRV::Capability::Subgroup2DBlockTransformINTEL);
1956 break;
1957 }
1958 break;
1959 }
1960 case SPIRV::OpKill: {
1961 Reqs.addCapability(SPIRV::Capability::Shader);
1962 } break;
1963 case SPIRV::OpDemoteToHelperInvocation:
1964 Reqs.addCapability(SPIRV::Capability::DemoteToHelperInvocation);
1965
1966 if (ST.canUseExtension(
1967 SPIRV::Extension::SPV_EXT_demote_to_helper_invocation)) {
1968 if (!ST.isAtLeastSPIRVVer(llvm::VersionTuple(1, 6)))
1969 Reqs.addExtension(
1970 SPIRV::Extension::SPV_EXT_demote_to_helper_invocation);
1971 }
1972 break;
1973 case SPIRV::OpSDot:
1974 case SPIRV::OpUDot:
1975 case SPIRV::OpSUDot:
1976 case SPIRV::OpSDotAccSat:
1977 case SPIRV::OpUDotAccSat:
1978 case SPIRV::OpSUDotAccSat:
1979 AddDotProductRequirements(MI, Reqs, ST);
1980 break;
1981 case SPIRV::OpImageRead: {
1982 Register ImageReg = MI.getOperand(2).getReg();
1983 SPIRVType *TypeDef = ST.getSPIRVGlobalRegistry()->getResultType(
1984 ImageReg, const_cast<MachineFunction *>(MI.getMF()));
1985 // OpImageRead and OpImageWrite can use Unknown Image Formats
1986 // when the Kernel capability is declared. In the OpenCL environment we are
1987 // not allowed to produce
1988 // StorageImageReadWithoutFormat/StorageImageWriteWithoutFormat, see
1989 // https://github.com/KhronosGroup/SPIRV-Headers/issues/487
1990
1991 if (isImageTypeWithUnknownFormat(TypeDef) && ST.isShader())
1992 Reqs.addCapability(SPIRV::Capability::StorageImageReadWithoutFormat);
1993 break;
1994 }
1995 case SPIRV::OpImageWrite: {
1996 Register ImageReg = MI.getOperand(0).getReg();
1997 SPIRVType *TypeDef = ST.getSPIRVGlobalRegistry()->getResultType(
1998 ImageReg, const_cast<MachineFunction *>(MI.getMF()));
1999 // OpImageRead and OpImageWrite can use Unknown Image Formats
2000 // when the Kernel capability is declared. In the OpenCL environment we are
2001 // not allowed to produce
2002 // StorageImageReadWithoutFormat/StorageImageWriteWithoutFormat, see
2003 // https://github.com/KhronosGroup/SPIRV-Headers/issues/487
2004
2005 if (isImageTypeWithUnknownFormat(TypeDef) && ST.isShader())
2006 Reqs.addCapability(SPIRV::Capability::StorageImageWriteWithoutFormat);
2007 break;
2008 }
2009 case SPIRV::OpTypeStructContinuedINTEL:
2010 case SPIRV::OpConstantCompositeContinuedINTEL:
2011 case SPIRV::OpSpecConstantCompositeContinuedINTEL:
2012 case SPIRV::OpCompositeConstructContinuedINTEL: {
2013 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_long_composites))
2015 "Continued instructions require the "
2016 "following SPIR-V extension: SPV_INTEL_long_composites",
2017 false);
2018 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_long_composites);
2019 Reqs.addCapability(SPIRV::Capability::LongCompositesINTEL);
2020 break;
2021 }
2022 case SPIRV::OpSubgroupMatrixMultiplyAccumulateINTEL: {
2023 if (!ST.canUseExtension(
2024 SPIRV::Extension::SPV_INTEL_subgroup_matrix_multiply_accumulate))
2026 "OpSubgroupMatrixMultiplyAccumulateINTEL instruction requires the "
2027 "following SPIR-V "
2028 "extension: SPV_INTEL_subgroup_matrix_multiply_accumulate",
2029 false);
2030 Reqs.addExtension(
2031 SPIRV::Extension::SPV_INTEL_subgroup_matrix_multiply_accumulate);
2032 Reqs.addCapability(
2033 SPIRV::Capability::SubgroupMatrixMultiplyAccumulateINTEL);
2034 break;
2035 }
2036 case SPIRV::OpBitwiseFunctionINTEL: {
2037 if (!ST.canUseExtension(
2038 SPIRV::Extension::SPV_INTEL_ternary_bitwise_function))
2040 "OpBitwiseFunctionINTEL instruction requires the following SPIR-V "
2041 "extension: SPV_INTEL_ternary_bitwise_function",
2042 false);
2043 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_ternary_bitwise_function);
2044 Reqs.addCapability(SPIRV::Capability::TernaryBitwiseFunctionINTEL);
2045 break;
2046 }
2047 case SPIRV::OpCopyMemorySized: {
2048 Reqs.addCapability(SPIRV::Capability::Addresses);
2049 // TODO: Add UntypedPointersKHR when implemented.
2050 break;
2051 }
2052 case SPIRV::OpPredicatedLoadINTEL:
2053 case SPIRV::OpPredicatedStoreINTEL: {
2054 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_predicated_io))
2056 "OpPredicated[Load/Store]INTEL instructions require "
2057 "the following SPIR-V extension: SPV_INTEL_predicated_io",
2058 false);
2059 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_predicated_io);
2060 Reqs.addCapability(SPIRV::Capability::PredicatedIOINTEL);
2061 break;
2062 }
2063
2064 default:
2065 break;
2066 }
2067
2068 // If we require capability Shader, then we can remove the requirement for
2069 // the BitInstructions capability, since Shader is a superset capability
2070 // of BitInstructions.
2071 Reqs.removeCapabilityIf(SPIRV::Capability::BitInstructions,
2072 SPIRV::Capability::Shader);
2073}
2074
2075static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI,
2076 MachineModuleInfo *MMI, const SPIRVSubtarget &ST) {
2077 // Collect requirements for existing instructions.
2078 for (auto F = M.begin(), E = M.end(); F != E; ++F) {
2080 if (!MF)
2081 continue;
2082 for (const MachineBasicBlock &MBB : *MF)
2083 for (const MachineInstr &MI : MBB)
2084 addInstrRequirements(MI, MAI, ST);
2085 }
2086 // Collect requirements for OpExecutionMode instructions.
2087 auto Node = M.getNamedMetadata("spirv.ExecutionMode");
2088 if (Node) {
2089 bool RequireFloatControls = false, RequireIntelFloatControls2 = false,
2090 RequireKHRFloatControls2 = false,
2091 VerLower14 = !ST.isAtLeastSPIRVVer(VersionTuple(1, 4));
2092 bool HasIntelFloatControls2 =
2093 ST.canUseExtension(SPIRV::Extension::SPV_INTEL_float_controls2);
2094 bool HasKHRFloatControls2 =
2095 ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2);
2096 for (unsigned i = 0; i < Node->getNumOperands(); i++) {
2097 MDNode *MDN = cast<MDNode>(Node->getOperand(i));
2098 const MDOperand &MDOp = MDN->getOperand(1);
2099 if (auto *CMeta = dyn_cast<ConstantAsMetadata>(MDOp)) {
2100 Constant *C = CMeta->getValue();
2101 if (ConstantInt *Const = dyn_cast<ConstantInt>(C)) {
2102 auto EM = Const->getZExtValue();
2103 // SPV_KHR_float_controls is not available until v1.4:
2104 // add SPV_KHR_float_controls if the version is too low
2105 switch (EM) {
2106 case SPIRV::ExecutionMode::DenormPreserve:
2107 case SPIRV::ExecutionMode::DenormFlushToZero:
2108 case SPIRV::ExecutionMode::RoundingModeRTE:
2109 case SPIRV::ExecutionMode::RoundingModeRTZ:
2110 RequireFloatControls = VerLower14;
2112 SPIRV::OperandCategory::ExecutionModeOperand, EM, ST);
2113 break;
2114 case SPIRV::ExecutionMode::RoundingModeRTPINTEL:
2115 case SPIRV::ExecutionMode::RoundingModeRTNINTEL:
2116 case SPIRV::ExecutionMode::FloatingPointModeALTINTEL:
2117 case SPIRV::ExecutionMode::FloatingPointModeIEEEINTEL:
2118 if (HasIntelFloatControls2) {
2119 RequireIntelFloatControls2 = true;
2121 SPIRV::OperandCategory::ExecutionModeOperand, EM, ST);
2122 }
2123 break;
2124 case SPIRV::ExecutionMode::FPFastMathDefault: {
2125 if (HasKHRFloatControls2) {
2126 RequireKHRFloatControls2 = true;
2128 SPIRV::OperandCategory::ExecutionModeOperand, EM, ST);
2129 }
2130 break;
2131 }
2132 case SPIRV::ExecutionMode::ContractionOff:
2133 case SPIRV::ExecutionMode::SignedZeroInfNanPreserve:
2134 if (HasKHRFloatControls2) {
2135 RequireKHRFloatControls2 = true;
2137 SPIRV::OperandCategory::ExecutionModeOperand,
2138 SPIRV::ExecutionMode::FPFastMathDefault, ST);
2139 } else {
2141 SPIRV::OperandCategory::ExecutionModeOperand, EM, ST);
2142 }
2143 break;
2144 default:
2146 SPIRV::OperandCategory::ExecutionModeOperand, EM, ST);
2147 }
2148 }
2149 }
2150 }
2151 if (RequireFloatControls &&
2152 ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls))
2153 MAI.Reqs.addExtension(SPIRV::Extension::SPV_KHR_float_controls);
2154 if (RequireIntelFloatControls2)
2155 MAI.Reqs.addExtension(SPIRV::Extension::SPV_INTEL_float_controls2);
2156 if (RequireKHRFloatControls2)
2157 MAI.Reqs.addExtension(SPIRV::Extension::SPV_KHR_float_controls2);
2158 }
2159 for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) {
2160 const Function &F = *FI;
2161 if (F.isDeclaration())
2162 continue;
2163 if (F.getMetadata("reqd_work_group_size"))
2165 SPIRV::OperandCategory::ExecutionModeOperand,
2166 SPIRV::ExecutionMode::LocalSize, ST);
2167 if (F.getFnAttribute("hlsl.numthreads").isValid()) {
2169 SPIRV::OperandCategory::ExecutionModeOperand,
2170 SPIRV::ExecutionMode::LocalSize, ST);
2171 }
2172 if (F.getFnAttribute("enable-maximal-reconvergence").getValueAsBool()) {
2173 MAI.Reqs.addExtension(SPIRV::Extension::SPV_KHR_maximal_reconvergence);
2174 }
2175 if (F.getMetadata("work_group_size_hint"))
2177 SPIRV::OperandCategory::ExecutionModeOperand,
2178 SPIRV::ExecutionMode::LocalSizeHint, ST);
2179 if (F.getMetadata("intel_reqd_sub_group_size"))
2181 SPIRV::OperandCategory::ExecutionModeOperand,
2182 SPIRV::ExecutionMode::SubgroupSize, ST);
2183 if (F.getMetadata("vec_type_hint"))
2185 SPIRV::OperandCategory::ExecutionModeOperand,
2186 SPIRV::ExecutionMode::VecTypeHint, ST);
2187
2188 if (F.hasOptNone()) {
2189 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_optnone)) {
2190 MAI.Reqs.addExtension(SPIRV::Extension::SPV_INTEL_optnone);
2191 MAI.Reqs.addCapability(SPIRV::Capability::OptNoneINTEL);
2192 } else if (ST.canUseExtension(SPIRV::Extension::SPV_EXT_optnone)) {
2193 MAI.Reqs.addExtension(SPIRV::Extension::SPV_EXT_optnone);
2194 MAI.Reqs.addCapability(SPIRV::Capability::OptNoneEXT);
2195 }
2196 }
2197 }
2198}
2199
2200static unsigned getFastMathFlags(const MachineInstr &I,
2201 const SPIRVSubtarget &ST) {
2202 unsigned Flags = SPIRV::FPFastMathMode::None;
2203 bool CanUseKHRFloatControls2 =
2204 ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2);
2205 if (I.getFlag(MachineInstr::MIFlag::FmNoNans))
2206 Flags |= SPIRV::FPFastMathMode::NotNaN;
2207 if (I.getFlag(MachineInstr::MIFlag::FmNoInfs))
2208 Flags |= SPIRV::FPFastMathMode::NotInf;
2209 if (I.getFlag(MachineInstr::MIFlag::FmNsz))
2210 Flags |= SPIRV::FPFastMathMode::NSZ;
2211 if (I.getFlag(MachineInstr::MIFlag::FmArcp))
2212 Flags |= SPIRV::FPFastMathMode::AllowRecip;
2213 if (I.getFlag(MachineInstr::MIFlag::FmContract) && CanUseKHRFloatControls2)
2214 Flags |= SPIRV::FPFastMathMode::AllowContract;
2215 if (I.getFlag(MachineInstr::MIFlag::FmReassoc)) {
2216 if (CanUseKHRFloatControls2)
2217 // LLVM reassoc maps to SPIRV transform, see
2218 // https://github.com/KhronosGroup/SPIRV-Registry/issues/326 for details.
2219 // Because we are enabling AllowTransform, we must enable AllowReassoc and
2220 // AllowContract too, as required by SPIRV spec. Also, we used to map
2221 // MIFlag::FmReassoc to FPFastMathMode::Fast, which now should instead by
2222 // replaced by turning all the other bits instead. Therefore, we're
2223 // enabling every bit here except None and Fast.
2224 Flags |= SPIRV::FPFastMathMode::NotNaN | SPIRV::FPFastMathMode::NotInf |
2225 SPIRV::FPFastMathMode::NSZ | SPIRV::FPFastMathMode::AllowRecip |
2226 SPIRV::FPFastMathMode::AllowTransform |
2227 SPIRV::FPFastMathMode::AllowReassoc |
2228 SPIRV::FPFastMathMode::AllowContract;
2229 else
2230 Flags |= SPIRV::FPFastMathMode::Fast;
2231 }
2232
2233 if (CanUseKHRFloatControls2) {
2234 // Error out if SPIRV::FPFastMathMode::Fast is enabled.
2235 assert(!(Flags & SPIRV::FPFastMathMode::Fast) &&
2236 "SPIRV::FPFastMathMode::Fast is deprecated and should not be used "
2237 "anymore.");
2238
2239 // Error out if AllowTransform is enabled without AllowReassoc and
2240 // AllowContract.
2241 assert((!(Flags & SPIRV::FPFastMathMode::AllowTransform) ||
2242 ((Flags & SPIRV::FPFastMathMode::AllowReassoc &&
2243 Flags & SPIRV::FPFastMathMode::AllowContract))) &&
2244 "SPIRV::FPFastMathMode::AllowTransform requires AllowReassoc and "
2245 "AllowContract flags to be enabled as well.");
2246 }
2247
2248 return Flags;
2249}
2250
2251static bool isFastMathModeAvailable(const SPIRVSubtarget &ST) {
2252 if (ST.isKernel())
2253 return true;
2254 if (ST.getSPIRVVersion() < VersionTuple(1, 2))
2255 return false;
2256 return ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2);
2257}
2258
2259static void handleMIFlagDecoration(
2260 MachineInstr &I, const SPIRVSubtarget &ST, const SPIRVInstrInfo &TII,
2262 SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec) {
2263 if (I.getFlag(MachineInstr::MIFlag::NoSWrap) && TII.canUseNSW(I) &&
2264 getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand,
2265 SPIRV::Decoration::NoSignedWrap, ST, Reqs)
2266 .IsSatisfiable) {
2267 buildOpDecorate(I.getOperand(0).getReg(), I, TII,
2268 SPIRV::Decoration::NoSignedWrap, {});
2269 }
2270 if (I.getFlag(MachineInstr::MIFlag::NoUWrap) && TII.canUseNUW(I) &&
2271 getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand,
2272 SPIRV::Decoration::NoUnsignedWrap, ST,
2273 Reqs)
2274 .IsSatisfiable) {
2275 buildOpDecorate(I.getOperand(0).getReg(), I, TII,
2276 SPIRV::Decoration::NoUnsignedWrap, {});
2277 }
2278 if (!TII.canUseFastMathFlags(
2279 I, ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)))
2280 return;
2281
2282 unsigned FMFlags = getFastMathFlags(I, ST);
2283 if (FMFlags == SPIRV::FPFastMathMode::None) {
2284 // We also need to check if any FPFastMathDefault info was set for the
2285 // types used in this instruction.
2286 if (FPFastMathDefaultInfoVec.empty())
2287 return;
2288
2289 // There are three types of instructions that can use fast math flags:
2290 // 1. Arithmetic instructions (FAdd, FMul, FSub, FDiv, FRem, etc.)
2291 // 2. Relational instructions (FCmp, FOrd, FUnord, etc.)
2292 // 3. Extended instructions (ExtInst)
2293 // For arithmetic instructions, the floating point type can be in the
2294 // result type or in the operands, but they all must be the same.
2295 // For the relational and logical instructions, the floating point type
2296 // can only be in the operands 1 and 2, not the result type. Also, the
2297 // operands must have the same type. For the extended instructions, the
2298 // floating point type can be in the result type or in the operands. It's
2299 // unclear if the operands and the result type must be the same. Let's
2300 // assume they must be. Therefore, for 1. and 2., we can check the first
2301 // operand type, and for 3. we can check the result type.
2302 assert(I.getNumOperands() >= 3 && "Expected at least 3 operands");
2303 Register ResReg = I.getOpcode() == SPIRV::OpExtInst
2304 ? I.getOperand(1).getReg()
2305 : I.getOperand(2).getReg();
2306 SPIRVType *ResType = GR->getSPIRVTypeForVReg(ResReg, I.getMF());
2307 const Type *Ty = GR->getTypeForSPIRVType(ResType);
2308 Ty = Ty->isVectorTy() ? cast<VectorType>(Ty)->getElementType() : Ty;
2309
2310 // Match instruction type with the FPFastMathDefaultInfoVec.
2311 bool Emit = false;
2312 for (SPIRV::FPFastMathDefaultInfo &Elem : FPFastMathDefaultInfoVec) {
2313 if (Ty == Elem.Ty) {
2314 FMFlags = Elem.FastMathFlags;
2315 Emit = Elem.ContractionOff || Elem.SignedZeroInfNanPreserve ||
2316 Elem.FPFastMathDefault;
2317 break;
2318 }
2319 }
2320
2321 if (FMFlags == SPIRV::FPFastMathMode::None && !Emit)
2322 return;
2323 }
2324 if (isFastMathModeAvailable(ST)) {
2325 Register DstReg = I.getOperand(0).getReg();
2326 buildOpDecorate(DstReg, I, TII, SPIRV::Decoration::FPFastMathMode,
2327 {FMFlags});
2328 }
2329}
2330
2331// Walk all functions and add decorations related to MI flags.
2332static void addDecorations(const Module &M, const SPIRVInstrInfo &TII,
2333 MachineModuleInfo *MMI, const SPIRVSubtarget &ST,
2335 const SPIRVGlobalRegistry *GR) {
2336 for (auto F = M.begin(), E = M.end(); F != E; ++F) {
2338 if (!MF)
2339 continue;
2340
2341 for (auto &MBB : *MF)
2342 for (auto &MI : MBB)
2343 handleMIFlagDecoration(MI, ST, TII, MAI.Reqs, GR,
2344 MAI.FPFastMathDefaultInfoMap[&(*F)]);
2345 }
2346}
2347
2348static void addMBBNames(const Module &M, const SPIRVInstrInfo &TII,
2349 MachineModuleInfo *MMI, const SPIRVSubtarget &ST,
2351 for (auto F = M.begin(), E = M.end(); F != E; ++F) {
2353 if (!MF)
2354 continue;
2356 for (auto &MBB : *MF) {
2357 if (!MBB.hasName() || MBB.empty())
2358 continue;
2359 // Emit basic block names.
2360 Register Reg = MRI.createGenericVirtualRegister(LLT::scalar(64));
2361 MRI.setRegClass(Reg, &SPIRV::IDRegClass);
2362 buildOpName(Reg, MBB.getName(), *std::prev(MBB.end()), TII);
2363 MCRegister GlobalReg = MAI.getOrCreateMBBRegister(MBB);
2364 MAI.setRegisterAlias(MF, Reg, GlobalReg);
2365 }
2366 }
2367}
2368
2369// patching Instruction::PHI to SPIRV::OpPhi
2370static void patchPhis(const Module &M, SPIRVGlobalRegistry *GR,
2371 const SPIRVInstrInfo &TII, MachineModuleInfo *MMI) {
2372 for (auto F = M.begin(), E = M.end(); F != E; ++F) {
2374 if (!MF)
2375 continue;
2376 for (auto &MBB : *MF) {
2377 for (MachineInstr &MI : MBB.phis()) {
2378 MI.setDesc(TII.get(SPIRV::OpPhi));
2379 Register ResTypeReg = GR->getSPIRVTypeID(
2380 GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg(), MF));
2381 MI.insert(MI.operands_begin() + 1,
2382 {MachineOperand::CreateReg(ResTypeReg, false)});
2383 }
2384 }
2385
2386 MF->getProperties().setNoPHIs();
2387 }
2388}
2389
2391 const Module &M, SPIRV::ModuleAnalysisInfo &MAI, const Function *F) {
2392 auto it = MAI.FPFastMathDefaultInfoMap.find(F);
2393 if (it != MAI.FPFastMathDefaultInfoMap.end())
2394 return it->second;
2395
2396 // If the map does not contain the entry, create a new one. Initialize it to
2397 // contain all 3 elements sorted by bit width of target type: {half, float,
2398 // double}.
2399 SPIRV::FPFastMathDefaultInfoVector FPFastMathDefaultInfoVec;
2400 FPFastMathDefaultInfoVec.emplace_back(Type::getHalfTy(M.getContext()),
2401 SPIRV::FPFastMathMode::None);
2402 FPFastMathDefaultInfoVec.emplace_back(Type::getFloatTy(M.getContext()),
2403 SPIRV::FPFastMathMode::None);
2404 FPFastMathDefaultInfoVec.emplace_back(Type::getDoubleTy(M.getContext()),
2405 SPIRV::FPFastMathMode::None);
2406 return MAI.FPFastMathDefaultInfoMap[F] = std::move(FPFastMathDefaultInfoVec);
2407}
2408
2410 SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec,
2411 const Type *Ty) {
2412 size_t BitWidth = Ty->getScalarSizeInBits();
2413 int Index =
2415 BitWidth);
2416 assert(Index >= 0 && Index < 3 &&
2417 "Expected FPFastMathDefaultInfo for half, float, or double");
2418 assert(FPFastMathDefaultInfoVec.size() == 3 &&
2419 "Expected FPFastMathDefaultInfoVec to have exactly 3 elements");
2420 return FPFastMathDefaultInfoVec[Index];
2421}
2422
2423static void collectFPFastMathDefaults(const Module &M,
2425 const SPIRVSubtarget &ST) {
2426 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2))
2427 return;
2428
2429 // Store the FPFastMathDefaultInfo in the FPFastMathDefaultInfoMap.
2430 // We need the entry point (function) as the key, and the target
2431 // type and flags as the value.
2432 // We also need to check ContractionOff and SignedZeroInfNanPreserve
2433 // execution modes, as they are now deprecated and must be replaced
2434 // with FPFastMathDefaultInfo.
2435 auto Node = M.getNamedMetadata("spirv.ExecutionMode");
2436 if (!Node)
2437 return;
2438
2439 for (unsigned i = 0; i < Node->getNumOperands(); i++) {
2440 MDNode *MDN = cast<MDNode>(Node->getOperand(i));
2441 assert(MDN->getNumOperands() >= 2 && "Expected at least 2 operands");
2442 const Function *F = cast<Function>(
2443 cast<ConstantAsMetadata>(MDN->getOperand(0))->getValue());
2444 const auto EM =
2446 cast<ConstantAsMetadata>(MDN->getOperand(1))->getValue())
2447 ->getZExtValue();
2448 if (EM == SPIRV::ExecutionMode::FPFastMathDefault) {
2449 assert(MDN->getNumOperands() == 4 &&
2450 "Expected 4 operands for FPFastMathDefault");
2451
2452 const Type *T = cast<ValueAsMetadata>(MDN->getOperand(2))->getType();
2453 unsigned Flags =
2455 cast<ConstantAsMetadata>(MDN->getOperand(3))->getValue())
2456 ->getZExtValue();
2457 SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec =
2460 getFPFastMathDefaultInfo(FPFastMathDefaultInfoVec, T);
2461 Info.FastMathFlags = Flags;
2462 Info.FPFastMathDefault = true;
2463 } else if (EM == SPIRV::ExecutionMode::ContractionOff) {
2464 assert(MDN->getNumOperands() == 2 &&
2465 "Expected no operands for ContractionOff");
2466
2467 // We need to save this info for every possible FP type, i.e. {half,
2468 // float, double, fp128}.
2469 SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec =
2471 for (SPIRV::FPFastMathDefaultInfo &Info : FPFastMathDefaultInfoVec) {
2472 Info.ContractionOff = true;
2473 }
2474 } else if (EM == SPIRV::ExecutionMode::SignedZeroInfNanPreserve) {
2475 assert(MDN->getNumOperands() == 3 &&
2476 "Expected 1 operand for SignedZeroInfNanPreserve");
2477 unsigned TargetWidth =
2479 cast<ConstantAsMetadata>(MDN->getOperand(2))->getValue())
2480 ->getZExtValue();
2481 // We need to save this info only for the FP type with TargetWidth.
2482 SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec =
2486 assert(Index >= 0 && Index < 3 &&
2487 "Expected FPFastMathDefaultInfo for half, float, or double");
2488 assert(FPFastMathDefaultInfoVec.size() == 3 &&
2489 "Expected FPFastMathDefaultInfoVec to have exactly 3 elements");
2490 FPFastMathDefaultInfoVec[Index].SignedZeroInfNanPreserve = true;
2491 }
2492 }
2493}
2494
2496
2498 AU.addRequired<TargetPassConfig>();
2499 AU.addRequired<MachineModuleInfoWrapperPass>();
2500}
2501
2503 SPIRVTargetMachine &TM =
2505 ST = TM.getSubtargetImpl();
2506 GR = ST->getSPIRVGlobalRegistry();
2507 TII = ST->getInstrInfo();
2508
2510
2511 setBaseInfo(M);
2512
2513 patchPhis(M, GR, *TII, MMI);
2514
2515 addMBBNames(M, *TII, MMI, *ST, MAI);
2516 collectFPFastMathDefaults(M, MAI, *ST);
2517 addDecorations(M, *TII, MMI, *ST, MAI, GR);
2518
2519 collectReqs(M, MAI, MMI, *ST);
2520
2521 // Process type/const/global var/func decl instructions, number their
2522 // destination registers from 0 to N, collect Extensions and Capabilities.
2523 collectReqs(M, MAI, MMI, *ST);
2524 collectDeclarations(M);
2525
2526 // Number rest of registers from N+1 onwards.
2527 numberRegistersGlobally(M);
2528
2529 // Collect OpName, OpEntryPoint, OpDecorate etc, process other instructions.
2530 processOtherInstrs(M);
2531
2532 // If there are no entry points, we need the Linkage capability.
2533 if (MAI.MS[SPIRV::MB_EntryPoints].empty())
2534 MAI.Reqs.addCapability(SPIRV::Capability::Linkage);
2535
2536 // Set maximum ID used.
2537 GR->setBound(MAI.MaxID);
2538
2539 return false;
2540}
unsigned const MachineRegisterInfo * MRI
MachineInstrBuilder & UseMI
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
aarch64 promote const
ReachingDefInfo InstSet & ToRemove
MachineBasicBlock & MBB
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
Analysis containing CSE Info
Definition CSEInfo.cpp:27
#define clEnumValN(ENUMVAL, FLAGNAME, DESC)
#define DEBUG_TYPE
const HexagonInstrInfo * TII
IRTranslator LLVM IR MI
#define F(x, y, z)
Definition MD5.cpp:55
#define I(x, y, z)
Definition MD5.cpp:58
Register Reg
Promote Memory to Register
Definition Mem2Reg.cpp:110
#define T
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
Definition PassSupport.h:56
static SPIRV::FPFastMathDefaultInfoVector & getOrCreateFPFastMathDefaultInfoVec(const Module &M, DenseMap< Function *, SPIRV::FPFastMathDefaultInfoVector > &FPFastMathDefaultInfoMap, Function *F)
static SPIRV::FPFastMathDefaultInfo & getFPFastMathDefaultInfo(SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec, const Type *Ty)
#define ATOM_FLT_REQ_EXT_MSG(ExtName)
static cl::opt< bool > SPVDumpDeps("spv-dump-deps", cl::desc("Dump MIR with SPIR-V dependencies info"), cl::Optional, cl::init(false))
unsigned unsigned DefaultVal
unsigned OpIndex
static cl::list< SPIRV::Capability::Capability > AvoidCapabilities("avoid-spirv-capabilities", cl::desc("SPIR-V capabilities to avoid if there are " "other options enabling a feature"), cl::ZeroOrMore, cl::Hidden, cl::values(clEnumValN(SPIRV::Capability::Shader, "Shader", "SPIR-V Shader capability")))
This file contains some templates that are useful if you are working with the STL at all.
#define LLVM_DEBUG(...)
Definition Debug.h:114
Target-Independent Code Generator Pass Configuration Options pass.
The Input class is used to parse a yaml document into in-memory structs and vectors.
This is the shared class of boolean and integer constants.
Definition Constants.h:87
This is an important base class in LLVM.
Definition Constant.h:43
static constexpr LLT scalar(unsigned SizeInBits)
Get a low-level scalar or aggregate "bag of bits".
Wrapper class representing physical registers. Should be passed by value.
Definition MCRegister.h:33
constexpr bool isValid() const
Definition MCRegister.h:76
Metadata node.
Definition Metadata.h:1078
const MDOperand & getOperand(unsigned I) const
Definition Metadata.h:1442
unsigned getNumOperands() const
Return number of MDNode operands.
Definition Metadata.h:1448
Tracking metadata reference owned by Metadata.
Definition Metadata.h:900
const MachineFunction * getParent() const
Return the MachineFunction containing this basic block.
MachineRegisterInfo & getRegInfo()
getRegInfo - Return information about the registers currently in use.
const MachineFunctionProperties & getProperties() const
Get the function properties.
Register getReg(unsigned Idx) const
Get the register for the operand index.
Representation of each machine instruction.
unsigned getOpcode() const
Returns the opcode of this MachineInstr.
const MachineBasicBlock * getParent() const
unsigned getNumOperands() const
Retuns the total number of operands.
LLVM_ABI const MachineFunction * getMF() const
Return the function that contains the basic block that this instruction belongs to.
const MachineOperand & getOperand(unsigned i) const
This class contains meta information specific to a module.
LLVM_ABI MachineFunction * getMachineFunction(const Function &F) const
Returns the MachineFunction associated to IR function F if there is one, otherwise nullptr.
MachineOperand class - Representation of each machine instruction operand.
unsigned getSubReg() const
int64_t getImm() const
bool isReg() const
isReg - Tests if this is a MO_Register operand.
bool isImm() const
isImm - Tests if this is a MO_Immediate operand.
LLVM_ABI void print(raw_ostream &os, const TargetRegisterInfo *TRI=nullptr) const
Print the MachineOperand to os.
MachineInstr * getParent()
getParent - Return the instruction that this operand belongs to.
static MachineOperand CreateImm(int64_t Val)
MachineOperandType getType() const
getType - Returns the MachineOperandType for this operand.
Register getReg() const
getReg - Returns the register number.
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:67
virtual void print(raw_ostream &OS, const Module *M) const
print - Print out the internal state of the pass.
Definition Pass.cpp:140
AnalysisType & getAnalysis() const
getAnalysis<AnalysisType>() - This function is used by subclasses to get to the analysis information ...
Wrapper class representing virtual and physical registers.
Definition Register.h:19
constexpr bool isValid() const
Definition Register.h:107
SPIRVType * getSPIRVTypeForVReg(Register VReg, const MachineFunction *MF=nullptr) const
const Type * getTypeForSPIRVType(const SPIRVType *Ty) const
Register getSPIRVTypeID(const SPIRVType *SpirvType) const
unsigned getScalarOrVectorBitWidth(const SPIRVType *Type) const
bool isConstantInstr(const MachineInstr &MI) const
const SPIRVInstrInfo * getInstrInfo() const override
SmallSet - This maintains a set of unique values, optimizing for the case when the set is small (less...
Definition SmallSet.h:133
bool contains(const T &V) const
Check if the SmallSet contains the given element.
Definition SmallSet.h:228
std::pair< const_iterator, bool > insert(const T &V)
insert - Insert an element into the set if it isn't already there.
Definition SmallSet.h:183
reference emplace_back(ArgTypes &&... Args)
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
iterator insert(iterator I, T &&Elt)
void push_back(const T &Elt)
Target-Independent Code Generator Pass Configuration Options.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:45
bool isVectorTy() const
True if this is an instance of VectorType.
Definition Type.h:273
static LLVM_ABI Type * getDoubleTy(LLVMContext &C)
Definition Type.cpp:286
static LLVM_ABI Type * getFloatTy(LLVMContext &C)
Definition Type.cpp:285
static LLVM_ABI Type * getHalfTy(LLVMContext &C)
Definition Type.cpp:283
Represents a version number in the form major[.minor[.subminor[.build]]].
bool empty() const
Determine whether this version information is empty (e.g., all version components are zero).
NodeTy * getNextNode()
Get the next node, or nullptr for the list tail.
Definition ilist_node.h:348
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
SmallVector< const MachineInstr * > InstrList
ValuesClass values(OptsTy... Options)
Helper to build a ValuesClass by forwarding a variable number of arguments as an initializer list to ...
initializer< Ty > init(const Ty &Val)
std::enable_if_t< detail::IsValidPointer< X, Y >::value, X * > extract(Y &&MD)
Extract a Value from Metadata.
Definition Metadata.h:667
NodeAddr< InstrNode * > Instr
Definition RDFGraph.h:389
This is an optimization pass for GlobalISel generic memory operations.
void buildOpName(Register Target, const StringRef &Name, MachineIRBuilder &MIRBuilder)
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
std::string getStringImm(const MachineInstr &MI, unsigned StartIndex)
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1725
hash_code hash_value(const FixedPointSemantics &Val)
ExtensionList getSymbolicOperandExtensions(SPIRV::OperandCategory::OperandCategory Category, uint32_t Value)
CapabilityList getSymbolicOperandCapabilities(SPIRV::OperandCategory::OperandCategory Category, uint32_t Value)
SmallVector< SPIRV::Extension::Extension, 8 > ExtensionList
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
SmallVector< size_t > InstrSignature
VersionTuple getSymbolicOperandMaxVersion(SPIRV::OperandCategory::OperandCategory Category, uint32_t Value)
void buildOpDecorate(Register Reg, MachineIRBuilder &MIRBuilder, SPIRV::Decoration::Decoration Dec, const std::vector< uint32_t > &DecArgs, StringRef StrImm)
MachineInstr * getImm(const MachineOperand &MO, const MachineRegisterInfo *MRI)
CapabilityList getCapabilitiesEnabledByExtension(SPIRV::Extension::Extension Extension)
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:207
LLVM_ABI void report_fatal_error(Error Err, bool gen_crash_diag=true)
Definition Error.cpp:167
const MachineInstr SPIRVType
std::string getSymbolicOperandMnemonic(SPIRV::OperandCategory::OperandCategory Category, int32_t Value)
LLVM_ABI raw_fd_ostream & errs()
This returns a reference to a raw_ostream for standard error.
DWARFExpression::Operation Op
VersionTuple getSymbolicOperandMinVersion(SPIRV::OperandCategory::OperandCategory Category, uint32_t Value)
constexpr unsigned BitWidth
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
SmallVector< SPIRV::Capability::Capability, 8 > CapabilityList
std::set< InstrSignature > InstrTraces
hash_code hash_combine(const Ts &...args)
Combine values into a single hash_code.
Definition Hashing.h:592
std::map< SmallVector< size_t >, unsigned > InstrGRegsMap
#define N
SmallSet< SPIRV::Capability::Capability, 4 > S
static struct SPIRV::ModuleAnalysisInfo MAI
bool runOnModule(Module &M) override
runOnModule - Virtual method overriden by subclasses to process the module being operated on.
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
static size_t computeFPFastMathDefaultInfoVecIndex(size_t BitWidth)
Definition SPIRVUtils.h:146
void setSkipEmission(const MachineInstr *MI)
MCRegister getRegisterAlias(const MachineFunction *MF, Register Reg)
MCRegister getOrCreateMBBRegister(const MachineBasicBlock &MBB)
InstrList MS[NUM_MODULE_SECTIONS]
AddressingModel::AddressingModel Addr
void setRegisterAlias(const MachineFunction *MF, Register Reg, MCRegister AliasReg)
DenseMap< const Function *, SPIRV::FPFastMathDefaultInfoVector > FPFastMathDefaultInfoMap
void addCapabilities(const CapabilityList &ToAdd)
bool isCapabilityAvailable(Capability::Capability Cap) const
void checkSatisfiable(const SPIRVSubtarget &ST) const
void getAndAddRequirements(SPIRV::OperandCategory::OperandCategory Category, uint32_t i, const SPIRVSubtarget &ST)
void addExtension(Extension::Extension ToAdd)
void initAvailableCapabilities(const SPIRVSubtarget &ST)
void removeCapabilityIf(const Capability::Capability ToRemove, const Capability::Capability IfPresent)
void addCapability(Capability::Capability ToAdd)
void addAvailableCaps(const CapabilityList &ToAdd)
void addRequirements(const Requirements &Req)
const std::optional< Capability::Capability > Cap