LLVM 19.0.0git
NVPTXISelLowering.cpp
Go to the documentation of this file.
1//===-- NVPTXISelLowering.cpp - NVPTX DAG Lowering Implementation ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the interfaces that NVPTX uses to lower LLVM code into a
10// selection DAG.
11//
12//===----------------------------------------------------------------------===//
13
14#include "NVPTXISelLowering.h"
16#include "NVPTX.h"
17#include "NVPTXSubtarget.h"
18#include "NVPTXTargetMachine.h"
20#include "NVPTXUtilities.h"
21#include "llvm/ADT/APInt.h"
22#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/StringRef.h"
35#include "llvm/IR/Argument.h"
36#include "llvm/IR/Attributes.h"
37#include "llvm/IR/Constants.h"
38#include "llvm/IR/DataLayout.h"
41#include "llvm/IR/FPEnv.h"
42#include "llvm/IR/Function.h"
43#include "llvm/IR/GlobalValue.h"
44#include "llvm/IR/Instruction.h"
46#include "llvm/IR/IntrinsicsNVPTX.h"
47#include "llvm/IR/Module.h"
48#include "llvm/IR/Type.h"
49#include "llvm/IR/Value.h"
58#include <algorithm>
59#include <cassert>
60#include <cmath>
61#include <cstdint>
62#include <iterator>
63#include <optional>
64#include <sstream>
65#include <string>
66#include <utility>
67#include <vector>
68
69#define DEBUG_TYPE "nvptx-lower"
70
71using namespace llvm;
72
73static std::atomic<unsigned> GlobalUniqueCallSite;
74
76 "nvptx-sched4reg",
77 cl::desc("NVPTX Specific: schedule for register pressue"), cl::init(false));
78
80 "nvptx-fma-level", cl::Hidden,
81 cl::desc("NVPTX Specific: FMA contraction (0: don't do it"
82 " 1: do it 2: do it aggressively"),
83 cl::init(2));
84
86 "nvptx-prec-divf32", cl::Hidden,
87 cl::desc("NVPTX Specifies: 0 use div.approx, 1 use div.full, 2 use"
88 " IEEE Compliant F32 div.rnd if available."),
89 cl::init(2));
90
92 "nvptx-prec-sqrtf32", cl::Hidden,
93 cl::desc("NVPTX Specific: 0 use sqrt.approx, 1 use sqrt.rn."),
94 cl::init(true));
95
97 "nvptx-force-min-byval-param-align", cl::Hidden,
98 cl::desc("NVPTX Specific: force 4-byte minimal alignment for byval"
99 " params of device functions."),
100 cl::init(false));
101
103 if (UsePrecDivF32.getNumOccurrences() > 0) {
104 // If nvptx-prec-div32=N is used on the command-line, always honor it
105 return UsePrecDivF32;
106 } else {
107 // Otherwise, use div.approx if fast math is enabled
108 if (getTargetMachine().Options.UnsafeFPMath)
109 return 0;
110 else
111 return 2;
112 }
113}
114
116 if (UsePrecSqrtF32.getNumOccurrences() > 0) {
117 // If nvptx-prec-sqrtf32 is used on the command-line, always honor it
118 return UsePrecSqrtF32;
119 } else {
120 // Otherwise, use sqrt.approx if fast math is enabled
122 }
123}
124
128}
129
130static bool IsPTXVectorType(MVT VT) {
131 switch (VT.SimpleTy) {
132 default:
133 return false;
134 case MVT::v2i1:
135 case MVT::v4i1:
136 case MVT::v2i8:
137 case MVT::v4i8:
138 case MVT::v2i16:
139 case MVT::v4i16:
140 case MVT::v8i16: // <4 x i16x2>
141 case MVT::v2i32:
142 case MVT::v4i32:
143 case MVT::v2i64:
144 case MVT::v2f16:
145 case MVT::v4f16:
146 case MVT::v8f16: // <4 x f16x2>
147 case MVT::v2bf16:
148 case MVT::v4bf16:
149 case MVT::v8bf16: // <4 x bf16x2>
150 case MVT::v2f32:
151 case MVT::v4f32:
152 case MVT::v2f64:
153 return true;
154 }
155}
156
157static bool Is16bitsType(MVT VT) {
158 return (VT.SimpleTy == MVT::f16 || VT.SimpleTy == MVT::bf16 ||
159 VT.SimpleTy == MVT::i16);
160}
161
162/// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive
163/// EVTs that compose it. Unlike ComputeValueVTs, this will break apart vectors
164/// into their primitive components.
165/// NOTE: This is a band-aid for code that expects ComputeValueVTs to return the
166/// same number of types as the Ins/Outs arrays in LowerFormalArguments,
167/// LowerCall, and LowerReturn.
168static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
169 Type *Ty, SmallVectorImpl<EVT> &ValueVTs,
170 SmallVectorImpl<uint64_t> *Offsets = nullptr,
171 uint64_t StartingOffset = 0) {
172 SmallVector<EVT, 16> TempVTs;
173 SmallVector<uint64_t, 16> TempOffsets;
174
175 // Special case for i128 - decompose to (i64, i64)
176 if (Ty->isIntegerTy(128)) {
177 ValueVTs.push_back(EVT(MVT::i64));
178 ValueVTs.push_back(EVT(MVT::i64));
179
180 if (Offsets) {
181 Offsets->push_back(StartingOffset + 0);
182 Offsets->push_back(StartingOffset + 8);
183 }
184
185 return;
186 }
187
188 // Given a struct type, recursively traverse the elements with custom ComputePTXValueVTs.
189 if (StructType *STy = dyn_cast<StructType>(Ty)) {
190 auto const *SL = DL.getStructLayout(STy);
191 auto ElementNum = 0;
192 for(auto *EI : STy->elements()) {
193 ComputePTXValueVTs(TLI, DL, EI, ValueVTs, Offsets,
194 StartingOffset + SL->getElementOffset(ElementNum));
195 ++ElementNum;
196 }
197 return;
198 }
199
200 ComputeValueVTs(TLI, DL, Ty, TempVTs, &TempOffsets, StartingOffset);
201 for (unsigned i = 0, e = TempVTs.size(); i != e; ++i) {
202 EVT VT = TempVTs[i];
203 uint64_t Off = TempOffsets[i];
204 // Split vectors into individual elements, except for v2f16, which
205 // we will pass as a single scalar.
206 if (VT.isVector()) {
207 unsigned NumElts = VT.getVectorNumElements();
208 EVT EltVT = VT.getVectorElementType();
209 // Vectors with an even number of f16 elements will be passed to
210 // us as an array of v2f16/v2bf16 elements. We must match this so we
211 // stay in sync with Ins/Outs.
212 if ((Is16bitsType(EltVT.getSimpleVT())) && NumElts % 2 == 0) {
213 switch (EltVT.getSimpleVT().SimpleTy) {
214 case MVT::f16:
215 EltVT = MVT::v2f16;
216 break;
217 case MVT::bf16:
218 EltVT = MVT::v2bf16;
219 break;
220 case MVT::i16:
221 EltVT = MVT::v2i16;
222 break;
223 default:
224 llvm_unreachable("Unexpected type");
225 }
226 NumElts /= 2;
227 } else if (EltVT.getSimpleVT() == MVT::i8 &&
228 (NumElts % 4 == 0 || NumElts == 3)) {
229 // v*i8 are formally lowered as v4i8
230 EltVT = MVT::v4i8;
231 NumElts = (NumElts + 3) / 4;
232 }
233 for (unsigned j = 0; j != NumElts; ++j) {
234 ValueVTs.push_back(EltVT);
235 if (Offsets)
236 Offsets->push_back(Off + j * EltVT.getStoreSize());
237 }
238 } else {
239 ValueVTs.push_back(VT);
240 if (Offsets)
241 Offsets->push_back(Off);
242 }
243 }
244}
245
246/// PromoteScalarIntegerPTX
247/// Used to make sure the arguments/returns are suitable for passing
248/// and promote them to a larger size if they're not.
249///
250/// The promoted type is placed in \p PromoteVT if the function returns true.
251static bool PromoteScalarIntegerPTX(const EVT &VT, MVT *PromotedVT) {
252 if (VT.isScalarInteger()) {
253 switch (PowerOf2Ceil(VT.getFixedSizeInBits())) {
254 default:
256 "Promotion is not suitable for scalars of size larger than 64-bits");
257 case 1:
258 *PromotedVT = MVT::i1;
259 break;
260 case 2:
261 case 4:
262 case 8:
263 *PromotedVT = MVT::i8;
264 break;
265 case 16:
266 *PromotedVT = MVT::i16;
267 break;
268 case 32:
269 *PromotedVT = MVT::i32;
270 break;
271 case 64:
272 *PromotedVT = MVT::i64;
273 break;
274 }
275 return EVT(*PromotedVT) != VT;
276 }
277 return false;
278}
279
280// Check whether we can merge loads/stores of some of the pieces of a
281// flattened function parameter or return value into a single vector
282// load/store.
283//
284// The flattened parameter is represented as a list of EVTs and
285// offsets, and the whole structure is aligned to ParamAlignment. This
286// function determines whether we can load/store pieces of the
287// parameter starting at index Idx using a single vectorized op of
288// size AccessSize. If so, it returns the number of param pieces
289// covered by the vector op. Otherwise, it returns 1.
291 unsigned Idx, uint32_t AccessSize, const SmallVectorImpl<EVT> &ValueVTs,
292 const SmallVectorImpl<uint64_t> &Offsets, Align ParamAlignment) {
293
294 // Can't vectorize if param alignment is not sufficient.
295 if (ParamAlignment < AccessSize)
296 return 1;
297 // Can't vectorize if offset is not aligned.
298 if (Offsets[Idx] & (AccessSize - 1))
299 return 1;
300
301 EVT EltVT = ValueVTs[Idx];
302 unsigned EltSize = EltVT.getStoreSize();
303
304 // Element is too large to vectorize.
305 if (EltSize >= AccessSize)
306 return 1;
307
308 unsigned NumElts = AccessSize / EltSize;
309 // Can't vectorize if AccessBytes if not a multiple of EltSize.
310 if (AccessSize != EltSize * NumElts)
311 return 1;
312
313 // We don't have enough elements to vectorize.
314 if (Idx + NumElts > ValueVTs.size())
315 return 1;
316
317 // PTX ISA can only deal with 2- and 4-element vector ops.
318 if (NumElts != 4 && NumElts != 2)
319 return 1;
320
321 for (unsigned j = Idx + 1; j < Idx + NumElts; ++j) {
322 // Types do not match.
323 if (ValueVTs[j] != EltVT)
324 return 1;
325
326 // Elements are not contiguous.
327 if (Offsets[j] - Offsets[j - 1] != EltSize)
328 return 1;
329 }
330 // OK. We can vectorize ValueVTs[i..i+NumElts)
331 return NumElts;
332}
333
334// Flags for tracking per-element vectorization state of loads/stores
335// of a flattened function parameter or return value.
337 PVF_INNER = 0x0, // Middle elements of a vector.
338 PVF_FIRST = 0x1, // First element of the vector.
339 PVF_LAST = 0x2, // Last element of the vector.
340 // Scalar is effectively a 1-element vector.
343
344// Computes whether and how we can vectorize the loads/stores of a
345// flattened function parameter or return value.
346//
347// The flattened parameter is represented as the list of ValueVTs and
348// Offsets, and is aligned to ParamAlignment bytes. We return a vector
349// of the same size as ValueVTs indicating how each piece should be
350// loaded/stored (i.e. as a scalar, or as part of a vector
351// load/store).
354 const SmallVectorImpl<uint64_t> &Offsets,
355 Align ParamAlignment, bool IsVAArg = false) {
356 // Set vector size to match ValueVTs and mark all elements as
357 // scalars by default.
359 VectorInfo.assign(ValueVTs.size(), PVF_SCALAR);
360
361 if (IsVAArg)
362 return VectorInfo;
363
364 // Check what we can vectorize using 128/64/32-bit accesses.
365 for (int I = 0, E = ValueVTs.size(); I != E; ++I) {
366 // Skip elements we've already processed.
367 assert(VectorInfo[I] == PVF_SCALAR && "Unexpected vector info state.");
368 for (unsigned AccessSize : {16, 8, 4, 2}) {
369 unsigned NumElts = CanMergeParamLoadStoresStartingAt(
370 I, AccessSize, ValueVTs, Offsets, ParamAlignment);
371 // Mark vectorized elements.
372 switch (NumElts) {
373 default:
374 llvm_unreachable("Unexpected return value");
375 case 1:
376 // Can't vectorize using this size, try next smaller size.
377 continue;
378 case 2:
379 assert(I + 1 < E && "Not enough elements.");
380 VectorInfo[I] = PVF_FIRST;
381 VectorInfo[I + 1] = PVF_LAST;
382 I += 1;
383 break;
384 case 4:
385 assert(I + 3 < E && "Not enough elements.");
386 VectorInfo[I] = PVF_FIRST;
387 VectorInfo[I + 1] = PVF_INNER;
388 VectorInfo[I + 2] = PVF_INNER;
389 VectorInfo[I + 3] = PVF_LAST;
390 I += 3;
391 break;
392 }
393 // Break out of the inner loop because we've already succeeded
394 // using largest possible AccessSize.
395 break;
396 }
397 }
398 return VectorInfo;
399}
400
401// NVPTXTargetLowering Constructor.
403 const NVPTXSubtarget &STI)
404 : TargetLowering(TM), nvTM(&TM), STI(STI) {
405 // always lower memset, memcpy, and memmove intrinsics to load/store
406 // instructions, rather
407 // then generating calls to memset, mempcy or memmove.
411
414
415 // Jump is Expensive. Don't create extra control flow for 'and', 'or'
416 // condition branches.
417 setJumpIsExpensive(true);
418
419 // Wide divides are _very_ slow. Try to reduce the width of the divide if
420 // possible.
421 addBypassSlowDiv(64, 32);
422
423 // By default, use the Source scheduling
424 if (sched4reg)
426 else
428
429 auto setFP16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
430 LegalizeAction NoF16Action) {
431 setOperationAction(Op, VT, STI.allowFP16Math() ? Action : NoF16Action);
432 };
433
434 auto setBF16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
435 LegalizeAction NoBF16Action) {
436 bool IsOpSupported = STI.hasBF16Math();
437 // Few instructions are available on sm_90 only
438 switch(Op) {
439 case ISD::FADD:
440 case ISD::FMUL:
441 case ISD::FSUB:
442 case ISD::SELECT:
443 case ISD::SELECT_CC:
444 case ISD::SETCC:
445 case ISD::FEXP2:
446 case ISD::FCEIL:
447 case ISD::FFLOOR:
448 case ISD::FNEARBYINT:
449 case ISD::FRINT:
450 case ISD::FROUNDEVEN:
451 case ISD::FTRUNC:
452 IsOpSupported = STI.getSmVersion() >= 90 && STI.getPTXVersion() >= 78;
453 break;
454 }
456 Op, VT, IsOpSupported ? Action : NoBF16Action);
457 };
458
459 auto setI16x2OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
460 LegalizeAction NoI16x2Action) {
461 bool IsOpSupported = false;
462 // instructions are available on sm_90 only
463 switch (Op) {
464 case ISD::ADD:
465 case ISD::SMAX:
466 case ISD::SMIN:
467 case ISD::UMIN:
468 case ISD::UMAX:
469 IsOpSupported = STI.getSmVersion() >= 90 && STI.getPTXVersion() >= 80;
470 break;
471 }
472 setOperationAction(Op, VT, IsOpSupported ? Action : NoI16x2Action);
473 };
474
475 addRegisterClass(MVT::i1, &NVPTX::Int1RegsRegClass);
476 addRegisterClass(MVT::i16, &NVPTX::Int16RegsRegClass);
477 addRegisterClass(MVT::v2i16, &NVPTX::Int32RegsRegClass);
478 addRegisterClass(MVT::v4i8, &NVPTX::Int32RegsRegClass);
479 addRegisterClass(MVT::i32, &NVPTX::Int32RegsRegClass);
480 addRegisterClass(MVT::i64, &NVPTX::Int64RegsRegClass);
481 addRegisterClass(MVT::f32, &NVPTX::Float32RegsRegClass);
482 addRegisterClass(MVT::f64, &NVPTX::Float64RegsRegClass);
483 addRegisterClass(MVT::f16, &NVPTX::Int16RegsRegClass);
484 addRegisterClass(MVT::v2f16, &NVPTX::Int32RegsRegClass);
485 addRegisterClass(MVT::bf16, &NVPTX::Int16RegsRegClass);
486 addRegisterClass(MVT::v2bf16, &NVPTX::Int32RegsRegClass);
487
488 // Conversion to/from FP16/FP16x2 is always legal.
493
495 if (STI.getSmVersion() >= 30 && STI.getPTXVersion() > 31)
497
498 setFP16OperationAction(ISD::SETCC, MVT::f16, Legal, Promote);
499 setFP16OperationAction(ISD::SETCC, MVT::v2f16, Legal, Expand);
500
501 // Conversion to/from BFP16/BFP16x2 is always legal.
506
507 setBF16OperationAction(ISD::SETCC, MVT::v2bf16, Legal, Expand);
508 setBF16OperationAction(ISD::SETCC, MVT::bf16, Legal, Promote);
509 if (getOperationAction(ISD::SETCC, MVT::bf16) == Promote)
510 AddPromotedToType(ISD::SETCC, MVT::bf16, MVT::f32);
511
512 // Conversion to/from i16/i16x2 is always legal.
517
522 // Only logical ops can be done on v4i8 directly, others must be done
523 // elementwise.
540 MVT::v4i8, Expand);
541
542 // Operations not directly supported by NVPTX.
543 for (MVT VT : {MVT::bf16, MVT::f16, MVT::v2bf16, MVT::v2f16, MVT::f32,
544 MVT::f64, MVT::i1, MVT::i8, MVT::i16, MVT::v2i16, MVT::v4i8,
545 MVT::i32, MVT::i64}) {
548 }
549
550 // Some SIGN_EXTEND_INREG can be done using cvt instruction.
551 // For others we will expand to a SHL/SRA pair.
558
565
568
569 // TODO: we may consider expanding ROTL/ROTR on older GPUs. Currently on GPUs
570 // that don't have h/w rotation we lower them to multi-instruction assembly.
571 // See ROT*_sw in NVPTXIntrInfo.td
576
578 setOperationAction(ISD::ROTL, MVT::v2i16, Expand);
580 setOperationAction(ISD::ROTR, MVT::v2i16, Expand);
584
585 // Indirect branch is not supported.
586 // This also disables Jump Table creation.
589
592
593 // We want to legalize constant related memmove and memcopy
594 // intrinsics.
596
597 // Turn FP extload into load/fpextend
598 setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::f16, Expand);
599 setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f16, Expand);
600 setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::bf16, Expand);
601 setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::bf16, Expand);
602 setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f32, Expand);
603 setLoadExtAction(ISD::EXTLOAD, MVT::v2f32, MVT::v2f16, Expand);
604 setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2f16, Expand);
605 setLoadExtAction(ISD::EXTLOAD, MVT::v2f32, MVT::v2bf16, Expand);
606 setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2bf16, Expand);
607 setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2f32, Expand);
608 setLoadExtAction(ISD::EXTLOAD, MVT::v4f32, MVT::v4f16, Expand);
609 setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4f16, Expand);
610 setLoadExtAction(ISD::EXTLOAD, MVT::v4f32, MVT::v4bf16, Expand);
611 setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4bf16, Expand);
612 setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4f32, Expand);
613 setLoadExtAction(ISD::EXTLOAD, MVT::v8f32, MVT::v8f16, Expand);
614 setLoadExtAction(ISD::EXTLOAD, MVT::v8f64, MVT::v8f16, Expand);
615 setLoadExtAction(ISD::EXTLOAD, MVT::v8f32, MVT::v8bf16, Expand);
616 setLoadExtAction(ISD::EXTLOAD, MVT::v8f64, MVT::v8bf16, Expand);
617 // Turn FP truncstore into trunc + store.
618 // FIXME: vector types should also be expanded
619 setTruncStoreAction(MVT::f32, MVT::f16, Expand);
620 setTruncStoreAction(MVT::f64, MVT::f16, Expand);
621 setTruncStoreAction(MVT::f32, MVT::bf16, Expand);
622 setTruncStoreAction(MVT::f64, MVT::bf16, Expand);
623 setTruncStoreAction(MVT::f64, MVT::f32, Expand);
624
625 // PTX does not support load / store predicate registers
628
629 for (MVT VT : MVT::integer_valuetypes()) {
632 setTruncStoreAction(VT, MVT::i1, Expand);
633 }
634
635 // expand extload of vector of integers.
637 MVT::v2i8, Expand);
638 setTruncStoreAction(MVT::v2i16, MVT::v2i8, Expand);
639
640 // This is legal in NVPTX
645
648
649 // TRAP can be lowered to PTX trap
650 setOperationAction(ISD::TRAP, MVT::Other, Legal);
651
652 // Register custom handling for vector loads/stores
654 if (IsPTXVectorType(VT)) {
658 }
659 }
660
661 // Support varargs.
666
667 // Custom handling for i8 intrinsics
669
670 for (const auto& Ty : {MVT::i16, MVT::i32, MVT::i64}) {
676
679 }
680
681 setI16x2OperationAction(ISD::ABS, MVT::v2i16, Legal, Custom);
682 setI16x2OperationAction(ISD::SMIN, MVT::v2i16, Legal, Custom);
683 setI16x2OperationAction(ISD::SMAX, MVT::v2i16, Legal, Custom);
684 setI16x2OperationAction(ISD::UMIN, MVT::v2i16, Legal, Custom);
685 setI16x2OperationAction(ISD::UMAX, MVT::v2i16, Legal, Custom);
686 setI16x2OperationAction(ISD::CTPOP, MVT::v2i16, Legal, Expand);
687 setI16x2OperationAction(ISD::CTLZ, MVT::v2i16, Legal, Expand);
688
689 setI16x2OperationAction(ISD::ADD, MVT::v2i16, Legal, Custom);
690 setI16x2OperationAction(ISD::SUB, MVT::v2i16, Legal, Custom);
691 setI16x2OperationAction(ISD::MUL, MVT::v2i16, Legal, Custom);
692 setI16x2OperationAction(ISD::SHL, MVT::v2i16, Legal, Custom);
693 setI16x2OperationAction(ISD::SREM, MVT::v2i16, Legal, Custom);
694 setI16x2OperationAction(ISD::UREM, MVT::v2i16, Legal, Custom);
695
696 // Other arithmetic and logic ops are unsupported.
700 MVT::v2i16, Expand);
701
706 if (STI.getPTXVersion() >= 43) {
711 }
712
714 setOperationAction(ISD::CTTZ, MVT::v2i16, Expand);
717
718 // PTX does not directly support SELP of i1, so promote to i32 first
720
721 // PTX cannot multiply two i64s in a single instruction.
724
725 // We have some custom DAG combine patterns for these nodes
728 ISD::VSELECT});
729
730 // setcc for f16x2 and bf16x2 needs special handling to prevent
731 // legalizer's attempt to scalarize it due to v2i1 not being legal.
732 if (STI.allowFP16Math() || STI.hasBF16Math())
734
735 // Promote fp16 arithmetic if fp16 hardware isn't available or the
736 // user passed --nvptx-no-fp16-math. The flag is useful because,
737 // although sm_53+ GPUs have some sort of FP16 support in
738 // hardware, only sm_53 and sm_60 have full implementation. Others
739 // only have token amount of hardware and are likely to run faster
740 // by using fp32 units instead.
741 for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB, ISD::FMA}) {
742 setFP16OperationAction(Op, MVT::f16, Legal, Promote);
743 setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
744 setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
745 // bf16 must be promoted to f32.
746 setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
747 if (getOperationAction(Op, MVT::bf16) == Promote)
748 AddPromotedToType(Op, MVT::bf16, MVT::f32);
749 }
750
751 // f16/f16x2 neg was introduced in PTX 60, SM_53.
752 const bool IsFP16FP16x2NegAvailable = STI.getSmVersion() >= 53 &&
753 STI.getPTXVersion() >= 60 &&
754 STI.allowFP16Math();
755 for (const auto &VT : {MVT::f16, MVT::v2f16})
757 IsFP16FP16x2NegAvailable ? Legal : Expand);
758
759 setBF16OperationAction(ISD::FNEG, MVT::bf16, Legal, Expand);
760 setBF16OperationAction(ISD::FNEG, MVT::v2bf16, Legal, Expand);
761 // (would be) Library functions.
762
763 // These map to conversion instructions for scalar FP types.
764 for (const auto &Op : {ISD::FCEIL, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT,
766 setOperationAction(Op, MVT::f16, Legal);
767 setOperationAction(Op, MVT::f32, Legal);
768 setOperationAction(Op, MVT::f64, Legal);
769 setOperationAction(Op, MVT::v2f16, Expand);
770 setOperationAction(Op, MVT::v2bf16, Expand);
771 setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
772 if (getOperationAction(Op, MVT::bf16) == Promote)
773 AddPromotedToType(Op, MVT::bf16, MVT::f32);
774 }
775
776 if (STI.getSmVersion() < 80 || STI.getPTXVersion() < 71) {
778 }
779 if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
780 for (MVT VT : {MVT::bf16, MVT::f32, MVT::f64}) {
783 }
784 }
785
786 // sm_80 only has conversions between f32 and bf16. Custom lower all other
787 // bf16 conversions.
788 if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
789 for (MVT VT : {MVT::i1, MVT::i16, MVT::i32, MVT::i64}) {
792 VT, Custom);
793 }
796 MVT::bf16, Custom);
797 }
798
805 AddPromotedToType(ISD::FROUND, MVT::bf16, MVT::f32);
806
807 // 'Expand' implements FCOPYSIGN without calling an external library.
814
815 // These map to corresponding instructions for f32/f64. f16 must be
816 // promoted to f32. v2f16 is expanded to f16, which is then promoted
817 // to f32.
818 for (const auto &Op :
820 setOperationAction(Op, MVT::f16, Promote);
821 setOperationAction(Op, MVT::f32, Legal);
822 setOperationAction(Op, MVT::f64, Legal);
823 setOperationAction(Op, MVT::v2f16, Expand);
824 setOperationAction(Op, MVT::v2bf16, Expand);
825 setOperationAction(Op, MVT::bf16, Promote);
826 AddPromotedToType(Op, MVT::bf16, MVT::f32);
827 }
828 for (const auto &Op : {ISD::FABS}) {
829 setOperationAction(Op, MVT::f16, Promote);
830 setOperationAction(Op, MVT::f32, Legal);
831 setOperationAction(Op, MVT::f64, Legal);
832 setOperationAction(Op, MVT::v2f16, Expand);
833 setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
834 setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
835 if (getOperationAction(Op, MVT::bf16) == Promote)
836 AddPromotedToType(Op, MVT::bf16, MVT::f32);
837 }
838
839 // max.f16, max.f16x2 and max.NaN are supported on sm_80+.
840 auto GetMinMaxAction = [&](LegalizeAction NotSm80Action) {
841 bool IsAtLeastSm80 = STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70;
842 return IsAtLeastSm80 ? Legal : NotSm80Action;
843 };
844 for (const auto &Op : {ISD::FMINNUM, ISD::FMAXNUM}) {
845 setFP16OperationAction(Op, MVT::f16, GetMinMaxAction(Promote), Promote);
846 setOperationAction(Op, MVT::f32, Legal);
847 setOperationAction(Op, MVT::f64, Legal);
848 setFP16OperationAction(Op, MVT::v2f16, GetMinMaxAction(Expand), Expand);
849 setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
850 setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
851 if (getOperationAction(Op, MVT::bf16) == Promote)
852 AddPromotedToType(Op, MVT::bf16, MVT::f32);
853 }
854 for (const auto &Op : {ISD::FMINIMUM, ISD::FMAXIMUM}) {
855 setFP16OperationAction(Op, MVT::f16, GetMinMaxAction(Expand), Expand);
856 setFP16OperationAction(Op, MVT::bf16, Legal, Expand);
857 setOperationAction(Op, MVT::f32, GetMinMaxAction(Expand));
858 setFP16OperationAction(Op, MVT::v2f16, GetMinMaxAction(Expand), Expand);
859 setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
860 }
861
862 // No FEXP2, FLOG2. The PTX ex2 and log2 functions are always approximate.
863 // No FPOW or FREM in PTX.
864
865 // Now deduce the information based on the above mentioned
866 // actions
868
871}
872
873const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
874
875#define MAKE_CASE(V) \
876 case V: \
877 return #V;
878
879 switch ((NVPTXISD::NodeType)Opcode) {
881 break;
882
1026
1117
1129
1141
1153
1165
1177
1189
1201
1213
1225
1237
1249
1261
1273
1285
1297 }
1298 return nullptr;
1299
1300#undef MAKE_CASE
1301}
1302
1305 if (!VT.isScalableVector() && VT.getVectorNumElements() != 1 &&
1306 VT.getScalarType() == MVT::i1)
1307 return TypeSplitVector;
1308 if (Isv2x16VT(VT))
1309 return TypeLegal;
1311}
1312
1314 int Enabled, int &ExtraSteps,
1315 bool &UseOneConst,
1316 bool Reciprocal) const {
1319 return SDValue();
1320
1321 if (ExtraSteps == ReciprocalEstimate::Unspecified)
1322 ExtraSteps = 0;
1323
1324 SDLoc DL(Operand);
1325 EVT VT = Operand.getValueType();
1326 bool Ftz = useF32FTZ(DAG.getMachineFunction());
1327
1328 auto MakeIntrinsicCall = [&](Intrinsic::ID IID) {
1329 return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
1330 DAG.getConstant(IID, DL, MVT::i32), Operand);
1331 };
1332
1333 // The sqrt and rsqrt refinement processes assume we always start out with an
1334 // approximation of the rsqrt. Therefore, if we're going to do any refinement
1335 // (i.e. ExtraSteps > 0), we must return an rsqrt. But if we're *not* doing
1336 // any refinement, we must return a regular sqrt.
1337 if (Reciprocal || ExtraSteps > 0) {
1338 if (VT == MVT::f32)
1339 return MakeIntrinsicCall(Ftz ? Intrinsic::nvvm_rsqrt_approx_ftz_f
1340 : Intrinsic::nvvm_rsqrt_approx_f);
1341 else if (VT == MVT::f64)
1342 return MakeIntrinsicCall(Intrinsic::nvvm_rsqrt_approx_d);
1343 else
1344 return SDValue();
1345 } else {
1346 if (VT == MVT::f32)
1347 return MakeIntrinsicCall(Ftz ? Intrinsic::nvvm_sqrt_approx_ftz_f
1348 : Intrinsic::nvvm_sqrt_approx_f);
1349 else {
1350 // There's no sqrt.approx.f64 instruction, so we emit
1351 // reciprocal(rsqrt(x)). This is faster than
1352 // select(x == 0, 0, x * rsqrt(x)). (In fact, it's faster than plain
1353 // x * rsqrt(x).)
1354 return DAG.getNode(
1356 DAG.getConstant(Intrinsic::nvvm_rcp_approx_ftz_d, DL, MVT::i32),
1357 MakeIntrinsicCall(Intrinsic::nvvm_rsqrt_approx_d));
1358 }
1359 }
1360}
1361
1362SDValue
1364 SDLoc dl(Op);
1365 const GlobalAddressSDNode *GAN = cast<GlobalAddressSDNode>(Op);
1366 auto PtrVT = getPointerTy(DAG.getDataLayout(), GAN->getAddressSpace());
1367 Op = DAG.getTargetGlobalAddress(GAN->getGlobal(), dl, PtrVT);
1368 return DAG.getNode(NVPTXISD::Wrapper, dl, PtrVT, Op);
1369}
1370
1371static bool IsTypePassedAsArray(const Type *Ty) {
1372 return Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128) ||
1373 Ty->isHalfTy() || Ty->isBFloatTy();
1374}
1375
1377 const DataLayout &DL, Type *retTy, const ArgListTy &Args,
1378 const SmallVectorImpl<ISD::OutputArg> &Outs, MaybeAlign retAlignment,
1379 std::optional<std::pair<unsigned, const APInt &>> VAInfo,
1380 const CallBase &CB, unsigned UniqueCallSite) const {
1381 auto PtrVT = getPointerTy(DL);
1382
1383 bool isABI = (STI.getSmVersion() >= 20);
1384 assert(isABI && "Non-ABI compilation is not supported");
1385 if (!isABI)
1386 return "";
1387
1388 std::string Prototype;
1389 raw_string_ostream O(Prototype);
1390 O << "prototype_" << UniqueCallSite << " : .callprototype ";
1391
1392 if (retTy->getTypeID() == Type::VoidTyID) {
1393 O << "()";
1394 } else {
1395 O << "(";
1396 if ((retTy->isFloatingPointTy() || retTy->isIntegerTy()) &&
1397 !IsTypePassedAsArray(retTy)) {
1398 unsigned size = 0;
1399 if (auto *ITy = dyn_cast<IntegerType>(retTy)) {
1400 size = ITy->getBitWidth();
1401 } else {
1402 assert(retTy->isFloatingPointTy() &&
1403 "Floating point type expected here");
1404 size = retTy->getPrimitiveSizeInBits();
1405 }
1406 // PTX ABI requires all scalar return values to be at least 32
1407 // bits in size. fp16 normally uses .b16 as its storage type in
1408 // PTX, so its size must be adjusted here, too.
1410
1411 O << ".param .b" << size << " _";
1412 } else if (isa<PointerType>(retTy)) {
1413 O << ".param .b" << PtrVT.getSizeInBits() << " _";
1414 } else if (IsTypePassedAsArray(retTy)) {
1415 O << ".param .align " << (retAlignment ? retAlignment->value() : 0)
1416 << " .b8 _[" << DL.getTypeAllocSize(retTy) << "]";
1417 } else {
1418 llvm_unreachable("Unknown return type");
1419 }
1420 O << ") ";
1421 }
1422 O << "_ (";
1423
1424 bool first = true;
1425
1426 const Function *F = CB.getFunction();
1427 unsigned NumArgs = VAInfo ? VAInfo->first : Args.size();
1428 for (unsigned i = 0, OIdx = 0; i != NumArgs; ++i, ++OIdx) {
1429 Type *Ty = Args[i].Ty;
1430 if (!first) {
1431 O << ", ";
1432 }
1433 first = false;
1434
1435 if (!Outs[OIdx].Flags.isByVal()) {
1436 if (IsTypePassedAsArray(Ty)) {
1437 unsigned ParamAlign = 0;
1438 const CallInst *CallI = cast<CallInst>(&CB);
1439 // +1 because index 0 is reserved for return type alignment
1440 if (!getAlign(*CallI, i + 1, ParamAlign))
1441 ParamAlign = getFunctionParamOptimizedAlign(F, Ty, DL).value();
1442 O << ".param .align " << ParamAlign << " .b8 ";
1443 O << "_";
1444 O << "[" << DL.getTypeAllocSize(Ty) << "]";
1445 // update the index for Outs
1446 SmallVector<EVT, 16> vtparts;
1447 ComputeValueVTs(*this, DL, Ty, vtparts);
1448 if (unsigned len = vtparts.size())
1449 OIdx += len - 1;
1450 continue;
1451 }
1452 // i8 types in IR will be i16 types in SDAG
1453 assert((getValueType(DL, Ty) == Outs[OIdx].VT ||
1454 (getValueType(DL, Ty) == MVT::i8 && Outs[OIdx].VT == MVT::i16)) &&
1455 "type mismatch between callee prototype and arguments");
1456 // scalar type
1457 unsigned sz = 0;
1458 if (isa<IntegerType>(Ty)) {
1459 sz = cast<IntegerType>(Ty)->getBitWidth();
1461 } else if (isa<PointerType>(Ty)) {
1462 sz = PtrVT.getSizeInBits();
1463 } else {
1464 sz = Ty->getPrimitiveSizeInBits();
1465 }
1466 O << ".param .b" << sz << " ";
1467 O << "_";
1468 continue;
1469 }
1470
1471 Type *ETy = Args[i].IndirectType;
1472 Align InitialAlign = Outs[OIdx].Flags.getNonZeroByValAlign();
1473 Align ParamByValAlign =
1474 getFunctionByValParamAlign(F, ETy, InitialAlign, DL);
1475
1476 O << ".param .align " << ParamByValAlign.value() << " .b8 ";
1477 O << "_";
1478 O << "[" << Outs[OIdx].Flags.getByValSize() << "]";
1479 }
1480
1481 if (VAInfo)
1482 O << (first ? "" : ",") << " .param .align " << VAInfo->second
1483 << " .b8 _[]\n";
1484 O << ")";
1486 O << " .noreturn";
1487 O << ";";
1488
1489 return Prototype;
1490}
1491
1492Align NVPTXTargetLowering::getArgumentAlignment(const CallBase *CB, Type *Ty,
1493 unsigned Idx,
1494 const DataLayout &DL) const {
1495 if (!CB) {
1496 // CallSite is zero, fallback to ABI type alignment
1497 return DL.getABITypeAlign(Ty);
1498 }
1499
1500 unsigned Alignment = 0;
1501 const Function *DirectCallee = CB->getCalledFunction();
1502
1503 if (!DirectCallee) {
1504 // We don't have a direct function symbol, but that may be because of
1505 // constant cast instructions in the call.
1506
1507 // With bitcast'd call targets, the instruction will be the call
1508 if (const auto *CI = dyn_cast<CallInst>(CB)) {
1509 // Check if we have call alignment metadata
1510 if (getAlign(*CI, Idx, Alignment))
1511 return Align(Alignment);
1512 }
1513 DirectCallee = getMaybeBitcastedCallee(CB);
1514 }
1515
1516 // Check for function alignment information if we found that the
1517 // ultimate target is a Function
1518 if (DirectCallee) {
1519 if (getAlign(*DirectCallee, Idx, Alignment))
1520 return Align(Alignment);
1521 // If alignment information is not available, fall back to the
1522 // default function param optimized type alignment
1523 return getFunctionParamOptimizedAlign(DirectCallee, Ty, DL);
1524 }
1525
1526 // Call is indirect, fall back to the ABI type alignment
1527 return DL.getABITypeAlign(Ty);
1528}
1529
1530static bool adjustElementType(EVT &ElementType) {
1531 switch (ElementType.getSimpleVT().SimpleTy) {
1532 default:
1533 return false;
1534 case MVT::f16:
1535 case MVT::bf16:
1536 ElementType = MVT::i16;
1537 return true;
1538 case MVT::f32:
1539 case MVT::v2f16:
1540 case MVT::v2bf16:
1541 ElementType = MVT::i32;
1542 return true;
1543 case MVT::f64:
1544 ElementType = MVT::i64;
1545 return true;
1546 }
1547}
1548
1549// Use byte-store when the param address of the argument value is unaligned.
1550// This may happen when the return value is a field of a packed structure.
1551//
1552// This is called in LowerCall() when passing the param values.
1554 uint64_t Offset, EVT ElementType,
1555 SDValue StVal, SDValue &InGlue,
1556 unsigned ArgID, const SDLoc &dl) {
1557 // Bit logic only works on integer types
1558 if (adjustElementType(ElementType))
1559 StVal = DAG.getNode(ISD::BITCAST, dl, ElementType, StVal);
1560
1561 // Store each byte
1562 SDVTList StoreVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1563 for (unsigned i = 0, n = ElementType.getSizeInBits() / 8; i < n; i++) {
1564 // Shift the byte to the last byte position
1565 SDValue ShiftVal = DAG.getNode(ISD::SRL, dl, ElementType, StVal,
1566 DAG.getConstant(i * 8, dl, MVT::i32));
1567 SDValue StoreOperands[] = {Chain, DAG.getConstant(ArgID, dl, MVT::i32),
1568 DAG.getConstant(Offset + i, dl, MVT::i32),
1569 ShiftVal, InGlue};
1570 // Trunc store only the last byte by using
1571 // st.param.b8
1572 // The register type can be larger than b8.
1573 Chain = DAG.getMemIntrinsicNode(
1574 NVPTXISD::StoreParam, dl, StoreVTs, StoreOperands, MVT::i8,
1576 InGlue = Chain.getValue(1);
1577 }
1578 return Chain;
1579}
1580
1581// Use byte-load when the param adress of the returned value is unaligned.
1582// This may happen when the returned value is a field of a packed structure.
1583static SDValue
1585 EVT ElementType, SDValue &InGlue,
1586 SmallVectorImpl<SDValue> &TempProxyRegOps,
1587 const SDLoc &dl) {
1588 // Bit logic only works on integer types
1589 EVT MergedType = ElementType;
1590 adjustElementType(MergedType);
1591
1592 // Load each byte and construct the whole value. Initial value to 0
1593 SDValue RetVal = DAG.getConstant(0, dl, MergedType);
1594 // LoadParamMemI8 loads into i16 register only
1595 SDVTList LoadVTs = DAG.getVTList(MVT::i16, MVT::Other, MVT::Glue);
1596 for (unsigned i = 0, n = ElementType.getSizeInBits() / 8; i < n; i++) {
1597 SDValue LoadOperands[] = {Chain, DAG.getConstant(1, dl, MVT::i32),
1598 DAG.getConstant(Offset + i, dl, MVT::i32),
1599 InGlue};
1600 // This will be selected to LoadParamMemI8
1601 SDValue LdVal =
1602 DAG.getMemIntrinsicNode(NVPTXISD::LoadParam, dl, LoadVTs, LoadOperands,
1603 MVT::i8, MachinePointerInfo(), Align(1));
1604 SDValue TmpLdVal = LdVal.getValue(0);
1605 Chain = LdVal.getValue(1);
1606 InGlue = LdVal.getValue(2);
1607
1608 TmpLdVal = DAG.getNode(NVPTXISD::ProxyReg, dl,
1609 TmpLdVal.getSimpleValueType(), TmpLdVal);
1610 TempProxyRegOps.push_back(TmpLdVal);
1611
1612 SDValue CMask = DAG.getConstant(255, dl, MergedType);
1613 SDValue CShift = DAG.getConstant(i * 8, dl, MVT::i32);
1614 // Need to extend the i16 register to the whole width.
1615 TmpLdVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MergedType, TmpLdVal);
1616 // Mask off the high bits. Leave only the lower 8bits.
1617 // Do this because we are using loadparam.b8.
1618 TmpLdVal = DAG.getNode(ISD::AND, dl, MergedType, TmpLdVal, CMask);
1619 // Shift and merge
1620 TmpLdVal = DAG.getNode(ISD::SHL, dl, MergedType, TmpLdVal, CShift);
1621 RetVal = DAG.getNode(ISD::OR, dl, MergedType, RetVal, TmpLdVal);
1622 }
1623 if (ElementType != MergedType)
1624 RetVal = DAG.getNode(ISD::BITCAST, dl, ElementType, RetVal);
1625
1626 return RetVal;
1627}
1628
1630 SmallVectorImpl<SDValue> &InVals) const {
1631
1632 if (CLI.IsVarArg && (STI.getPTXVersion() < 60 || STI.getSmVersion() < 30))
1634 "Support for variadic functions (unsized array parameter) introduced "
1635 "in PTX ISA version 6.0 and requires target sm_30.");
1636
1637 SelectionDAG &DAG = CLI.DAG;
1638 SDLoc dl = CLI.DL;
1640 SmallVectorImpl<SDValue> &OutVals = CLI.OutVals;
1642 SDValue Chain = CLI.Chain;
1643 SDValue Callee = CLI.Callee;
1644 bool &isTailCall = CLI.IsTailCall;
1645 ArgListTy &Args = CLI.getArgs();
1646 Type *RetTy = CLI.RetTy;
1647 const CallBase *CB = CLI.CB;
1648 const DataLayout &DL = DAG.getDataLayout();
1649
1650 bool isABI = (STI.getSmVersion() >= 20);
1651 assert(isABI && "Non-ABI compilation is not supported");
1652 if (!isABI)
1653 return Chain;
1654
1655 // Variadic arguments.
1656 //
1657 // Normally, for each argument, we declare a param scalar or a param
1658 // byte array in the .param space, and store the argument value to that
1659 // param scalar or array starting at offset 0.
1660 //
1661 // In the case of the first variadic argument, we declare a vararg byte array
1662 // with size 0. The exact size of this array isn't known at this point, so
1663 // it'll be patched later. All the variadic arguments will be stored to this
1664 // array at a certain offset (which gets tracked by 'VAOffset'). The offset is
1665 // initially set to 0, so it can be used for non-variadic arguments (which use
1666 // 0 offset) to simplify the code.
1667 //
1668 // After all vararg is processed, 'VAOffset' holds the size of the
1669 // vararg byte array.
1670
1671 SDValue VADeclareParam; // vararg byte array
1672 unsigned FirstVAArg = CLI.NumFixedArgs; // position of the first variadic
1673 unsigned VAOffset = 0; // current offset in the param array
1674
1675 unsigned UniqueCallSite = GlobalUniqueCallSite.fetch_add(1);
1676 SDValue TempChain = Chain;
1677 Chain = DAG.getCALLSEQ_START(Chain, UniqueCallSite, 0, dl);
1678 SDValue InGlue = Chain.getValue(1);
1679
1680 unsigned ParamCount = 0;
1681 // Args.size() and Outs.size() need not match.
1682 // Outs.size() will be larger
1683 // * if there is an aggregate argument with multiple fields (each field
1684 // showing up separately in Outs)
1685 // * if there is a vector argument with more than typical vector-length
1686 // elements (generally if more than 4) where each vector element is
1687 // individually present in Outs.
1688 // So a different index should be used for indexing into Outs/OutVals.
1689 // See similar issue in LowerFormalArguments.
1690 unsigned OIdx = 0;
1691 // Declare the .params or .reg need to pass values
1692 // to the function
1693 for (unsigned i = 0, e = Args.size(); i != e; ++i, ++OIdx) {
1694 EVT VT = Outs[OIdx].VT;
1695 Type *Ty = Args[i].Ty;
1696 bool IsVAArg = (i >= CLI.NumFixedArgs);
1697 bool IsByVal = Outs[OIdx].Flags.isByVal();
1698
1701
1702 assert((!IsByVal || Args[i].IndirectType) &&
1703 "byval arg must have indirect type");
1704 Type *ETy = (IsByVal ? Args[i].IndirectType : Ty);
1705 ComputePTXValueVTs(*this, DL, ETy, VTs, &Offsets, IsByVal ? 0 : VAOffset);
1706
1707 Align ArgAlign;
1708 if (IsByVal) {
1709 // The ByValAlign in the Outs[OIdx].Flags is always set at this point,
1710 // so we don't need to worry whether it's naturally aligned or not.
1711 // See TargetLowering::LowerCallTo().
1712 Align InitialAlign = Outs[OIdx].Flags.getNonZeroByValAlign();
1713 ArgAlign = getFunctionByValParamAlign(CB->getCalledFunction(), ETy,
1714 InitialAlign, DL);
1715 if (IsVAArg)
1716 VAOffset = alignTo(VAOffset, ArgAlign);
1717 } else {
1718 ArgAlign = getArgumentAlignment(CB, Ty, ParamCount + 1, DL);
1719 }
1720
1721 unsigned TypeSize =
1722 (IsByVal ? Outs[OIdx].Flags.getByValSize() : DL.getTypeAllocSize(Ty));
1723 SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1724
1725 bool NeedAlign; // Does argument declaration specify alignment?
1726 bool PassAsArray = IsByVal || IsTypePassedAsArray(Ty);
1727 if (IsVAArg) {
1728 if (ParamCount == FirstVAArg) {
1729 SDValue DeclareParamOps[] = {
1730 Chain, DAG.getConstant(STI.getMaxRequiredAlignment(), dl, MVT::i32),
1731 DAG.getConstant(ParamCount, dl, MVT::i32),
1732 DAG.getConstant(1, dl, MVT::i32), InGlue};
1733 VADeclareParam = Chain = DAG.getNode(NVPTXISD::DeclareParam, dl,
1734 DeclareParamVTs, DeclareParamOps);
1735 }
1736 NeedAlign = PassAsArray;
1737 } else if (PassAsArray) {
1738 // declare .param .align <align> .b8 .param<n>[<size>];
1739 SDValue DeclareParamOps[] = {
1740 Chain, DAG.getConstant(ArgAlign.value(), dl, MVT::i32),
1741 DAG.getConstant(ParamCount, dl, MVT::i32),
1742 DAG.getConstant(TypeSize, dl, MVT::i32), InGlue};
1743 Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
1744 DeclareParamOps);
1745 NeedAlign = true;
1746 } else {
1747 // declare .param .b<size> .param<n>;
1748 if (VT.isInteger() || VT.isFloatingPoint()) {
1749 // PTX ABI requires integral types to be at least 32 bits in
1750 // size. FP16 is loaded/stored using i16, so it's handled
1751 // here as well.
1753 }
1754 SDValue DeclareScalarParamOps[] = {
1755 Chain, DAG.getConstant(ParamCount, dl, MVT::i32),
1756 DAG.getConstant(TypeSize * 8, dl, MVT::i32),
1757 DAG.getConstant(0, dl, MVT::i32), InGlue};
1758 Chain = DAG.getNode(NVPTXISD::DeclareScalarParam, dl, DeclareParamVTs,
1759 DeclareScalarParamOps);
1760 NeedAlign = false;
1761 }
1762 InGlue = Chain.getValue(1);
1763
1764 // PTX Interoperability Guide 3.3(A): [Integer] Values shorter
1765 // than 32-bits are sign extended or zero extended, depending on
1766 // whether they are signed or unsigned types. This case applies
1767 // only to scalar parameters and not to aggregate values.
1768 bool ExtendIntegerParam =
1769 Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Ty) < 32;
1770
1771 auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign, IsVAArg);
1772 SmallVector<SDValue, 6> StoreOperands;
1773 for (unsigned j = 0, je = VTs.size(); j != je; ++j) {
1774 EVT EltVT = VTs[j];
1775 int CurOffset = Offsets[j];
1776 MaybeAlign PartAlign;
1777 if (NeedAlign)
1778 PartAlign = commonAlignment(ArgAlign, CurOffset);
1779
1780 SDValue StVal = OutVals[OIdx];
1781
1782 MVT PromotedVT;
1783 if (PromoteScalarIntegerPTX(EltVT, &PromotedVT)) {
1784 EltVT = EVT(PromotedVT);
1785 }
1786 if (PromoteScalarIntegerPTX(StVal.getValueType(), &PromotedVT)) {
1788 Outs[OIdx].Flags.isSExt() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
1789 StVal = DAG.getNode(Ext, dl, PromotedVT, StVal);
1790 }
1791
1792 if (IsByVal) {
1793 auto PtrVT = getPointerTy(DL);
1794 SDValue srcAddr = DAG.getNode(ISD::ADD, dl, PtrVT, StVal,
1795 DAG.getConstant(CurOffset, dl, PtrVT));
1796 StVal = DAG.getLoad(EltVT, dl, TempChain, srcAddr, MachinePointerInfo(),
1797 PartAlign);
1798 } else if (ExtendIntegerParam) {
1799 assert(VTs.size() == 1 && "Scalar can't have multiple parts.");
1800 // zext/sext to i32
1801 StVal = DAG.getNode(Outs[OIdx].Flags.isSExt() ? ISD::SIGN_EXTEND
1803 dl, MVT::i32, StVal);
1804 }
1805
1806 if (!ExtendIntegerParam && EltVT.getSizeInBits() < 16) {
1807 // Use 16-bit registers for small stores as it's the
1808 // smallest general purpose register size supported by NVPTX.
1809 StVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, StVal);
1810 }
1811
1812 // If we have a PVF_SCALAR entry, it may not be sufficiently aligned for a
1813 // scalar store. In such cases, fall back to byte stores.
1814 if (VectorInfo[j] == PVF_SCALAR && !IsVAArg && PartAlign.has_value() &&
1815 PartAlign.value() <
1816 DL.getABITypeAlign(EltVT.getTypeForEVT(*DAG.getContext()))) {
1817 assert(StoreOperands.empty() && "Unfinished preceeding store.");
1819 DAG, Chain, IsByVal ? CurOffset + VAOffset : CurOffset, EltVT,
1820 StVal, InGlue, ParamCount, dl);
1821
1822 // LowerUnalignedStoreParam took care of inserting the necessary nodes
1823 // into the SDAG, so just move on to the next element.
1824 if (!IsByVal)
1825 ++OIdx;
1826 continue;
1827 }
1828
1829 // New store.
1830 if (VectorInfo[j] & PVF_FIRST) {
1831 assert(StoreOperands.empty() && "Unfinished preceding store.");
1832 StoreOperands.push_back(Chain);
1833 StoreOperands.push_back(
1834 DAG.getConstant(IsVAArg ? FirstVAArg : ParamCount, dl, MVT::i32));
1835
1836 StoreOperands.push_back(DAG.getConstant(
1837 IsByVal ? CurOffset + VAOffset : (IsVAArg ? VAOffset : CurOffset),
1838 dl, MVT::i32));
1839 }
1840
1841 // Record the value to store.
1842 StoreOperands.push_back(StVal);
1843
1844 if (VectorInfo[j] & PVF_LAST) {
1845 unsigned NumElts = StoreOperands.size() - 3;
1847 switch (NumElts) {
1848 case 1:
1850 break;
1851 case 2:
1853 break;
1854 case 4:
1856 break;
1857 default:
1858 llvm_unreachable("Invalid vector info.");
1859 }
1860
1861 StoreOperands.push_back(InGlue);
1862
1863 // Adjust type of the store op if we've extended the scalar
1864 // return value.
1865 EVT TheStoreType = ExtendIntegerParam ? MVT::i32 : EltVT;
1866
1867 Chain = DAG.getMemIntrinsicNode(
1868 Op, dl, DAG.getVTList(MVT::Other, MVT::Glue), StoreOperands,
1869 TheStoreType, MachinePointerInfo(), PartAlign,
1871 InGlue = Chain.getValue(1);
1872
1873 // Cleanup.
1874 StoreOperands.clear();
1875
1876 // TODO: We may need to support vector types that can be passed
1877 // as scalars in variadic arguments.
1878 if (!IsByVal && IsVAArg) {
1879 assert(NumElts == 1 &&
1880 "Vectorization is expected to be disabled for variadics.");
1881 VAOffset += DL.getTypeAllocSize(
1882 TheStoreType.getTypeForEVT(*DAG.getContext()));
1883 }
1884 }
1885 if (!IsByVal)
1886 ++OIdx;
1887 }
1888 assert(StoreOperands.empty() && "Unfinished parameter store.");
1889 if (!IsByVal && VTs.size() > 0)
1890 --OIdx;
1891 ++ParamCount;
1892 if (IsByVal && IsVAArg)
1893 VAOffset += TypeSize;
1894 }
1895
1896 GlobalAddressSDNode *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode());
1897 MaybeAlign retAlignment = std::nullopt;
1898
1899 // Handle Result
1900 if (Ins.size() > 0) {
1901 SmallVector<EVT, 16> resvtparts;
1902 ComputeValueVTs(*this, DL, RetTy, resvtparts);
1903
1904 // Declare
1905 // .param .align N .b8 retval0[<size-in-bytes>], or
1906 // .param .b<size-in-bits> retval0
1907 unsigned resultsz = DL.getTypeAllocSizeInBits(RetTy);
1908 if (!IsTypePassedAsArray(RetTy)) {
1909 resultsz = promoteScalarArgumentSize(resultsz);
1910 SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1911 SDValue DeclareRetOps[] = { Chain, DAG.getConstant(1, dl, MVT::i32),
1912 DAG.getConstant(resultsz, dl, MVT::i32),
1913 DAG.getConstant(0, dl, MVT::i32), InGlue };
1914 Chain = DAG.getNode(NVPTXISD::DeclareRet, dl, DeclareRetVTs,
1915 DeclareRetOps);
1916 InGlue = Chain.getValue(1);
1917 } else {
1918 retAlignment = getArgumentAlignment(CB, RetTy, 0, DL);
1919 assert(retAlignment && "retAlignment is guaranteed to be set");
1920 SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1921 SDValue DeclareRetOps[] = {
1922 Chain, DAG.getConstant(retAlignment->value(), dl, MVT::i32),
1923 DAG.getConstant(resultsz / 8, dl, MVT::i32),
1924 DAG.getConstant(0, dl, MVT::i32), InGlue};
1925 Chain = DAG.getNode(NVPTXISD::DeclareRetParam, dl, DeclareRetVTs,
1926 DeclareRetOps);
1927 InGlue = Chain.getValue(1);
1928 }
1929 }
1930
1931 bool HasVAArgs = CLI.IsVarArg && (CLI.Args.size() > CLI.NumFixedArgs);
1932 // Set the size of the vararg param byte array if the callee is a variadic
1933 // function and the variadic part is not empty.
1934 if (HasVAArgs) {
1935 SDValue DeclareParamOps[] = {
1936 VADeclareParam.getOperand(0), VADeclareParam.getOperand(1),
1937 VADeclareParam.getOperand(2), DAG.getConstant(VAOffset, dl, MVT::i32),
1938 VADeclareParam.getOperand(4)};
1939 DAG.MorphNodeTo(VADeclareParam.getNode(), VADeclareParam.getOpcode(),
1940 VADeclareParam->getVTList(), DeclareParamOps);
1941 }
1942
1943 // Both indirect calls and libcalls have nullptr Func. In order to distinguish
1944 // between them we must rely on the call site value which is valid for
1945 // indirect calls but is always null for libcalls.
1946 bool isIndirectCall = !Func && CB;
1947
1948 if (isa<ExternalSymbolSDNode>(Callee)) {
1949 Function* CalleeFunc = nullptr;
1950
1951 // Try to find the callee in the current module.
1952 Callee = DAG.getSymbolFunctionGlobalAddress(Callee, &CalleeFunc);
1953 assert(CalleeFunc != nullptr && "Libcall callee must be set.");
1954
1955 // Set the "libcall callee" attribute to indicate that the function
1956 // must always have a declaration.
1957 CalleeFunc->addFnAttr("nvptx-libcall-callee", "true");
1958 }
1959
1960 if (isIndirectCall) {
1961 // This is indirect function call case : PTX requires a prototype of the
1962 // form
1963 // proto_0 : .callprototype(.param .b32 _) _ (.param .b32 _);
1964 // to be emitted, and the label has to used as the last arg of call
1965 // instruction.
1966 // The prototype is embedded in a string and put as the operand for a
1967 // CallPrototype SDNode which will print out to the value of the string.
1968 SDVTList ProtoVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1969 std::string Proto = getPrototype(
1970 DL, RetTy, Args, Outs, retAlignment,
1971 HasVAArgs
1972 ? std::optional<std::pair<unsigned, const APInt &>>(std::make_pair(
1973 CLI.NumFixedArgs, VADeclareParam->getConstantOperandAPInt(1)))
1974 : std::nullopt,
1975 *CB, UniqueCallSite);
1976 const char *ProtoStr = nvTM->getStrPool().save(Proto).data();
1977 SDValue ProtoOps[] = {
1978 Chain,
1979 DAG.getTargetExternalSymbol(ProtoStr, MVT::i32),
1980 InGlue,
1981 };
1982 Chain = DAG.getNode(NVPTXISD::CallPrototype, dl, ProtoVTs, ProtoOps);
1983 InGlue = Chain.getValue(1);
1984 }
1985 // Op to just print "call"
1986 SDVTList PrintCallVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1987 SDValue PrintCallOps[] = {
1988 Chain, DAG.getConstant((Ins.size() == 0) ? 0 : 1, dl, MVT::i32), InGlue
1989 };
1990 // We model convergent calls as separate opcodes.
1992 if (CLI.IsConvergent)
1995 Chain = DAG.getNode(Opcode, dl, PrintCallVTs, PrintCallOps);
1996 InGlue = Chain.getValue(1);
1997
1998 // Ops to print out the function name
1999 SDVTList CallVoidVTs = DAG.getVTList(MVT::Other, MVT::Glue);
2000 SDValue CallVoidOps[] = { Chain, Callee, InGlue };
2001 Chain = DAG.getNode(NVPTXISD::CallVoid, dl, CallVoidVTs, CallVoidOps);
2002 InGlue = Chain.getValue(1);
2003
2004 // Ops to print out the param list
2005 SDVTList CallArgBeginVTs = DAG.getVTList(MVT::Other, MVT::Glue);
2006 SDValue CallArgBeginOps[] = { Chain, InGlue };
2007 Chain = DAG.getNode(NVPTXISD::CallArgBegin, dl, CallArgBeginVTs,
2008 CallArgBeginOps);
2009 InGlue = Chain.getValue(1);
2010
2011 for (unsigned i = 0, e = std::min(CLI.NumFixedArgs + 1, ParamCount); i != e;
2012 ++i) {
2013 unsigned opcode;
2014 if (i == (e - 1))
2015 opcode = NVPTXISD::LastCallArg;
2016 else
2017 opcode = NVPTXISD::CallArg;
2018 SDVTList CallArgVTs = DAG.getVTList(MVT::Other, MVT::Glue);
2019 SDValue CallArgOps[] = { Chain, DAG.getConstant(1, dl, MVT::i32),
2020 DAG.getConstant(i, dl, MVT::i32), InGlue };
2021 Chain = DAG.getNode(opcode, dl, CallArgVTs, CallArgOps);
2022 InGlue = Chain.getValue(1);
2023 }
2024 SDVTList CallArgEndVTs = DAG.getVTList(MVT::Other, MVT::Glue);
2025 SDValue CallArgEndOps[] = { Chain,
2026 DAG.getConstant(isIndirectCall ? 0 : 1, dl, MVT::i32),
2027 InGlue };
2028 Chain = DAG.getNode(NVPTXISD::CallArgEnd, dl, CallArgEndVTs, CallArgEndOps);
2029 InGlue = Chain.getValue(1);
2030
2031 if (isIndirectCall) {
2032 SDVTList PrototypeVTs = DAG.getVTList(MVT::Other, MVT::Glue);
2033 SDValue PrototypeOps[] = {
2034 Chain, DAG.getConstant(UniqueCallSite, dl, MVT::i32), InGlue};
2035 Chain = DAG.getNode(NVPTXISD::Prototype, dl, PrototypeVTs, PrototypeOps);
2036 InGlue = Chain.getValue(1);
2037 }
2038
2039 SmallVector<SDValue, 16> ProxyRegOps;
2040 SmallVector<std::optional<MVT>, 16> ProxyRegTruncates;
2041 // An item of the vector is filled if the element does not need a ProxyReg
2042 // operation on it and should be added to InVals as is. ProxyRegOps and
2043 // ProxyRegTruncates contain empty/none items at the same index.
2045 // A temporary ProxyReg operations inserted in `LowerUnalignedLoadRetParam()`
2046 // to use the values of `LoadParam`s and to be replaced later then
2047 // `CALLSEQ_END` is added.
2048 SmallVector<SDValue, 16> TempProxyRegOps;
2049
2050 // Generate loads from param memory/moves from registers for result
2051 if (Ins.size() > 0) {
2054 ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets, 0);
2055 assert(VTs.size() == Ins.size() && "Bad value decomposition");
2056
2057 Align RetAlign = getArgumentAlignment(CB, RetTy, 0, DL);
2058 auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, RetAlign);
2059
2060 SmallVector<EVT, 6> LoadVTs;
2061 int VecIdx = -1; // Index of the first element of the vector.
2062
2063 // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
2064 // 32-bits are sign extended or zero extended, depending on whether
2065 // they are signed or unsigned types.
2066 bool ExtendIntegerRetVal =
2067 RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32;
2068
2069 for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
2070 bool needTruncate = false;
2071 EVT TheLoadType = VTs[i];
2072 EVT EltType = Ins[i].VT;
2073 Align EltAlign = commonAlignment(RetAlign, Offsets[i]);
2074 MVT PromotedVT;
2075
2076 if (PromoteScalarIntegerPTX(TheLoadType, &PromotedVT)) {
2077 TheLoadType = EVT(PromotedVT);
2078 EltType = EVT(PromotedVT);
2079 needTruncate = true;
2080 }
2081
2082 if (ExtendIntegerRetVal) {
2083 TheLoadType = MVT::i32;
2084 EltType = MVT::i32;
2085 needTruncate = true;
2086 } else if (TheLoadType.getSizeInBits() < 16) {
2087 if (VTs[i].isInteger())
2088 needTruncate = true;
2089 EltType = MVT::i16;
2090 }
2091
2092 // If we have a PVF_SCALAR entry, it may not be sufficiently aligned for a
2093 // scalar load. In such cases, fall back to byte loads.
2094 if (VectorInfo[i] == PVF_SCALAR && RetTy->isAggregateType() &&
2095 EltAlign < DL.getABITypeAlign(
2096 TheLoadType.getTypeForEVT(*DAG.getContext()))) {
2097 assert(VecIdx == -1 && LoadVTs.empty() && "Orphaned operand list.");
2099 DAG, Chain, Offsets[i], TheLoadType, InGlue, TempProxyRegOps, dl);
2100 ProxyRegOps.push_back(SDValue());
2101 ProxyRegTruncates.push_back(std::optional<MVT>());
2102 RetElts.resize(i);
2103 RetElts.push_back(Ret);
2104
2105 continue;
2106 }
2107
2108 // Record index of the very first element of the vector.
2109 if (VectorInfo[i] & PVF_FIRST) {
2110 assert(VecIdx == -1 && LoadVTs.empty() && "Orphaned operand list.");
2111 VecIdx = i;
2112 }
2113
2114 LoadVTs.push_back(EltType);
2115
2116 if (VectorInfo[i] & PVF_LAST) {
2117 unsigned NumElts = LoadVTs.size();
2118 LoadVTs.push_back(MVT::Other);
2119 LoadVTs.push_back(MVT::Glue);
2121 switch (NumElts) {
2122 case 1:
2124 break;
2125 case 2:
2127 break;
2128 case 4:
2130 break;
2131 default:
2132 llvm_unreachable("Invalid vector info.");
2133 }
2134
2135 SDValue LoadOperands[] = {
2136 Chain, DAG.getConstant(1, dl, MVT::i32),
2137 DAG.getConstant(Offsets[VecIdx], dl, MVT::i32), InGlue};
2138 SDValue RetVal = DAG.getMemIntrinsicNode(
2139 Op, dl, DAG.getVTList(LoadVTs), LoadOperands, TheLoadType,
2140 MachinePointerInfo(), EltAlign,
2142
2143 for (unsigned j = 0; j < NumElts; ++j) {
2144 ProxyRegOps.push_back(RetVal.getValue(j));
2145
2146 if (needTruncate)
2147 ProxyRegTruncates.push_back(std::optional<MVT>(Ins[VecIdx + j].VT));
2148 else
2149 ProxyRegTruncates.push_back(std::optional<MVT>());
2150 }
2151
2152 Chain = RetVal.getValue(NumElts);
2153 InGlue = RetVal.getValue(NumElts + 1);
2154
2155 // Cleanup
2156 VecIdx = -1;
2157 LoadVTs.clear();
2158 }
2159 }
2160 }
2161
2162 Chain =
2163 DAG.getCALLSEQ_END(Chain, UniqueCallSite, UniqueCallSite + 1, InGlue, dl);
2164 InGlue = Chain.getValue(1);
2165
2166 // Append ProxyReg instructions to the chain to make sure that `callseq_end`
2167 // will not get lost. Otherwise, during libcalls expansion, the nodes can become
2168 // dangling.
2169 for (unsigned i = 0; i < ProxyRegOps.size(); ++i) {
2170 if (i < RetElts.size() && RetElts[i]) {
2171 InVals.push_back(RetElts[i]);
2172 continue;
2173 }
2174
2175 SDValue Ret = DAG.getNode(
2177 DAG.getVTList(ProxyRegOps[i].getSimpleValueType(), MVT::Other, MVT::Glue),
2178 { Chain, ProxyRegOps[i], InGlue }
2179 );
2180
2181 Chain = Ret.getValue(1);
2182 InGlue = Ret.getValue(2);
2183
2184 if (ProxyRegTruncates[i]) {
2185 Ret = DAG.getNode(ISD::TRUNCATE, dl, *ProxyRegTruncates[i], Ret);
2186 }
2187
2188 InVals.push_back(Ret);
2189 }
2190
2191 for (SDValue &T : TempProxyRegOps) {
2192 SDValue Repl = DAG.getNode(
2194 DAG.getVTList(T.getSimpleValueType(), MVT::Other, MVT::Glue),
2195 {Chain, T.getOperand(0), InGlue});
2196 DAG.ReplaceAllUsesWith(T, Repl);
2197 DAG.RemoveDeadNode(T.getNode());
2198
2199 Chain = Repl.getValue(1);
2200 InGlue = Repl.getValue(2);
2201 }
2202
2203 // set isTailCall to false for now, until we figure out how to express
2204 // tail call optimization in PTX
2205 isTailCall = false;
2206 return Chain;
2207}
2208
2210 SelectionDAG &DAG) const {
2211
2212 if (STI.getPTXVersion() < 73 || STI.getSmVersion() < 52) {
2213 const Function &Fn = DAG.getMachineFunction().getFunction();
2214
2215 DiagnosticInfoUnsupported NoDynamicAlloca(
2216 Fn,
2217 "Support for dynamic alloca introduced in PTX ISA version 7.3 and "
2218 "requires target sm_52.",
2219 SDLoc(Op).getDebugLoc());
2220 DAG.getContext()->diagnose(NoDynamicAlloca);
2221 auto Ops = {DAG.getConstant(0, SDLoc(), Op.getValueType()),
2222 Op.getOperand(0)};
2223 return DAG.getMergeValues(Ops, SDLoc());
2224 }
2225
2226 SDValue Chain = Op.getOperand(0);
2227 SDValue Size = Op.getOperand(1);
2228 uint64_t Align = cast<ConstantSDNode>(Op.getOperand(2))->getZExtValue();
2229 SDLoc DL(Op.getNode());
2230
2231 // The size for ptx alloca instruction is 64-bit for m64 and 32-bit for m32.
2232 if (nvTM->is64Bit())
2233 Size = DAG.getZExtOrTrunc(Size, DL, MVT::i64);
2234 else
2235 Size = DAG.getZExtOrTrunc(Size, DL, MVT::i32);
2236
2237 SDValue AllocOps[] = {Chain, Size,
2238 DAG.getTargetConstant(Align, DL, MVT::i32)};
2240 nvTM->is64Bit() ? MVT::i64 : MVT::i32, AllocOps);
2241
2242 SDValue MergeOps[] = {Alloca, Chain};
2243 return DAG.getMergeValues(MergeOps, DL);
2244}
2245
2246// By default CONCAT_VECTORS is lowered by ExpandVectorBuildThroughStack()
2247// (see LegalizeDAG.cpp). This is slow and uses local memory.
2248// We use extract/insert/build vector just as what LegalizeOp() does in llvm 2.5
2249SDValue
2250NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
2251 SDNode *Node = Op.getNode();
2252 SDLoc dl(Node);
2254 unsigned NumOperands = Node->getNumOperands();
2255 for (unsigned i = 0; i < NumOperands; ++i) {
2256 SDValue SubOp = Node->getOperand(i);
2257 EVT VVT = SubOp.getNode()->getValueType(0);
2258 EVT EltVT = VVT.getVectorElementType();
2259 unsigned NumSubElem = VVT.getVectorNumElements();
2260 for (unsigned j = 0; j < NumSubElem; ++j) {
2261 Ops.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, SubOp,
2262 DAG.getIntPtrConstant(j, dl)));
2263 }
2264 }
2265 return DAG.getBuildVector(Node->getValueType(0), dl, Ops);
2266}
2267
2268// We can init constant f16x2/v2i16/v4i8 with a single .b32 move. Normally it
2269// would get lowered as two constant loads and vector-packing move.
2270// Instead we want just a constant move:
2271// mov.b32 %r2, 0x40003C00
2272SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op,
2273 SelectionDAG &DAG) const {
2274 EVT VT = Op->getValueType(0);
2275 if (!(Isv2x16VT(VT) || VT == MVT::v4i8))
2276 return Op;
2277
2278 SDLoc DL(Op);
2279
2280 if (!llvm::all_of(Op->ops(), [](SDValue Operand) {
2281 return Operand->isUndef() || isa<ConstantSDNode>(Operand) ||
2282 isa<ConstantFPSDNode>(Operand);
2283 })) {
2284 // Lower non-const v4i8 vector as byte-wise constructed i32, which allows us
2285 // to optimize calculation of constant parts.
2286 if (VT == MVT::v4i8) {
2287 SDValue C8 = DAG.getConstant(8, DL, MVT::i32);
2288 SDValue E01 = DAG.getNode(
2289 NVPTXISD::BFI, DL, MVT::i32,
2290 DAG.getAnyExtOrTrunc(Op->getOperand(1), DL, MVT::i32),
2291 DAG.getAnyExtOrTrunc(Op->getOperand(0), DL, MVT::i32), C8, C8);
2292 SDValue E012 =
2293 DAG.getNode(NVPTXISD::BFI, DL, MVT::i32,
2294 DAG.getAnyExtOrTrunc(Op->getOperand(2), DL, MVT::i32),
2295 E01, DAG.getConstant(16, DL, MVT::i32), C8);
2296 SDValue E0123 =
2297 DAG.getNode(NVPTXISD::BFI, DL, MVT::i32,
2298 DAG.getAnyExtOrTrunc(Op->getOperand(3), DL, MVT::i32),
2299 E012, DAG.getConstant(24, DL, MVT::i32), C8);
2300 return DAG.getNode(ISD::BITCAST, DL, VT, E0123);
2301 }
2302 return Op;
2303 }
2304
2305 // Get value or the Nth operand as an APInt(32). Undef values treated as 0.
2306 auto GetOperand = [](SDValue Op, int N) -> APInt {
2307 const SDValue &Operand = Op->getOperand(N);
2308 EVT VT = Op->getValueType(0);
2309 if (Operand->isUndef())
2310 return APInt(32, 0);
2311 APInt Value;
2312 if (VT == MVT::v2f16 || VT == MVT::v2bf16)
2313 Value = cast<ConstantFPSDNode>(Operand)->getValueAPF().bitcastToAPInt();
2314 else if (VT == MVT::v2i16 || VT == MVT::v4i8)
2315 Value = Operand->getAsAPIntVal();
2316 else
2317 llvm_unreachable("Unsupported type");
2318 // i8 values are carried around as i16, so we need to zero out upper bits,
2319 // so they do not get in the way of combining individual byte values
2320 if (VT == MVT::v4i8)
2321 Value = Value.trunc(8);
2322 return Value.zext(32);
2323 };
2324 APInt Value;
2325 if (Isv2x16VT(VT)) {
2326 Value = GetOperand(Op, 0) | GetOperand(Op, 1).shl(16);
2327 } else if (VT == MVT::v4i8) {
2328 Value = GetOperand(Op, 0) | GetOperand(Op, 1).shl(8) |
2329 GetOperand(Op, 2).shl(16) | GetOperand(Op, 3).shl(24);
2330 } else {
2331 llvm_unreachable("Unsupported type");
2332 }
2333 SDValue Const = DAG.getConstant(Value, SDLoc(Op), MVT::i32);
2334 return DAG.getNode(ISD::BITCAST, SDLoc(Op), Op->getValueType(0), Const);
2335}
2336
2337SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
2338 SelectionDAG &DAG) const {
2339 SDValue Index = Op->getOperand(1);
2340 SDValue Vector = Op->getOperand(0);
2341 SDLoc DL(Op);
2342 EVT VectorVT = Vector.getValueType();
2343
2344 if (VectorVT == MVT::v4i8) {
2345 SDValue BFE =
2346 DAG.getNode(NVPTXISD::BFE, DL, MVT::i32,
2347 {Vector,
2348 DAG.getNode(ISD::MUL, DL, MVT::i32,
2349 DAG.getZExtOrTrunc(Index, DL, MVT::i32),
2350 DAG.getConstant(8, DL, MVT::i32)),
2351 DAG.getConstant(8, DL, MVT::i32)});
2352 return DAG.getAnyExtOrTrunc(BFE, DL, Op->getValueType(0));
2353 }
2354
2355 // Constant index will be matched by tablegen.
2356 if (isa<ConstantSDNode>(Index.getNode()))
2357 return Op;
2358
2359 // Extract individual elements and select one of them.
2360 assert(Isv2x16VT(VectorVT) && "Unexpected vector type.");
2361 EVT EltVT = VectorVT.getVectorElementType();
2362
2363 SDLoc dl(Op.getNode());
2364 SDValue E0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, Vector,
2365 DAG.getIntPtrConstant(0, dl));
2366 SDValue E1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, Vector,
2367 DAG.getIntPtrConstant(1, dl));
2368 return DAG.getSelectCC(dl, Index, DAG.getIntPtrConstant(0, dl), E0, E1,
2370}
2371
2372SDValue NVPTXTargetLowering::LowerINSERT_VECTOR_ELT(SDValue Op,
2373 SelectionDAG &DAG) const {
2374 SDValue Vector = Op->getOperand(0);
2375 EVT VectorVT = Vector.getValueType();
2376
2377 if (VectorVT != MVT::v4i8)
2378 return Op;
2379 SDLoc DL(Op);
2380 SDValue Value = Op->getOperand(1);
2381 if (Value->isUndef())
2382 return Vector;
2383
2384 SDValue Index = Op->getOperand(2);
2385
2386 SDValue BFI =
2387 DAG.getNode(NVPTXISD::BFI, DL, MVT::i32,
2388 {DAG.getZExtOrTrunc(Value, DL, MVT::i32), Vector,
2389 DAG.getNode(ISD::MUL, DL, MVT::i32,
2390 DAG.getZExtOrTrunc(Index, DL, MVT::i32),
2391 DAG.getConstant(8, DL, MVT::i32)),
2392 DAG.getConstant(8, DL, MVT::i32)});
2393 return DAG.getNode(ISD::BITCAST, DL, Op->getValueType(0), BFI);
2394}
2395
2396SDValue NVPTXTargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
2397 SelectionDAG &DAG) const {
2398 SDValue V1 = Op.getOperand(0);
2399 EVT VectorVT = V1.getValueType();
2400 if (VectorVT != MVT::v4i8 || Op.getValueType() != MVT::v4i8)
2401 return Op;
2402
2403 // Lower shuffle to PRMT instruction.
2404 const ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Op.getNode());
2405 SDValue V2 = Op.getOperand(1);
2406 uint32_t Selector = 0;
2407 for (auto I : llvm::enumerate(SVN->getMask())) {
2408 if (I.value() != -1) // -1 is a placeholder for undef.
2409 Selector |= (I.value() << (I.index() * 4));
2410 }
2411
2412 SDLoc DL(Op);
2413 return DAG.getNode(NVPTXISD::PRMT, DL, MVT::v4i8, V1, V2,
2414 DAG.getConstant(Selector, DL, MVT::i32),
2415 DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32));
2416}
2417/// LowerShiftRightParts - Lower SRL_PARTS, SRA_PARTS, which
2418/// 1) returns two i32 values and take a 2 x i32 value to shift plus a shift
2419/// amount, or
2420/// 2) returns two i64 values and take a 2 x i64 value to shift plus a shift
2421/// amount.
2422SDValue NVPTXTargetLowering::LowerShiftRightParts(SDValue Op,
2423 SelectionDAG &DAG) const {
2424 assert(Op.getNumOperands() == 3 && "Not a double-shift!");
2425 assert(Op.getOpcode() == ISD::SRA_PARTS || Op.getOpcode() == ISD::SRL_PARTS);
2426
2427 EVT VT = Op.getValueType();
2428 unsigned VTBits = VT.getSizeInBits();
2429 SDLoc dl(Op);
2430 SDValue ShOpLo = Op.getOperand(0);
2431 SDValue ShOpHi = Op.getOperand(1);
2432 SDValue ShAmt = Op.getOperand(2);
2433 unsigned Opc = (Op.getOpcode() == ISD::SRA_PARTS) ? ISD::SRA : ISD::SRL;
2434
2435 if (VTBits == 32 && STI.getSmVersion() >= 35) {
2436 // For 32bit and sm35, we can use the funnel shift 'shf' instruction.
2437 // {dHi, dLo} = {aHi, aLo} >> Amt
2438 // dHi = aHi >> Amt
2439 // dLo = shf.r.clamp aLo, aHi, Amt
2440
2441 SDValue Hi = DAG.getNode(Opc, dl, VT, ShOpHi, ShAmt);
2442 SDValue Lo = DAG.getNode(NVPTXISD::FUN_SHFR_CLAMP, dl, VT, ShOpLo, ShOpHi,
2443 ShAmt);
2444
2445 SDValue Ops[2] = { Lo, Hi };
2446 return DAG.getMergeValues(Ops, dl);
2447 }
2448 else {
2449 // {dHi, dLo} = {aHi, aLo} >> Amt
2450 // - if (Amt>=size) then
2451 // dLo = aHi >> (Amt-size)
2452 // dHi = aHi >> Amt (this is either all 0 or all 1)
2453 // else
2454 // dLo = (aLo >>logic Amt) | (aHi << (size-Amt))
2455 // dHi = aHi >> Amt
2456
2457 SDValue RevShAmt = DAG.getNode(ISD::SUB, dl, MVT::i32,
2458 DAG.getConstant(VTBits, dl, MVT::i32),
2459 ShAmt);
2460 SDValue Tmp1 = DAG.getNode(ISD::SRL, dl, VT, ShOpLo, ShAmt);
2461 SDValue ExtraShAmt = DAG.getNode(ISD::SUB, dl, MVT::i32, ShAmt,
2462 DAG.getConstant(VTBits, dl, MVT::i32));
2463 SDValue Tmp2 = DAG.getNode(ISD::SHL, dl, VT, ShOpHi, RevShAmt);
2464 SDValue FalseVal = DAG.getNode(ISD::OR, dl, VT, Tmp1, Tmp2);
2465 SDValue TrueVal = DAG.getNode(Opc, dl, VT, ShOpHi, ExtraShAmt);
2466
2467 SDValue Cmp = DAG.getSetCC(dl, MVT::i1, ShAmt,
2468 DAG.getConstant(VTBits, dl, MVT::i32),
2469 ISD::SETGE);
2470 SDValue Hi = DAG.getNode(Opc, dl, VT, ShOpHi, ShAmt);
2471 SDValue Lo = DAG.getNode(ISD::SELECT, dl, VT, Cmp, TrueVal, FalseVal);
2472
2473 SDValue Ops[2] = { Lo, Hi };
2474 return DAG.getMergeValues(Ops, dl);
2475 }
2476}
2477
2478/// LowerShiftLeftParts - Lower SHL_PARTS, which
2479/// 1) returns two i32 values and take a 2 x i32 value to shift plus a shift
2480/// amount, or
2481/// 2) returns two i64 values and take a 2 x i64 value to shift plus a shift
2482/// amount.
2483SDValue NVPTXTargetLowering::LowerShiftLeftParts(SDValue Op,
2484 SelectionDAG &DAG) const {
2485 assert(Op.getNumOperands() == 3 && "Not a double-shift!");
2486 assert(Op.getOpcode() == ISD::SHL_PARTS);
2487
2488 EVT VT = Op.getValueType();
2489 unsigned VTBits = VT.getSizeInBits();
2490 SDLoc dl(Op);
2491 SDValue ShOpLo = Op.getOperand(0);
2492 SDValue ShOpHi = Op.getOperand(1);
2493 SDValue ShAmt = Op.getOperand(2);
2494
2495 if (VTBits == 32 && STI.getSmVersion() >= 35) {
2496 // For 32bit and sm35, we can use the funnel shift 'shf' instruction.
2497 // {dHi, dLo} = {aHi, aLo} << Amt
2498 // dHi = shf.l.clamp aLo, aHi, Amt
2499 // dLo = aLo << Amt
2500
2501 SDValue Hi = DAG.getNode(NVPTXISD::FUN_SHFL_CLAMP, dl, VT, ShOpLo, ShOpHi,
2502 ShAmt);
2503 SDValue Lo = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ShAmt);
2504
2505 SDValue Ops[2] = { Lo, Hi };
2506 return DAG.getMergeValues(Ops, dl);
2507 }
2508 else {
2509 // {dHi, dLo} = {aHi, aLo} << Amt
2510 // - if (Amt>=size) then
2511 // dLo = aLo << Amt (all 0)
2512 // dLo = aLo << (Amt-size)
2513 // else
2514 // dLo = aLo << Amt
2515 // dHi = (aHi << Amt) | (aLo >> (size-Amt))
2516
2517 SDValue RevShAmt = DAG.getNode(ISD::SUB, dl, MVT::i32,
2518 DAG.getConstant(VTBits, dl, MVT::i32),
2519 ShAmt);
2520 SDValue Tmp1 = DAG.getNode(ISD::SHL, dl, VT, ShOpHi, ShAmt);
2521 SDValue ExtraShAmt = DAG.getNode(ISD::SUB, dl, MVT::i32, ShAmt,
2522 DAG.getConstant(VTBits, dl, MVT::i32));
2523 SDValue Tmp2 = DAG.getNode(ISD::SRL, dl, VT, ShOpLo, RevShAmt);
2524 SDValue FalseVal = DAG.getNode(ISD::OR, dl, VT, Tmp1, Tmp2);
2525 SDValue TrueVal = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ExtraShAmt);
2526
2527 SDValue Cmp = DAG.getSetCC(dl, MVT::i1, ShAmt,
2528 DAG.getConstant(VTBits, dl, MVT::i32),
2529 ISD::SETGE);
2530 SDValue Lo = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ShAmt);
2531 SDValue Hi = DAG.getNode(ISD::SELECT, dl, VT, Cmp, TrueVal, FalseVal);
2532
2533 SDValue Ops[2] = { Lo, Hi };
2534 return DAG.getMergeValues(Ops, dl);
2535 }
2536}
2537
2538SDValue NVPTXTargetLowering::LowerFROUND(SDValue Op, SelectionDAG &DAG) const {
2539 EVT VT = Op.getValueType();
2540
2541 if (VT == MVT::f32)
2542 return LowerFROUND32(Op, DAG);
2543
2544 if (VT == MVT::f64)
2545 return LowerFROUND64(Op, DAG);
2546
2547 llvm_unreachable("unhandled type");
2548}
2549
2550// This is the the rounding method used in CUDA libdevice in C like code:
2551// float roundf(float A)
2552// {
2553// float RoundedA = (float) (int) ( A > 0 ? (A + 0.5f) : (A - 0.5f));
2554// RoundedA = abs(A) > 0x1.0p23 ? A : RoundedA;
2555// return abs(A) < 0.5 ? (float)(int)A : RoundedA;
2556// }
2557SDValue NVPTXTargetLowering::LowerFROUND32(SDValue Op,
2558 SelectionDAG &DAG) const {
2559 SDLoc SL(Op);
2560 SDValue A = Op.getOperand(0);
2561 EVT VT = Op.getValueType();
2562
2563 SDValue AbsA = DAG.getNode(ISD::FABS, SL, VT, A);
2564
2565 // RoundedA = (float) (int) ( A > 0 ? (A + 0.5f) : (A - 0.5f))
2566 SDValue Bitcast = DAG.getNode(ISD::BITCAST, SL, MVT::i32, A);
2567 const int SignBitMask = 0x80000000;
2568 SDValue Sign = DAG.getNode(ISD::AND, SL, MVT::i32, Bitcast,
2569 DAG.getConstant(SignBitMask, SL, MVT::i32));
2570 const int PointFiveInBits = 0x3F000000;
2571 SDValue PointFiveWithSignRaw =
2572 DAG.getNode(ISD::OR, SL, MVT::i32, Sign,
2573 DAG.getConstant(PointFiveInBits, SL, MVT::i32));
2574 SDValue PointFiveWithSign =
2575 DAG.getNode(ISD::BITCAST, SL, VT, PointFiveWithSignRaw);
2576 SDValue AdjustedA = DAG.getNode(ISD::FADD, SL, VT, A, PointFiveWithSign);
2577 SDValue RoundedA = DAG.getNode(ISD::FTRUNC, SL, VT, AdjustedA);
2578
2579 // RoundedA = abs(A) > 0x1.0p23 ? A : RoundedA;
2580 EVT SetCCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
2581 SDValue IsLarge =
2582 DAG.getSetCC(SL, SetCCVT, AbsA, DAG.getConstantFP(pow(2.0, 23.0), SL, VT),
2583 ISD::SETOGT);
2584 RoundedA = DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA);
2585
2586 // return abs(A) < 0.5 ? (float)(int)A : RoundedA;
2587 SDValue IsSmall =DAG.getSetCC(SL, SetCCVT, AbsA,
2588 DAG.getConstantFP(0.5, SL, VT), ISD::SETOLT);
2589 SDValue RoundedAForSmallA = DAG.getNode(ISD::FTRUNC, SL, VT, A);
2590 return DAG.getNode(ISD::SELECT, SL, VT, IsSmall, RoundedAForSmallA, RoundedA);
2591}
2592
2593// The implementation of round(double) is similar to that of round(float) in
2594// that they both separate the value range into three regions and use a method
2595// specific to the region to round the values. However, round(double) first
2596// calculates the round of the absolute value and then adds the sign back while
2597// round(float) directly rounds the value with sign.
2598SDValue NVPTXTargetLowering::LowerFROUND64(SDValue Op,
2599 SelectionDAG &DAG) const {
2600 SDLoc SL(Op);
2601 SDValue A = Op.getOperand(0);
2602 EVT VT = Op.getValueType();
2603
2604 SDValue AbsA = DAG.getNode(ISD::FABS, SL, VT, A);
2605
2606 // double RoundedA = (double) (int) (abs(A) + 0.5f);
2607 SDValue AdjustedA = DAG.getNode(ISD::FADD, SL, VT, AbsA,
2608 DAG.getConstantFP(0.5, SL, VT));
2609 SDValue RoundedA = DAG.getNode(ISD::FTRUNC, SL, VT, AdjustedA);
2610
2611 // RoundedA = abs(A) < 0.5 ? (double)0 : RoundedA;
2612 EVT SetCCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
2613 SDValue IsSmall =DAG.getSetCC(SL, SetCCVT, AbsA,
2614 DAG.getConstantFP(0.5, SL, VT), ISD::SETOLT);
2615 RoundedA = DAG.getNode(ISD::SELECT, SL, VT, IsSmall,
2616 DAG.getConstantFP(0, SL, VT),
2617 RoundedA);
2618
2619 // Add sign to rounded_A
2620 RoundedA = DAG.getNode(ISD::FCOPYSIGN, SL, VT, RoundedA, A);
2621 DAG.getNode(ISD::FTRUNC, SL, VT, A);
2622
2623 // RoundedA = abs(A) > 0x1.0p52 ? A : RoundedA;
2624 SDValue IsLarge =
2625 DAG.getSetCC(SL, SetCCVT, AbsA, DAG.getConstantFP(pow(2.0, 52.0), SL, VT),
2626 ISD::SETOGT);
2627 return DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA);
2628}
2629
2630SDValue NVPTXTargetLowering::LowerINT_TO_FP(SDValue Op,
2631 SelectionDAG &DAG) const {
2632 assert(STI.getSmVersion() < 90 || STI.getPTXVersion() < 78);
2633
2634 if (Op.getValueType() == MVT::bf16) {
2635 SDLoc Loc(Op);
2636 return DAG.getNode(
2637 ISD::FP_ROUND, Loc, MVT::bf16,
2638 DAG.getNode(Op.getOpcode(), Loc, MVT::f32, Op.getOperand(0)),
2639 DAG.getIntPtrConstant(0, Loc));
2640 }
2641
2642 // Everything else is considered legal.
2643 return Op;
2644}
2645
2646SDValue NVPTXTargetLowering::LowerFP_TO_INT(SDValue Op,
2647 SelectionDAG &DAG) const {
2648 assert(STI.getSmVersion() < 90 || STI.getPTXVersion() < 78);
2649
2650 if (Op.getOperand(0).getValueType() == MVT::bf16) {
2651 SDLoc Loc(Op);
2652 return DAG.getNode(
2653 Op.getOpcode(), Loc, Op.getValueType(),
2654 DAG.getNode(ISD::FP_EXTEND, Loc, MVT::f32, Op.getOperand(0)));
2655 }
2656
2657 // Everything else is considered legal.
2658 return Op;
2659}
2660
2661SDValue NVPTXTargetLowering::LowerFP_ROUND(SDValue Op,
2662 SelectionDAG &DAG) const {
2663 EVT NarrowVT = Op.getValueType();
2664 SDValue Wide = Op.getOperand(0);
2665 EVT WideVT = Wide.getValueType();
2666 if (NarrowVT.getScalarType() == MVT::bf16) {
2667 const TargetLowering *TLI = STI.getTargetLowering();
2668 if (STI.getSmVersion() < 80 || STI.getPTXVersion() < 70) {
2669 return TLI->expandFP_ROUND(Op.getNode(), DAG);
2670 }
2671 if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
2672 // This combination was the first to support f32 -> bf16.
2673 if (STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70) {
2674 if (WideVT.getScalarType() == MVT::f32) {
2675 return Op;
2676 }
2677 if (WideVT.getScalarType() == MVT::f64) {
2678 SDLoc Loc(Op);
2679 // Round-inexact-to-odd f64 to f32, then do the final rounding using
2680 // the hardware f32 -> bf16 instruction.
2682 WideVT.isVector() ? WideVT.changeVectorElementType(MVT::f32)
2683 : MVT::f32,
2684 Wide, Loc, DAG);
2685 return DAG.getFPExtendOrRound(rod, Loc, NarrowVT);
2686 }
2687 }
2688 return TLI->expandFP_ROUND(Op.getNode(), DAG);
2689 }
2690 }
2691
2692 // Everything else is considered legal.
2693 return Op;
2694}
2695
2696SDValue NVPTXTargetLowering::LowerFP_EXTEND(SDValue Op,
2697 SelectionDAG &DAG) const {
2698 SDValue Narrow = Op.getOperand(0);
2699 EVT NarrowVT = Narrow.getValueType();
2700 EVT WideVT = Op.getValueType();
2701 if (NarrowVT.getScalarType() == MVT::bf16) {
2702 if (WideVT.getScalarType() == MVT::f32 &&
2703 (STI.getSmVersion() < 80 || STI.getPTXVersion() < 71)) {
2704 SDLoc Loc(Op);
2705 return DAG.getNode(ISD::BF16_TO_FP, Loc, WideVT, Narrow);
2706 }
2707 if (WideVT.getScalarType() == MVT::f64 &&
2708 (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78)) {
2709 EVT F32 = NarrowVT.isVector() ? NarrowVT.changeVectorElementType(MVT::f32)
2710 : MVT::f32;
2711 SDLoc Loc(Op);
2712 if (STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 71) {
2713 Op = DAG.getNode(ISD::FP_EXTEND, Loc, F32, Narrow);
2714 } else {
2715 Op = DAG.getNode(ISD::BF16_TO_FP, Loc, F32, Narrow);
2716 }
2717 return DAG.getNode(ISD::FP_EXTEND, Loc, WideVT, Op);
2718 }
2719 }
2720
2721 // Everything else is considered legal.
2722 return Op;
2723}
2724
2726 SDLoc DL(Op);
2727 if (Op.getValueType() != MVT::v2i16)
2728 return Op;
2729 EVT EltVT = Op.getValueType().getVectorElementType();
2730 SmallVector<SDValue> VecElements;
2731 for (int I = 0, E = Op.getValueType().getVectorNumElements(); I < E; I++) {
2732 SmallVector<SDValue> ScalarArgs;
2733 llvm::transform(Op->ops(), std::back_inserter(ScalarArgs),
2734 [&](const SDUse &O) {
2735 return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT,
2736 O.get(), DAG.getIntPtrConstant(I, DL));
2737 });
2738 VecElements.push_back(DAG.getNode(Op.getOpcode(), DL, EltVT, ScalarArgs));
2739 }
2740 SDValue V =
2741 DAG.getNode(ISD::BUILD_VECTOR, DL, Op.getValueType(), VecElements);
2742 return V;
2743}
2744
2745SDValue
2747 switch (Op.getOpcode()) {
2748 case ISD::RETURNADDR:
2749 return SDValue();
2750 case ISD::FRAMEADDR:
2751 return SDValue();
2752 case ISD::GlobalAddress:
2753 return LowerGlobalAddress(Op, DAG);
2755 return Op;
2756 case ISD::BUILD_VECTOR:
2757 return LowerBUILD_VECTOR(Op, DAG);
2759 return Op;
2761 return LowerEXTRACT_VECTOR_ELT(Op, DAG);
2763 return LowerINSERT_VECTOR_ELT(Op, DAG);
2765 return LowerVECTOR_SHUFFLE(Op, DAG);
2767 return LowerCONCAT_VECTORS(Op, DAG);
2768 case ISD::STORE:
2769 return LowerSTORE(Op, DAG);
2770 case ISD::LOAD:
2771 return LowerLOAD(Op, DAG);
2772 case ISD::SHL_PARTS:
2773 return LowerShiftLeftParts(Op, DAG);
2774 case ISD::SRA_PARTS:
2775 case ISD::SRL_PARTS:
2776 return LowerShiftRightParts(Op, DAG);
2777 case ISD::SELECT:
2778 return LowerSelect(Op, DAG);
2779 case ISD::FROUND:
2780 return LowerFROUND(Op, DAG);
2781 case ISD::SINT_TO_FP:
2782 case ISD::UINT_TO_FP:
2783 return LowerINT_TO_FP(Op, DAG);
2784 case ISD::FP_TO_SINT:
2785 case ISD::FP_TO_UINT:
2786 return LowerFP_TO_INT(Op, DAG);
2787 case ISD::FP_ROUND:
2788 return LowerFP_ROUND(Op, DAG);
2789 case ISD::FP_EXTEND:
2790 return LowerFP_EXTEND(Op, DAG);
2791 case ISD::VAARG:
2792 return LowerVAARG(Op, DAG);
2793 case ISD::VASTART:
2794 return LowerVASTART(Op, DAG);
2795 case ISD::ABS:
2796 case ISD::SMIN:
2797 case ISD::SMAX:
2798 case ISD::UMIN:
2799 case ISD::UMAX:
2800 case ISD::ADD:
2801 case ISD::SUB:
2802 case ISD::MUL:
2803 case ISD::SHL:
2804 case ISD::SREM:
2805 case ISD::UREM:
2806 return LowerVectorArith(Op, DAG);
2808 return LowerDYNAMIC_STACKALLOC(Op, DAG);
2809 default:
2810 llvm_unreachable("Custom lowering not defined for operation");
2811 }
2812}
2813
2814// This function is almost a copy of SelectionDAG::expandVAArg().
2815// The only diff is that this one produces loads from local address space.
2816SDValue NVPTXTargetLowering::LowerVAARG(SDValue Op, SelectionDAG &DAG) const {
2817 const TargetLowering *TLI = STI.getTargetLowering();
2818 SDLoc DL(Op);
2819
2820 SDNode *Node = Op.getNode();
2821 const Value *V = cast<SrcValueSDNode>(Node->getOperand(2))->getValue();
2822 EVT VT = Node->getValueType(0);
2823 auto *Ty = VT.getTypeForEVT(*DAG.getContext());
2824 SDValue Tmp1 = Node->getOperand(0);
2825 SDValue Tmp2 = Node->getOperand(1);
2826 const MaybeAlign MA(Node->getConstantOperandVal(3));
2827
2828 SDValue VAListLoad = DAG.getLoad(TLI->getPointerTy(DAG.getDataLayout()), DL,
2829 Tmp1, Tmp2, MachinePointerInfo(V));
2830 SDValue VAList = VAListLoad;
2831
2832 if (MA && *MA > TLI->getMinStackArgumentAlignment()) {
2833 VAList = DAG.getNode(
2834 ISD::ADD, DL, VAList.getValueType(), VAList,
2835 DAG.getConstant(MA->value() - 1, DL, VAList.getValueType()));
2836
2837 VAList = DAG.getNode(
2838 ISD::AND, DL, VAList.getValueType(), VAList,
2839 DAG.getConstant(-(int64_t)MA->value(), DL, VAList.getValueType()));
2840 }
2841
2842 // Increment the pointer, VAList, to the next vaarg
2843 Tmp1 = DAG.getNode(ISD::ADD, DL, VAList.getValueType(), VAList,
2845 DL, VAList.getValueType()));
2846
2847 // Store the incremented VAList to the legalized pointer
2848 Tmp1 = DAG.getStore(VAListLoad.getValue(1), DL, Tmp1, Tmp2,
2850
2851 const Value *SrcV =
2853
2854 // Load the actual argument out of the pointer VAList
2855 return DAG.getLoad(VT, DL, Tmp1, VAList, MachinePointerInfo(SrcV));
2856}
2857
2858SDValue NVPTXTargetLowering::LowerVASTART(SDValue Op, SelectionDAG &DAG) const {
2859 const TargetLowering *TLI = STI.getTargetLowering();
2860 SDLoc DL(Op);
2861 EVT PtrVT = TLI->getPointerTy(DAG.getDataLayout());
2862
2863 // Store the address of unsized array <function>_vararg[] in the ap object.
2864 SDValue Arg = getParamSymbol(DAG, /* vararg */ -1, PtrVT);
2865 SDValue VAReg = DAG.getNode(NVPTXISD::Wrapper, DL, PtrVT, Arg);
2866
2867 const Value *SV = cast<SrcValueSDNode>(Op.getOperand(2))->getValue();
2868 return DAG.getStore(Op.getOperand(0), DL, VAReg, Op.getOperand(1),
2869 MachinePointerInfo(SV));
2870}
2871
2872SDValue NVPTXTargetLowering::LowerSelect(SDValue Op, SelectionDAG &DAG) const {
2873 SDValue Op0 = Op->getOperand(0);
2874 SDValue Op1 = Op->getOperand(1);
2875 SDValue Op2 = Op->getOperand(2);
2876 SDLoc DL(Op.getNode());
2877
2878 assert(Op.getValueType() == MVT::i1 && "Custom lowering enabled only for i1");
2879
2880 Op1 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Op1);
2881 Op2 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Op2);
2882 SDValue Select = DAG.getNode(ISD::SELECT, DL, MVT::i32, Op0, Op1, Op2);
2883 SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, Select);
2884
2885 return Trunc;
2886}
2887
2888SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
2889 if (Op.getValueType() == MVT::i1)
2890 return LowerLOADi1(Op, DAG);
2891
2892 // v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to handle
2893 // unaligned loads and have to handle it here.
2894 EVT VT = Op.getValueType();
2895 if (Isv2x16VT(VT) || VT == MVT::v4i8) {
2896 LoadSDNode *Load = cast<LoadSDNode>(Op);
2897 EVT MemVT = Load->getMemoryVT();
2899 MemVT, *Load->getMemOperand())) {
2900 SDValue Ops[2];
2901 std::tie(Ops[0], Ops[1]) = expandUnalignedLoad(Load, DAG);
2902 return DAG.getMergeValues(Ops, SDLoc(Op));
2903 }
2904 }
2905
2906 return SDValue();
2907}
2908
2909// v = ld i1* addr
2910// =>
2911// v1 = ld i8* addr (-> i16)
2912// v = trunc i16 to i1
2913SDValue NVPTXTargetLowering::LowerLOADi1(SDValue Op, SelectionDAG &DAG) const {
2914 SDNode *Node = Op.getNode();
2915 LoadSDNode *LD = cast<LoadSDNode>(Node);
2916 SDLoc dl(Node);
2917 assert(LD->getExtensionType() == ISD::NON_EXTLOAD);
2918 assert(Node->getValueType(0) == MVT::i1 &&
2919 "Custom lowering for i1 load only");
2920 SDValue newLD = DAG.getLoad(MVT::i16, dl, LD->getChain(), LD->getBasePtr(),
2921 LD->getPointerInfo(), LD->getAlign(),
2922 LD->getMemOperand()->getFlags());
2923 SDValue result = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, newLD);
2924 // The legalizer (the caller) is expecting two values from the legalized
2925 // load, so we build a MergeValues node for it. See ExpandUnalignedLoad()
2926 // in LegalizeDAG.cpp which also uses MergeValues.
2927 SDValue Ops[] = { result, LD->getChain() };
2928 return DAG.getMergeValues(Ops, dl);
2929}
2930
2931SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
2932 StoreSDNode *Store = cast<StoreSDNode>(Op);
2933 EVT VT = Store->getMemoryVT();
2934
2935 if (VT == MVT::i1)
2936 return LowerSTOREi1(Op, DAG);
2937
2938 // v2f16 is legal, so we can't rely on legalizer to handle unaligned
2939 // stores and have to handle it here.
2940 if ((Isv2x16VT(VT) || VT == MVT::v4i8) &&
2942 VT, *Store->getMemOperand()))
2943 return expandUnalignedStore(Store, DAG);
2944
2945 // v2f16, v2bf16 and v2i16 don't need special handling.
2946 if (Isv2x16VT(VT) || VT == MVT::v4i8)
2947 return SDValue();
2948
2949 if (VT.isVector())
2950 return LowerSTOREVector(Op, DAG);
2951
2952 return SDValue();
2953}
2954
2955SDValue
2956NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
2957 SDNode *N = Op.getNode();
2958 SDValue Val = N->getOperand(1);
2959 SDLoc DL(N);
2960 EVT ValVT = Val.getValueType();
2961
2962 if (ValVT.isVector()) {
2963 // We only handle "native" vector sizes for now, e.g. <4 x double> is not
2964 // legal. We can (and should) split that into 2 stores of <2 x double> here
2965 // but I'm leaving that as a TODO for now.
2966 if (!ValVT.isSimple())
2967 return SDValue();
2968 switch (ValVT.getSimpleVT().SimpleTy) {
2969 default:
2970 return SDValue();
2971 case MVT::v2i8:
2972 case MVT::v2i16:
2973 case MVT::v2i32:
2974 case MVT::v2i64:
2975 case MVT::v2f16:
2976 case MVT::v2bf16:
2977 case MVT::v2f32:
2978 case MVT::v2f64:
2979 case MVT::v4i8:
2980 case MVT::v4i16:
2981 case MVT::v4i32:
2982 case MVT::v4f16:
2983 case MVT::v4bf16:
2984 case MVT::v4f32:
2985 case MVT::v8f16: // <4 x f16x2>
2986 case MVT::v8bf16: // <4 x bf16x2>
2987 case MVT::v8i16: // <4 x i16x2>
2988 // This is a "native" vector type
2989 break;
2990 }
2991
2992 MemSDNode *MemSD = cast<MemSDNode>(N);
2993 const DataLayout &TD = DAG.getDataLayout();
2994
2995 Align Alignment = MemSD->getAlign();
2996 Align PrefAlign =
2997 TD.getPrefTypeAlign(ValVT.getTypeForEVT(*DAG.getContext()));
2998 if (Alignment < PrefAlign) {
2999 // This store is not sufficiently aligned, so bail out and let this vector
3000 // store be scalarized. Note that we may still be able to emit smaller
3001 // vector stores. For example, if we are storing a <4 x float> with an
3002 // alignment of 8, this check will fail but the legalizer will try again
3003 // with 2 x <2 x float>, which will succeed with an alignment of 8.
3004 return SDValue();
3005 }
3006
3007 unsigned Opcode = 0;
3008 EVT EltVT = ValVT.getVectorElementType();
3009 unsigned NumElts = ValVT.getVectorNumElements();
3010
3011 // Since StoreV2 is a target node, we cannot rely on DAG type legalization.
3012 // Therefore, we must ensure the type is legal. For i1 and i8, we set the
3013 // stored type to i16 and propagate the "real" type as the memory type.
3014 bool NeedExt = false;
3015 if (EltVT.getSizeInBits() < 16)
3016 NeedExt = true;
3017
3018 bool StoreF16x2 = false;
3019 switch (NumElts) {
3020 default:
3021 return SDValue();
3022 case 2:
3023 Opcode = NVPTXISD::StoreV2;
3024 break;
3025 case 4:
3026 Opcode = NVPTXISD::StoreV4;
3027 break;
3028 case 8:
3029 // v8f16 is a special case. PTX doesn't have st.v8.f16
3030 // instruction. Instead, we split the vector into v2f16 chunks and
3031 // store them with st.v4.b32.
3032 assert(Is16bitsType(EltVT.getSimpleVT()) && "Wrong type for the vector.");
3033 Opcode = NVPTXISD::StoreV4;
3034 StoreF16x2 = true;
3035 break;
3036 }
3037
3039
3040 // First is the chain
3041 Ops.push_back(N->getOperand(0));
3042
3043 if (StoreF16x2) {
3044 // Combine f16,f16 -> v2f16
3045 NumElts /= 2;
3046 for (unsigned i = 0; i < NumElts; ++i) {
3047 SDValue E0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
3048 DAG.getIntPtrConstant(i * 2, DL));
3049 SDValue E1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
3050 DAG.getIntPtrConstant(i * 2 + 1, DL));
3051 EVT VecVT = EVT::getVectorVT(*DAG.getContext(), EltVT, 2);
3052 SDValue V2 = DAG.getNode(ISD::BUILD_VECTOR, DL, VecVT, E0, E1);
3053 Ops.push_back(V2);
3054 }
3055 } else {
3056 // Then the split values
3057 for (unsigned i = 0; i < NumElts; ++i) {
3058 SDValue ExtVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
3059 DAG.getIntPtrConstant(i, DL));
3060 if (NeedExt)
3061 ExtVal = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i16, ExtVal);
3062 Ops.push_back(ExtVal);
3063 }
3064 }
3065
3066 // Then any remaining arguments
3067 Ops.append(N->op_begin() + 2, N->op_end());
3068
3069 SDValue NewSt =
3070 DAG.getMemIntrinsicNode(Opcode, DL, DAG.getVTList(MVT::Other), Ops,
3071 MemSD->getMemoryVT(), MemSD->getMemOperand());
3072
3073 // return DCI.CombineTo(N, NewSt, true);
3074 return NewSt;
3075 }
3076
3077 return SDValue();
3078}
3079
3080// st i1 v, addr
3081// =>
3082// v1 = zxt v to i16
3083// st.u8 i16, addr
3084SDValue NVPTXTargetLowering::LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const {
3085 SDNode *Node = Op.getNode();
3086 SDLoc dl(Node);
3087 StoreSDNode *ST = cast<StoreSDNode>(Node);
3088 SDValue Tmp1 = ST->getChain();
3089 SDValue Tmp2 = ST->getBasePtr();
3090 SDValue Tmp3 = ST->getValue();
3091 assert(Tmp3.getValueType() == MVT::i1 && "Custom lowering for i1 store only");
3092 Tmp3 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Tmp3);
3093 SDValue Result =
3094 DAG.getTruncStore(Tmp1, dl, Tmp3, Tmp2, ST->getPointerInfo(), MVT::i8,
3095 ST->getAlign(), ST->getMemOperand()->getFlags());
3096 return Result;
3097}
3098
3099// This creates target external symbol for a function parameter.
3100// Name of the symbol is composed from its index and the function name.
3101// Negative index corresponds to special parameter (unsized array) used for
3102// passing variable arguments.
3103SDValue NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, int idx,
3104 EVT v) const {
3105 StringRef SavedStr = nvTM->getStrPool().save(
3107 return DAG.getTargetExternalSymbol(SavedStr.data(), v);
3108}
3109
3111 SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
3112 const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &dl,
3113 SelectionDAG &DAG, SmallVectorImpl<SDValue> &InVals) const {
3115 const DataLayout &DL = DAG.getDataLayout();
3116 auto PtrVT = getPointerTy(DAG.getDataLayout());
3117
3118 const Function *F = &MF.getFunction();
3119 const AttributeList &PAL = F->getAttributes();
3120 const TargetLowering *TLI = STI.getTargetLowering();
3121
3122 SDValue Root = DAG.getRoot();
3123 std::vector<SDValue> OutChains;
3124
3125 bool isABI = (STI.getSmVersion() >= 20);
3126 assert(isABI && "Non-ABI compilation is not supported");
3127 if (!isABI)
3128 return Chain;
3129
3130 std::vector<Type *> argTypes;
3131 std::vector<const Argument *> theArgs;
3132 for (const Argument &I : F->args()) {
3133 theArgs.push_back(&I);
3134 argTypes.push_back(I.getType());
3135 }
3136 // argTypes.size() (or theArgs.size()) and Ins.size() need not match.
3137 // Ins.size() will be larger
3138 // * if there is an aggregate argument with multiple fields (each field
3139 // showing up separately in Ins)
3140 // * if there is a vector argument with more than typical vector-length
3141 // elements (generally if more than 4) where each vector element is
3142 // individually present in Ins.
3143 // So a different index should be used for indexing into Ins.
3144 // See similar issue in LowerCall.
3145 unsigned InsIdx = 0;
3146
3147 for (unsigned i = 0, e = theArgs.size(); i != e; ++i, ++InsIdx) {
3148 Type *Ty = argTypes[i];
3149
3150 if (theArgs[i]->use_empty()) {
3151 // argument is dead
3152 if (IsTypePassedAsArray(Ty) && !Ty->isVectorTy()) {
3153 SmallVector<EVT, 16> vtparts;
3154
3155 ComputePTXValueVTs(*this, DAG.getDataLayout(), Ty, vtparts);
3156 if (vtparts.empty())
3157 report_fatal_error("Empty parameter types are not supported");
3158
3159 for (unsigned parti = 0, parte = vtparts.size(); parti != parte;
3160 ++parti) {
3161 InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
3162 ++InsIdx;
3163 }
3164 if (vtparts.size() > 0)
3165 --InsIdx;
3166 continue;
3167 }
3168 if (Ty->isVectorTy()) {
3169 EVT ObjectVT = getValueType(DL, Ty);
3170 unsigned NumRegs = TLI->getNumRegisters(F->getContext(), ObjectVT);
3171 for (unsigned parti = 0; parti < NumRegs; ++parti) {
3172 InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
3173 ++InsIdx;
3174 }
3175 if (NumRegs > 0)
3176 --InsIdx;
3177 continue;
3178 }
3179 InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
3180 continue;
3181 }
3182
3183 // In the following cases, assign a node order of "i+1"
3184 // to newly created nodes. The SDNodes for params have to
3185 // appear in the same order as their order of appearance
3186 // in the original function. "i+1" holds that order.
3187 if (!PAL.hasParamAttr(i, Attribute::ByVal)) {
3188 bool aggregateIsPacked = false;
3189 if (StructType *STy = dyn_cast<StructType>(Ty))
3190 aggregateIsPacked = STy->isPacked();
3191
3194 ComputePTXValueVTs(*this, DL, Ty, VTs, &Offsets, 0);
3195 if (VTs.empty())
3196 report_fatal_error("Empty parameter types are not supported");
3197
3198 auto VectorInfo =
3199 VectorizePTXValueVTs(VTs, Offsets, DL.getABITypeAlign(Ty));
3200
3201 SDValue Arg = getParamSymbol(DAG, i, PtrVT);
3202 int VecIdx = -1; // Index of the first element of the current vector.
3203 for (unsigned parti = 0, parte = VTs.size(); parti != parte; ++parti) {
3204 if (VectorInfo[parti] & PVF_FIRST) {
3205 assert(VecIdx == -1 && "Orphaned vector.");
3206 VecIdx = parti;
3207 }
3208
3209 // That's the last element of this store op.
3210 if (VectorInfo[parti] & PVF_LAST) {
3211 unsigned NumElts = parti - VecIdx + 1;
3212 EVT EltVT = VTs[parti];
3213 // i1 is loaded/stored as i8.
3214 EVT LoadVT = EltVT;
3215 if (EltVT == MVT::i1)
3216 LoadVT = MVT::i8;
3217 else if (Isv2x16VT(EltVT) || EltVT == MVT::v4i8)
3218 // getLoad needs a vector type, but it can't handle
3219 // vectors which contain v2f16 or v2bf16 elements. So we must load
3220 // using i32 here and then bitcast back.
3221 LoadVT = MVT::i32;
3222
3223 EVT VecVT = EVT::getVectorVT(F->getContext(), LoadVT, NumElts);
3224 SDValue VecAddr =
3225 DAG.getNode(ISD::ADD, dl, PtrVT, Arg,
3226 DAG.getConstant(Offsets[VecIdx], dl, PtrVT));
3228 EltVT.getTypeForEVT(F->getContext()), ADDRESS_SPACE_PARAM));
3229
3230 const MaybeAlign PartAlign = [&]() -> MaybeAlign {
3231 if (aggregateIsPacked)
3232 return Align(1);
3233 if (NumElts != 1)
3234 return std::nullopt;
3235 Align PartAlign =
3236 (Offsets[parti] == 0 && PAL.getParamAlignment(i))
3237 ? PAL.getParamAlignment(i).value()
3238 : DL.getABITypeAlign(EltVT.getTypeForEVT(F->getContext()));
3239 return commonAlignment(PartAlign, Offsets[parti]);
3240 }();
3241 SDValue P = DAG.getLoad(VecVT, dl, Root, VecAddr,
3242 MachinePointerInfo(srcValue), PartAlign,
3245 if (P.getNode())
3246 P.getNode()->setIROrder(i + 1);
3247 for (unsigned j = 0; j < NumElts; ++j) {
3248 SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, LoadVT, P,
3249 DAG.getIntPtrConstant(j, dl));
3250 // We've loaded i1 as an i8 and now must truncate it back to i1
3251 if (EltVT == MVT::i1)
3252 Elt = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Elt);
3253 // v2f16 was loaded as an i32. Now we must bitcast it back.
3254 else if (EltVT != LoadVT)
3255 Elt = DAG.getNode(ISD::BITCAST, dl, EltVT, Elt);
3256
3257 // If a promoted integer type is used, truncate down to the original
3258 MVT PromotedVT;
3259 if (PromoteScalarIntegerPTX(EltVT, &PromotedVT)) {
3260 Elt = DAG.getNode(ISD::TRUNCATE, dl, EltVT, Elt);
3261 }
3262
3263 // Extend the element if necessary (e.g. an i8 is loaded
3264 // into an i16 register)
3265 if (Ins[InsIdx].VT.isInteger() &&
3266 Ins[InsIdx].VT.getFixedSizeInBits() >
3267 LoadVT.getFixedSizeInBits()) {
3268 unsigned Extend = Ins[InsIdx].Flags.isSExt() ? ISD::SIGN_EXTEND
3270 Elt = DAG.getNode(Extend, dl, Ins[InsIdx].VT, Elt);
3271 }
3272 InVals.push_back(Elt);
3273 }
3274
3275 // Reset vector tracking state.
3276 VecIdx = -1;
3277 }
3278 ++InsIdx;
3279 }
3280 if (VTs.size() > 0)
3281 --InsIdx;
3282 continue;
3283 }
3284
3285 // Param has ByVal attribute
3286 // Return MoveParam(param symbol).
3287 // Ideally, the param symbol can be returned directly,
3288 // but when SDNode builder decides to use it in a CopyToReg(),
3289 // machine instruction fails because TargetExternalSymbol
3290 // (not lowered) is target dependent, and CopyToReg assumes
3291 // the source is lowered.
3292 EVT ObjectVT = getValueType(DL, Ty);
3293 assert(ObjectVT == Ins[InsIdx].VT &&
3294 "Ins type did not match function type");
3295 SDValue Arg = getParamSymbol(DAG, i, PtrVT);
3296 SDValue p = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg);
3297 if (p.getNode())
3298 p.getNode()->setIROrder(i + 1);
3299 InVals.push_back(p);
3300 }
3301
3302 if (!OutChains.empty())
3303 DAG.setRoot(DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OutChains));
3304
3305 return Chain;
3306}
3307
3308// Use byte-store when the param adress of the return value is unaligned.
3309// This may happen when the return value is a field of a packed structure.
3311 uint64_t Offset, EVT ElementType,
3312 SDValue RetVal, const SDLoc &dl) {
3313 // Bit logic only works on integer types
3314 if (adjustElementType(ElementType))
3315 RetVal = DAG.getNode(ISD::BITCAST, dl, ElementType, RetVal);
3316
3317 // Store each byte
3318 for (unsigned i = 0, n = ElementType.getSizeInBits() / 8; i < n; i++) {
3319 // Shift the byte to the last byte position
3320 SDValue ShiftVal = DAG.getNode(ISD::SRL, dl, ElementType, RetVal,
3321 DAG.getConstant(i * 8, dl, MVT::i32));
3322 SDValue StoreOperands[] = {Chain, DAG.getConstant(Offset + i, dl, MVT::i32),
3323 ShiftVal};
3324 // Trunc store only the last byte by using
3325 // st.param.b8
3326 // The register type can be larger than b8.
3328 DAG.getVTList(MVT::Other), StoreOperands,
3329 MVT::i8, MachinePointerInfo(), std::nullopt,
3331 }
3332 return Chain;
3333}
3334
3335SDValue
3337 bool isVarArg,
3339 const SmallVectorImpl<SDValue> &OutVals,
3340 const SDLoc &dl, SelectionDAG &DAG) const {
3341 const MachineFunction &MF = DAG.getMachineFunction();
3342 const Function &F = MF.getFunction();
3344
3345 bool isABI = (STI.getSmVersion() >= 20);
3346 assert(isABI && "Non-ABI compilation is not supported");
3347 if (!isABI)
3348 return Chain;
3349
3350 const DataLayout &DL = DAG.getDataLayout();
3351 SmallVector<SDValue, 16> PromotedOutVals;
3354 ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets);
3355 assert(VTs.size() == OutVals.size() && "Bad return value decomposition");
3356
3357 for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
3358 SDValue PromotedOutVal = OutVals[i];
3359 MVT PromotedVT;
3360 if (PromoteScalarIntegerPTX(VTs[i], &PromotedVT)) {
3361 VTs[i] = EVT(PromotedVT);
3362 }
3363 if (PromoteScalarIntegerPTX(PromotedOutVal.getValueType(), &PromotedVT)) {
3365 Outs[i].Flags.isSExt() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
3366 PromotedOutVal = DAG.getNode(Ext, dl, PromotedVT, PromotedOutVal);
3367 }
3368 PromotedOutVals.push_back(PromotedOutVal);
3369 }
3370
3371 auto VectorInfo = VectorizePTXValueVTs(
3372 VTs, Offsets,
3374 : Align(1));
3375
3376 // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
3377 // 32-bits are sign extended or zero extended, depending on whether
3378 // they are signed or unsigned types.
3379 bool ExtendIntegerRetVal =
3380 RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32;
3381
3382 SmallVector<SDValue, 6> StoreOperands;
3383 for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
3384 SDValue OutVal = OutVals[i];
3385 SDValue RetVal = PromotedOutVals[i];
3386
3387 if (ExtendIntegerRetVal) {
3388 RetVal = DAG.getNode(Outs[i].Flags.isSExt() ? ISD::SIGN_EXTEND
3390 dl, MVT::i32, RetVal);
3391 } else if (OutVal.getValueSizeInBits() < 16) {
3392 // Use 16-bit registers for small load-stores as it's the
3393 // smallest general purpose register size supported by NVPTX.
3394 RetVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, RetVal);
3395 }
3396
3397 // If we have a PVF_SCALAR entry, it may not even be sufficiently aligned
3398 // for a scalar store. In such cases, fall back to byte stores.
3399 if (VectorInfo[i] == PVF_SCALAR && RetTy->isAggregateType()) {
3400 EVT ElementType = ExtendIntegerRetVal ? MVT::i32 : VTs[i];
3401 Align ElementTypeAlign =
3402 DL.getABITypeAlign(ElementType.getTypeForEVT(RetTy->getContext()));
3403 Align ElementAlign =
3404 commonAlignment(DL.getABITypeAlign(RetTy), Offsets[i]);
3405 if (ElementAlign < ElementTypeAlign) {
3406 assert(StoreOperands.empty() && "Orphaned operand list.");
3407 Chain = LowerUnalignedStoreRet(DAG, Chain, Offsets[i], ElementType,
3408 RetVal, dl);
3409
3410 // The call to LowerUnalignedStoreRet inserted the necessary SDAG nodes
3411 // into the graph, so just move on to the next element.
3412 continue;
3413 }
3414 }
3415
3416 // New load/store. Record chain and offset operands.
3417 if (VectorInfo[i] & PVF_FIRST) {
3418 assert(StoreOperands.empty() && "Orphaned operand list.");
3419 StoreOperands.push_back(Chain);
3420 StoreOperands.push_back(DAG.getConstant(Offsets[i], dl, MVT::i32));
3421 }
3422
3423 // Record the value to return.
3424 StoreOperands.push_back(RetVal);
3425
3426 // That's the last element of this store op.
3427 if (VectorInfo[i] & PVF_LAST) {
3429 unsigned NumElts = StoreOperands.size() - 2;
3430 switch (NumElts) {
3431 case 1:
3433 break;
3434 case 2:
3436 break;
3437 case 4:
3439 break;
3440 default:
3441 llvm_unreachable("Invalid vector info.");
3442 }
3443
3444 // Adjust type of load/store op if we've extended the scalar
3445 // return value.
3446 EVT TheStoreType = ExtendIntegerRetVal ? MVT::i32 : VTs[i];
3447 Chain = DAG.getMemIntrinsicNode(
3448 Op, dl, DAG.getVTList(MVT::Other), StoreOperands, TheStoreType,
3450 // Cleanup vector state.
3451 StoreOperands.clear();
3452 }
3453 }
3454
3455 return DAG.getNode(NVPTXISD::RET_GLUE, dl, MVT::Other, Chain);
3456}
3457
3459 SDValue Op, StringRef Constraint, std::vector<SDValue> &Ops,
3460 SelectionDAG &DAG) const {
3461 if (Constraint.size() > 1)
3462 return;
3463 TargetLowering::LowerAsmOperandForConstraint(Op, Constraint, Ops, DAG);
3464}
3465
3466static unsigned getOpcForTextureInstr(unsigned Intrinsic) {
3467 switch (Intrinsic) {
3468 default:
3469 return 0;
3470
3471 case Intrinsic::nvvm_tex_1d_v4f32_s32:
3473 case Intrinsic::nvvm_tex_1d_v4f32_f32:
3475 case Intrinsic::nvvm_tex_1d_level_v4f32_f32:
3477 case Intrinsic::nvvm_tex_1d_grad_v4f32_f32:
3479 case Intrinsic::nvvm_tex_1d_v4s32_s32:
3480 return NVPTXISD::Tex1DS32S32;
3481 case Intrinsic::nvvm_tex_1d_v4s32_f32:
3483 case Intrinsic::nvvm_tex_1d_level_v4s32_f32:
3485 case Intrinsic::nvvm_tex_1d_grad_v4s32_f32:
3487 case Intrinsic::nvvm_tex_1d_v4u32_s32:
3488 return NVPTXISD::Tex1DU32S32;
3489 case Intrinsic::nvvm_tex_1d_v4u32_f32:
3491 case Intrinsic::nvvm_tex_1d_level_v4u32_f32:
3493 case Intrinsic::nvvm_tex_1d_grad_v4u32_f32:
3495
3496 case Intrinsic::nvvm_tex_1d_array_v4f32_s32:
3498 case Intrinsic::nvvm_tex_1d_array_v4f32_f32:
3500 case Intrinsic::nvvm_tex_1d_array_level_v4f32_f32:
3502 case Intrinsic::nvvm_tex_1d_array_grad_v4f32_f32:
3504 case Intrinsic::nvvm_tex_1d_array_v4s32_s32:
3506 case Intrinsic::nvvm_tex_1d_array_v4s32_f32:
3508 case Intrinsic::nvvm_tex_1d_array_level_v4s32_f32:
3510 case Intrinsic::nvvm_tex_1d_array_grad_v4s32_f32:
3512 case Intrinsic::nvvm_tex_1d_array_v4u32_s32:
3514 case Intrinsic::nvvm_tex_1d_array_v4u32_f32:
3516 case Intrinsic::nvvm_tex_1d_array_level_v4u32_f32:
3518 case Intrinsic::nvvm_tex_1d_array_grad_v4u32_f32:
3520
3521 case Intrinsic::nvvm_tex_2d_v4f32_s32:
3523 case Intrinsic::nvvm_tex_2d_v4f32_f32:
3525 case Intrinsic::nvvm_tex_2d_level_v4f32_f32:
3527 case Intrinsic::nvvm_tex_2d_grad_v4f32_f32:
3529 case Intrinsic::nvvm_tex_2d_v4s32_s32:
3530 return NVPTXISD::Tex2DS32S32;
3531 case Intrinsic::nvvm_tex_2d_v4s32_f32:
3533 case Intrinsic::nvvm_tex_2d_level_v4s32_f32:
3535 case Intrinsic::nvvm_tex_2d_grad_v4s32_f32:
3537 case Intrinsic::nvvm_tex_2d_v4u32_s32:
3538 return NVPTXISD::Tex2DU32S32;
3539 case Intrinsic::nvvm_tex_2d_v4u32_f32:
3541 case Intrinsic::nvvm_tex_2d_level_v4u32_f32:
3543 case Intrinsic::nvvm_tex_2d_grad_v4u32_f32:
3545
3546 case Intrinsic::nvvm_tex_2d_array_v4f32_s32:
3548 case Intrinsic::nvvm_tex_2d_array_v4f32_f32:
3550 case Intrinsic::nvvm_tex_2d_array_level_v4f32_f32:
3552 case Intrinsic::nvvm_tex_2d_array_grad_v4f32_f32:
3554 case Intrinsic::nvvm_tex_2d_array_v4s32_s32:
3556 case Intrinsic::nvvm_tex_2d_array_v4s32_f32:
3558 case Intrinsic::nvvm_tex_2d_array_level_v4s32_f32:
3560 case Intrinsic::nvvm_tex_2d_array_grad_v4s32_f32:
3562 case Intrinsic::nvvm_tex_2d_array_v4u32_s32:
3564 case Intrinsic::nvvm_tex_2d_array_v4u32_f32:
3566 case Intrinsic::nvvm_tex_2d_array_level_v4u32_f32:
3568 case Intrinsic::nvvm_tex_2d_array_grad_v4u32_f32:
3570
3571 case Intrinsic::nvvm_tex_3d_v4f32_s32:
3573 case Intrinsic::nvvm_tex_3d_v4f32_f32:
3575 case Intrinsic::nvvm_tex_3d_level_v4f32_f32:
3577 case Intrinsic::nvvm_tex_3d_grad_v4f32_f32:
3579 case Intrinsic::nvvm_tex_3d_v4s32_s32:
3580 return NVPTXISD::Tex3DS32S32;
3581 case Intrinsic::nvvm_tex_3d_v4s32_f32:
3583 case Intrinsic::nvvm_tex_3d_level_v4s32_f32:
3585 case Intrinsic::nvvm_tex_3d_grad_v4s32_f32:
3587 case Intrinsic::nvvm_tex_3d_v4u32_s32:
3588 return NVPTXISD::Tex3DU32S32;
3589 case Intrinsic::nvvm_tex_3d_v4u32_f32:
3591 case Intrinsic::nvvm_tex_3d_level_v4u32_f32:
3593 case Intrinsic::nvvm_tex_3d_grad_v4u32_f32:
3595
3596 case Intrinsic::nvvm_tex_cube_v4f32_f32:
3598 case Intrinsic::nvvm_tex_cube_level_v4f32_f32:
3600 case Intrinsic::nvvm_tex_cube_v4s32_f32:
3602 case Intrinsic::nvvm_tex_cube_level_v4s32_f32:
3604 case Intrinsic::nvvm_tex_cube_v4u32_f32:
3606 case Intrinsic::nvvm_tex_cube_level_v4u32_f32:
3608
3609 case Intrinsic::nvvm_tex_cube_array_v4f32_f32:
3611 case Intrinsic::nvvm_tex_cube_array_level_v4f32_f32:
3613 case Intrinsic::nvvm_tex_cube_array_v4s32_f32:
3615 case Intrinsic::nvvm_tex_cube_array_level_v4s32_f32:
3617 case Intrinsic::nvvm_tex_cube_array_v4u32_f32:
3619 case Intrinsic::nvvm_tex_cube_array_level_v4u32_f32:
3621
3622 case Intrinsic::nvvm_tld4_r_2d_v4f32_f32:
3624 case Intrinsic::nvvm_tld4_g_2d_v4f32_f32:
3626 case Intrinsic::nvvm_tld4_b_2d_v4f32_f32:
3628 case Intrinsic::nvvm_tld4_a_2d_v4f32_f32:
3630 case Intrinsic::nvvm_tld4_r_2d_v4s32_f32:
3632 case Intrinsic::nvvm_tld4_g_2d_v4s32_f32:
3634 case Intrinsic::nvvm_tld4_b_2d_v4s32_f32:
3636 case Intrinsic::nvvm_tld4_a_2d_v4s32_f32:
3638 case Intrinsic::nvvm_tld4_r_2d_v4u32_f32:
3640 case Intrinsic::nvvm_tld4_g_2d_v4u32_f32:
3642 case Intrinsic::nvvm_tld4_b_2d_v4u32_f32:
3644 case Intrinsic::nvvm_tld4_a_2d_v4u32_f32:
3646
3647 case Intrinsic::nvvm_tex_unified_1d_v4f32_s32:
3649 case Intrinsic::nvvm_tex_unified_1d_v4f32_f32:
3651 case Intrinsic::nvvm_tex_unified_1d_level_v4f32_f32:
3653 case Intrinsic::nvvm_tex_unified_1d_grad_v4f32_f32:
3655 case Intrinsic::nvvm_tex_unified_1d_v4s32_s32:
3657 case Intrinsic::nvvm_tex_unified_1d_v4s32_f32:
3659 case Intrinsic::nvvm_tex_unified_1d_level_v4s32_f32:
3661 case Intrinsic::nvvm_tex_unified_1d_grad_v4s32_f32:
3663 case Intrinsic::nvvm_tex_unified_1d_v4u32_s32:
3665 case Intrinsic::nvvm_tex_unified_1d_v4u32_f32:
3667 case Intrinsic::nvvm_tex_unified_1d_level_v4u32_f32:
3669 case Intrinsic::nvvm_tex_unified_1d_grad_v4u32_f32:
3671
3672 case Intrinsic::nvvm_tex_unified_1d_array_v4f32_s32:
3674 case Intrinsic::nvvm_tex_unified_1d_array_v4f32_f32:
3676 case Intrinsic::nvvm_tex_unified_1d_array_level_v4f32_f32:
3678 case Intrinsic::nvvm_tex_unified_1d_array_grad_v4f32_f32:
3680 case Intrinsic::nvvm_tex_unified_1d_array_v4s32_s32:
3682 case Intrinsic::nvvm_tex_unified_1d_array_v4s32_f32:
3684 case Intrinsic::nvvm_tex_unified_1d_array_level_v4s32_f32:
3686 case Intrinsic::nvvm_tex_unified_1d_array_grad_v4s32_f32:
3688 case Intrinsic::nvvm_tex_unified_1d_array_v4u32_s32:
3690 case Intrinsic::nvvm_tex_unified_1d_array_v4u32_f32:
3692 case Intrinsic::nvvm_tex_unified_1d_array_level_v4u32_f32:
3694 case Intrinsic::nvvm_tex_unified_1d_array_grad_v4u32_f32:
3696
3697 case Intrinsic::nvvm_tex_unified_2d_v4f32_s32:
3699 case Intrinsic::nvvm_tex_unified_2d_v4f32_f32:
3701 case Intrinsic::nvvm_tex_unified_2d_level_v4f32_f32:
3703 case Intrinsic::nvvm_tex_unified_2d_grad_v4f32_f32:
3705 case Intrinsic::nvvm_tex_unified_2d_v4s32_s32:
3707 case Intrinsic::nvvm_tex_unified_2d_v4s32_f32:
3709 case Intrinsic::nvvm_tex_unified_2d_level_v4s32_f32:
3711 case Intrinsic::nvvm_tex_unified_2d_grad_v4s32_f32:
3713 case Intrinsic::nvvm_tex_unified_2d_v4u32_s32:
3715 case Intrinsic::nvvm_tex_unified_2d_v4u32_f32:
3717 case Intrinsic::nvvm_tex_unified_2d_level_v4u32_f32:
3719 case Intrinsic::nvvm_tex_unified_2d_grad_v4u32_f32:
3721
3722 case Intrinsic::nvvm_tex_unified_2d_array_v4f32_s32:
3724 case Intrinsic::nvvm_tex_unified_2d_array_v4f32_f32:
3726 case Intrinsic::nvvm_tex_unified_2d_array_level_v4f32_f32:
3728 case Intrinsic::nvvm_tex_unified_2d_array_grad_v4f32_f32:
3730 case Intrinsic::nvvm_tex_unified_2d_array_v4s32_s32:
3732 case Intrinsic::nvvm_tex_unified_2d_array_v4s32_f32:
3734 case Intrinsic::nvvm_tex_unified_2d_array_level_v4s32_f32:
3736 case Intrinsic::nvvm_tex_unified_2d_array_grad_v4s32_f32:
3738 case Intrinsic::nvvm_tex_unified_2d_array_v4u32_s32:
3740 case Intrinsic::nvvm_tex_unified_2d_array_v4u32_f32:
3742 case Intrinsic::nvvm_tex_unified_2d_array_level_v4u32_f32:
3744 case Intrinsic::nvvm_tex_unified_2d_array_grad_v4u32_f32:
3746
3747 case Intrinsic::nvvm_tex_unified_3d_v4f32_s32:
3749 case Intrinsic::nvvm_tex_unified_3d_v4f32_f32:
3751 case Intrinsic::nvvm_tex_unified_3d_level_v4f32_f32:
3753 case Intrinsic::nvvm_tex_unified_3d_grad_v4f32_f32:
3755 case Intrinsic::nvvm_tex_unified_3d_v4s32_s32:
3757 case Intrinsic::nvvm_tex_unified_3d_v4s32_f32:
3759 case Intrinsic::nvvm_tex_unified_3d_level_v4s32_f32:
3761 case Intrinsic::nvvm_tex_unified_3d_grad_v4s32_f32:
3763 case Intrinsic::nvvm_tex_unified_3d_v4u32_s32:
3765 case Intrinsic::nvvm_tex_unified_3d_v4u32_f32:
3767 case Intrinsic::nvvm_tex_unified_3d_level_v4u32_f32:
3769 case Intrinsic::nvvm_tex_unified_3d_grad_v4u32_f32:
3771
3772 case Intrinsic::nvvm_tex_unified_cube_v4f32_f32:
3774 case Intrinsic::nvvm_tex_unified_cube_level_v4f32_f32:
3776 case Intrinsic::nvvm_tex_unified_cube_v4s32_f32:
3778 case Intrinsic::nvvm_tex_unified_cube_level_v4s32_f32:
3780 case Intrinsic::nvvm_tex_unified_cube_v4u32_f32:
3782 case Intrinsic::nvvm_tex_unified_cube_level_v4u32_f32:
3784
3785 case Intrinsic::nvvm_tex_unified_cube_array_v4f32_f32:
3787 case Intrinsic::nvvm_tex_unified_cube_array_level_v4f32_f32:
3789 case Intrinsic::nvvm_tex_unified_cube_array_v4s32_f32:
3791 case Intrinsic::nvvm_tex_unified_cube_array_level_v4s32_f32:
3793 case Intrinsic::nvvm_tex_unified_cube_array_v4u32_f32:
3795 case Intrinsic::nvvm_tex_unified_cube_array_level_v4u32_f32:
3797
3798 case Intrinsic::nvvm_tex_unified_cube_grad_v4f32_f32:
3800 case Intrinsic::nvvm_tex_unified_cube_grad_v4s32_f32:
3802 case Intrinsic::nvvm_tex_unified_cube_grad_v4u32_f32:
3804 case Intrinsic::nvvm_tex_unified_cube_array_grad_v4f32_f32:
3806 case Intrinsic::nvvm_tex_unified_cube_array_grad_v4s32_f32:
3808 case Intrinsic::nvvm_tex_unified_cube_array_grad_v4u32_f32:
3810
3811 case Intrinsic::nvvm_tld4_unified_r_2d_v4f32_f32:
3813 case Intrinsic::nvvm_tld4_unified_g_2d_v4f32_f32:
3815 case Intrinsic::nvvm_tld4_unified_b_2d_v4f32_f32:
3817 case Intrinsic::nvvm_tld4_unified_a_2d_v4f32_f32:
3819 case Intrinsic::nvvm_tld4_unified_r_2d_v4s32_f32:
3821 case Intrinsic::nvvm_tld4_unified_g_2d_v4s32_f32:
3823 case Intrinsic::nvvm_tld4_unified_b_2d_v4s32_f32:
3825 case Intrinsic::nvvm_tld4_unified_a_2d_v4s32_f32:
3827 case Intrinsic::nvvm_tld4_unified_r_2d_v4u32_f32:
3829 case Intrinsic::nvvm_tld4_unified_g_2d_v4u32_f32:
3831 case Intrinsic::nvvm_tld4_unified_b_2d_v4u32_f32:
3833 case Intrinsic::nvvm_tld4_unified_a_2d_v4u32_f32:
3835 }
3836}
3837
3838static unsigned getOpcForSurfaceInstr(unsigned Intrinsic) {
3839 switch (Intrinsic) {
3840 default:
3841 return 0;
3842 case Intrinsic::nvvm_suld_1d_i8_clamp:
3844 case Intrinsic::nvvm_suld_1d_i16_clamp:
3846 case Intrinsic::nvvm_suld_1d_i32_clamp:
3848 case Intrinsic::nvvm_suld_1d_i64_clamp:
3850 case Intrinsic::nvvm_suld_1d_v2i8_clamp:
3852 case Intrinsic::nvvm_suld_1d_v2i16_clamp:
3854 case Intrinsic::nvvm_suld_1d_v2i32_clamp:
3856 case Intrinsic::nvvm_suld_1d_v2i64_clamp:
3858 case Intrinsic::nvvm_suld_1d_v4i8_clamp:
3860 case Intrinsic::nvvm_suld_1d_v4i16_clamp:
3862 case Intrinsic::nvvm_suld_1d_v4i32_clamp:
3864 case Intrinsic::nvvm_suld_1d_array_i8_clamp:
3866 case Intrinsic::nvvm_suld_1d_array_i16_clamp:
3868 case Intrinsic::nvvm_suld_1d_array_i32_clamp:
3870 case Intrinsic::nvvm_suld_1d_array_i64_clamp:
3872 case Intrinsic::nvvm_suld_1d_array_v2i8_clamp:
3874 case Intrinsic::nvvm_suld_1d_array_v2i16_clamp:
3876 case Intrinsic::nvvm_suld_1d_array_v2i32_clamp:
3878 case Intrinsic::nvvm_suld_1d_array_v2i64_clamp:
3880 case Intrinsic::nvvm_suld_1d_array_v4i8_clamp:
3882 case Intrinsic::nvvm_suld_1d_array_v4i16_clamp:
3884 case Intrinsic::nvvm_suld_1d_array_v4i32_clamp:
3886 case Intrinsic::nvvm_suld_2d_i8_clamp:
3888 case Intrinsic::nvvm_suld_2d_i16_clamp:
3890 case Intrinsic::nvvm_suld_2d_i32_clamp:
3892 case Intrinsic::nvvm_suld_2d_i64_clamp:
3894 case Intrinsic::nvvm_suld_2d_v2i8_clamp:
3896 case Intrinsic::nvvm_suld_2d_v2i16_clamp:
3898 case Intrinsic::nvvm_suld_2d_v2i32_clamp:
3900 case Intrinsic::nvvm_suld_2d_v2i64_clamp:
3902 case Intrinsic::nvvm_suld_2d_v4i8_clamp:
3904 case Intrinsic::nvvm_suld_2d_v4i16_clamp:
3906 case Intrinsic::nvvm_suld_2d_v4i32_clamp:
3908 case Intrinsic::nvvm_suld_2d_array_i8_clamp:
3910 case Intrinsic::nvvm_suld_2d_array_i16_clamp:
3912 case Intrinsic::nvvm_suld_2d_array_i32_clamp:
3914 case Intrinsic::nvvm_suld_2d_array_i64_clamp:
3916 case Intrinsic::nvvm_suld_2d_array_v2i8_clamp:
3918 case Intrinsic::nvvm_suld_2d_array_v2i16_clamp:
3920 case Intrinsic::nvvm_suld_2d_array_v2i32_clamp:
3922 case Intrinsic::nvvm_suld_2d_array_v2i64_clamp:
3924 case Intrinsic::nvvm_suld_2d_array_v4i8_clamp:
3926 case Intrinsic::nvvm_suld_2d_array_v4i16_clamp:
3928 case Intrinsic::nvvm_suld_2d_array_v4i32_clamp:
3930 case Intrinsic::nvvm_suld_3d_i8_clamp:
3932 case Intrinsic::nvvm_suld_3d_i16_clamp:
3934 case Intrinsic::nvvm_suld_3d_i32_clamp:
3936 case Intrinsic::nvvm_suld_3d_i64_clamp:
3938 case Intrinsic::nvvm_suld_3d_v2i8_clamp:
3940 case Intrinsic::nvvm_suld_3d_v2i16_clamp:
3942 case Intrinsic::nvvm_suld_3d_v2i32_clamp:
3944 case Intrinsic::nvvm_suld_3d_v2i64_clamp:
3946 case Intrinsic::nvvm_suld_3d_v4i8_clamp:
3948 case Intrinsic::nvvm_suld_3d_v4i16_clamp:
3950 case Intrinsic::nvvm_suld_3d_v4i32_clamp:
3952 case Intrinsic::nvvm_suld_1d_i8_trap:
3954 case Intrinsic::nvvm_suld_1d_i16_trap:
3956 case Intrinsic::nvvm_suld_1d_i32_trap:
3958 case Intrinsic::nvvm_suld_1d_i64_trap:
3960 case Intrinsic::nvvm_suld_1d_v2i8_trap:
3962 case Intrinsic::nvvm_suld_1d_v2i16_trap:
3964 case Intrinsic::nvvm_suld_1d_v2i32_trap:
3966 case Intrinsic::nvvm_suld_1d_v2i64_trap:
3968 case Intrinsic::nvvm_suld_1d_v4i8_trap:
3970 case Intrinsic::nvvm_suld_1d_v4i16_trap:
3972 case Intrinsic::nvvm_suld_1d_v4i32_trap:
3974 case Intrinsic::nvvm_suld_1d_array_i8_trap:
3976 case Intrinsic::nvvm_suld_1d_array_i16_trap:
3978 case Intrinsic::nvvm_suld_1d_array_i32_trap:
3980 case Intrinsic::nvvm_suld_1d_array_i64_trap:
3982 case Intrinsic::nvvm_suld_1d_array_v2i8_trap:
3984 case Intrinsic::nvvm_suld_1d_array_v2i16_trap:
3986 case Intrinsic::nvvm_suld_1d_array_v2i32_trap:
3988 case Intrinsic::nvvm_suld_1d_array_v2i64_trap:
3990 case Intrinsic::nvvm_suld_1d_array_v4i8_trap:
3992 case Intrinsic::nvvm_suld_1d_array_v4i16_trap:
3994 case Intrinsic::nvvm_suld_1d_array_v4i32_trap:
3996 case Intrinsic::nvvm_suld_2d_i8_trap:
3998 case Intrinsic::nvvm_suld_2d_i16_trap:
4000 case Intrinsic::nvvm_suld_2d_i32_trap:
4002 case Intrinsic::nvvm_suld_2d_i64_trap:
4004 case Intrinsic::nvvm_suld_2d_v2i8_trap:
4006 case Intrinsic::nvvm_suld_2d_v2i16_trap:
4008 case Intrinsic::nvvm_suld_2d_v2i32_trap:
4010 case Intrinsic::nvvm_suld_2d_v2i64_trap:
4012 case Intrinsic::nvvm_suld_2d_v4i8_trap:
4014 case Intrinsic::nvvm_suld_2d_v4i16_trap:
4016 case Intrinsic::nvvm_suld_2d_v4i32_trap:
4018 case Intrinsic::nvvm_suld_2d_array_i8_trap:
4020 case Intrinsic::nvvm_suld_2d_array_i16_trap:
4022 case Intrinsic::nvvm_suld_2d_array_i32_trap:
4024 case Intrinsic::nvvm_suld_2d_array_i64_trap:
4026 case Intrinsic::nvvm_suld_2d_array_v2i8_trap:
4028 case Intrinsic::nvvm_suld_2d_array_v2i16_trap:
4030 case Intrinsic::nvvm_suld_2d_array_v2i32_trap:
4032 case Intrinsic::nvvm_suld_2d_array_v2i64_trap:
4034 case Intrinsic::nvvm_suld_2d_array_v4i8_trap:
4036 case Intrinsic::nvvm_suld_2d_array_v4i16_trap:
4038 case Intrinsic::nvvm_suld_2d_array_v4i32_trap:
4040 case Intrinsic::nvvm_suld_3d_i8_trap:
4042 case Intrinsic::nvvm_suld_3d_i16_trap:
4044 case Intrinsic::nvvm_suld_3d_i32_trap:
4046 case Intrinsic::nvvm_suld_3d_i64_trap:
4048 case Intrinsic::nvvm_suld_3d_v2i8_trap:
4050 case Intrinsic::nvvm_suld_3d_v2i16_trap:
4052 case Intrinsic::nvvm_suld_3d_v2i32_trap:
4054 case Intrinsic::nvvm_suld_3d_v2i64_trap:
4056 case Intrinsic::nvvm_suld_3d_v4i8_trap:
4058 case Intrinsic::nvvm_suld_3d_v4i16_trap:
4060 case Intrinsic::nvvm_suld_3d_v4i32_trap:
4062 case Intrinsic::nvvm_suld_1d_i8_zero:
4064 case Intrinsic::nvvm_suld_1d_i16_zero:
4066 case Intrinsic::nvvm_suld_1d_i32_zero:
4068 case Intrinsic::nvvm_suld_1d_i64_zero:
4070 case Intrinsic::nvvm_suld_1d_v2i8_zero:
4072 case Intrinsic::nvvm_suld_1d_v2i16_zero: