LLVM 20.0.0git
NVPTXISelLowering.cpp
Go to the documentation of this file.
1//===-- NVPTXISelLowering.cpp - NVPTX DAG Lowering Implementation ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the interfaces that NVPTX uses to lower LLVM code into a
10// selection DAG.
11//
12//===----------------------------------------------------------------------===//
13
14#include "NVPTXISelLowering.h"
16#include "NVPTX.h"
17#include "NVPTXSubtarget.h"
18#include "NVPTXTargetMachine.h"
20#include "NVPTXUtilities.h"
21#include "llvm/ADT/APInt.h"
22#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/StringRef.h"
36#include "llvm/IR/Argument.h"
37#include "llvm/IR/Attributes.h"
38#include "llvm/IR/Constants.h"
39#include "llvm/IR/DataLayout.h"
42#include "llvm/IR/FPEnv.h"
43#include "llvm/IR/Function.h"
44#include "llvm/IR/GlobalValue.h"
45#include "llvm/IR/Instruction.h"
47#include "llvm/IR/IntrinsicsNVPTX.h"
48#include "llvm/IR/Module.h"
49#include "llvm/IR/Type.h"
50#include "llvm/IR/Value.h"
59#include <algorithm>
60#include <cassert>
61#include <cmath>
62#include <cstdint>
63#include <iterator>
64#include <optional>
65#include <sstream>
66#include <string>
67#include <utility>
68#include <vector>
69
70#define DEBUG_TYPE "nvptx-lower"
71
72using namespace llvm;
73
74static std::atomic<unsigned> GlobalUniqueCallSite;
75
77 "nvptx-sched4reg",
78 cl::desc("NVPTX Specific: schedule for register pressue"), cl::init(false));
79
81 "nvptx-fma-level", cl::Hidden,
82 cl::desc("NVPTX Specific: FMA contraction (0: don't do it"
83 " 1: do it 2: do it aggressively"),
84 cl::init(2));
85
87 "nvptx-prec-divf32", cl::Hidden,
88 cl::desc("NVPTX Specifies: 0 use div.approx, 1 use div.full, 2 use"
89 " IEEE Compliant F32 div.rnd if available."),
90 cl::init(2));
91
93 "nvptx-prec-sqrtf32", cl::Hidden,
94 cl::desc("NVPTX Specific: 0 use sqrt.approx, 1 use sqrt.rn."),
95 cl::init(true));
96
98 "nvptx-force-min-byval-param-align", cl::Hidden,
99 cl::desc("NVPTX Specific: force 4-byte minimal alignment for byval"
100 " params of device functions."),
101 cl::init(false));
102
104 if (UsePrecDivF32.getNumOccurrences() > 0) {
105 // If nvptx-prec-div32=N is used on the command-line, always honor it
106 return UsePrecDivF32;
107 } else {
108 // Otherwise, use div.approx if fast math is enabled
109 if (getTargetMachine().Options.UnsafeFPMath)
110 return 0;
111 else
112 return 2;
113 }
114}
115
117 if (UsePrecSqrtF32.getNumOccurrences() > 0) {
118 // If nvptx-prec-sqrtf32 is used on the command-line, always honor it
119 return UsePrecSqrtF32;
120 } else {
121 // Otherwise, use sqrt.approx if fast math is enabled
123 }
124}
125
129}
130
131static bool IsPTXVectorType(MVT VT) {
132 switch (VT.SimpleTy) {
133 default:
134 return false;
135 case MVT::v2i1:
136 case MVT::v4i1:
137 case MVT::v2i8:
138 case MVT::v4i8:
139 case MVT::v2i16:
140 case MVT::v4i16:
141 case MVT::v8i16: // <4 x i16x2>
142 case MVT::v2i32:
143 case MVT::v4i32:
144 case MVT::v2i64:
145 case MVT::v2f16:
146 case MVT::v4f16:
147 case MVT::v8f16: // <4 x f16x2>
148 case MVT::v2bf16:
149 case MVT::v4bf16:
150 case MVT::v8bf16: // <4 x bf16x2>
151 case MVT::v2f32:
152 case MVT::v4f32:
153 case MVT::v2f64:
154 return true;
155 }
156}
157
158static bool Is16bitsType(MVT VT) {
159 return (VT.SimpleTy == MVT::f16 || VT.SimpleTy == MVT::bf16 ||
160 VT.SimpleTy == MVT::i16);
161}
162
163/// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive
164/// EVTs that compose it. Unlike ComputeValueVTs, this will break apart vectors
165/// into their primitive components.
166/// NOTE: This is a band-aid for code that expects ComputeValueVTs to return the
167/// same number of types as the Ins/Outs arrays in LowerFormalArguments,
168/// LowerCall, and LowerReturn.
169static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
170 Type *Ty, SmallVectorImpl<EVT> &ValueVTs,
171 SmallVectorImpl<uint64_t> *Offsets = nullptr,
172 uint64_t StartingOffset = 0) {
173 SmallVector<EVT, 16> TempVTs;
174 SmallVector<uint64_t, 16> TempOffsets;
175
176 // Special case for i128 - decompose to (i64, i64)
177 if (Ty->isIntegerTy(128)) {
178 ValueVTs.push_back(EVT(MVT::i64));
179 ValueVTs.push_back(EVT(MVT::i64));
180
181 if (Offsets) {
182 Offsets->push_back(StartingOffset + 0);
183 Offsets->push_back(StartingOffset + 8);
184 }
185
186 return;
187 }
188
189 // Given a struct type, recursively traverse the elements with custom ComputePTXValueVTs.
190 if (StructType *STy = dyn_cast<StructType>(Ty)) {
191 auto const *SL = DL.getStructLayout(STy);
192 auto ElementNum = 0;
193 for(auto *EI : STy->elements()) {
194 ComputePTXValueVTs(TLI, DL, EI, ValueVTs, Offsets,
195 StartingOffset + SL->getElementOffset(ElementNum));
196 ++ElementNum;
197 }
198 return;
199 }
200
201 ComputeValueVTs(TLI, DL, Ty, TempVTs, &TempOffsets, StartingOffset);
202 for (unsigned i = 0, e = TempVTs.size(); i != e; ++i) {
203 EVT VT = TempVTs[i];
204 uint64_t Off = TempOffsets[i];
205 // Split vectors into individual elements, except for v2f16, which
206 // we will pass as a single scalar.
207 if (VT.isVector()) {
208 unsigned NumElts = VT.getVectorNumElements();
209 EVT EltVT = VT.getVectorElementType();
210 // Vectors with an even number of f16 elements will be passed to
211 // us as an array of v2f16/v2bf16 elements. We must match this so we
212 // stay in sync with Ins/Outs.
213 if ((Is16bitsType(EltVT.getSimpleVT())) && NumElts % 2 == 0) {
214 switch (EltVT.getSimpleVT().SimpleTy) {
215 case MVT::f16:
216 EltVT = MVT::v2f16;
217 break;
218 case MVT::bf16:
219 EltVT = MVT::v2bf16;
220 break;
221 case MVT::i16:
222 EltVT = MVT::v2i16;
223 break;
224 default:
225 llvm_unreachable("Unexpected type");
226 }
227 NumElts /= 2;
228 } else if (EltVT.getSimpleVT() == MVT::i8 &&
229 (NumElts % 4 == 0 || NumElts == 3)) {
230 // v*i8 are formally lowered as v4i8
231 EltVT = MVT::v4i8;
232 NumElts = (NumElts + 3) / 4;
233 }
234 for (unsigned j = 0; j != NumElts; ++j) {
235 ValueVTs.push_back(EltVT);
236 if (Offsets)
237 Offsets->push_back(Off + j * EltVT.getStoreSize());
238 }
239 } else {
240 ValueVTs.push_back(VT);
241 if (Offsets)
242 Offsets->push_back(Off);
243 }
244 }
245}
246
247/// PromoteScalarIntegerPTX
248/// Used to make sure the arguments/returns are suitable for passing
249/// and promote them to a larger size if they're not.
250///
251/// The promoted type is placed in \p PromoteVT if the function returns true.
252static bool PromoteScalarIntegerPTX(const EVT &VT, MVT *PromotedVT) {
253 if (VT.isScalarInteger()) {
254 switch (PowerOf2Ceil(VT.getFixedSizeInBits())) {
255 default:
257 "Promotion is not suitable for scalars of size larger than 64-bits");
258 case 1:
259 *PromotedVT = MVT::i1;
260 break;
261 case 2:
262 case 4:
263 case 8:
264 *PromotedVT = MVT::i8;
265 break;
266 case 16:
267 *PromotedVT = MVT::i16;
268 break;
269 case 32:
270 *PromotedVT = MVT::i32;
271 break;
272 case 64:
273 *PromotedVT = MVT::i64;
274 break;
275 }
276 return EVT(*PromotedVT) != VT;
277 }
278 return false;
279}
280
281// Check whether we can merge loads/stores of some of the pieces of a
282// flattened function parameter or return value into a single vector
283// load/store.
284//
285// The flattened parameter is represented as a list of EVTs and
286// offsets, and the whole structure is aligned to ParamAlignment. This
287// function determines whether we can load/store pieces of the
288// parameter starting at index Idx using a single vectorized op of
289// size AccessSize. If so, it returns the number of param pieces
290// covered by the vector op. Otherwise, it returns 1.
292 unsigned Idx, uint32_t AccessSize, const SmallVectorImpl<EVT> &ValueVTs,
293 const SmallVectorImpl<uint64_t> &Offsets, Align ParamAlignment) {
294
295 // Can't vectorize if param alignment is not sufficient.
296 if (ParamAlignment < AccessSize)
297 return 1;
298 // Can't vectorize if offset is not aligned.
299 if (Offsets[Idx] & (AccessSize - 1))
300 return 1;
301
302 EVT EltVT = ValueVTs[Idx];
303 unsigned EltSize = EltVT.getStoreSize();
304
305 // Element is too large to vectorize.
306 if (EltSize >= AccessSize)
307 return 1;
308
309 unsigned NumElts = AccessSize / EltSize;
310 // Can't vectorize if AccessBytes if not a multiple of EltSize.
311 if (AccessSize != EltSize * NumElts)
312 return 1;
313
314 // We don't have enough elements to vectorize.
315 if (Idx + NumElts > ValueVTs.size())
316 return 1;
317
318 // PTX ISA can only deal with 2- and 4-element vector ops.
319 if (NumElts != 4 && NumElts != 2)
320 return 1;
321
322 for (unsigned j = Idx + 1; j < Idx + NumElts; ++j) {
323 // Types do not match.
324 if (ValueVTs[j] != EltVT)
325 return 1;
326
327 // Elements are not contiguous.
328 if (Offsets[j] - Offsets[j - 1] != EltSize)
329 return 1;
330 }
331 // OK. We can vectorize ValueVTs[i..i+NumElts)
332 return NumElts;
333}
334
335// Flags for tracking per-element vectorization state of loads/stores
336// of a flattened function parameter or return value.
338 PVF_INNER = 0x0, // Middle elements of a vector.
339 PVF_FIRST = 0x1, // First element of the vector.
340 PVF_LAST = 0x2, // Last element of the vector.
341 // Scalar is effectively a 1-element vector.
344
345// Computes whether and how we can vectorize the loads/stores of a
346// flattened function parameter or return value.
347//
348// The flattened parameter is represented as the list of ValueVTs and
349// Offsets, and is aligned to ParamAlignment bytes. We return a vector
350// of the same size as ValueVTs indicating how each piece should be
351// loaded/stored (i.e. as a scalar, or as part of a vector
352// load/store).
355 const SmallVectorImpl<uint64_t> &Offsets,
356 Align ParamAlignment, bool IsVAArg = false) {
357 // Set vector size to match ValueVTs and mark all elements as
358 // scalars by default.
360 VectorInfo.assign(ValueVTs.size(), PVF_SCALAR);
361
362 if (IsVAArg)
363 return VectorInfo;
364
365 // Check what we can vectorize using 128/64/32-bit accesses.
366 for (int I = 0, E = ValueVTs.size(); I != E; ++I) {
367 // Skip elements we've already processed.
368 assert(VectorInfo[I] == PVF_SCALAR && "Unexpected vector info state.");
369 for (unsigned AccessSize : {16, 8, 4, 2}) {
370 unsigned NumElts = CanMergeParamLoadStoresStartingAt(
371 I, AccessSize, ValueVTs, Offsets, ParamAlignment);
372 // Mark vectorized elements.
373 switch (NumElts) {
374 default:
375 llvm_unreachable("Unexpected return value");
376 case 1:
377 // Can't vectorize using this size, try next smaller size.
378 continue;
379 case 2:
380 assert(I + 1 < E && "Not enough elements.");
381 VectorInfo[I] = PVF_FIRST;
382 VectorInfo[I + 1] = PVF_LAST;
383 I += 1;
384 break;
385 case 4:
386 assert(I + 3 < E && "Not enough elements.");
387 VectorInfo[I] = PVF_FIRST;
388 VectorInfo[I + 1] = PVF_INNER;
389 VectorInfo[I + 2] = PVF_INNER;
390 VectorInfo[I + 3] = PVF_LAST;
391 I += 3;
392 break;
393 }
394 // Break out of the inner loop because we've already succeeded
395 // using largest possible AccessSize.
396 break;
397 }
398 }
399 return VectorInfo;
400}
401
402// NVPTXTargetLowering Constructor.
404 const NVPTXSubtarget &STI)
405 : TargetLowering(TM), nvTM(&TM), STI(STI) {
406 // always lower memset, memcpy, and memmove intrinsics to load/store
407 // instructions, rather
408 // then generating calls to memset, mempcy or memmove.
412
415
416 // Jump is Expensive. Don't create extra control flow for 'and', 'or'
417 // condition branches.
418 setJumpIsExpensive(true);
419
420 // Wide divides are _very_ slow. Try to reduce the width of the divide if
421 // possible.
422 addBypassSlowDiv(64, 32);
423
424 // By default, use the Source scheduling
425 if (sched4reg)
427 else
429
430 auto setFP16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
431 LegalizeAction NoF16Action) {
432 setOperationAction(Op, VT, STI.allowFP16Math() ? Action : NoF16Action);
433 };
434
435 auto setBF16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
436 LegalizeAction NoBF16Action) {
437 bool IsOpSupported = STI.hasBF16Math();
438 // Few instructions are available on sm_90 only
439 switch(Op) {
440 case ISD::FADD:
441 case ISD::FMUL:
442 case ISD::FSUB:
443 case ISD::SELECT:
444 case ISD::SELECT_CC:
445 case ISD::SETCC:
446 case ISD::FEXP2:
447 case ISD::FCEIL:
448 case ISD::FFLOOR:
449 case ISD::FNEARBYINT:
450 case ISD::FRINT:
451 case ISD::FROUNDEVEN:
452 case ISD::FTRUNC:
453 IsOpSupported = STI.getSmVersion() >= 90 && STI.getPTXVersion() >= 78;
454 break;
455 }
457 Op, VT, IsOpSupported ? Action : NoBF16Action);
458 };
459
460 auto setI16x2OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
461 LegalizeAction NoI16x2Action) {
462 bool IsOpSupported = false;
463 // instructions are available on sm_90 only
464 switch (Op) {
465 case ISD::ADD:
466 case ISD::SMAX:
467 case ISD::SMIN:
468 case ISD::UMIN:
469 case ISD::UMAX:
470 IsOpSupported = STI.getSmVersion() >= 90 && STI.getPTXVersion() >= 80;
471 break;
472 }
473 setOperationAction(Op, VT, IsOpSupported ? Action : NoI16x2Action);
474 };
475
476 addRegisterClass(MVT::i1, &NVPTX::Int1RegsRegClass);
477 addRegisterClass(MVT::i16, &NVPTX::Int16RegsRegClass);
478 addRegisterClass(MVT::v2i16, &NVPTX::Int32RegsRegClass);
479 addRegisterClass(MVT::v4i8, &NVPTX::Int32RegsRegClass);
480 addRegisterClass(MVT::i32, &NVPTX::Int32RegsRegClass);
481 addRegisterClass(MVT::i64, &NVPTX::Int64RegsRegClass);
482 addRegisterClass(MVT::f32, &NVPTX::Float32RegsRegClass);
483 addRegisterClass(MVT::f64, &NVPTX::Float64RegsRegClass);
484 addRegisterClass(MVT::f16, &NVPTX::Int16RegsRegClass);
485 addRegisterClass(MVT::v2f16, &NVPTX::Int32RegsRegClass);
486 addRegisterClass(MVT::bf16, &NVPTX::Int16RegsRegClass);
487 addRegisterClass(MVT::v2bf16, &NVPTX::Int32RegsRegClass);
488
489 // Conversion to/from FP16/FP16x2 is always legal.
494
496 if (STI.getSmVersion() >= 30 && STI.getPTXVersion() > 31)
498
499 setFP16OperationAction(ISD::SETCC, MVT::f16, Legal, Promote);
500 setFP16OperationAction(ISD::SETCC, MVT::v2f16, Legal, Expand);
501
502 // Conversion to/from BFP16/BFP16x2 is always legal.
507
508 setBF16OperationAction(ISD::SETCC, MVT::v2bf16, Legal, Expand);
509 setBF16OperationAction(ISD::SETCC, MVT::bf16, Legal, Promote);
510 if (getOperationAction(ISD::SETCC, MVT::bf16) == Promote)
511 AddPromotedToType(ISD::SETCC, MVT::bf16, MVT::f32);
512
513 // Conversion to/from i16/i16x2 is always legal.
518
523 // Only logical ops can be done on v4i8 directly, others must be done
524 // elementwise.
541 MVT::v4i8, Expand);
542
543 // Operations not directly supported by NVPTX.
544 for (MVT VT : {MVT::bf16, MVT::f16, MVT::v2bf16, MVT::v2f16, MVT::f32,
545 MVT::f64, MVT::i1, MVT::i8, MVT::i16, MVT::v2i16, MVT::v4i8,
546 MVT::i32, MVT::i64}) {
549 }
550
551 // Some SIGN_EXTEND_INREG can be done using cvt instruction.
552 // For others we will expand to a SHL/SRA pair.
559
566
569
570 // TODO: we may consider expanding ROTL/ROTR on older GPUs. Currently on GPUs
571 // that don't have h/w rotation we lower them to multi-instruction assembly.
572 // See ROT*_sw in NVPTXIntrInfo.td
577
579 setOperationAction(ISD::ROTL, MVT::v2i16, Expand);
581 setOperationAction(ISD::ROTR, MVT::v2i16, Expand);
585
588
591
592 // We want to legalize constant related memmove and memcopy
593 // intrinsics.
595
596 // Turn FP extload into load/fpextend
597 setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::f16, Expand);
598 setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f16, Expand);
599 setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::bf16, Expand);
600 setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::bf16, Expand);
601 setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f32, Expand);
602 setLoadExtAction(ISD::EXTLOAD, MVT::v2f32, MVT::v2f16, Expand);
603 setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2f16, Expand);
604 setLoadExtAction(ISD::EXTLOAD, MVT::v2f32, MVT::v2bf16, Expand);
605 setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2bf16, Expand);
606 setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2f32, Expand);
607 setLoadExtAction(ISD::EXTLOAD, MVT::v4f32, MVT::v4f16, Expand);
608 setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4f16, Expand);
609 setLoadExtAction(ISD::EXTLOAD, MVT::v4f32, MVT::v4bf16, Expand);
610 setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4bf16, Expand);
611 setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4f32, Expand);
612 setLoadExtAction(ISD::EXTLOAD, MVT::v8f32, MVT::v8f16, Expand);
613 setLoadExtAction(ISD::EXTLOAD, MVT::v8f64, MVT::v8f16, Expand);
614 setLoadExtAction(ISD::EXTLOAD, MVT::v8f32, MVT::v8bf16, Expand);
615 setLoadExtAction(ISD::EXTLOAD, MVT::v8f64, MVT::v8bf16, Expand);
616 // Turn FP truncstore into trunc + store.
617 // FIXME: vector types should also be expanded
618 setTruncStoreAction(MVT::f32, MVT::f16, Expand);
619 setTruncStoreAction(MVT::f64, MVT::f16, Expand);
620 setTruncStoreAction(MVT::f32, MVT::bf16, Expand);
621 setTruncStoreAction(MVT::f64, MVT::bf16, Expand);
622 setTruncStoreAction(MVT::f64, MVT::f32, Expand);
623
624 // PTX does not support load / store predicate registers
627
628 for (MVT VT : MVT::integer_valuetypes()) {
632 setTruncStoreAction(VT, MVT::i1, Expand);
633 }
634
635 // expand extload of vector of integers.
637 MVT::v2i8, Expand);
638 setTruncStoreAction(MVT::v2i16, MVT::v2i8, Expand);
639
640 // This is legal in NVPTX
645
648
649 // TRAP can be lowered to PTX trap
650 setOperationAction(ISD::TRAP, MVT::Other, Legal);
651
652 // Register custom handling for vector loads/stores
654 if (IsPTXVectorType(VT)) {
658 }
659 }
660
661 // Support varargs.
666
667 // Custom handling for i8 intrinsics
669
670 for (const auto& Ty : {MVT::i16, MVT::i32, MVT::i64}) {
676
679 }
680
681 setI16x2OperationAction(ISD::ABS, MVT::v2i16, Legal, Custom);
682 setI16x2OperationAction(ISD::SMIN, MVT::v2i16, Legal, Custom);
683 setI16x2OperationAction(ISD::SMAX, MVT::v2i16, Legal, Custom);
684 setI16x2OperationAction(ISD::UMIN, MVT::v2i16, Legal, Custom);
685 setI16x2OperationAction(ISD::UMAX, MVT::v2i16, Legal, Custom);
686 setI16x2OperationAction(ISD::CTPOP, MVT::v2i16, Legal, Expand);
687 setI16x2OperationAction(ISD::CTLZ, MVT::v2i16, Legal, Expand);
688
689 setI16x2OperationAction(ISD::ADD, MVT::v2i16, Legal, Custom);
690 setI16x2OperationAction(ISD::SUB, MVT::v2i16, Legal, Custom);
691 setI16x2OperationAction(ISD::MUL, MVT::v2i16, Legal, Custom);
692 setI16x2OperationAction(ISD::SHL, MVT::v2i16, Legal, Custom);
693 setI16x2OperationAction(ISD::SREM, MVT::v2i16, Legal, Custom);
694 setI16x2OperationAction(ISD::UREM, MVT::v2i16, Legal, Custom);
695
696 // Other arithmetic and logic ops are unsupported.
700 MVT::v2i16, Expand);
701
706 if (STI.getPTXVersion() >= 43) {
711 }
712
714 setOperationAction(ISD::CTTZ, MVT::v2i16, Expand);
717
718 // PTX does not directly support SELP of i1, so promote to i32 first
720
721 // PTX cannot multiply two i64s in a single instruction.
724
725 // We have some custom DAG combine patterns for these nodes
728 ISD::VSELECT});
729
730 // setcc for f16x2 and bf16x2 needs special handling to prevent
731 // legalizer's attempt to scalarize it due to v2i1 not being legal.
732 if (STI.allowFP16Math() || STI.hasBF16Math())
734
735 // Promote fp16 arithmetic if fp16 hardware isn't available or the
736 // user passed --nvptx-no-fp16-math. The flag is useful because,
737 // although sm_53+ GPUs have some sort of FP16 support in
738 // hardware, only sm_53 and sm_60 have full implementation. Others
739 // only have token amount of hardware and are likely to run faster
740 // by using fp32 units instead.
741 for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB, ISD::FMA}) {
742 setFP16OperationAction(Op, MVT::f16, Legal, Promote);
743 setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
744 setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
745 // bf16 must be promoted to f32.
746 setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
747 if (getOperationAction(Op, MVT::bf16) == Promote)
748 AddPromotedToType(Op, MVT::bf16, MVT::f32);
749 }
750
751 // f16/f16x2 neg was introduced in PTX 60, SM_53.
752 const bool IsFP16FP16x2NegAvailable = STI.getSmVersion() >= 53 &&
753 STI.getPTXVersion() >= 60 &&
754 STI.allowFP16Math();
755 for (const auto &VT : {MVT::f16, MVT::v2f16})
757 IsFP16FP16x2NegAvailable ? Legal : Expand);
758
759 setBF16OperationAction(ISD::FNEG, MVT::bf16, Legal, Expand);
760 setBF16OperationAction(ISD::FNEG, MVT::v2bf16, Legal, Expand);
761 // (would be) Library functions.
762
763 // These map to conversion instructions for scalar FP types.
764 for (const auto &Op : {ISD::FCEIL, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT,
766 setOperationAction(Op, MVT::f16, Legal);
767 setOperationAction(Op, MVT::f32, Legal);
768 setOperationAction(Op, MVT::f64, Legal);
769 setOperationAction(Op, MVT::v2f16, Expand);
770 setOperationAction(Op, MVT::v2bf16, Expand);
771 setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
772 if (getOperationAction(Op, MVT::bf16) == Promote)
773 AddPromotedToType(Op, MVT::bf16, MVT::f32);
774 }
775
776 if (STI.getSmVersion() < 80 || STI.getPTXVersion() < 71) {
778 }
779 if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
780 for (MVT VT : {MVT::bf16, MVT::f32, MVT::f64}) {
783 }
784 }
785
786 // sm_80 only has conversions between f32 and bf16. Custom lower all other
787 // bf16 conversions.
788 if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
789 for (MVT VT : {MVT::i1, MVT::i16, MVT::i32, MVT::i64}) {
792 VT, Custom);
793 }
796 MVT::bf16, Custom);
797 }
798
805 AddPromotedToType(ISD::FROUND, MVT::bf16, MVT::f32);
806
807 // 'Expand' implements FCOPYSIGN without calling an external library.
814
815 // These map to corresponding instructions for f32/f64. f16 must be
816 // promoted to f32. v2f16 is expanded to f16, which is then promoted
817 // to f32.
818 for (const auto &Op :
820 setOperationAction(Op, MVT::f16, Promote);
821 setOperationAction(Op, MVT::f32, Legal);
822 setOperationAction(Op, MVT::f64, Legal);
823 setOperationAction(Op, MVT::v2f16, Expand);
824 setOperationAction(Op, MVT::v2bf16, Expand);
825 setOperationAction(Op, MVT::bf16, Promote);
826 AddPromotedToType(Op, MVT::bf16, MVT::f32);
827 }
828 for (const auto &Op : {ISD::FABS}) {
829 setOperationAction(Op, MVT::f16, Promote);
830 setOperationAction(Op, MVT::f32, Legal);
831 setOperationAction(Op, MVT::f64, Legal);
832 setOperationAction(Op, MVT::v2f16, Expand);
833 setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
834 setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
835 if (getOperationAction(Op, MVT::bf16) == Promote)
836 AddPromotedToType(Op, MVT::bf16, MVT::f32);
837 }
838
839 // max.f16, max.f16x2 and max.NaN are supported on sm_80+.
840 auto GetMinMaxAction = [&](LegalizeAction NotSm80Action) {
841 bool IsAtLeastSm80 = STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70;
842 return IsAtLeastSm80 ? Legal : NotSm80Action;
843 };
844 for (const auto &Op : {ISD::FMINNUM, ISD::FMAXNUM}) {
845 setFP16OperationAction(Op, MVT::f16, GetMinMaxAction(Promote), Promote);
846 setOperationAction(Op, MVT::f32, Legal);
847 setOperationAction(Op, MVT::f64, Legal);
848 setFP16OperationAction(Op, MVT::v2f16, GetMinMaxAction(Expand), Expand);
849 setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
850 setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
851 if (getOperationAction(Op, MVT::bf16) == Promote)
852 AddPromotedToType(Op, MVT::bf16, MVT::f32);
853 }
854 for (const auto &Op : {ISD::FMINIMUM, ISD::FMAXIMUM}) {
855 setFP16OperationAction(Op, MVT::f16, GetMinMaxAction(Expand), Expand);
856 setFP16OperationAction(Op, MVT::bf16, Legal, Expand);
857 setOperationAction(Op, MVT::f32, GetMinMaxAction(Expand));
858 setFP16OperationAction(Op, MVT::v2f16, GetMinMaxAction(Expand), Expand);
859 setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
860 }
861
862 // Custom lowering for inline asm with 128-bit operands
865
866 // No FEXP2, FLOG2. The PTX ex2 and log2 functions are always approximate.
867 // No FPOW or FREM in PTX.
868
869 // Now deduce the information based on the above mentioned
870 // actions
872
876}
877
878const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
879
880#define MAKE_CASE(V) \
881 case V: \
882 return #V;
883
884 switch ((NVPTXISD::NodeType)Opcode) {
886 break;
887
1034
1125
1137
1149
1161
1173
1185
1197
1209
1221
1233
1245
1257
1269
1281
1293
1305 }
1306 return nullptr;
1307
1308#undef MAKE_CASE
1309}
1310
1313 if (!VT.isScalableVector() && VT.getVectorNumElements() != 1 &&
1314 VT.getScalarType() == MVT::i1)
1315 return TypeSplitVector;
1316 if (Isv2x16VT(VT))
1317 return TypeLegal;
1319}
1320
1322 int Enabled, int &ExtraSteps,
1323 bool &UseOneConst,
1324 bool Reciprocal) const {
1327 return SDValue();
1328
1329 if (ExtraSteps == ReciprocalEstimate::Unspecified)
1330 ExtraSteps = 0;
1331
1332 SDLoc DL(Operand);
1333 EVT VT = Operand.getValueType();
1334 bool Ftz = useF32FTZ(DAG.getMachineFunction());
1335
1336 auto MakeIntrinsicCall = [&](Intrinsic::ID IID) {
1337 return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
1338 DAG.getConstant(IID, DL, MVT::i32), Operand);
1339 };
1340
1341 // The sqrt and rsqrt refinement processes assume we always start out with an
1342 // approximation of the rsqrt. Therefore, if we're going to do any refinement
1343 // (i.e. ExtraSteps > 0), we must return an rsqrt. But if we're *not* doing
1344 // any refinement, we must return a regular sqrt.
1345 if (Reciprocal || ExtraSteps > 0) {
1346 if (VT == MVT::f32)
1347 return MakeIntrinsicCall(Ftz ? Intrinsic::nvvm_rsqrt_approx_ftz_f
1348 : Intrinsic::nvvm_rsqrt_approx_f);
1349 else if (VT == MVT::f64)
1350 return MakeIntrinsicCall(Intrinsic::nvvm_rsqrt_approx_d);
1351 else
1352 return SDValue();
1353 } else {
1354 if (VT == MVT::f32)
1355 return MakeIntrinsicCall(Ftz ? Intrinsic::nvvm_sqrt_approx_ftz_f
1356 : Intrinsic::nvvm_sqrt_approx_f);
1357 else {
1358 // There's no sqrt.approx.f64 instruction, so we emit
1359 // reciprocal(rsqrt(x)). This is faster than
1360 // select(x == 0, 0, x * rsqrt(x)). (In fact, it's faster than plain
1361 // x * rsqrt(x).)
1362 return DAG.getNode(
1364 DAG.getConstant(Intrinsic::nvvm_rcp_approx_ftz_d, DL, MVT::i32),
1365 MakeIntrinsicCall(Intrinsic::nvvm_rsqrt_approx_d));
1366 }
1367 }
1368}
1369
1370SDValue
1372 SDLoc dl(Op);
1373 const GlobalAddressSDNode *GAN = cast<GlobalAddressSDNode>(Op);
1374 auto PtrVT = getPointerTy(DAG.getDataLayout(), GAN->getAddressSpace());
1375 Op = DAG.getTargetGlobalAddress(GAN->getGlobal(), dl, PtrVT);
1376 return DAG.getNode(NVPTXISD::Wrapper, dl, PtrVT, Op);
1377}
1378
1379static bool IsTypePassedAsArray(const Type *Ty) {
1380 return Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128) ||
1381 Ty->isHalfTy() || Ty->isBFloatTy();
1382}
1383
1385 const DataLayout &DL, Type *retTy, const ArgListTy &Args,
1386 const SmallVectorImpl<ISD::OutputArg> &Outs, MaybeAlign retAlignment,
1387 std::optional<std::pair<unsigned, const APInt &>> VAInfo,
1388 const CallBase &CB, unsigned UniqueCallSite) const {
1389 auto PtrVT = getPointerTy(DL);
1390
1391 bool isABI = (STI.getSmVersion() >= 20);
1392 assert(isABI && "Non-ABI compilation is not supported");
1393 if (!isABI)
1394 return "";
1395
1396 std::string Prototype;
1397 raw_string_ostream O(Prototype);
1398 O << "prototype_" << UniqueCallSite << " : .callprototype ";
1399
1400 if (retTy->getTypeID() == Type::VoidTyID) {
1401 O << "()";
1402 } else {
1403 O << "(";
1404 if ((retTy->isFloatingPointTy() || retTy->isIntegerTy()) &&
1405 !IsTypePassedAsArray(retTy)) {
1406 unsigned size = 0;
1407 if (auto *ITy = dyn_cast<IntegerType>(retTy)) {
1408 size = ITy->getBitWidth();
1409 } else {
1410 assert(retTy->isFloatingPointTy() &&
1411 "Floating point type expected here");
1412 size = retTy->getPrimitiveSizeInBits();
1413 }
1414 // PTX ABI requires all scalar return values to be at least 32
1415 // bits in size. fp16 normally uses .b16 as its storage type in
1416 // PTX, so its size must be adjusted here, too.
1418
1419 O << ".param .b" << size << " _";
1420 } else if (isa<PointerType>(retTy)) {
1421 O << ".param .b" << PtrVT.getSizeInBits() << " _";
1422 } else if (IsTypePassedAsArray(retTy)) {
1423 O << ".param .align " << (retAlignment ? retAlignment->value() : 0)
1424 << " .b8 _[" << DL.getTypeAllocSize(retTy) << "]";
1425 } else {
1426 llvm_unreachable("Unknown return type");
1427 }
1428 O << ") ";
1429 }
1430 O << "_ (";
1431
1432 bool first = true;
1433
1434 unsigned NumArgs = VAInfo ? VAInfo->first : Args.size();
1435 for (unsigned i = 0, OIdx = 0; i != NumArgs; ++i, ++OIdx) {
1436 Type *Ty = Args[i].Ty;
1437 if (!first) {
1438 O << ", ";
1439 }
1440 first = false;
1441
1442 if (!Outs[OIdx].Flags.isByVal()) {
1443 if (IsTypePassedAsArray(Ty)) {
1444 Align ParamAlign =
1445 getArgumentAlignment(&CB, Ty, i + AttributeList::FirstArgIndex, DL);
1446 O << ".param .align " << ParamAlign.value() << " .b8 ";
1447 O << "_";
1448 O << "[" << DL.getTypeAllocSize(Ty) << "]";
1449 // update the index for Outs
1450 SmallVector<EVT, 16> vtparts;
1451 ComputeValueVTs(*this, DL, Ty, vtparts);
1452 if (unsigned len = vtparts.size())
1453 OIdx += len - 1;
1454 continue;
1455 }
1456 // i8 types in IR will be i16 types in SDAG
1457 assert((getValueType(DL, Ty) == Outs[OIdx].VT ||
1458 (getValueType(DL, Ty) == MVT::i8 && Outs[OIdx].VT == MVT::i16)) &&
1459 "type mismatch between callee prototype and arguments");
1460 // scalar type
1461 unsigned sz = 0;
1462 if (isa<IntegerType>(Ty)) {
1463 sz = cast<IntegerType>(Ty)->getBitWidth();
1465 } else if (isa<PointerType>(Ty)) {
1466 sz = PtrVT.getSizeInBits();
1467 } else {
1468 sz = Ty->getPrimitiveSizeInBits();
1469 }
1470 O << ".param .b" << sz << " ";
1471 O << "_";
1472 continue;
1473 }
1474
1475 // Indirect calls need strict ABI alignment so we disable optimizations by
1476 // not providing a function to optimize.
1477 Type *ETy = Args[i].IndirectType;
1478 Align InitialAlign = Outs[OIdx].Flags.getNonZeroByValAlign();
1479 Align ParamByValAlign =
1480 getFunctionByValParamAlign(/*F=*/nullptr, ETy, InitialAlign, DL);
1481
1482 O << ".param .align " << ParamByValAlign.value() << " .b8 ";
1483 O << "_";
1484 O << "[" << Outs[OIdx].Flags.getByValSize() << "]";
1485 }
1486
1487 if (VAInfo)
1488 O << (first ? "" : ",") << " .param .align " << VAInfo->second
1489 << " .b8 _[]\n";
1490 O << ")";
1492 O << " .noreturn";
1493 O << ";";
1494
1495 return Prototype;
1496}
1497
1499 const Function *F, Type *Ty, unsigned Idx, const DataLayout &DL) const {
1500 return getAlign(*F, Idx).value_or(getFunctionParamOptimizedAlign(F, Ty, DL));
1501}
1502
1503Align NVPTXTargetLowering::getArgumentAlignment(const CallBase *CB, Type *Ty,
1504 unsigned Idx,
1505 const DataLayout &DL) const {
1506 if (!CB) {
1507 // CallSite is zero, fallback to ABI type alignment
1508 return DL.getABITypeAlign(Ty);
1509 }
1510
1511 const Function *DirectCallee = CB->getCalledFunction();
1512
1513 if (!DirectCallee) {
1514 // We don't have a direct function symbol, but that may be because of
1515 // constant cast instructions in the call.
1516
1517 // With bitcast'd call targets, the instruction will be the call
1518 if (const auto *CI = dyn_cast<CallInst>(CB)) {
1519 // Check if we have call alignment metadata
1520 if (MaybeAlign StackAlign = getAlign(*CI, Idx))
1521 return StackAlign.value();
1522 }
1523 DirectCallee = getMaybeBitcastedCallee(CB);
1524 }
1525
1526 // Check for function alignment information if we found that the
1527 // ultimate target is a Function
1528 if (DirectCallee)
1529 return getFunctionArgumentAlignment(DirectCallee, Ty, Idx, DL);
1530
1531 // Call is indirect, fall back to the ABI type alignment
1532 return DL.getABITypeAlign(Ty);
1533}
1534
1535static bool adjustElementType(EVT &ElementType) {
1536 switch (ElementType.getSimpleVT().SimpleTy) {
1537 default:
1538 return false;
1539 case MVT::f16:
1540 case MVT::bf16:
1541 ElementType = MVT::i16;
1542 return true;
1543 case MVT::f32:
1544 case MVT::v2f16:
1545 case MVT::v2bf16:
1546 ElementType = MVT::i32;
1547 return true;
1548 case MVT::f64:
1549 ElementType = MVT::i64;
1550 return true;
1551 }
1552}
1553
1554// Use byte-store when the param address of the argument value is unaligned.
1555// This may happen when the return value is a field of a packed structure.
1556//
1557// This is called in LowerCall() when passing the param values.
1559 uint64_t Offset, EVT ElementType,
1560 SDValue StVal, SDValue &InGlue,
1561 unsigned ArgID, const SDLoc &dl) {
1562 // Bit logic only works on integer types
1563 if (adjustElementType(ElementType))
1564 StVal = DAG.getNode(ISD::BITCAST, dl, ElementType, StVal);
1565
1566 // Store each byte
1567 SDVTList StoreVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1568 for (unsigned i = 0, n = ElementType.getSizeInBits() / 8; i < n; i++) {
1569 // Shift the byte to the last byte position
1570 SDValue ShiftVal = DAG.getNode(ISD::SRL, dl, ElementType, StVal,
1571 DAG.getConstant(i * 8, dl, MVT::i32));
1572 SDValue StoreOperands[] = {Chain, DAG.getConstant(ArgID, dl, MVT::i32),
1573 DAG.getConstant(Offset + i, dl, MVT::i32),
1574 ShiftVal, InGlue};
1575 // Trunc store only the last byte by using
1576 // st.param.b8
1577 // The register type can be larger than b8.
1578 Chain = DAG.getMemIntrinsicNode(
1579 NVPTXISD::StoreParam, dl, StoreVTs, StoreOperands, MVT::i8,
1581 InGlue = Chain.getValue(1);
1582 }
1583 return Chain;
1584}
1585
1586// Use byte-load when the param adress of the returned value is unaligned.
1587// This may happen when the returned value is a field of a packed structure.
1588static SDValue
1590 EVT ElementType, SDValue &InGlue,
1591 SmallVectorImpl<SDValue> &TempProxyRegOps,
1592 const SDLoc &dl) {
1593 // Bit logic only works on integer types
1594 EVT MergedType = ElementType;
1595 adjustElementType(MergedType);
1596
1597 // Load each byte and construct the whole value. Initial value to 0
1598 SDValue RetVal = DAG.getConstant(0, dl, MergedType);
1599 // LoadParamMemI8 loads into i16 register only
1600 SDVTList LoadVTs = DAG.getVTList(MVT::i16, MVT::Other, MVT::Glue);
1601 for (unsigned i = 0, n = ElementType.getSizeInBits() / 8; i < n; i++) {
1602 SDValue LoadOperands[] = {Chain, DAG.getConstant(1, dl, MVT::i32),
1603 DAG.getConstant(Offset + i, dl, MVT::i32),
1604 InGlue};
1605 // This will be selected to LoadParamMemI8
1606 SDValue LdVal =
1607 DAG.getMemIntrinsicNode(NVPTXISD::LoadParam, dl, LoadVTs, LoadOperands,
1608 MVT::i8, MachinePointerInfo(), Align(1));
1609 SDValue TmpLdVal = LdVal.getValue(0);
1610 Chain = LdVal.getValue(1);
1611 InGlue = LdVal.getValue(2);
1612
1613 TmpLdVal = DAG.getNode(NVPTXISD::ProxyReg, dl,
1614 TmpLdVal.getSimpleValueType(), TmpLdVal);
1615 TempProxyRegOps.push_back(TmpLdVal);
1616
1617 SDValue CMask = DAG.getConstant(255, dl, MergedType);
1618 SDValue CShift = DAG.getConstant(i * 8, dl, MVT::i32);
1619 // Need to extend the i16 register to the whole width.
1620 TmpLdVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MergedType, TmpLdVal);
1621 // Mask off the high bits. Leave only the lower 8bits.
1622 // Do this because we are using loadparam.b8.
1623 TmpLdVal = DAG.getNode(ISD::AND, dl, MergedType, TmpLdVal, CMask);
1624 // Shift and merge
1625 TmpLdVal = DAG.getNode(ISD::SHL, dl, MergedType, TmpLdVal, CShift);
1626 RetVal = DAG.getNode(ISD::OR, dl, MergedType, RetVal, TmpLdVal);
1627 }
1628 if (ElementType != MergedType)
1629 RetVal = DAG.getNode(ISD::BITCAST, dl, ElementType, RetVal);
1630
1631 return RetVal;
1632}
1633
1635 SmallVectorImpl<SDValue> &InVals) const {
1636
1637 if (CLI.IsVarArg && (STI.getPTXVersion() < 60 || STI.getSmVersion() < 30))
1639 "Support for variadic functions (unsized array parameter) introduced "
1640 "in PTX ISA version 6.0 and requires target sm_30.");
1641
1642 SelectionDAG &DAG = CLI.DAG;
1643 SDLoc dl = CLI.DL;
1645 SmallVectorImpl<SDValue> &OutVals = CLI.OutVals;
1647 SDValue Chain = CLI.Chain;
1648 SDValue Callee = CLI.Callee;
1649 bool &isTailCall = CLI.IsTailCall;
1650 ArgListTy &Args = CLI.getArgs();
1651 Type *RetTy = CLI.RetTy;
1652 const CallBase *CB = CLI.CB;
1653 const DataLayout &DL = DAG.getDataLayout();
1654
1655 bool isABI = (STI.getSmVersion() >= 20);
1656 assert(isABI && "Non-ABI compilation is not supported");
1657 if (!isABI)
1658 return Chain;
1659
1660 // Variadic arguments.
1661 //
1662 // Normally, for each argument, we declare a param scalar or a param
1663 // byte array in the .param space, and store the argument value to that
1664 // param scalar or array starting at offset 0.
1665 //
1666 // In the case of the first variadic argument, we declare a vararg byte array
1667 // with size 0. The exact size of this array isn't known at this point, so
1668 // it'll be patched later. All the variadic arguments will be stored to this
1669 // array at a certain offset (which gets tracked by 'VAOffset'). The offset is
1670 // initially set to 0, so it can be used for non-variadic arguments (which use
1671 // 0 offset) to simplify the code.
1672 //
1673 // After all vararg is processed, 'VAOffset' holds the size of the
1674 // vararg byte array.
1675
1676 SDValue VADeclareParam; // vararg byte array
1677 unsigned FirstVAArg = CLI.NumFixedArgs; // position of the first variadic
1678 unsigned VAOffset = 0; // current offset in the param array
1679
1680 unsigned UniqueCallSite = GlobalUniqueCallSite.fetch_add(1);
1681 SDValue TempChain = Chain;
1682 Chain = DAG.getCALLSEQ_START(Chain, UniqueCallSite, 0, dl);
1683 SDValue InGlue = Chain.getValue(1);
1684
1685 unsigned ParamCount = 0;
1686 // Args.size() and Outs.size() need not match.
1687 // Outs.size() will be larger
1688 // * if there is an aggregate argument with multiple fields (each field
1689 // showing up separately in Outs)
1690 // * if there is a vector argument with more than typical vector-length
1691 // elements (generally if more than 4) where each vector element is
1692 // individually present in Outs.
1693 // So a different index should be used for indexing into Outs/OutVals.
1694 // See similar issue in LowerFormalArguments.
1695 unsigned OIdx = 0;
1696 // Declare the .params or .reg need to pass values
1697 // to the function
1698 for (unsigned i = 0, e = Args.size(); i != e; ++i, ++OIdx) {
1699 EVT VT = Outs[OIdx].VT;
1700 Type *Ty = Args[i].Ty;
1701 bool IsVAArg = (i >= CLI.NumFixedArgs);
1702 bool IsByVal = Outs[OIdx].Flags.isByVal();
1703
1706
1707 assert((!IsByVal || Args[i].IndirectType) &&
1708 "byval arg must have indirect type");
1709 Type *ETy = (IsByVal ? Args[i].IndirectType : Ty);
1710 ComputePTXValueVTs(*this, DL, ETy, VTs, &Offsets, IsByVal ? 0 : VAOffset);
1711
1712 Align ArgAlign;
1713 if (IsByVal) {
1714 // The ByValAlign in the Outs[OIdx].Flags is always set at this point,
1715 // so we don't need to worry whether it's naturally aligned or not.
1716 // See TargetLowering::LowerCallTo().
1717 Align InitialAlign = Outs[OIdx].Flags.getNonZeroByValAlign();
1718 ArgAlign = getFunctionByValParamAlign(CB->getCalledFunction(), ETy,
1719 InitialAlign, DL);
1720 if (IsVAArg)
1721 VAOffset = alignTo(VAOffset, ArgAlign);
1722 } else {
1723 ArgAlign = getArgumentAlignment(CB, Ty, ParamCount + 1, DL);
1724 }
1725
1726 unsigned TypeSize =
1727 (IsByVal ? Outs[OIdx].Flags.getByValSize() : DL.getTypeAllocSize(Ty));
1728 SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1729
1730 bool NeedAlign; // Does argument declaration specify alignment?
1731 bool PassAsArray = IsByVal || IsTypePassedAsArray(Ty);
1732 if (IsVAArg) {
1733 if (ParamCount == FirstVAArg) {
1734 SDValue DeclareParamOps[] = {
1735 Chain, DAG.getConstant(STI.getMaxRequiredAlignment(), dl, MVT::i32),
1736 DAG.getConstant(ParamCount, dl, MVT::i32),
1737 DAG.getConstant(1, dl, MVT::i32), InGlue};
1738 VADeclareParam = Chain = DAG.getNode(NVPTXISD::DeclareParam, dl,
1739 DeclareParamVTs, DeclareParamOps);
1740 }
1741 NeedAlign = PassAsArray;
1742 } else if (PassAsArray) {
1743 // declare .param .align <align> .b8 .param<n>[<size>];
1744 SDValue DeclareParamOps[] = {
1745 Chain, DAG.getConstant(ArgAlign.value(), dl, MVT::i32),
1746 DAG.getConstant(ParamCount, dl, MVT::i32),
1747 DAG.getConstant(TypeSize, dl, MVT::i32), InGlue};
1748 Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
1749 DeclareParamOps);
1750 NeedAlign = true;
1751 } else {
1752 // declare .param .b<size> .param<n>;
1753 if (VT.isInteger() || VT.isFloatingPoint()) {
1754 // PTX ABI requires integral types to be at least 32 bits in
1755 // size. FP16 is loaded/stored using i16, so it's handled
1756 // here as well.
1758 }
1759 SDValue DeclareScalarParamOps[] = {
1760 Chain, DAG.getConstant(ParamCount, dl, MVT::i32),
1761 DAG.getConstant(TypeSize * 8, dl, MVT::i32),
1762 DAG.getConstant(0, dl, MVT::i32), InGlue};
1763 Chain = DAG.getNode(NVPTXISD::DeclareScalarParam, dl, DeclareParamVTs,
1764 DeclareScalarParamOps);
1765 NeedAlign = false;
1766 }
1767 InGlue = Chain.getValue(1);
1768
1769 // PTX Interoperability Guide 3.3(A): [Integer] Values shorter
1770 // than 32-bits are sign extended or zero extended, depending on
1771 // whether they are signed or unsigned types. This case applies
1772 // only to scalar parameters and not to aggregate values.
1773 bool ExtendIntegerParam =
1774 Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Ty) < 32;
1775
1776 auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign, IsVAArg);
1777 SmallVector<SDValue, 6> StoreOperands;
1778 for (unsigned j = 0, je = VTs.size(); j != je; ++j) {
1779 EVT EltVT = VTs[j];
1780 int CurOffset = Offsets[j];
1781 MaybeAlign PartAlign;
1782 if (NeedAlign)
1783 PartAlign = commonAlignment(ArgAlign, CurOffset);
1784
1785 SDValue StVal = OutVals[OIdx];
1786
1787 MVT PromotedVT;
1788 if (PromoteScalarIntegerPTX(EltVT, &PromotedVT)) {
1789 EltVT = EVT(PromotedVT);
1790 }
1791 if (PromoteScalarIntegerPTX(StVal.getValueType(), &PromotedVT)) {
1793 Outs[OIdx].Flags.isSExt() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
1794 StVal = DAG.getNode(Ext, dl, PromotedVT, StVal);
1795 }
1796
1797 if (IsByVal) {
1798 auto PtrVT = getPointerTy(DL);
1799 SDValue srcAddr = DAG.getNode(ISD::ADD, dl, PtrVT, StVal,
1800 DAG.getConstant(CurOffset, dl, PtrVT));
1801 StVal = DAG.getLoad(EltVT, dl, TempChain, srcAddr, MachinePointerInfo(),
1802 PartAlign);
1803 } else if (ExtendIntegerParam) {
1804 assert(VTs.size() == 1 && "Scalar can't have multiple parts.");
1805 // zext/sext to i32
1806 StVal = DAG.getNode(Outs[OIdx].Flags.isSExt() ? ISD::SIGN_EXTEND
1808 dl, MVT::i32, StVal);
1809 }
1810
1811 if (!ExtendIntegerParam && EltVT.getSizeInBits() < 16) {
1812 // Use 16-bit registers for small stores as it's the
1813 // smallest general purpose register size supported by NVPTX.
1814 StVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, StVal);
1815 }
1816
1817 // If we have a PVF_SCALAR entry, it may not be sufficiently aligned for a
1818 // scalar store. In such cases, fall back to byte stores.
1819 if (VectorInfo[j] == PVF_SCALAR && !IsVAArg && PartAlign.has_value() &&
1820 PartAlign.value() <
1821 DL.getABITypeAlign(EltVT.getTypeForEVT(*DAG.getContext()))) {
1822 assert(StoreOperands.empty() && "Unfinished preceeding store.");
1824 DAG, Chain, IsByVal ? CurOffset + VAOffset : CurOffset, EltVT,
1825 StVal, InGlue, ParamCount, dl);
1826
1827 // LowerUnalignedStoreParam took care of inserting the necessary nodes
1828 // into the SDAG, so just move on to the next element.
1829 if (!IsByVal)
1830 ++OIdx;
1831 continue;
1832 }
1833
1834 // New store.
1835 if (VectorInfo[j] & PVF_FIRST) {
1836 assert(StoreOperands.empty() && "Unfinished preceding store.");
1837 StoreOperands.push_back(Chain);
1838 StoreOperands.push_back(
1839 DAG.getConstant(IsVAArg ? FirstVAArg : ParamCount, dl, MVT::i32));
1840
1841 StoreOperands.push_back(DAG.getConstant(
1842 IsByVal ? CurOffset + VAOffset : (IsVAArg ? VAOffset : CurOffset),
1843 dl, MVT::i32));
1844 }
1845
1846 // Record the value to store.
1847 StoreOperands.push_back(StVal);
1848
1849 if (VectorInfo[j] & PVF_LAST) {
1850 unsigned NumElts = StoreOperands.size() - 3;
1852 switch (NumElts) {
1853 case 1:
1855 break;
1856 case 2:
1858 break;
1859 case 4:
1861 break;
1862 default:
1863 llvm_unreachable("Invalid vector info.");
1864 }
1865
1866 StoreOperands.push_back(InGlue);
1867
1868 // Adjust type of the store op if we've extended the scalar
1869 // return value.
1870 EVT TheStoreType = ExtendIntegerParam ? MVT::i32 : EltVT;
1871
1872 Chain = DAG.getMemIntrinsicNode(
1873 Op, dl, DAG.getVTList(MVT::Other, MVT::Glue), StoreOperands,
1874 TheStoreType, MachinePointerInfo(), PartAlign,
1876 InGlue = Chain.getValue(1);
1877
1878 // Cleanup.
1879 StoreOperands.clear();
1880
1881 // TODO: We may need to support vector types that can be passed
1882 // as scalars in variadic arguments.
1883 if (!IsByVal && IsVAArg) {
1884 assert(NumElts == 1 &&
1885 "Vectorization is expected to be disabled for variadics.");
1886 VAOffset += DL.getTypeAllocSize(
1887 TheStoreType.getTypeForEVT(*DAG.getContext()));
1888 }
1889 }
1890 if (!IsByVal)
1891 ++OIdx;
1892 }
1893 assert(StoreOperands.empty() && "Unfinished parameter store.");
1894 if (!IsByVal && VTs.size() > 0)
1895 --OIdx;
1896 ++ParamCount;
1897 if (IsByVal && IsVAArg)
1898 VAOffset += TypeSize;
1899 }
1900
1901 GlobalAddressSDNode *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode());
1902 MaybeAlign retAlignment = std::nullopt;
1903
1904 // Handle Result
1905 if (Ins.size() > 0) {
1906 SmallVector<EVT, 16> resvtparts;
1907 ComputeValueVTs(*this, DL, RetTy, resvtparts);
1908
1909 // Declare
1910 // .param .align N .b8 retval0[<size-in-bytes>], or
1911 // .param .b<size-in-bits> retval0
1912 unsigned resultsz = DL.getTypeAllocSizeInBits(RetTy);
1913 if (!IsTypePassedAsArray(RetTy)) {
1914 resultsz = promoteScalarArgumentSize(resultsz);
1915 SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1916 SDValue DeclareRetOps[] = { Chain, DAG.getConstant(1, dl, MVT::i32),
1917 DAG.getConstant(resultsz, dl, MVT::i32),
1918 DAG.getConstant(0, dl, MVT::i32), InGlue };
1919 Chain = DAG.getNode(NVPTXISD::DeclareRet, dl, DeclareRetVTs,
1920 DeclareRetOps);
1921 InGlue = Chain.getValue(1);
1922 } else {
1923 retAlignment = getArgumentAlignment(CB, RetTy, 0, DL);
1924 assert(retAlignment && "retAlignment is guaranteed to be set");
1925 SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1926 SDValue DeclareRetOps[] = {
1927 Chain, DAG.getConstant(retAlignment->value(), dl, MVT::i32),
1928 DAG.getConstant(resultsz / 8, dl, MVT::i32),
1929 DAG.getConstant(0, dl, MVT::i32), InGlue};
1930 Chain = DAG.getNode(NVPTXISD::DeclareRetParam, dl, DeclareRetVTs,
1931 DeclareRetOps);
1932 InGlue = Chain.getValue(1);
1933 }
1934 }
1935
1936 bool HasVAArgs = CLI.IsVarArg && (CLI.Args.size() > CLI.NumFixedArgs);
1937 // Set the size of the vararg param byte array if the callee is a variadic
1938 // function and the variadic part is not empty.
1939 if (HasVAArgs) {
1940 SDValue DeclareParamOps[] = {
1941 VADeclareParam.getOperand(0), VADeclareParam.getOperand(1),
1942 VADeclareParam.getOperand(2), DAG.getConstant(VAOffset, dl, MVT::i32),
1943 VADeclareParam.getOperand(4)};
1944 DAG.MorphNodeTo(VADeclareParam.getNode(), VADeclareParam.getOpcode(),
1945 VADeclareParam->getVTList(), DeclareParamOps);
1946 }
1947
1948 // Both indirect calls and libcalls have nullptr Func. In order to distinguish
1949 // between them we must rely on the call site value which is valid for
1950 // indirect calls but is always null for libcalls.
1951 bool isIndirectCall = !Func && CB;
1952
1953 if (isa<ExternalSymbolSDNode>(Callee)) {
1954 Function* CalleeFunc = nullptr;
1955
1956 // Try to find the callee in the current module.
1957 Callee = DAG.getSymbolFunctionGlobalAddress(Callee, &CalleeFunc);
1958 assert(CalleeFunc != nullptr && "Libcall callee must be set.");
1959
1960 // Set the "libcall callee" attribute to indicate that the function
1961 // must always have a declaration.
1962 CalleeFunc->addFnAttr("nvptx-libcall-callee", "true");
1963 }
1964
1965 if (isIndirectCall) {
1966 // This is indirect function call case : PTX requires a prototype of the
1967 // form
1968 // proto_0 : .callprototype(.param .b32 _) _ (.param .b32 _);
1969 // to be emitted, and the label has to used as the last arg of call
1970 // instruction.
1971 // The prototype is embedded in a string and put as the operand for a
1972 // CallPrototype SDNode which will print out to the value of the string.
1973 SDVTList ProtoVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1974 std::string Proto = getPrototype(
1975 DL, RetTy, Args, Outs, retAlignment,
1976 HasVAArgs
1977 ? std::optional<std::pair<unsigned, const APInt &>>(std::make_pair(
1978 CLI.NumFixedArgs, VADeclareParam->getConstantOperandAPInt(1)))
1979 : std::nullopt,
1980 *CB, UniqueCallSite);
1981 const char *ProtoStr = nvTM->getStrPool().save(Proto).data();
1982 SDValue ProtoOps[] = {
1983 Chain,
1984 DAG.getTargetExternalSymbol(ProtoStr, MVT::i32),
1985 InGlue,
1986 };
1987 Chain = DAG.getNode(NVPTXISD::CallPrototype, dl, ProtoVTs, ProtoOps);
1988 InGlue = Chain.getValue(1);
1989 }
1990 // Op to just print "call"
1991 SDVTList PrintCallVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1992 SDValue PrintCallOps[] = {
1993 Chain, DAG.getConstant((Ins.size() == 0) ? 0 : 1, dl, MVT::i32), InGlue
1994 };
1995 // We model convergent calls as separate opcodes.
1997 if (CLI.IsConvergent)
2000 Chain = DAG.getNode(Opcode, dl, PrintCallVTs, PrintCallOps);
2001 InGlue = Chain.getValue(1);
2002
2003 // Ops to print out the function name
2004 SDVTList CallVoidVTs = DAG.getVTList(MVT::Other, MVT::Glue);
2005 SDValue CallVoidOps[] = { Chain, Callee, InGlue };
2006 Chain = DAG.getNode(NVPTXISD::CallVoid, dl, CallVoidVTs, CallVoidOps);
2007 InGlue = Chain.getValue(1);
2008
2009 // Ops to print out the param list
2010 SDVTList CallArgBeginVTs = DAG.getVTList(MVT::Other, MVT::Glue);
2011 SDValue CallArgBeginOps[] = { Chain, InGlue };
2012 Chain = DAG.getNode(NVPTXISD::CallArgBegin, dl, CallArgBeginVTs,
2013 CallArgBeginOps);
2014 InGlue = Chain.getValue(1);
2015
2016 for (unsigned i = 0, e = std::min(CLI.NumFixedArgs + 1, ParamCount); i != e;
2017 ++i) {
2018 unsigned opcode;
2019 if (i == (e - 1))
2020 opcode = NVPTXISD::LastCallArg;
2021 else
2022 opcode = NVPTXISD::CallArg;
2023 SDVTList CallArgVTs = DAG.getVTList(MVT::Other, MVT::Glue);
2024 SDValue CallArgOps[] = { Chain, DAG.getConstant(1, dl, MVT::i32),
2025 DAG.getConstant(i, dl, MVT::i32), InGlue };
2026 Chain = DAG.getNode(opcode, dl, CallArgVTs, CallArgOps);
2027 InGlue = Chain.getValue(1);
2028 }
2029 SDVTList CallArgEndVTs = DAG.getVTList(MVT::Other, MVT::Glue);
2030 SDValue CallArgEndOps[] = { Chain,
2031 DAG.getConstant(isIndirectCall ? 0 : 1, dl, MVT::i32),
2032 InGlue };
2033 Chain = DAG.getNode(NVPTXISD::CallArgEnd, dl, CallArgEndVTs, CallArgEndOps);
2034 InGlue = Chain.getValue(1);
2035
2036 if (isIndirectCall) {
2037 SDVTList PrototypeVTs = DAG.getVTList(MVT::Other, MVT::Glue);
2038 SDValue PrototypeOps[] = {
2039 Chain, DAG.getConstant(UniqueCallSite, dl, MVT::i32), InGlue};
2040 Chain = DAG.getNode(NVPTXISD::Prototype, dl, PrototypeVTs, PrototypeOps);
2041 InGlue = Chain.getValue(1);
2042 }
2043
2044 SmallVector<SDValue, 16> ProxyRegOps;
2045 SmallVector<std::optional<MVT>, 16> ProxyRegTruncates;
2046 // An item of the vector is filled if the element does not need a ProxyReg
2047 // operation on it and should be added to InVals as is. ProxyRegOps and
2048 // ProxyRegTruncates contain empty/none items at the same index.
2050 // A temporary ProxyReg operations inserted in `LowerUnalignedLoadRetParam()`
2051 // to use the values of `LoadParam`s and to be replaced later then
2052 // `CALLSEQ_END` is added.
2053 SmallVector<SDValue, 16> TempProxyRegOps;
2054
2055 // Generate loads from param memory/moves from registers for result
2056 if (Ins.size() > 0) {
2059 ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets, 0);
2060 assert(VTs.size() == Ins.size() && "Bad value decomposition");
2061
2062 Align RetAlign = getArgumentAlignment(CB, RetTy, 0, DL);
2063 auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, RetAlign);
2064
2065 SmallVector<EVT, 6> LoadVTs;
2066 int VecIdx = -1; // Index of the first element of the vector.
2067
2068 // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
2069 // 32-bits are sign extended or zero extended, depending on whether
2070 // they are signed or unsigned types.
2071 bool ExtendIntegerRetVal =
2072 RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32;
2073
2074 for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
2075 bool needTruncate = false;
2076 EVT TheLoadType = VTs[i];
2077 EVT EltType = Ins[i].VT;
2078 Align EltAlign = commonAlignment(RetAlign, Offsets[i]);
2079 MVT PromotedVT;
2080
2081 if (PromoteScalarIntegerPTX(TheLoadType, &PromotedVT)) {
2082 TheLoadType = EVT(PromotedVT);
2083 EltType = EVT(PromotedVT);
2084 needTruncate = true;
2085 }
2086
2087 if (ExtendIntegerRetVal) {
2088 TheLoadType = MVT::i32;
2089 EltType = MVT::i32;
2090 needTruncate = true;
2091 } else if (TheLoadType.getSizeInBits() < 16) {
2092 if (VTs[i].isInteger())
2093 needTruncate = true;
2094 EltType = MVT::i16;
2095 }
2096
2097 // If we have a PVF_SCALAR entry, it may not be sufficiently aligned for a
2098 // scalar load. In such cases, fall back to byte loads.
2099 if (VectorInfo[i] == PVF_SCALAR && RetTy->isAggregateType() &&
2100 EltAlign < DL.getABITypeAlign(
2101 TheLoadType.getTypeForEVT(*DAG.getContext()))) {
2102 assert(VecIdx == -1 && LoadVTs.empty() && "Orphaned operand list.");
2104 DAG, Chain, Offsets[i], TheLoadType, InGlue, TempProxyRegOps, dl);
2105 ProxyRegOps.push_back(SDValue());
2106 ProxyRegTruncates.push_back(std::optional<MVT>());
2107 RetElts.resize(i);
2108 RetElts.push_back(Ret);
2109
2110 continue;
2111 }
2112
2113 // Record index of the very first element of the vector.
2114 if (VectorInfo[i] & PVF_FIRST) {
2115 assert(VecIdx == -1 && LoadVTs.empty() && "Orphaned operand list.");
2116 VecIdx = i;
2117 }
2118
2119 LoadVTs.push_back(EltType);
2120
2121 if (VectorInfo[i] & PVF_LAST) {
2122 unsigned NumElts = LoadVTs.size();
2123 LoadVTs.push_back(MVT::Other);
2124 LoadVTs.push_back(MVT::Glue);
2126 switch (NumElts) {
2127 case 1:
2129 break;
2130 case 2:
2132 break;
2133 case 4:
2135 break;
2136 default:
2137 llvm_unreachable("Invalid vector info.");
2138 }
2139
2140 SDValue LoadOperands[] = {
2141 Chain, DAG.getConstant(1, dl, MVT::i32),
2142 DAG.getConstant(Offsets[VecIdx], dl, MVT::i32), InGlue};
2143 SDValue RetVal = DAG.getMemIntrinsicNode(
2144 Op, dl, DAG.getVTList(LoadVTs), LoadOperands, TheLoadType,
2145 MachinePointerInfo(), EltAlign,
2147
2148 for (unsigned j = 0; j < NumElts; ++j) {
2149 ProxyRegOps.push_back(RetVal.getValue(j));
2150
2151 if (needTruncate)
2152 ProxyRegTruncates.push_back(std::optional<MVT>(Ins[VecIdx + j].VT));
2153 else
2154 ProxyRegTruncates.push_back(std::optional<MVT>());
2155 }
2156
2157 Chain = RetVal.getValue(NumElts);
2158 InGlue = RetVal.getValue(NumElts + 1);
2159
2160 // Cleanup
2161 VecIdx = -1;
2162 LoadVTs.clear();
2163 }
2164 }
2165 }
2166
2167 Chain =
2168 DAG.getCALLSEQ_END(Chain, UniqueCallSite, UniqueCallSite + 1, InGlue, dl);
2169 InGlue = Chain.getValue(1);
2170
2171 // Append ProxyReg instructions to the chain to make sure that `callseq_end`
2172 // will not get lost. Otherwise, during libcalls expansion, the nodes can become
2173 // dangling.
2174 for (unsigned i = 0; i < ProxyRegOps.size(); ++i) {
2175 if (i < RetElts.size() && RetElts[i]) {
2176 InVals.push_back(RetElts[i]);
2177 continue;
2178 }
2179
2180 SDValue Ret = DAG.getNode(
2182 DAG.getVTList(ProxyRegOps[i].getSimpleValueType(), MVT::Other, MVT::Glue),
2183 { Chain, ProxyRegOps[i], InGlue }
2184 );
2185
2186 Chain = Ret.getValue(1);
2187 InGlue = Ret.getValue(2);
2188
2189 if (ProxyRegTruncates[i]) {
2190 Ret = DAG.getNode(ISD::TRUNCATE, dl, *ProxyRegTruncates[i], Ret);
2191 }
2192
2193 InVals.push_back(Ret);
2194 }
2195
2196 for (SDValue &T : TempProxyRegOps) {
2197 SDValue Repl = DAG.getNode(
2199 DAG.getVTList(T.getSimpleValueType(), MVT::Other, MVT::Glue),
2200 {Chain, T.getOperand(0), InGlue});
2201 DAG.ReplaceAllUsesWith(T, Repl);
2202 DAG.RemoveDeadNode(T.getNode());
2203
2204 Chain = Repl.getValue(1);
2205 InGlue = Repl.getValue(2);
2206 }
2207
2208 // set isTailCall to false for now, until we figure out how to express
2209 // tail call optimization in PTX
2210 isTailCall = false;
2211 return Chain;
2212}
2213
2215 SelectionDAG &DAG) const {
2216
2217 if (STI.getPTXVersion() < 73 || STI.getSmVersion() < 52) {
2218 const Function &Fn = DAG.getMachineFunction().getFunction();
2219
2220 DiagnosticInfoUnsupported NoDynamicAlloca(
2221 Fn,
2222 "Support for dynamic alloca introduced in PTX ISA version 7.3 and "
2223 "requires target sm_52.",
2224 SDLoc(Op).getDebugLoc());
2225 DAG.getContext()->diagnose(NoDynamicAlloca);
2226 auto Ops = {DAG.getConstant(0, SDLoc(), Op.getValueType()),
2227 Op.getOperand(0)};
2228 return DAG.getMergeValues(Ops, SDLoc());
2229 }
2230
2231 SDValue Chain = Op.getOperand(0);
2232 SDValue Size = Op.getOperand(1);
2233 uint64_t Align = cast<ConstantSDNode>(Op.getOperand(2))->getZExtValue();
2234 SDLoc DL(Op.getNode());
2235
2236 // The size for ptx alloca instruction is 64-bit for m64 and 32-bit for m32.
2237 MVT ValueSizeTy = nvTM->is64Bit() ? MVT::i64 : MVT::i32;
2238
2239 SDValue AllocOps[] = {Chain, DAG.getZExtOrTrunc(Size, DL, ValueSizeTy),
2240 DAG.getTargetConstant(Align, DL, MVT::i32)};
2241 EVT RetTypes[] = {ValueSizeTy, MVT::Other};
2242 return DAG.getNode(NVPTXISD::DYNAMIC_STACKALLOC, DL, RetTypes, AllocOps);
2243}
2244
2245// By default CONCAT_VECTORS is lowered by ExpandVectorBuildThroughStack()
2246// (see LegalizeDAG.cpp). This is slow and uses local memory.
2247// We use extract/insert/build vector just as what LegalizeOp() does in llvm 2.5
2248SDValue
2249NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
2250 SDNode *Node = Op.getNode();
2251 SDLoc dl(Node);
2253 unsigned NumOperands = Node->getNumOperands();
2254 for (unsigned i = 0; i < NumOperands; ++i) {
2255 SDValue SubOp = Node->getOperand(i);
2256 EVT VVT = SubOp.getNode()->getValueType(0);
2257 EVT EltVT = VVT.getVectorElementType();
2258 unsigned NumSubElem = VVT.getVectorNumElements();
2259 for (unsigned j = 0; j < NumSubElem; ++j) {
2260 Ops.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, SubOp,
2261 DAG.getIntPtrConstant(j, dl)));
2262 }
2263 }
2264 return DAG.getBuildVector(Node->getValueType(0), dl, Ops);
2265}
2266
2267// We can init constant f16x2/v2i16/v4i8 with a single .b32 move. Normally it
2268// would get lowered as two constant loads and vector-packing move.
2269// Instead we want just a constant move:
2270// mov.b32 %r2, 0x40003C00
2271SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op,
2272 SelectionDAG &DAG) const {
2273 EVT VT = Op->getValueType(0);
2274 if (!(Isv2x16VT(VT) || VT == MVT::v4i8))
2275 return Op;
2276
2277 SDLoc DL(Op);
2278
2279 if (!llvm::all_of(Op->ops(), [](SDValue Operand) {
2280 return Operand->isUndef() || isa<ConstantSDNode>(Operand) ||
2281 isa<ConstantFPSDNode>(Operand);
2282 })) {
2283 // Lower non-const v4i8 vector as byte-wise constructed i32, which allows us
2284 // to optimize calculation of constant parts.
2285 if (VT == MVT::v4i8) {
2286 SDValue C8 = DAG.getConstant(8, DL, MVT::i32);
2287 SDValue E01 = DAG.getNode(
2288 NVPTXISD::BFI, DL, MVT::i32,
2289 DAG.getAnyExtOrTrunc(Op->getOperand(1), DL, MVT::i32),
2290 DAG.getAnyExtOrTrunc(Op->getOperand(0), DL, MVT::i32), C8, C8);
2291 SDValue E012 =
2292 DAG.getNode(NVPTXISD::BFI, DL, MVT::i32,
2293 DAG.getAnyExtOrTrunc(Op->getOperand(2), DL, MVT::i32),
2294 E01, DAG.getConstant(16, DL, MVT::i32), C8);
2295 SDValue E0123 =
2296 DAG.getNode(NVPTXISD::BFI, DL, MVT::i32,
2297 DAG.getAnyExtOrTrunc(Op->getOperand(3), DL, MVT::i32),
2298 E012, DAG.getConstant(24, DL, MVT::i32), C8);
2299 return DAG.getNode(ISD::BITCAST, DL, VT, E0123);
2300 }
2301 return Op;
2302 }
2303
2304 // Get value or the Nth operand as an APInt(32). Undef values treated as 0.
2305 auto GetOperand = [](SDValue Op, int N) -> APInt {
2306 const SDValue &Operand = Op->getOperand(N);
2307 EVT VT = Op->getValueType(0);
2308 if (Operand->isUndef())
2309 return APInt(32, 0);
2310 APInt Value;
2311 if (VT == MVT::v2f16 || VT == MVT::v2bf16)
2312 Value = cast<ConstantFPSDNode>(Operand)->getValueAPF().bitcastToAPInt();
2313 else if (VT == MVT::v2i16 || VT == MVT::v4i8)
2314 Value = Operand->getAsAPIntVal();
2315 else
2316 llvm_unreachable("Unsupported type");
2317 // i8 values are carried around as i16, so we need to zero out upper bits,
2318 // so they do not get in the way of combining individual byte values
2319 if (VT == MVT::v4i8)
2320 Value = Value.trunc(8);
2321 return Value.zext(32);
2322 };
2323 APInt Value;
2324 if (Isv2x16VT(VT)) {
2325 Value = GetOperand(Op, 0) | GetOperand(Op, 1).shl(16);
2326 } else if (VT == MVT::v4i8) {
2327 Value = GetOperand(Op, 0) | GetOperand(Op, 1).shl(8) |
2328 GetOperand(Op, 2).shl(16) | GetOperand(Op, 3).shl(24);
2329 } else {
2330 llvm_unreachable("Unsupported type");
2331 }
2332 SDValue Const = DAG.getConstant(Value, SDLoc(Op), MVT::i32);
2333 return DAG.getNode(ISD::BITCAST, SDLoc(Op), Op->getValueType(0), Const);
2334}
2335
2336SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
2337 SelectionDAG &DAG) const {
2338 SDValue Index = Op->getOperand(1);
2339 SDValue Vector = Op->getOperand(0);
2340 SDLoc DL(Op);
2341 EVT VectorVT = Vector.getValueType();
2342
2343 if (VectorVT == MVT::v4i8) {
2344 SDValue BFE =
2345 DAG.getNode(NVPTXISD::BFE, DL, MVT::i32,
2346 {Vector,
2347 DAG.getNode(ISD::MUL, DL, MVT::i32,
2348 DAG.getZExtOrTrunc(Index, DL, MVT::i32),
2349 DAG.getConstant(8, DL, MVT::i32)),
2350 DAG.getConstant(8, DL, MVT::i32)});
2351 return DAG.getAnyExtOrTrunc(BFE, DL, Op->getValueType(0));
2352 }
2353
2354 // Constant index will be matched by tablegen.
2355 if (isa<ConstantSDNode>(Index.getNode()))
2356 return Op;
2357
2358 // Extract individual elements and select one of them.
2359 assert(Isv2x16VT(VectorVT) && "Unexpected vector type.");
2360 EVT EltVT = VectorVT.getVectorElementType();
2361
2362 SDLoc dl(Op.getNode());
2363 SDValue E0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, Vector,
2364 DAG.getIntPtrConstant(0, dl));
2365 SDValue E1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, Vector,
2366 DAG.getIntPtrConstant(1, dl));
2367 return DAG.getSelectCC(dl, Index, DAG.getIntPtrConstant(0, dl), E0, E1,
2369}
2370
2371SDValue NVPTXTargetLowering::LowerINSERT_VECTOR_ELT(SDValue Op,
2372 SelectionDAG &DAG) const {
2373 SDValue Vector = Op->getOperand(0);
2374 EVT VectorVT = Vector.getValueType();
2375
2376 if (VectorVT != MVT::v4i8)
2377 return Op;
2378 SDLoc DL(Op);
2379 SDValue Value = Op->getOperand(1);
2380 if (Value->isUndef())
2381 return Vector;
2382
2383 SDValue Index = Op->getOperand(2);
2384
2385 SDValue BFI =
2386 DAG.getNode(NVPTXISD::BFI, DL, MVT::i32,
2387 {DAG.getZExtOrTrunc(Value, DL, MVT::i32), Vector,
2388 DAG.getNode(ISD::MUL, DL, MVT::i32,
2389 DAG.getZExtOrTrunc(Index, DL, MVT::i32),
2390 DAG.getConstant(8, DL, MVT::i32)),
2391 DAG.getConstant(8, DL, MVT::i32)});
2392 return DAG.getNode(ISD::BITCAST, DL, Op->getValueType(0), BFI);
2393}
2394
2395SDValue NVPTXTargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
2396 SelectionDAG &DAG) const {
2397 SDValue V1 = Op.getOperand(0);
2398 EVT VectorVT = V1.getValueType();
2399 if (VectorVT != MVT::v4i8 || Op.getValueType() != MVT::v4i8)
2400 return Op;
2401
2402 // Lower shuffle to PRMT instruction.
2403 const ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Op.getNode());
2404 SDValue V2 = Op.getOperand(1);
2405 uint32_t Selector = 0;
2406 for (auto I : llvm::enumerate(SVN->getMask())) {
2407 if (I.value() != -1) // -1 is a placeholder for undef.
2408 Selector |= (I.value() << (I.index() * 4));
2409 }
2410
2411 SDLoc DL(Op);
2412 return DAG.getNode(NVPTXISD::PRMT, DL, MVT::v4i8, V1, V2,
2413 DAG.getConstant(Selector, DL, MVT::i32),
2414 DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32));
2415}
2416/// LowerShiftRightParts - Lower SRL_PARTS, SRA_PARTS, which
2417/// 1) returns two i32 values and take a 2 x i32 value to shift plus a shift
2418/// amount, or
2419/// 2) returns two i64 values and take a 2 x i64 value to shift plus a shift
2420/// amount.
2421SDValue NVPTXTargetLowering::LowerShiftRightParts(SDValue Op,
2422 SelectionDAG &DAG) const {
2423 assert(Op.getNumOperands() == 3 && "Not a double-shift!");
2424 assert(Op.getOpcode() == ISD::SRA_PARTS || Op.getOpcode() == ISD::SRL_PARTS);
2425
2426 EVT VT = Op.getValueType();
2427 unsigned VTBits = VT.getSizeInBits();
2428 SDLoc dl(Op);
2429 SDValue ShOpLo = Op.getOperand(0);
2430 SDValue ShOpHi = Op.getOperand(1);
2431 SDValue ShAmt = Op.getOperand(2);
2432 unsigned Opc = (Op.getOpcode() == ISD::SRA_PARTS) ? ISD::SRA : ISD::SRL;
2433
2434 if (VTBits == 32 && STI.getSmVersion() >= 35) {
2435 // For 32bit and sm35, we can use the funnel shift 'shf' instruction.
2436 // {dHi, dLo} = {aHi, aLo} >> Amt
2437 // dHi = aHi >> Amt
2438 // dLo = shf.r.clamp aLo, aHi, Amt
2439
2440 SDValue Hi = DAG.getNode(Opc, dl, VT, ShOpHi, ShAmt);
2441 SDValue Lo = DAG.getNode(NVPTXISD::FUN_SHFR_CLAMP, dl, VT, ShOpLo, ShOpHi,
2442 ShAmt);
2443
2444 SDValue Ops[2] = { Lo, Hi };
2445 return DAG.getMergeValues(Ops, dl);
2446 }
2447 else {
2448 // {dHi, dLo} = {aHi, aLo} >> Amt
2449 // - if (Amt>=size) then
2450 // dLo = aHi >> (Amt-size)
2451 // dHi = aHi >> Amt (this is either all 0 or all 1)
2452 // else
2453 // dLo = (aLo >>logic Amt) | (aHi << (size-Amt))
2454 // dHi = aHi >> Amt
2455
2456 SDValue RevShAmt = DAG.getNode(ISD::SUB, dl, MVT::i32,
2457 DAG.getConstant(VTBits, dl, MVT::i32),
2458 ShAmt);
2459 SDValue Tmp1 = DAG.getNode(ISD::SRL, dl, VT, ShOpLo, ShAmt);
2460 SDValue ExtraShAmt = DAG.getNode(ISD::SUB, dl, MVT::i32, ShAmt,
2461 DAG.getConstant(VTBits, dl, MVT::i32));
2462 SDValue Tmp2 = DAG.getNode(ISD::SHL, dl, VT, ShOpHi, RevShAmt);
2463 SDValue FalseVal = DAG.getNode(ISD::OR, dl, VT, Tmp1, Tmp2);
2464 SDValue TrueVal = DAG.getNode(Opc, dl, VT, ShOpHi, ExtraShAmt);
2465
2466 SDValue Cmp = DAG.getSetCC(dl, MVT::i1, ShAmt,
2467 DAG.getConstant(VTBits, dl, MVT::i32),
2468 ISD::SETGE);
2469 SDValue Hi = DAG.getNode(Opc, dl, VT, ShOpHi, ShAmt);
2470 SDValue Lo = DAG.getNode(ISD::SELECT, dl, VT, Cmp, TrueVal, FalseVal);
2471
2472 SDValue Ops[2] = { Lo, Hi };
2473 return DAG.getMergeValues(Ops, dl);
2474 }
2475}
2476
2477/// LowerShiftLeftParts - Lower SHL_PARTS, which
2478/// 1) returns two i32 values and take a 2 x i32 value to shift plus a shift
2479/// amount, or
2480/// 2) returns two i64 values and take a 2 x i64 value to shift plus a shift
2481/// amount.
2482SDValue NVPTXTargetLowering::LowerShiftLeftParts(SDValue Op,
2483 SelectionDAG &DAG) const {
2484 assert(Op.getNumOperands() == 3 && "Not a double-shift!");
2485 assert(Op.getOpcode() == ISD::SHL_PARTS);
2486
2487 EVT VT = Op.getValueType();
2488 unsigned VTBits = VT.getSizeInBits();
2489 SDLoc dl(Op);
2490 SDValue ShOpLo = Op.getOperand(0);
2491 SDValue ShOpHi = Op.getOperand(1);
2492 SDValue ShAmt = Op.getOperand(2);
2493
2494 if (VTBits == 32 && STI.getSmVersion() >= 35) {
2495 // For 32bit and sm35, we can use the funnel shift 'shf' instruction.
2496 // {dHi, dLo} = {aHi, aLo} << Amt
2497 // dHi = shf.l.clamp aLo, aHi, Amt
2498 // dLo = aLo << Amt
2499
2500 SDValue Hi = DAG.getNode(NVPTXISD::FUN_SHFL_CLAMP, dl, VT, ShOpLo, ShOpHi,
2501 ShAmt);
2502 SDValue Lo = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ShAmt);
2503
2504 SDValue Ops[2] = { Lo, Hi };
2505 return DAG.getMergeValues(Ops, dl);
2506 }
2507 else {
2508 // {dHi, dLo} = {aHi, aLo} << Amt
2509 // - if (Amt>=size) then
2510 // dLo = aLo << Amt (all 0)
2511 // dLo = aLo << (Amt-size)
2512 // else
2513 // dLo = aLo << Amt
2514 // dHi = (aHi << Amt) | (aLo >> (size-Amt))
2515
2516 SDValue RevShAmt = DAG.getNode(ISD::SUB, dl, MVT::i32,
2517 DAG.getConstant(VTBits, dl, MVT::i32),
2518 ShAmt);
2519 SDValue Tmp1 = DAG.getNode(ISD::SHL, dl, VT, ShOpHi, ShAmt);
2520 SDValue ExtraShAmt = DAG.getNode(ISD::SUB, dl, MVT::i32, ShAmt,
2521 DAG.getConstant(VTBits, dl, MVT::i32));
2522 SDValue Tmp2 = DAG.getNode(ISD::SRL, dl, VT, ShOpLo, RevShAmt);
2523 SDValue FalseVal = DAG.getNode(ISD::OR, dl, VT, Tmp1, Tmp2);
2524 SDValue TrueVal = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ExtraShAmt);
2525
2526 SDValue Cmp = DAG.getSetCC(dl, MVT::i1, ShAmt,
2527 DAG.getConstant(VTBits, dl, MVT::i32),
2528 ISD::SETGE);
2529 SDValue Lo = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ShAmt);
2530 SDValue Hi = DAG.getNode(ISD::SELECT, dl, VT, Cmp, TrueVal, FalseVal);
2531
2532 SDValue Ops[2] = { Lo, Hi };
2533 return DAG.getMergeValues(Ops, dl);
2534 }
2535}
2536
2537SDValue NVPTXTargetLowering::LowerFROUND(SDValue Op, SelectionDAG &DAG) const {
2538 EVT VT = Op.getValueType();
2539
2540 if (VT == MVT::f32)
2541 return LowerFROUND32(Op, DAG);
2542
2543 if (VT == MVT::f64)
2544 return LowerFROUND64(Op, DAG);
2545
2546 llvm_unreachable("unhandled type");
2547}
2548
2549// This is the the rounding method used in CUDA libdevice in C like code:
2550// float roundf(float A)
2551// {
2552// float RoundedA = (float) (int) ( A > 0 ? (A + 0.5f) : (A - 0.5f));
2553// RoundedA = abs(A) > 0x1.0p23 ? A : RoundedA;
2554// return abs(A) < 0.5 ? (float)(int)A : RoundedA;
2555// }
2556SDValue NVPTXTargetLowering::LowerFROUND32(SDValue Op,
2557 SelectionDAG &DAG) const {
2558 SDLoc SL(Op);
2559 SDValue A = Op.getOperand(0);
2560 EVT VT = Op.getValueType();
2561
2562 SDValue AbsA = DAG.getNode(ISD::FABS, SL, VT, A);
2563
2564 // RoundedA = (float) (int) ( A > 0 ? (A + 0.5f) : (A - 0.5f))
2565 SDValue Bitcast = DAG.getNode(ISD::BITCAST, SL, MVT::i32, A);
2566 const int SignBitMask = 0x80000000;
2567 SDValue Sign = DAG.getNode(ISD::AND, SL, MVT::i32, Bitcast,
2568 DAG.getConstant(SignBitMask, SL, MVT::i32));
2569 const int PointFiveInBits = 0x3F000000;
2570 SDValue PointFiveWithSignRaw =
2571 DAG.getNode(ISD::OR, SL, MVT::i32, Sign,
2572 DAG.getConstant(PointFiveInBits, SL, MVT::i32));
2573 SDValue PointFiveWithSign =
2574 DAG.getNode(ISD::BITCAST, SL, VT, PointFiveWithSignRaw);
2575 SDValue AdjustedA = DAG.getNode(ISD::FADD, SL, VT, A, PointFiveWithSign);
2576 SDValue RoundedA = DAG.getNode(ISD::FTRUNC, SL, VT, AdjustedA);
2577
2578 // RoundedA = abs(A) > 0x1.0p23 ? A : RoundedA;
2579 EVT SetCCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
2580 SDValue IsLarge =
2581 DAG.getSetCC(SL, SetCCVT, AbsA, DAG.getConstantFP(pow(2.0, 23.0), SL, VT),
2582 ISD::SETOGT);
2583 RoundedA = DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA);
2584
2585 // return abs(A) < 0.5 ? (float)(int)A : RoundedA;
2586 SDValue IsSmall =DAG.getSetCC(SL, SetCCVT, AbsA,
2587 DAG.getConstantFP(0.5, SL, VT), ISD::SETOLT);
2588 SDValue RoundedAForSmallA = DAG.getNode(ISD::FTRUNC, SL, VT, A);
2589 return DAG.getNode(ISD::SELECT, SL, VT, IsSmall, RoundedAForSmallA, RoundedA);
2590}
2591
2592// The implementation of round(double) is similar to that of round(float) in
2593// that they both separate the value range into three regions and use a method
2594// specific to the region to round the values. However, round(double) first
2595// calculates the round of the absolute value and then adds the sign back while
2596// round(float) directly rounds the value with sign.
2597SDValue NVPTXTargetLowering::LowerFROUND64(SDValue Op,
2598 SelectionDAG &DAG) const {
2599 SDLoc SL(Op);
2600 SDValue A = Op.getOperand(0);
2601 EVT VT = Op.getValueType();
2602
2603 SDValue AbsA = DAG.getNode(ISD::FABS, SL, VT, A);
2604
2605 // double RoundedA = (double) (int) (abs(A) + 0.5f);
2606 SDValue AdjustedA = DAG.getNode(ISD::FADD, SL, VT, AbsA,
2607 DAG.getConstantFP(0.5, SL, VT));
2608 SDValue RoundedA = DAG.getNode(ISD::FTRUNC, SL, VT, AdjustedA);
2609
2610 // RoundedA = abs(A) < 0.5 ? (double)0 : RoundedA;
2611 EVT SetCCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
2612 SDValue IsSmall =DAG.getSetCC(SL, SetCCVT, AbsA,
2613 DAG.getConstantFP(0.5, SL, VT), ISD::SETOLT);
2614 RoundedA = DAG.getNode(ISD::SELECT, SL, VT, IsSmall,
2615 DAG.getConstantFP(0, SL, VT),
2616 RoundedA);
2617
2618 // Add sign to rounded_A
2619 RoundedA = DAG.getNode(ISD::FCOPYSIGN, SL, VT, RoundedA, A);
2620 DAG.getNode(ISD::FTRUNC, SL, VT, A);
2621
2622 // RoundedA = abs(A) > 0x1.0p52 ? A : RoundedA;
2623 SDValue IsLarge =
2624 DAG.getSetCC(SL, SetCCVT, AbsA, DAG.getConstantFP(pow(2.0, 52.0), SL, VT),
2625 ISD::SETOGT);
2626 return DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA);
2627}
2628
2629SDValue NVPTXTargetLowering::LowerINT_TO_FP(SDValue Op,
2630 SelectionDAG &DAG) const {
2631 assert(STI.getSmVersion() < 90 || STI.getPTXVersion() < 78);
2632
2633 if (Op.getValueType() == MVT::bf16) {
2634 SDLoc Loc(Op);
2635 return DAG.getNode(
2636 ISD::FP_ROUND, Loc, MVT::bf16,
2637 DAG.getNode(Op.getOpcode(), Loc, MVT::f32, Op.getOperand(0)),
2638 DAG.getIntPtrConstant(0, Loc));
2639 }
2640
2641 // Everything else is considered legal.
2642 return Op;
2643}
2644
2645SDValue NVPTXTargetLowering::LowerFP_TO_INT(SDValue Op,
2646 SelectionDAG &DAG) const {
2647 assert(STI.getSmVersion() < 90 || STI.getPTXVersion() < 78);
2648
2649 if (Op.getOperand(0).getValueType() == MVT::bf16) {
2650 SDLoc Loc(Op);
2651 return DAG.getNode(
2652 Op.getOpcode(), Loc, Op.getValueType(),
2653 DAG.getNode(ISD::FP_EXTEND, Loc, MVT::f32, Op.getOperand(0)));
2654 }
2655
2656 // Everything else is considered legal.
2657 return Op;
2658}
2659
2660SDValue NVPTXTargetLowering::LowerFP_ROUND(SDValue Op,
2661 SelectionDAG &DAG) const {
2662 EVT NarrowVT = Op.getValueType();
2663 SDValue Wide = Op.getOperand(0);
2664 EVT WideVT = Wide.getValueType();
2665 if (NarrowVT.getScalarType() == MVT::bf16) {
2666 const TargetLowering *TLI = STI.getTargetLowering();
2667 if (STI.getSmVersion() < 80 || STI.getPTXVersion() < 70) {
2668 return TLI->expandFP_ROUND(Op.getNode(), DAG);
2669 }
2670 if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
2671 // This combination was the first to support f32 -> bf16.
2672 if (STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70) {
2673 if (WideVT.getScalarType() == MVT::f32) {
2674 return Op;
2675 }
2676 if (WideVT.getScalarType() == MVT::f64) {
2677 SDLoc Loc(Op);
2678 // Round-inexact-to-odd f64 to f32, then do the final rounding using
2679 // the hardware f32 -> bf16 instruction.
2681 WideVT.isVector() ? WideVT.changeVectorElementType(MVT::f32)
2682 : MVT::f32,
2683 Wide, Loc, DAG);
2684 return DAG.getFPExtendOrRound(rod, Loc, NarrowVT);
2685 }
2686 }
2687 return TLI->expandFP_ROUND(Op.getNode(), DAG);
2688 }
2689 }
2690
2691 // Everything else is considered legal.
2692 return Op;
2693}
2694
2695SDValue NVPTXTargetLowering::LowerFP_EXTEND(SDValue Op,
2696 SelectionDAG &DAG) const {
2697 SDValue Narrow = Op.getOperand(0);
2698 EVT NarrowVT = Narrow.getValueType();
2699 EVT WideVT = Op.getValueType();
2700 if (NarrowVT.getScalarType() == MVT::bf16) {
2701 if (WideVT.getScalarType() == MVT::f32 &&
2702 (STI.getSmVersion() < 80 || STI.getPTXVersion() < 71)) {
2703 SDLoc Loc(Op);
2704 return DAG.getNode(ISD::BF16_TO_FP, Loc, WideVT, Narrow);
2705 }
2706 if (WideVT.getScalarType() == MVT::f64 &&
2707 (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78)) {
2708 EVT F32 = NarrowVT.isVector() ? NarrowVT.changeVectorElementType(MVT::f32)
2709 : MVT::f32;
2710 SDLoc Loc(Op);
2711 if (STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 71) {
2712 Op = DAG.getNode(ISD::FP_EXTEND, Loc, F32, Narrow);
2713 } else {
2714 Op = DAG.getNode(ISD::BF16_TO_FP, Loc, F32, Narrow);
2715 }
2716 return DAG.getNode(ISD::FP_EXTEND, Loc, WideVT, Op);
2717 }
2718 }
2719
2720 // Everything else is considered legal.
2721 return Op;
2722}
2723
2725 SDLoc DL(Op);
2726 if (Op.getValueType() != MVT::v2i16)
2727 return Op;
2728 EVT EltVT = Op.getValueType().getVectorElementType();
2729 SmallVector<SDValue> VecElements;
2730 for (int I = 0, E = Op.getValueType().getVectorNumElements(); I < E; I++) {
2731 SmallVector<SDValue> ScalarArgs;
2732 llvm::transform(Op->ops(), std::back_inserter(ScalarArgs),
2733 [&](const SDUse &O) {
2734 return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT,
2735 O.get(), DAG.getIntPtrConstant(I, DL));
2736 });
2737 VecElements.push_back(DAG.getNode(Op.getOpcode(), DL, EltVT, ScalarArgs));
2738 }
2739 SDValue V =
2740 DAG.getNode(ISD::BUILD_VECTOR, DL, Op.getValueType(), VecElements);
2741 return V;
2742}
2743
2744SDValue
2746 switch (Op.getOpcode()) {
2747 case ISD::RETURNADDR:
2748 return SDValue();
2749 case ISD::FRAMEADDR:
2750 return SDValue();
2751 case ISD::GlobalAddress:
2752 return LowerGlobalAddress(Op, DAG);
2754 return Op;
2755 case ISD::BUILD_VECTOR:
2756 return LowerBUILD_VECTOR(Op, DAG);
2758 return Op;
2760 return LowerEXTRACT_VECTOR_ELT(Op, DAG);
2762 return LowerINSERT_VECTOR_ELT(Op, DAG);
2764 return LowerVECTOR_SHUFFLE(Op, DAG);
2766 return LowerCONCAT_VECTORS(Op, DAG);
2767 case ISD::STORE:
2768 return LowerSTORE(Op, DAG);
2769 case ISD::LOAD:
2770 return LowerLOAD(Op, DAG);
2771 case ISD::SHL_PARTS:
2772 return LowerShiftLeftParts(Op, DAG);
2773 case ISD::SRA_PARTS:
2774 case ISD::SRL_PARTS:
2775 return LowerShiftRightParts(Op, DAG);
2776 case ISD::SELECT:
2777 return LowerSelect(Op, DAG);
2778 case ISD::FROUND:
2779 return LowerFROUND(Op, DAG);
2780 case ISD::SINT_TO_FP:
2781 case ISD::UINT_TO_FP:
2782 return LowerINT_TO_FP(Op, DAG);
2783 case ISD::FP_TO_SINT:
2784 case ISD::FP_TO_UINT:
2785 return LowerFP_TO_INT(Op, DAG);
2786 case ISD::FP_ROUND:
2787 return LowerFP_ROUND(Op, DAG);
2788 case ISD::FP_EXTEND:
2789 return LowerFP_EXTEND(Op, DAG);
2790 case ISD::BR_JT:
2791 return LowerBR_JT(Op, DAG);
2792 case ISD::VAARG:
2793 return LowerVAARG(Op, DAG);
2794 case ISD::VASTART:
2795 return LowerVASTART(Op, DAG);
2796 case ISD::ABS:
2797 case ISD::SMIN:
2798 case ISD::SMAX:
2799 case ISD::UMIN:
2800 case ISD::UMAX:
2801 case ISD::ADD:
2802 case ISD::SUB:
2803 case ISD::MUL:
2804 case ISD::SHL:
2805 case ISD::SREM:
2806 case ISD::UREM:
2807 return LowerVectorArith(Op, DAG);
2809 return LowerDYNAMIC_STACKALLOC(Op, DAG);
2810 case ISD::CopyToReg:
2811 return LowerCopyToReg_128(Op, DAG);
2812 default:
2813 llvm_unreachable("Custom lowering not defined for operation");
2814 }
2815}
2816
2817SDValue NVPTXTargetLowering::LowerBR_JT(SDValue Op, SelectionDAG &DAG) const {
2818 SDLoc DL(Op);
2819 SDValue Chain = Op.getOperand(0);
2820 const auto *JT = cast<JumpTableSDNode>(Op.getOperand(1));
2821 SDValue Index = Op.getOperand(2);
2822
2823 unsigned JId = JT->getIndex();
2825 ArrayRef<MachineBasicBlock *> MBBs = MJTI->getJumpTables()[JId].MBBs;
2826
2827 SDValue IdV = DAG.getConstant(JId, DL, MVT::i32);
2828
2829 // Generate BrxStart node
2830 SDVTList VTs = DAG.getVTList(MVT::Other, MVT::Glue);
2831 Chain = DAG.getNode(NVPTXISD::BrxStart, DL, VTs, Chain, IdV);
2832
2833 // Generate BrxItem nodes
2834 assert(!MBBs.empty());
2835 for (MachineBasicBlock *MBB : MBBs.drop_back())
2836 Chain = DAG.getNode(NVPTXISD::BrxItem, DL, VTs, Chain.getValue(0),
2837 DAG.getBasicBlock(MBB), Chain.getValue(1));
2838
2839 // Generate BrxEnd nodes
2840 SDValue EndOps[] = {Chain.getValue(0), DAG.getBasicBlock(MBBs.back()), Index,
2841 IdV, Chain.getValue(1)};
2842 SDValue BrxEnd = DAG.getNode(NVPTXISD::BrxEnd, DL, VTs, EndOps);
2843
2844 return BrxEnd;
2845}
2846
2847// This will prevent AsmPrinter from trying to print the jump tables itself.
2850}
2851
2852// This function is almost a copy of SelectionDAG::expandVAArg().
2853// The only diff is that this one produces loads from local address space.
2854SDValue NVPTXTargetLowering::LowerVAARG(SDValue Op, SelectionDAG &DAG) const {
2855 const TargetLowering *TLI = STI.getTargetLowering();
2856 SDLoc DL(Op);
2857
2858 SDNode *Node = Op.getNode();
2859 const Value *V = cast<SrcValueSDNode>(Node->getOperand(2))->getValue();
2860 EVT VT = Node->getValueType(0);
2861 auto *Ty = VT.getTypeForEVT(*DAG.getContext());
2862 SDValue Tmp1 = Node->getOperand(0);
2863 SDValue Tmp2 = Node->getOperand(1);
2864 const MaybeAlign MA(Node->getConstantOperandVal(3));
2865
2866 SDValue VAListLoad = DAG.getLoad(TLI->getPointerTy(DAG.getDataLayout()), DL,
2867 Tmp1, Tmp2, MachinePointerInfo(V));
2868 SDValue VAList = VAListLoad;
2869
2870 if (MA && *MA > TLI->getMinStackArgumentAlignment()) {
2871 VAList = DAG.getNode(
2872 ISD::ADD, DL, VAList.getValueType(), VAList,
2873 DAG.getConstant(MA->value() - 1, DL, VAList.getValueType()));
2874
2875 VAList = DAG.getNode(
2876 ISD::AND, DL, VAList.getValueType(), VAList,
2877 DAG.getConstant(-(int64_t)MA->value(), DL, VAList.getValueType()));
2878 }
2879
2880 // Increment the pointer, VAList, to the next vaarg
2881 Tmp1 = DAG.getNode(ISD::ADD, DL, VAList.getValueType(), VAList,
2883 DL, VAList.getValueType()));
2884
2885 // Store the incremented VAList to the legalized pointer
2886 Tmp1 = DAG.getStore(VAListLoad.getValue(1), DL, Tmp1, Tmp2,
2888
2889 const Value *SrcV =
2891
2892 // Load the actual argument out of the pointer VAList
2893 return DAG.getLoad(VT, DL, Tmp1, VAList, MachinePointerInfo(SrcV));
2894}
2895
2896SDValue NVPTXTargetLowering::LowerVASTART(SDValue Op, SelectionDAG &DAG) const {
2897 const TargetLowering *TLI = STI.getTargetLowering();
2898 SDLoc DL(Op);
2899 EVT PtrVT = TLI->getPointerTy(DAG.getDataLayout());
2900
2901 // Store the address of unsized array <function>_vararg[] in the ap object.
2902 SDValue Arg = getParamSymbol(DAG, /* vararg */ -1, PtrVT);
2903 SDValue VAReg = DAG.getNode(NVPTXISD::Wrapper, DL, PtrVT, Arg);
2904
2905 const Value *SV = cast<SrcValueSDNode>(Op.getOperand(2))->getValue();
2906 return DAG.getStore(Op.getOperand(0), DL, VAReg, Op.getOperand(1),
2907 MachinePointerInfo(SV));
2908}
2909
2910SDValue NVPTXTargetLowering::LowerSelect(SDValue Op, SelectionDAG &DAG) const {
2911 SDValue Op0 = Op->getOperand(0);
2912 SDValue Op1 = Op->getOperand(1);
2913 SDValue Op2 = Op->getOperand(2);
2914 SDLoc DL(Op.getNode());
2915
2916 assert(Op.getValueType() == MVT::i1 && "Custom lowering enabled only for i1");
2917
2918 Op1 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Op1);
2919 Op2 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Op2);
2920 SDValue Select = DAG.getNode(ISD::SELECT, DL, MVT::i32, Op0, Op1, Op2);
2921 SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, Select);
2922
2923 return Trunc;
2924}
2925
2926SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
2927 if (Op.getValueType() == MVT::i1)
2928 return LowerLOADi1(Op, DAG);
2929
2930 // v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to handle
2931 // unaligned loads and have to handle it here.
2932 EVT VT = Op.getValueType();
2933 if (Isv2x16VT(VT) || VT == MVT::v4i8) {
2934 LoadSDNode *Load = cast<LoadSDNode>(Op);
2935 EVT MemVT = Load->getMemoryVT();
2937 MemVT, *Load->getMemOperand())) {
2938 SDValue Ops[2];
2939 std::tie(Ops[0], Ops[1]) = expandUnalignedLoad(Load, DAG);
2940 return DAG.getMergeValues(Ops, SDLoc(Op));
2941 }
2942 }
2943
2944 return SDValue();
2945}
2946
2947// v = ld i1* addr
2948// =>
2949// v1 = ld i8* addr (-> i16)
2950// v = trunc i16 to i1
2951SDValue NVPTXTargetLowering::LowerLOADi1(SDValue Op, SelectionDAG &DAG) const {
2952 SDNode *Node = Op.getNode();
2953 LoadSDNode *LD = cast<LoadSDNode>(Node);
2954 SDLoc dl(Node);
2955 assert(LD->getExtensionType() == ISD::NON_EXTLOAD);
2956 assert(Node->getValueType(0) == MVT::i1 &&
2957 "Custom lowering for i1 load only");
2958 SDValue newLD = DAG.getExtLoad(ISD::ZEXTLOAD, dl, MVT::i16, LD->getChain(),
2959 LD->getBasePtr(), LD->getPointerInfo(),
2960 MVT::i8, LD->getAlign(),
2961 LD->getMemOperand()->getFlags());
2962 SDValue result = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, newLD);
2963 // The legalizer (the caller) is expecting two values from the legalized
2964 // load, so we build a MergeValues node for it. See ExpandUnalignedLoad()
2965 // in LegalizeDAG.cpp which also uses MergeValues.
2966 SDValue Ops[] = { result, LD->getChain() };
2967 return DAG.getMergeValues(Ops, dl);
2968}
2969
2970SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
2971 StoreSDNode *Store = cast<StoreSDNode>(Op);
2972 EVT VT = Store->getMemoryVT();
2973
2974 if (VT == MVT::i1)
2975 return LowerSTOREi1(Op, DAG);
2976
2977 // v2f16 is legal, so we can't rely on legalizer to handle unaligned
2978 // stores and have to handle it here.
2979 if ((Isv2x16VT(VT) || VT == MVT::v4i8) &&
2981 VT, *Store->getMemOperand()))
2982 return expandUnalignedStore(Store, DAG);
2983
2984 // v2f16, v2bf16 and v2i16 don't need special handling.
2985 if (Isv2x16VT(VT) || VT == MVT::v4i8)
2986 return SDValue();
2987
2988 if (VT.isVector())
2989 return LowerSTOREVector(Op, DAG);
2990
2991 return SDValue();
2992}
2993
2994SDValue
2995NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
2996 SDNode *N = Op.getNode();
2997 SDValue Val = N->getOperand(1);
2998 SDLoc DL(N);
2999 EVT ValVT = Val.getValueType();
3000
3001 if (ValVT.isVector()) {
3002 // We only handle "native" vector sizes for now, e.g. <4 x double> is not
3003 // legal. We can (and should) split that into 2 stores of <2 x double> here
3004 // but I'm leaving that as a TODO for now.
3005 if (!ValVT.isSimple())
3006 return SDValue();
3007 switch (ValVT.getSimpleVT().SimpleTy) {
3008 default:
3009 return SDValue();
3010 case MVT::v2i8:
3011 case MVT::v2i16:
3012 case MVT::v2i32:
3013 case MVT::v2i64:
3014 case MVT::v2f16:
3015 case MVT::v2bf16:
3016 case MVT::v2f32:
3017 case MVT::v2f64:
3018 case MVT::v4i8:
3019 case MVT::v4i16:
3020 case MVT::v4i32:
3021 case MVT::v4f16:
3022 case MVT::v4bf16:
3023 case MVT::v4f32:
3024 case MVT::v8f16: // <4 x f16x2>
3025 case MVT::v8bf16: // <4 x bf16x2>
3026 case MVT::v8i16: // <4 x i16x2>
3027 // This is a "native" vector type
3028 break;
3029 }
3030
3031 MemSDNode *MemSD = cast<MemSDNode>(N);
3032 const DataLayout &TD = DAG.getDataLayout();
3033
3034 Align Alignment = MemSD->getAlign();
3035 Align PrefAlign =
3036 TD.getPrefTypeAlign(ValVT.getTypeForEVT(*DAG.getContext()));
3037 if (Alignment < PrefAlign) {
3038 // This store is not sufficiently aligned, so bail out and let this vector
3039 // store be scalarized. Note that we may still be able to emit smaller
3040 // vector stores. For example, if we are storing a <4 x float> with an
3041 // alignment of 8, this check will fail but the legalizer will try again
3042 // with 2 x <2 x float>, which will succeed with an alignment of 8.
3043 return SDValue();
3044 }
3045
3046 unsigned Opcode = 0;
3047 EVT EltVT = ValVT.getVectorElementType();
3048 unsigned NumElts = ValVT.getVectorNumElements();
3049
3050 // Since StoreV2 is a target node, we cannot rely on DAG type legalization.
3051 // Therefore, we must ensure the type is legal. For i1 and i8, we set the
3052 // stored type to i16 and propagate the "real" type as the memory type.
3053 bool NeedExt = false;
3054 if (EltVT.getSizeInBits() < 16)
3055 NeedExt = true;
3056
3057 bool StoreF16x2 = false;
3058 switch (NumElts) {
3059 default:
3060 return SDValue();
3061 case 2:
3062 Opcode = NVPTXISD::StoreV2;
3063 break;
3064 case 4:
3065 Opcode = NVPTXISD::StoreV4;
3066 break;
3067 case 8:
3068 // v8f16 is a special case. PTX doesn't have st.v8.f16
3069 // instruction. Instead, we split the vector into v2f16 chunks and
3070 // store them with st.v4.b32.
3071 assert(Is16bitsType(EltVT.getSimpleVT()) && "Wrong type for the vector.");
3072 Opcode = NVPTXISD::StoreV4;
3073 StoreF16x2 = true;
3074 break;
3075 }
3076
3078
3079 // First is the chain
3080 Ops.push_back(N->getOperand(0));
3081
3082 if (StoreF16x2) {
3083 // Combine f16,f16 -> v2f16
3084 NumElts /= 2;
3085 for (unsigned i = 0; i < NumElts; ++i) {
3086 SDValue E0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
3087 DAG.getIntPtrConstant(i * 2, DL));
3088 SDValue E1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
3089 DAG.getIntPtrConstant(i * 2 + 1, DL));
3090 EVT VecVT = EVT::getVectorVT(*DAG.getContext(), EltVT, 2);
3091 SDValue V2 = DAG.getNode(ISD::BUILD_VECTOR, DL, VecVT, E0, E1);
3092 Ops.push_back(V2);
3093 }
3094 } else {
3095 // Then the split values
3096 for (unsigned i = 0; i < NumElts; ++i) {
3097 SDValue ExtVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
3098 DAG.getIntPtrConstant(i, DL));
3099 if (NeedExt)
3100 ExtVal = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i16, ExtVal);
3101 Ops.push_back(ExtVal);
3102 }
3103 }
3104
3105 // Then any remaining arguments
3106 Ops.append(N->op_begin() + 2, N->op_end());
3107
3108 SDValue NewSt =
3109 DAG.getMemIntrinsicNode(Opcode, DL, DAG.getVTList(MVT::Other), Ops,
3110 MemSD->getMemoryVT(), MemSD->getMemOperand());
3111
3112 // return DCI.CombineTo(N, NewSt, true);
3113 return NewSt;
3114 }
3115
3116 return SDValue();
3117}
3118
3119// st i1 v, addr
3120// =>
3121// v1 = zxt v to i16
3122// st.u8 i16, addr
3123SDValue NVPTXTargetLowering::LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const {
3124 SDNode *Node = Op.getNode();
3125 SDLoc dl(Node);
3126 StoreSDNode *ST = cast<StoreSDNode>(Node);
3127 SDValue Tmp1 = ST->getChain();
3128 SDValue Tmp2 = ST->getBasePtr();
3129 SDValue Tmp3 = ST->getValue();
3130 assert(Tmp3.getValueType() == MVT::i1 && "Custom lowering for i1 store only");
3131 Tmp3 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Tmp3);
3132 SDValue Result =
3133 DAG.getTruncStore(Tmp1, dl, Tmp3, Tmp2, ST->getPointerInfo(), MVT::i8,
3134 ST->getAlign(), ST->getMemOperand()->getFlags());
3135 return Result;
3136}
3137
3138SDValue NVPTXTargetLowering::LowerCopyToReg_128(SDValue Op,
3139 SelectionDAG &DAG) const {
3140 // Change the CopyToReg to take in two 64-bit operands instead of a 128-bit
3141 // operand so that it can pass the legalization.
3142
3143 assert(Op.getOperand(1).getValueType() == MVT::i128 &&
3144 "Custom lowering for 128-bit CopyToReg only");
3145
3146 SDNode *Node = Op.getNode();
3147 SDLoc DL(Node);
3148
3149 SDValue Cast = DAG.getBitcast(MVT::v2i64, Op->getOperand(2));
3150 SDValue Lo = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i64, Cast,
3151 DAG.getIntPtrConstant(0, DL));
3152 SDValue Hi = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i64, Cast,
3153 DAG.getIntPtrConstant(1, DL));
3154
3156 SmallVector<EVT, 3> ResultsType(Node->values());
3157
3158 NewOps[0] = Op->getOperand(0); // Chain
3159 NewOps[1] = Op->getOperand(1); // Dst Reg
3160 NewOps[2] = Lo; // Lower 64-bit
3161 NewOps[3] = Hi; // Higher 64-bit
3162 if (Op.getNumOperands() == 4)
3163 NewOps[4] = Op->getOperand(3); // Glue if exists
3164
3165 return DAG.getNode(ISD::CopyToReg, DL, ResultsType, NewOps);
3166}
3167
3168unsigned NVPTXTargetLowering::getNumRegisters(
3169 LLVMContext &Context, EVT VT,
3170 std::optional<MVT> RegisterVT = std::nullopt) const {
3171 if (VT == MVT::i128 && RegisterVT == MVT::i128)
3172 return 1;
3173 return TargetLoweringBase::getNumRegisters(Context, VT, RegisterVT);
3174}
3175
3176bool NVPTXTargetLowering::splitValueIntoRegisterParts(
3177 SelectionDAG &DAG, const SDLoc &DL, SDValue Val, SDValue *Parts,
3178 unsigned NumParts, MVT PartVT, std::optional<CallingConv::ID> CC) const {
3179 if (Val.getValueType() == MVT::i128 && NumParts == 1) {
3180 Parts[0] = Val;
3181 return true;
3182 }
3183 return false;
3184}
3185
3186// This creates target external symbol for a function parameter.
3187// Name of the symbol is composed from its index and the function name.
3188// Negative index corresponds to special parameter (unsized array) used for
3189// passing variable arguments.
3190SDValue NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, int idx,
3191 EVT v) const {
3192 StringRef SavedStr = nvTM->getStrPool().save(
3194 return DAG.getTargetExternalSymbol(SavedStr.data(), v);
3195}
3196
3198 SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
3199 const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &dl,
3200 SelectionDAG &DAG, SmallVectorImpl<SDValue> &InVals) const {
3202 const DataLayout &DL = DAG.getDataLayout();
3203 auto PtrVT = getPointerTy(DAG.getDataLayout());
3204
3205 const Function *F = &MF.getFunction();
3206 const AttributeList &PAL = F->getAttributes();
3207 const TargetLowering *TLI = STI.getTargetLowering();
3208
3209 SDValue Root = DAG.getRoot();
3210 std::vector<SDValue> OutChains;
3211
3212 bool isABI = (STI.getSmVersion() >= 20);
3213 assert(isABI && "Non-ABI compilation is not supported");
3214 if (!isABI)
3215 return Chain;
3216
3217 std::vector<Type *> argTypes;
3218 std::vector<const Argument *> theArgs;
3219 for (const Argument &I : F->args()) {
3220 theArgs.push_back(&I);
3221 argTypes.push_back(I.getType());
3222 }
3223 // argTypes.size() (or theArgs.size()) and Ins.size() need not match.
3224 // Ins.size() will be larger
3225 // * if there is an aggregate argument with multiple fields (each field
3226 // showing up separately in Ins)
3227 // * if there is a vector argument with more than typical vector-length
3228 // elements (generally if more than 4) where each vector element is
3229 // individually present in Ins.
3230 // So a different index should be used for indexing into Ins.
3231 // See similar issue in LowerCall.
3232 unsigned InsIdx = 0;
3233
3234 for (unsigned i = 0, e = theArgs.size(); i != e; ++i, ++InsIdx) {
3235 Type *Ty = argTypes[i];
3236
3237 if (theArgs[i]->use_empty()) {
3238 // argument is dead
3239 if (IsTypePassedAsArray(Ty) && !Ty->isVectorTy()) {
3240 SmallVector<EVT, 16> vtparts;
3241
3242 ComputePTXValueVTs(*this, DAG.getDataLayout(), Ty, vtparts);
3243 if (vtparts.empty())
3244 report_fatal_error("Empty parameter types are not supported");
3245
3246 for (unsigned parti = 0, parte = vtparts.size(); parti != parte;
3247 ++parti) {
3248 InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
3249 ++InsIdx;
3250 }
3251 if (vtparts.size() > 0)
3252 --InsIdx;
3253 continue;
3254 }
3255 if (Ty->isVectorTy()) {
3256 EVT ObjectVT = getValueType(DL, Ty);
3257 unsigned NumRegs = TLI->getNumRegisters(F->getContext(), ObjectVT);
3258 for (unsigned parti = 0; parti < NumRegs; ++parti) {
3259 InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
3260 ++InsIdx;
3261 }
3262 if (NumRegs > 0)
3263 --InsIdx;
3264 continue;
3265 }
3266 InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
3267 continue;
3268 }
3269
3270 // In the following cases, assign a node order of "i+1"
3271 // to newly created nodes. The SDNodes for params have to
3272 // appear in the same order as their order of appearance
3273 // in the original function. "i+1" holds that order.
3274 if (!PAL.hasParamAttr(i, Attribute::ByVal)) {
3275 bool aggregateIsPacked = false;
3276 if (StructType *STy = dyn_cast<StructType>(Ty))
3277 aggregateIsPacked = STy->isPacked();
3278
3281 ComputePTXValueVTs(*this, DL, Ty, VTs, &Offsets, 0);
3282 if (VTs.empty())
3283 report_fatal_error("Empty parameter types are not supported");
3284
3287 auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign);
3288
3289 SDValue Arg = getParamSymbol(DAG, i, PtrVT);
3290 int VecIdx = -1; // Index of the first element of the current vector.
3291 for (unsigned parti = 0, parte = VTs.size(); parti != parte; ++parti) {
3292 if (VectorInfo[parti] & PVF_FIRST) {
3293 assert(VecIdx == -1 && "Orphaned vector.");
3294 VecIdx = parti;
3295 }
3296
3297 // That's the last element of this store op.
3298 if (VectorInfo[parti] & PVF_LAST) {
3299 unsigned NumElts = parti - VecIdx + 1;
3300 EVT EltVT = VTs[parti];
3301 // i1 is loaded/stored as i8.
3302 EVT LoadVT = EltVT;
3303 if (EltVT == MVT::i1)
3304 LoadVT = MVT::i8;
3305 else if (Isv2x16VT(EltVT) || EltVT == MVT::v4i8)
3306 // getLoad needs a vector type, but it can't handle
3307 // vectors which contain v2f16 or v2bf16 elements. So we must load
3308 // using i32 here and then bitcast back.
3309 LoadVT = MVT::i32;
3310
3311 EVT VecVT = EVT::getVectorVT(F->getContext(), LoadVT, NumElts);
3312 SDValue VecAddr =
3313 DAG.getNode(ISD::ADD, dl, PtrVT, Arg,
3314 DAG.getConstant(Offsets[VecIdx], dl, PtrVT));
3316 EltVT.getTypeForEVT(F->getContext()), ADDRESS_SPACE_PARAM));
3317
3318 const MaybeAlign PartAlign = [&]() -> MaybeAlign {
3319 if (aggregateIsPacked)
3320 return Align(1);
3321 if (NumElts != 1)
3322 return std::nullopt;
3323 Align PartAlign =
3324 DL.getABITypeAlign(EltVT.getTypeForEVT(F->getContext()));
3325 return commonAlignment(PartAlign, Offsets[parti]);
3326 }();
3327 SDValue P = DAG.getLoad(VecVT, dl, Root, VecAddr,
3328 MachinePointerInfo(srcValue), PartAlign,
3331 if (P.getNode())
3332 P.getNode()->setIROrder(i + 1);
3333 for (unsigned j = 0; j < NumElts; ++j) {
3334 SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, LoadVT, P,
3335 DAG.getIntPtrConstant(j, dl));
3336 // We've loaded i1 as an i8 and now must truncate it back to i1
3337 if (EltVT == MVT::i1)
3338 Elt = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Elt);
3339 // v2f16 was loaded as an i32. Now we must bitcast it back.
3340 else if (EltVT != LoadVT)
3341 Elt = DAG.getNode(ISD::BITCAST, dl, EltVT, Elt);
3342
3343 // If a promoted integer type is used, truncate down to the original
3344 MVT PromotedVT;
3345 if (PromoteScalarIntegerPTX(EltVT, &PromotedVT)) {
3346 Elt = DAG.getNode(ISD::TRUNCATE, dl, EltVT, Elt);
3347 }
3348
3349 // Extend the element if necessary (e.g. an i8 is loaded
3350 // into an i16 register)
3351 if (Ins[InsIdx].VT.isInteger() &&
3352 Ins[InsIdx].VT.getFixedSizeInBits() >
3353 LoadVT.getFixedSizeInBits()) {
3354 unsigned Extend = Ins[InsIdx].Flags.isSExt() ? ISD::SIGN_EXTEND
3356 Elt = DAG.getNode(Extend, dl, Ins[InsIdx].VT, Elt);
3357 }
3358 InVals.push_back(Elt);
3359 }
3360
3361 // Reset vector tracking state.
3362 VecIdx = -1;
3363 }
3364 ++InsIdx;
3365 }
3366 if (VTs.size() > 0)
3367 --InsIdx;
3368 continue;
3369 }
3370
3371 // Param has ByVal attribute
3372 // Return MoveParam(param symbol).
3373 // Ideally, the param symbol can be returned directly,
3374 // but when SDNode builder decides to use it in a CopyToReg(),
3375 // machine instruction fails because TargetExternalSymbol
3376 // (not lowered) is target dependent, and CopyToReg assumes
3377 // the source is lowered.
3378 EVT ObjectVT = getValueType(DL, Ty);
3379 assert(ObjectVT == Ins[InsIdx].VT &&
3380 "Ins type did not match function type");
3381 SDValue Arg = getParamSymbol(DAG, i, PtrVT);
3382 SDValue p = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg);
3383 if (p.getNode())
3384 p.getNode()->setIROrder(i + 1);
3385 InVals.push_back(p);
3386 }
3387
3388 if (!OutChains.empty())
3389 DAG.setRoot(DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OutChains));
3390
3391 return Chain;
3392}
3393
3394// Use byte-store when the param adress of the return value is unaligned.
3395// This may happen when the return value is a field of a packed structure.
3397 uint64_t Offset, EVT ElementType,
3398 SDValue RetVal, const SDLoc &dl) {
3399 // Bit logic only works on integer types
3400 if (adjustElementType(ElementType))
3401 RetVal = DAG.getNode(ISD::BITCAST, dl, ElementType, RetVal);
3402
3403 // Store each byte
3404 for (unsigned i = 0, n = ElementType.getSizeInBits() / 8; i < n; i++) {
3405 // Shift the byte to the last byte position
3406 SDValue ShiftVal = DAG.getNode(ISD::SRL, dl, ElementType, RetVal,
3407 DAG.getConstant(i * 8, dl, MVT::i32));
3408 SDValue StoreOperands[] = {Chain, DAG.getConstant(Offset + i, dl, MVT::i32),
3409 ShiftVal};
3410 // Trunc store only the last byte by using
3411 // st.param.b8
3412 // The register type can be larger than b8.
3414 DAG.getVTList(MVT::Other), StoreOperands,
3415 MVT::i8, MachinePointerInfo(), std::nullopt,
3417 }
3418 return Chain;
3419}
3420
3421SDValue
3423 bool isVarArg,
3425 const SmallVectorImpl<SDValue> &OutVals,
3426 const SDLoc &dl, SelectionDAG &DAG) const {
3427 const MachineFunction &MF = DAG.getMachineFunction();
3428 const Function &F = MF.getFunction();
3430
3431 bool isABI = (STI.getSmVersion() >= 20);
3432 assert(isABI && "Non-ABI compilation is not supported");
3433 if (!isABI)
3434 return Chain;
3435
3436 const DataLayout &DL = DAG.getDataLayout();
3437 SmallVector<SDValue, 16> PromotedOutVals;
3440 ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets);
3441 assert(VTs.size() == OutVals.size() && "Bad return value decomposition");
3442
3443 for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
3444 SDValue PromotedOutVal = OutVals[i];
3445 MVT PromotedVT;
3446 if (PromoteScalarIntegerPTX(VTs[i], &PromotedVT)) {
3447 VTs[i] = EVT(PromotedVT);
3448 }
3449 if (PromoteScalarIntegerPTX(PromotedOutVal.getValueType(), &PromotedVT)) {
3451 Outs[i].Flags.isSExt() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
3452 PromotedOutVal = DAG.getNode(Ext, dl, PromotedVT, PromotedOutVal);
3453 }
3454 PromotedOutVals.push_back(PromotedOutVal);
3455 }
3456
3457 auto VectorInfo = VectorizePTXValueVTs(
3458 VTs, Offsets,
3460 : Align(1));
3461
3462 // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
3463 // 32-bits are sign extended or zero extended, depending on whether
3464 // they are signed or unsigned types.
3465 bool ExtendIntegerRetVal =
3466 RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32;
3467
3468 SmallVector<SDValue, 6> StoreOperands;
3469 for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
3470 SDValue OutVal = OutVals[i];
3471 SDValue RetVal = PromotedOutVals[i];
3472
3473 if (ExtendIntegerRetVal) {
3474 RetVal = DAG.getNode(Outs[i].Flags.isSExt() ? ISD::SIGN_EXTEND
3476 dl, MVT::i32, RetVal);
3477 } else if (OutVal.getValueSizeInBits() < 16) {
3478 // Use 16-bit registers for small load-stores as it's the
3479 // smallest general purpose register size supported by NVPTX.
3480 RetVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, RetVal);
3481 }
3482
3483 // If we have a PVF_SCALAR entry, it may not even be sufficiently aligned
3484 // for a scalar store. In such cases, fall back to byte stores.
3485 if (VectorInfo[i] == PVF_SCALAR && RetTy->isAggregateType()) {
3486 EVT ElementType = ExtendIntegerRetVal ? MVT::i32 : VTs[i];
3487 Align ElementTypeAlign =
3488 DL.getABITypeAlign(ElementType.getTypeForEVT(RetTy->getContext()));
3489 Align ElementAlign =
3490 commonAlignment(DL.getABITypeAlign(RetTy), Offsets[i]);
3491 if (ElementAlign < ElementTypeAlign) {
3492 assert(StoreOperands.empty() && "Orphaned operand list.");
3493 Chain = LowerUnalignedStoreRet(DAG, Chain, Offsets[i], ElementType,
3494 RetVal, dl);
3495
3496 // The call to LowerUnalignedStoreRet inserted the necessary SDAG nodes
3497 // into the graph, so just move on to the next element.
3498 continue;
3499 }
3500 }
3501
3502 // New load/store. Record chain and offset operands.
3503 if (VectorInfo[i] & PVF_FIRST) {
3504 assert(StoreOperands.empty() && "Orphaned operand list.");
3505 StoreOperands.push_back(Chain);
3506 StoreOperands.push_back(DAG.getConstant(Offsets[i], dl, MVT::i32));
3507 }
3508
3509 // Record the value to return.
3510 StoreOperands.push_back(RetVal);
3511
3512 // That's the last element of this store op.
3513 if (VectorInfo[i] & PVF_LAST) {
3515 unsigned NumElts = StoreOperands.size() - 2;
3516 switch (NumElts) {
3517 case 1:
3519 break;
3520 case 2:
3522 break;
3523 case 4:
3525 break;
3526 default:
3527 llvm_unreachable("Invalid vector info.");
3528 }
3529
3530 // Adjust type of load/store op if we've extended the scalar
3531 // return value.
3532 EVT TheStoreType = ExtendIntegerRetVal ? MVT::i32 : VTs[i];
3533 Chain = DAG.getMemIntrinsicNode(
3534 Op, dl, DAG.getVTList(MVT::Other), StoreOperands, TheStoreType,
3536 // Cleanup vector state.
3537 StoreOperands.clear();
3538 }
3539 }
3540
3541 return DAG.getNode(NVPTXISD::RET_GLUE, dl, MVT::Other, Chain);
3542}
3543
3545 SDValue Op, StringRef Constraint, std::vector<SDValue> &Ops,
3546 SelectionDAG &DAG) const {
3547 if (Constraint.size() > 1)
3548 return;
3549 TargetLowering::LowerAsmOperandForConstraint(Op, Constraint, Ops, DAG);
3550}
3551
3552static unsigned getOpcForTextureInstr(unsigned Intrinsic) {
3553 switch (Intrinsic) {
3554 default:
3555 return 0;
3556
3557 case Intrinsic::nvvm_tex_1d_v4f32_s32:
3559 case Intrinsic::nvvm_tex_1d_v4f32_f32:
3561 case Intrinsic::nvvm_tex_1d_level_v4f32_f32:
3563 case Intrinsic::nvvm_tex_1d_grad_v4f32_f32:
3565 case Intrinsic::nvvm_tex_1d_v4s32_s32:
3566 return NVPTXISD::Tex1DS32S32;
3567 case Intrinsic::nvvm_tex_1d_v4s32_f32:
3569 case Intrinsic::nvvm_tex_1d_level_v4s32_f32:
3571 case Intrinsic::nvvm_tex_1d_grad_v4s32_f32:
3573 case Intrinsic::nvvm_tex_1d_v4u32_s32:
3574 return NVPTXISD::Tex1DU32S32;
3575 case Intrinsic::nvvm_tex_1d_v4u32_f32:
3577 case Intrinsic::nvvm_tex_1d_level_v4u32_f32:
3579 case Intrinsic::nvvm_tex_1d_grad_v4u32_f32:
3581
3582 case Intrinsic::nvvm_tex_1d_array_v4f32_s32:
3584 case Intrinsic::nvvm_tex_1d_array_v4f32_f32:
3586 case Intrinsic::nvvm_tex_1d_array_level_v4f32_f32:
3588 case Intrinsic::nvvm_tex_1d_array_grad_v4f32_f32:
3590 case Intrinsic::nvvm_tex_1d_array_v4s32_s32:
3592 case Intrinsic::nvvm_tex_1d_array_v4s32_f32:
3594 case Intrinsic::nvvm_tex_1d_array_level_v4s32_f32:
3596 case Intrinsic::nvvm_tex_1d_array_grad_v4s32_f32:
3598 case Intrinsic::nvvm_tex_1d_array_v4u32_s32:
3600 case Intrinsic::nvvm_tex_1d_array_v4u32_f32:
3602 case Intrinsic::nvvm_tex_1d_array_level_v4u32_f32:
3604 case Intrinsic::nvvm_tex_1d_array_grad_v4u32_f32:
3606
3607 case Intrinsic::nvvm_tex_2d_v4f32_s32:
3609 case Intrinsic::nvvm_tex_2d_v4f32_f32:
3611 case Intrinsic::nvvm_tex_2d_level_v4f32_f32:
3613 case Intrinsic::nvvm_tex_2d_grad_v4f32_f32:
3615 case Intrinsic::nvvm_tex_2d_v4s32_s32:
3616 return NVPTXISD::Tex2DS32S32;
3617 case Intrinsic::nvvm_tex_2d_v4s32_f32:
3619 case Intrinsic::nvvm_tex_2d_level_v4s32_f32:
3621 case Intrinsic::nvvm_tex_2d_grad_v4s32_f32:
3623 case Intrinsic::nvvm_tex_2d_v4u32_s32:
3624 return NVPTXISD::Tex2DU32S32;
3625 case Intrinsic::nvvm_tex_2d_v4u32_f32:
3627 case Intrinsic::nvvm_tex_2d_level_v4u32_f32:
3629 case Intrinsic::nvvm_tex_2d_grad_v4u32_f32:
3631
3632 case Intrinsic::nvvm_tex_2d_array_v4f32_s32:
3634 case Intrinsic::nvvm_tex_2d_array_v4f32_f32:
3636 case Intrinsic::nvvm_tex_2d_array_level_v4f32_f32:
3638 case Intrinsic::nvvm_tex_2d_array_grad_v4f32_f32:
3640 case Intrinsic::nvvm_tex_2d_array_v4s32_s32:
3642 case Intrinsic::nvvm_tex_2d_array_v4s32_f32:
3644 case Intrinsic::nvvm_tex_2d_array_level_v4s32_f32:
3646 case Intrinsic::nvvm_tex_2d_array_grad_v4s32_f32:
3648 case Intrinsic::nvvm_tex_2d_array_v4u32_s32:
3650 case Intrinsic::nvvm_tex_2d_array_v4u32_f32:
3652 case Intrinsic::nvvm_tex_2d_array_level_v4u32_f32:
3654 case Intrinsic::nvvm_tex_2d_array_grad_v4u32_f32:
3656
3657 case Intrinsic::nvvm_tex_3d_v4f32_s32:
3659 case Intrinsic::nvvm_tex_3d_v4f32_f32:
3661 case Intrinsic::nvvm_tex_3d_level_v4f32_f32:
3663 case Intrinsic::nvvm_tex_3d_grad_v4f32_f32:
3665 case Intrinsic::nvvm_tex_3d_v4s32_s32:
3666 return NVPTXISD::Tex3DS32S32;
3667 case Intrinsic::nvvm_tex_3d_v4s32_f32:
3669 case Intrinsic::nvvm_tex_3d_level_v4s32_f32:
3671 case Intrinsic::nvvm_tex_3d_grad_v4s32_f32:
3673 case Intrinsic::nvvm_tex_3d_v4u32_s32:
3674 return NVPTXISD::Tex3DU32S32;
3675 case Intrinsic::nvvm_tex_3d_v4u32_f32:
3677 case Intrinsic::nvvm_tex_3d_level_v4u32_f32:
3679 case Intrinsic::nvvm_tex_3d_grad_v4u32_f32:
3681
3682 case Intrinsic::nvvm_tex_cube_v4f32_f32:
3684 case Intrinsic::nvvm_tex_cube_level_v4f32_f32:
3686 case Intrinsic::nvvm_tex_cube_v4s32_f32:
3688 case Intrinsic::nvvm_tex_cube_level_v4s32_f32:
3690 case Intrinsic::nvvm_tex_cube_v4u32_f32:
3692 case Intrinsic::nvvm_tex_cube_level_v4u32_f32:
3694
3695 case Intrinsic::nvvm_tex_cube_array_v4f32_f32:
3697 case Intrinsic::nvvm_tex_cube_array_level_v4f32_f32:
3699 case Intrinsic::nvvm_tex_cube_array_v4s32_f32:
3701 case Intrinsic::nvvm_tex_cube_array_level_v4s32_f32:
3703 case Intrinsic::nvvm_tex_cube_array_v4u32_f32:
3705 case Intrinsic::nvvm_tex_cube_array_level_v4u32_f32:
3707
3708 case Intrinsic::nvvm_tld4_r_2d_v4f32_f32:
3710 case Intrinsic::nvvm_tld4_g_2d_v4f32_f32:
3712 case Intrinsic::nvvm_tld4_b_2d_v4f32_f32:
3714 case Intrinsic::nvvm_tld4_a_2d_v4f32_f32:
3716 case Intrinsic::nvvm_tld4_r_2d_v4s32_f32:
3718 case Intrinsic::nvvm_tld4_g_2d_v4s32_f32:
3720 case Intrinsic::nvvm_tld4_b_2d_v4s32_f32:
3722 case Intrinsic::nvvm_tld4_a_2d_v4s32_f32:
3724 case Intrinsic::nvvm_tld4_r_2d_v4u32_f32:
3726 case Intrinsic::nvvm_tld4_g_2d_v4u32_f32:
3728 case Intrinsic::nvvm_tld4_b_2d_v4u32_f32:
3730 case Intrinsic::nvvm_tld4_a_2d_v4u32_f32:
3732
3733 case Intrinsic::nvvm_tex_unified_1d_v4f32_s32:
3735 case Intrinsic::nvvm_tex_unified_1d_v4f32_f32:
3737 case Intrinsic::nvvm_tex_unified_1d_level_v4f32_f32:
3739 case Intrinsic::nvvm_tex_unified_1d_grad_v4f32_f32:
3741 case Intrinsic::nvvm_tex_unified_1d_v4s32_s32:
3743 case Intrinsic::nvvm_tex_unified_1d_v4s32_f32:
3745 case Intrinsic::nvvm_tex_unified_1d_level_v4s32_f32:
3747 case Intrinsic::nvvm_tex_unified_1d_grad_v4s32_f32:
3749 case Intrinsic::nvvm_tex_unified_1d_v4u32_s32:
3751 case Intrinsic::nvvm_tex_unified_1d_v4u32_f32:
3753 case Intrinsic::nvvm_tex_unified_1d_level_v4u32_f32:
3755 case Intrinsic::nvvm_tex_unified_1d_grad_v4u32_f32:
3757
3758 case Intrinsic::nvvm_tex_unified_1d_array_v4f32_s32:
3760 case Intrinsic::nvvm_tex_unified_1d_array_v4f32_f32:
3762 case Intrinsic::nvvm_tex_unified_1d_array_level_v4f32_f32:
3764 case Intrinsic::nvvm_tex_unified_1d_array_grad_v4f32_f32:
3766 case Intrinsic::nvvm_tex_unified_1d_array_v4s32_s32:
3768 case Intrinsic::nvvm_tex_unified_1d_array_v4s32_f32:
3770 case Intrinsic::nvvm_tex_unified_1d_array_level_v4s32_f32:
3772 case Intrinsic::nvvm_tex_unified_1d_array_grad_v4s32_f32:
3774 case Intrinsic::nvvm_tex_unified_1d_array_v4u32_s32:
3776 case Intrinsic::nvvm_tex_unified_1d_array_v4u32_f32:
3778 case Intrinsic::nvvm_tex_unified_1d_array_level_v4u32_f32:
3780 case Intrinsic::nvvm_tex_unified_1d_array_grad_v4u32_f32:
3782
3783 case Intrinsic::nvvm_tex_unified_2d_v4f32_s32:
3785 case Intrinsic::nvvm_tex_unified_2d_v4f32_f32:
3787 case Intrinsic::nvvm_tex_unified_2d_level_v4f32_f32:
3789 case Intrinsic::nvvm_tex_unified_2d_grad_v4f32_f32:
3791 case Intrinsic::nvvm_tex_unified_2d_v4s32_s32:
3793 case Intrinsic::nvvm_tex_unified_2d_v4s32_f32:
3795 case Intrinsic::nvvm_tex_unified_2d_level_v4s32_f32:
3797 case Intrinsic::nvvm_tex_unified_2d_grad_v4s32_f32:
3799 case Intrinsic::nvvm_tex_unified_2d_v4u32_s32:
3801 case Intrinsic::nvvm_tex_unified_2d_v4u32_f32:
3803 case Intrinsic::nvvm_tex_unified_2d_level_v4u32_f32:
3805 case Intrinsic::nvvm_tex_unified_2d_grad_v4u32_f32:
3807
3808 case Intrinsic::nvvm_tex_unified_2d_array_v4f32_s32:
3810 case Intrinsic::nvvm_tex_unified_2d_array_v4f32_f32:
3812 case Intrinsic::nvvm_tex_unified_2d_array_level_v4f32_f32:
3814 case Intrinsic::nvvm_tex_unified_2d_array_grad_v4f32_f32:
3816 case Intrinsic::nvvm_tex_unified_2d_array_v4s32_s32:
3818 case Intrinsic::nvvm_tex_unified_2d_array_v4s32_f32:
3820 case Intrinsic::nvvm_tex_unified_2d_array_level_v4s32_f32:
3822 case Intrinsic::nvvm_tex_unified_2d_array_grad_v4s32_f32:
3824 case Intrinsic::nvvm_tex_unified_2d_array_v4u32_s32:
3826 case Intrinsic::nvvm_tex_unified_2d_array_v4u32_f32:
3828 case Intrinsic::nvvm_tex_unified_2d_array_level_v4u32_f32:
3830 case Intrinsic::nvvm_tex_unified_2d_array_grad_v4u32_f32:
3832
3833 case Intrinsic::nvvm_tex_unified_3d_v4f32_s32:
3835 case Intrinsic::nvvm_tex_unified_3d_v4f32_f32:
3837 case Intrinsic::nvvm_tex_unified_3d_level_v4f32_f32:
3839 case Intrinsic::nvvm_tex_unified_3d_grad_v4f32_f32:
3841 case Intrinsic::nvvm_tex_unified_3d_v4s32_s32:
3843 case Intrinsic::nvvm_tex_unified_3d_v4s32_f32:
3845 case Intrinsic::nvvm_tex_unified_3d_level_v4s32_f32:
3847 case Intrinsic::nvvm_tex_unified_3d_grad_v4s32_f32:
3849 case Intrinsic::nvvm_tex_unified_3d_v4u32_s32:
3851 case Intrinsic::nvvm_tex_unified_3d_v4u32_f32:
3853 case Intrinsic::nvvm_tex_unified_3d_level_v4u32_f32:
3855 case Intrinsic::nvvm_tex_unified_3d_grad_v4u32_f32:
3857
3858 case Intrinsic::nvvm_tex_unified_cube_v4f32_f32:
3860 case Intrinsic::nvvm_tex_unified_cube_level_v4f32_f32:
3862 case Intrinsic::nvvm_tex_unified_cube_v4s32_f32:
3864 case Intrinsic::nvvm_tex_unified_cube_level_v4s32_f32:
3866 case Intrinsic::nvvm_tex_unified_cube_v4u32_f32:
3868 case Intrinsic::nvvm_tex_unified_cube_level_v4u32_f32:
3870
3871 case Intrinsic::nvvm_tex_unified_cube_array_v4f32_f32:
3873 case Intrinsic::nvvm_tex_unified_cube_array_level_v4f32_f32:
3875 case Intrinsic::nvvm_tex_unified_cube_array_v4s32_f32:
3877 case Intrinsic::nvvm_tex_unified_cube_array_level_v4s32_f32:
3879 case Intrinsic::nvvm_tex_unified_cube_array_v4u32_f32:
3881 case Intrinsic::nvvm_tex_unified_cube_array_level_v4u32_f32:
3883
3884 case Intrinsic::nvvm_tex_unified_cube_grad_v4f32_f32:
3886 case Intrinsic::nvvm_tex_unified_cube_grad_v4s32_f32:
3888 case Intrinsic::nvvm_tex_unified_cube_grad_v4u32_f32:
3890 case Intrinsic::nvvm_tex_unified_cube_array_grad_v4f32_f32:
3892 case Intrinsic::nvvm_tex_unified_cube_array_grad_v4s32_f32:
3894 case Intrinsic::nvvm_tex_unified_cube_array_grad_v4u32_f32:
3896
3897 case Intrinsic::nvvm_tld4_unified_r_2d_v4f32_f32:
3899 case Intrinsic::nvvm_tld4_unified_g_2d_v4f32_f32:
3901 case Intrinsic::nvvm_tld4_unified_b_2d_v4f32_f32:
3903 case Intrinsic::nvvm_tld4_unified_a_2d_v4f32_f32:
3905 case Intrinsic::nvvm_tld4_unified_r_2d_v4s32_f32:
3907 case Intrinsic::nvvm_tld4_unified_g_2d_v4s32_f32:
3909 case Intrinsic::nvvm_tld4_unified_b_2d_v4s32_f32:
3911 case Intrinsic::nvvm_tld4_unified_a_2d_v4s32_f32:
3913 case Intrinsic::nvvm_tld4_unified_r_2d_v4u32_f32:
3915 case Intrinsic::nvvm_tld4_unified_g_2d_v4u32_f32:
3917 case Intrinsic::nvvm_tld4_unified_b_2d_v4u32_f32:
3919 case Intrinsic::nvvm_tld4_unified_a_2d_v4u32_f32:
3921 }
3922}
3923
3924static unsigned getOpcForSurfaceInstr(unsigned Intrinsic) {
3925 switch (Intrinsic) {
3926 default:
3927 return 0;
3928 case Intrinsic::nvvm_suld_1d_i8_clamp:
3930 case Intrinsic::nvvm_suld_1d_i16_clamp:
3932 case Intrinsic::nvvm_suld_1d_i32_clamp:
3934 case Intrinsic::nvvm_suld_1d_i64_clamp:
3936 case Intrinsic::nvvm_suld_1d_v2i8_clamp:
3938 case Intrinsic::nvvm_suld_1d_v2i16_clamp:
3940 case Intrinsic::nvvm_suld_1d_v2i32_clamp:
3942 case Intrinsic::nvvm_suld_1d_v2i64_clamp:
3944 case Intrinsic::nvvm_suld_1d_v4i8_clamp:
3946 case Intrinsic::nvvm_suld_1d_v4i16_clamp:
3948 case Intrinsic::nvvm_suld_1d_v4i32_clamp:
3950 case Intrinsic::nvvm_suld_1d_array_i8_clamp:
3952 case Intrinsic::nvvm_suld_1d_array_i16_clamp:
3954 case Intrinsic::nvvm_suld_1d_array_i32_clamp:
3956 case Intrinsic::nvvm_suld_1d_array_i64_clamp:
3958 case Intrinsic::nvvm_suld_1d_array_v2i8_clamp:
3960 case Intrinsic::nvvm_suld_1d_array_v2i16_clamp:
3962 case Intrinsic::nvvm_suld_1d_array_v2i32_clamp:
3964 case Intrinsic::nvvm_suld_1d_array_v2i64_clamp:
3966 case Intrinsic::nvvm_suld_1d_array_v4i8_clamp:
3968 case Intrinsic::nvvm_suld_1d_array_v4i16_clamp:
3970 case Intrinsic::nvvm_suld_1d_array_v4i32_clamp:
3972 case Intrinsic::nvvm_suld_2d_i8_clamp:
3974 case Intrinsic::nvvm_suld_2d_i16_clamp:
3976 case Intrinsic::nvvm_suld_2d_i32_clamp:
3978 case Intrinsic::nvvm_suld_2d_i64_clamp:
3980 case Intrinsic::nvvm_suld_2d_v2i8_clamp:
3982 case Intrinsic::nvvm_suld_2d_v2i16_clamp:
3984 case Intrinsic::nvvm_suld_2d_v2i32_clamp:
3986 case Intrinsic::nvvm_suld_2d_v2i64_clamp:
3988 case Intrinsic::nvvm_suld_2d_v4i8_clamp:
3990 case Intrinsic::nvvm_suld_2d_v4i16_clamp:
3992 case Intrinsic::nvvm_suld_2d_v4i32_clamp:
3994 case Intrinsic::nvvm_suld_2d_array_i8_clamp:
3996 case Intrinsic::nvvm_suld_2d_array_i16_clamp:
3998 case Intrinsic::nvvm_suld_2d_array_i32_clamp:
4000 case Intrinsic::nvvm_suld_2d_array_i64_clamp:
4002 case Intrinsic::nvvm_suld_2d_array_v2i8_clamp:
4004 case Intrinsic::nvvm_suld_2d_array_v2i16_clamp:
4006 case Intrinsic::nvvm_suld_2d_array_v2i32_clamp:
4008 case Intrinsic::nvvm_suld_2d_array_v2i64_clamp:
4010 case Intrinsic::nvvm_suld_2d_array_v4i8_clamp:
4012 case Intrinsic::nvvm_suld_2d_array_v4i16_clamp:
4014 case Intrinsic::nvvm_suld_2d_array_v4i32_clamp:
4016 case Intrinsic::nvvm_suld_3d_i8_clamp:
4018 case Intrinsic::nvvm_suld_3d_i16_clamp:
4020 case Intrinsic::nvvm_suld_3d_i32_clamp:
4022 case Intrinsic::nvvm_suld_3d_i64_clamp:
4024 case Intrinsic::nvvm_suld_3d_v2i8_clamp:
4026 case Intrinsic::nvvm_suld_3d_v2i16_clamp:
4028 case Intrinsic::nvvm_suld_3d_v2i32_clamp:
4030 case Intrinsic::nvvm_suld_3d_v2i64_clamp:
4032 case Intrinsic::nvvm_suld_3d_v4i8_clamp:
4034 case Intrinsic::nvvm_suld_3d_v4i16_clamp:
4036 case Intrinsic::nvvm_suld_3d_v4i32_clamp:
4038 case Intrinsic::nvvm_suld_1d_i8_trap:
4040 case Intrinsic::nvvm_suld_1d_i16_trap:
4042 case Intrinsic::nvvm_suld_1d_i32_trap:
4044 case Intrinsic::nvvm_suld_1d_i64_trap:
4046 case Intrinsic::nvvm_suld_1d_v2i8_trap:
4048 case Intrinsic::nvvm_suld_1d_v2i16_trap:
4050 case Intrinsic::nvvm_suld_1d_v2i32_trap:
4052 case Intrinsic::nvvm_suld_1d_v2i64_trap:
4054 case Intrinsic::nvvm_suld_1d_v4i8_trap:
4055 return