LLVM 20.0.0git
NVPTXISelLowering.cpp
Go to the documentation of this file.
1//===-- NVPTXISelLowering.cpp - NVPTX DAG Lowering Implementation ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the interfaces that NVPTX uses to lower LLVM code into a
10// selection DAG.
11//
12//===----------------------------------------------------------------------===//
13
14#include "NVPTXISelLowering.h"
16#include "NVPTX.h"
17#include "NVPTXSubtarget.h"
18#include "NVPTXTargetMachine.h"
20#include "NVPTXUtilities.h"
21#include "llvm/ADT/APInt.h"
22#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/StringRef.h"
36#include "llvm/IR/Argument.h"
37#include "llvm/IR/Attributes.h"
38#include "llvm/IR/Constants.h"
39#include "llvm/IR/DataLayout.h"
42#include "llvm/IR/FPEnv.h"
43#include "llvm/IR/Function.h"
44#include "llvm/IR/GlobalValue.h"
45#include "llvm/IR/Instruction.h"
47#include "llvm/IR/IntrinsicsNVPTX.h"
48#include "llvm/IR/Module.h"
49#include "llvm/IR/Type.h"
50#include "llvm/IR/Value.h"
60#include <algorithm>
61#include <cassert>
62#include <cmath>
63#include <cstdint>
64#include <iterator>
65#include <optional>
66#include <string>
67#include <utility>
68#include <vector>
69
70#define DEBUG_TYPE "nvptx-lower"
71
72using namespace llvm;
73
74static std::atomic<unsigned> GlobalUniqueCallSite;
75
77 "nvptx-sched4reg",
78 cl::desc("NVPTX Specific: schedule for register pressue"), cl::init(false));
79
81 "nvptx-fma-level", cl::Hidden,
82 cl::desc("NVPTX Specific: FMA contraction (0: don't do it"
83 " 1: do it 2: do it aggressively"),
84 cl::init(2));
85
87 "nvptx-prec-divf32", cl::Hidden,
88 cl::desc("NVPTX Specifies: 0 use div.approx, 1 use div.full, 2 use"
89 " IEEE Compliant F32 div.rnd if available."),
90 cl::init(2));
91
93 "nvptx-prec-sqrtf32", cl::Hidden,
94 cl::desc("NVPTX Specific: 0 use sqrt.approx, 1 use sqrt.rn."),
95 cl::init(true));
96
98 "nvptx-force-min-byval-param-align", cl::Hidden,
99 cl::desc("NVPTX Specific: force 4-byte minimal alignment for byval"
100 " params of device functions."),
101 cl::init(false));
102
104 if (UsePrecDivF32.getNumOccurrences() > 0) {
105 // If nvptx-prec-div32=N is used on the command-line, always honor it
106 return UsePrecDivF32;
107 } else {
108 // Otherwise, use div.approx if fast math is enabled
109 if (getTargetMachine().Options.UnsafeFPMath)
110 return 0;
111 else
112 return 2;
113 }
114}
115
117 if (UsePrecSqrtF32.getNumOccurrences() > 0) {
118 // If nvptx-prec-sqrtf32 is used on the command-line, always honor it
119 return UsePrecSqrtF32;
120 } else {
121 // Otherwise, use sqrt.approx if fast math is enabled
123 }
124}
125
129}
130
131static bool IsPTXVectorType(MVT VT) {
132 switch (VT.SimpleTy) {
133 default:
134 return false;
135 case MVT::v2i1:
136 case MVT::v4i1:
137 case MVT::v2i8:
138 case MVT::v4i8:
139 case MVT::v8i8: // <2 x i8x4>
140 case MVT::v16i8: // <4 x i8x4>
141 case MVT::v2i16:
142 case MVT::v4i16:
143 case MVT::v8i16: // <4 x i16x2>
144 case MVT::v2i32:
145 case MVT::v4i32:
146 case MVT::v2i64:
147 case MVT::v2f16:
148 case MVT::v4f16:
149 case MVT::v8f16: // <4 x f16x2>
150 case MVT::v2bf16:
151 case MVT::v4bf16:
152 case MVT::v8bf16: // <4 x bf16x2>
153 case MVT::v2f32:
154 case MVT::v4f32:
155 case MVT::v2f64:
156 return true;
157 }
158}
159
160static bool Is16bitsType(MVT VT) {
161 return (VT.SimpleTy == MVT::f16 || VT.SimpleTy == MVT::bf16 ||
162 VT.SimpleTy == MVT::i16);
163}
164
165// When legalizing vector loads/stores, this function is called, which does two
166// things:
167// 1. Determines Whether the vector is something we want to custom lower,
168// std::nullopt is returned if we do not want to custom lower it.
169// 2. If we do want to handle it, returns two parameters:
170// - unsigned int NumElts - The number of elements in the final vector
171// - EVT EltVT - The type of the elements in the final vector
172static std::optional<std::pair<unsigned int, EVT>>
174 if (!VectorVT.isVector() || !VectorVT.isSimple())
175 return std::nullopt;
176
177 EVT EltVT = VectorVT.getVectorElementType();
178 unsigned NumElts = VectorVT.getVectorNumElements();
179
180 // We only handle "native" vector sizes for now, e.g. <4 x double> is not
181 // legal. We can (and should) split that into 2 stores of <2 x double> here
182 // but I'm leaving that as a TODO for now.
183 switch (VectorVT.getSimpleVT().SimpleTy) {
184 default:
185 return std::nullopt;
186 case MVT::v2i8:
187 case MVT::v2i16:
188 case MVT::v2i32:
189 case MVT::v2i64:
190 case MVT::v2f16:
191 case MVT::v2bf16:
192 case MVT::v2f32:
193 case MVT::v2f64:
194 case MVT::v4i8:
195 case MVT::v4i16:
196 case MVT::v4i32:
197 case MVT::v4f16:
198 case MVT::v4bf16:
199 case MVT::v4f32:
200 // This is a "native" vector type
201 return std::pair(NumElts, EltVT);
202 case MVT::v8i8: // <2 x i8x4>
203 case MVT::v8f16: // <4 x f16x2>
204 case MVT::v8bf16: // <4 x bf16x2>
205 case MVT::v8i16: // <4 x i16x2>
206 case MVT::v16i8: // <4 x i8x4>
207 // This can be upsized into a "native" vector type.
208 // Despite vectors like v8i8, v16i8, v8i16 being within the bit-limit for
209 // total load/store size, PTX syntax only supports v2/v4. Thus, we can't use
210 // vectorized loads/stores with the actual element type for i8/i16 as that
211 // would require v8/v16 variants that do not exist.
212 // In order to load/store such vectors efficiently, here in Type
213 // Legalization, we split the vector into word-sized chunks (v2x16/v4i8).
214 // Later, we will lower to PTX as vectors of b32.
215
216 // Number of elements to pack in one word.
217 unsigned NPerWord = 32 / EltVT.getSizeInBits();
218
219 return std::pair(NumElts / NPerWord,
220 MVT::getVectorVT(EltVT.getSimpleVT(), NPerWord));
221 }
222
223 llvm_unreachable("All cases in switch should return.");
224}
225
226/// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive
227/// EVTs that compose it. Unlike ComputeValueVTs, this will break apart vectors
228/// into their primitive components.
229/// NOTE: This is a band-aid for code that expects ComputeValueVTs to return the
230/// same number of types as the Ins/Outs arrays in LowerFormalArguments,
231/// LowerCall, and LowerReturn.
232static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
233 Type *Ty, SmallVectorImpl<EVT> &ValueVTs,
234 SmallVectorImpl<uint64_t> *Offsets = nullptr,
235 uint64_t StartingOffset = 0) {
236 SmallVector<EVT, 16> TempVTs;
237 SmallVector<uint64_t, 16> TempOffsets;
238
239 // Special case for i128 - decompose to (i64, i64)
240 if (Ty->isIntegerTy(128)) {
241 ValueVTs.push_back(EVT(MVT::i64));
242 ValueVTs.push_back(EVT(MVT::i64));
243
244 if (Offsets) {
245 Offsets->push_back(StartingOffset + 0);
246 Offsets->push_back(StartingOffset + 8);
247 }
248
249 return;
250 }
251
252 // Given a struct type, recursively traverse the elements with custom ComputePTXValueVTs.
253 if (StructType *STy = dyn_cast<StructType>(Ty)) {
254 auto const *SL = DL.getStructLayout(STy);
255 auto ElementNum = 0;
256 for(auto *EI : STy->elements()) {
257 ComputePTXValueVTs(TLI, DL, EI, ValueVTs, Offsets,
258 StartingOffset + SL->getElementOffset(ElementNum));
259 ++ElementNum;
260 }
261 return;
262 }
263
264 ComputeValueVTs(TLI, DL, Ty, TempVTs, &TempOffsets, StartingOffset);
265 for (unsigned i = 0, e = TempVTs.size(); i != e; ++i) {
266 EVT VT = TempVTs[i];
267 uint64_t Off = TempOffsets[i];
268 // Split vectors into individual elements, except for v2f16, which
269 // we will pass as a single scalar.
270 if (VT.isVector()) {
271 unsigned NumElts = VT.getVectorNumElements();
272 EVT EltVT = VT.getVectorElementType();
273 // We require power-of-2 sized vectors becuase
274 // TargetLoweringBase::getVectorTypeBreakdown() which is invoked in
275 // ComputePTXValueVTs() cannot currently break down non-power-of-2 sized
276 // vectors.
277 if ((Is16bitsType(EltVT.getSimpleVT())) && NumElts % 2 == 0 &&
278 isPowerOf2_32(NumElts)) {
279 // Vectors with an even number of f16 elements will be passed to
280 // us as an array of v2f16/v2bf16 elements. We must match this so we
281 // stay in sync with Ins/Outs.
282 switch (EltVT.getSimpleVT().SimpleTy) {
283 case MVT::f16:
284 EltVT = MVT::v2f16;
285 break;
286 case MVT::bf16:
287 EltVT = MVT::v2bf16;
288 break;
289 case MVT::i16:
290 EltVT = MVT::v2i16;
291 break;
292 default:
293 llvm_unreachable("Unexpected type");
294 }
295 NumElts /= 2;
296 } else if (EltVT.getSimpleVT() == MVT::i8 &&
297 ((NumElts % 4 == 0 && isPowerOf2_32(NumElts)) ||
298 NumElts == 3)) {
299 // v*i8 are formally lowered as v4i8
300 EltVT = MVT::v4i8;
301 NumElts = (NumElts + 3) / 4;
302 } else if (EltVT.getSimpleVT() == MVT::i8 && NumElts == 2) {
303 // v2i8 is promoted to v2i16
304 NumElts = 1;
305 EltVT = MVT::v2i16;
306 }
307 for (unsigned j = 0; j != NumElts; ++j) {
308 ValueVTs.push_back(EltVT);
309 if (Offsets)
310 Offsets->push_back(Off + j * EltVT.getStoreSize());
311 }
312 } else {
313 ValueVTs.push_back(VT);
314 if (Offsets)
315 Offsets->push_back(Off);
316 }
317 }
318}
319
320/// PromoteScalarIntegerPTX
321/// Used to make sure the arguments/returns are suitable for passing
322/// and promote them to a larger size if they're not.
323///
324/// The promoted type is placed in \p PromoteVT if the function returns true.
325static bool PromoteScalarIntegerPTX(const EVT &VT, MVT *PromotedVT) {
326 if (VT.isScalarInteger()) {
327 switch (PowerOf2Ceil(VT.getFixedSizeInBits())) {
328 default:
330 "Promotion is not suitable for scalars of size larger than 64-bits");
331 case 1:
332 *PromotedVT = MVT::i1;
333 break;
334 case 2:
335 case 4:
336 case 8:
337 *PromotedVT = MVT::i8;
338 break;
339 case 16:
340 *PromotedVT = MVT::i16;
341 break;
342 case 32:
343 *PromotedVT = MVT::i32;
344 break;
345 case 64:
346 *PromotedVT = MVT::i64;
347 break;
348 }
349 return EVT(*PromotedVT) != VT;
350 }
351 return false;
352}
353
354// Check whether we can merge loads/stores of some of the pieces of a
355// flattened function parameter or return value into a single vector
356// load/store.
357//
358// The flattened parameter is represented as a list of EVTs and
359// offsets, and the whole structure is aligned to ParamAlignment. This
360// function determines whether we can load/store pieces of the
361// parameter starting at index Idx using a single vectorized op of
362// size AccessSize. If so, it returns the number of param pieces
363// covered by the vector op. Otherwise, it returns 1.
365 unsigned Idx, uint32_t AccessSize, const SmallVectorImpl<EVT> &ValueVTs,
366 const SmallVectorImpl<uint64_t> &Offsets, Align ParamAlignment) {
367
368 // Can't vectorize if param alignment is not sufficient.
369 if (ParamAlignment < AccessSize)
370 return 1;
371 // Can't vectorize if offset is not aligned.
372 if (Offsets[Idx] & (AccessSize - 1))
373 return 1;
374
375 EVT EltVT = ValueVTs[Idx];
376 unsigned EltSize = EltVT.getStoreSize();
377
378 // Element is too large to vectorize.
379 if (EltSize >= AccessSize)
380 return 1;
381
382 unsigned NumElts = AccessSize / EltSize;
383 // Can't vectorize if AccessBytes if not a multiple of EltSize.
384 if (AccessSize != EltSize * NumElts)
385 return 1;
386
387 // We don't have enough elements to vectorize.
388 if (Idx + NumElts > ValueVTs.size())
389 return 1;
390
391 // PTX ISA can only deal with 2- and 4-element vector ops.
392 if (NumElts != 4 && NumElts != 2)
393 return 1;
394
395 for (unsigned j = Idx + 1; j < Idx + NumElts; ++j) {
396 // Types do not match.
397 if (ValueVTs[j] != EltVT)
398 return 1;
399
400 // Elements are not contiguous.
401 if (Offsets[j] - Offsets[j - 1] != EltSize)
402 return 1;
403 }
404 // OK. We can vectorize ValueVTs[i..i+NumElts)
405 return NumElts;
406}
407
408// Flags for tracking per-element vectorization state of loads/stores
409// of a flattened function parameter or return value.
411 PVF_INNER = 0x0, // Middle elements of a vector.
412 PVF_FIRST = 0x1, // First element of the vector.
413 PVF_LAST = 0x2, // Last element of the vector.
414 // Scalar is effectively a 1-element vector.
417
418// Computes whether and how we can vectorize the loads/stores of a
419// flattened function parameter or return value.
420//
421// The flattened parameter is represented as the list of ValueVTs and
422// Offsets, and is aligned to ParamAlignment bytes. We return a vector
423// of the same size as ValueVTs indicating how each piece should be
424// loaded/stored (i.e. as a scalar, or as part of a vector
425// load/store).
428 const SmallVectorImpl<uint64_t> &Offsets,
429 Align ParamAlignment, bool IsVAArg = false) {
430 // Set vector size to match ValueVTs and mark all elements as
431 // scalars by default.
433 VectorInfo.assign(ValueVTs.size(), PVF_SCALAR);
434
435 if (IsVAArg)
436 return VectorInfo;
437
438 // Check what we can vectorize using 128/64/32-bit accesses.
439 for (int I = 0, E = ValueVTs.size(); I != E; ++I) {
440 // Skip elements we've already processed.
441 assert(VectorInfo[I] == PVF_SCALAR && "Unexpected vector info state.");
442 for (unsigned AccessSize : {16, 8, 4, 2}) {
443 unsigned NumElts = CanMergeParamLoadStoresStartingAt(
444 I, AccessSize, ValueVTs, Offsets, ParamAlignment);
445 // Mark vectorized elements.
446 switch (NumElts) {
447 default:
448 llvm_unreachable("Unexpected return value");
449 case 1:
450 // Can't vectorize using this size, try next smaller size.
451 continue;
452 case 2:
453 assert(I + 1 < E && "Not enough elements.");
454 VectorInfo[I] = PVF_FIRST;
455 VectorInfo[I + 1] = PVF_LAST;
456 I += 1;
457 break;
458 case 4:
459 assert(I + 3 < E && "Not enough elements.");
460 VectorInfo[I] = PVF_FIRST;
461 VectorInfo[I + 1] = PVF_INNER;
462 VectorInfo[I + 2] = PVF_INNER;
463 VectorInfo[I + 3] = PVF_LAST;
464 I += 3;
465 break;
466 }
467 // Break out of the inner loop because we've already succeeded
468 // using largest possible AccessSize.
469 break;
470 }
471 }
472 return VectorInfo;
473}
474
476 SDValue Value) {
477 if (Value->getValueType(0) == VT)
478 return Value;
479 return DAG.getNode(ISD::BITCAST, DL, VT, Value);
480}
481
482// NVPTXTargetLowering Constructor.
484 const NVPTXSubtarget &STI)
485 : TargetLowering(TM), nvTM(&TM), STI(STI) {
486 // always lower memset, memcpy, and memmove intrinsics to load/store
487 // instructions, rather
488 // then generating calls to memset, mempcy or memmove.
492
495
496 // Jump is Expensive. Don't create extra control flow for 'and', 'or'
497 // condition branches.
498 setJumpIsExpensive(true);
499
500 // Wide divides are _very_ slow. Try to reduce the width of the divide if
501 // possible.
502 addBypassSlowDiv(64, 32);
503
504 // By default, use the Source scheduling
505 if (sched4reg)
507 else
509
510 auto setFP16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
511 LegalizeAction NoF16Action) {
512 bool IsOpSupported = STI.allowFP16Math();
513 switch (Op) {
514 // Several FP16 instructions are available on sm_80 only.
515 case ISD::FMINNUM:
516 case ISD::FMAXNUM:
519 case ISD::FMAXIMUM:
520 case ISD::FMINIMUM:
521 IsOpSupported &= STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70;
522 break;
523 }
524 setOperationAction(Op, VT, IsOpSupported ? Action : NoF16Action);
525 };
526
527 auto setBF16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
528 LegalizeAction NoBF16Action) {
529 bool IsOpSupported = STI.hasBF16Math();
530 switch (Op) {
531 // Several BF16 instructions are available on sm_90 only.
532 case ISD::FADD:
533 case ISD::FMUL:
534 case ISD::FSUB:
535 case ISD::SELECT:
536 case ISD::SELECT_CC:
537 case ISD::SETCC:
538 case ISD::FEXP2:
539 case ISD::FCEIL:
540 case ISD::FFLOOR:
541 case ISD::FNEARBYINT:
542 case ISD::FRINT:
543 case ISD::FROUNDEVEN:
544 case ISD::FTRUNC:
545 IsOpSupported = STI.getSmVersion() >= 90 && STI.getPTXVersion() >= 78;
546 break;
547 // Several BF16 instructions are available on sm_80 only.
548 case ISD::FMINNUM:
549 case ISD::FMAXNUM:
552 case ISD::FMAXIMUM:
553 case ISD::FMINIMUM:
554 IsOpSupported &= STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70;
555 break;
556 }
558 Op, VT, IsOpSupported ? Action : NoBF16Action);
559 };
560
561 auto setI16x2OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
562 LegalizeAction NoI16x2Action) {
563 bool IsOpSupported = false;
564 // instructions are available on sm_90 only
565 switch (Op) {
566 case ISD::ADD:
567 case ISD::SMAX:
568 case ISD::SMIN:
569 case ISD::UMIN:
570 case ISD::UMAX:
571 IsOpSupported = STI.getSmVersion() >= 90 && STI.getPTXVersion() >= 80;
572 break;
573 }
574 setOperationAction(Op, VT, IsOpSupported ? Action : NoI16x2Action);
575 };
576
577 addRegisterClass(MVT::i1, &NVPTX::Int1RegsRegClass);
578 addRegisterClass(MVT::i16, &NVPTX::Int16RegsRegClass);
579 addRegisterClass(MVT::v2i16, &NVPTX::Int32RegsRegClass);
580 addRegisterClass(MVT::v4i8, &NVPTX::Int32RegsRegClass);
581 addRegisterClass(MVT::i32, &NVPTX::Int32RegsRegClass);
582 addRegisterClass(MVT::i64, &NVPTX::Int64RegsRegClass);
583 addRegisterClass(MVT::f32, &NVPTX::Float32RegsRegClass);
584 addRegisterClass(MVT::f64, &NVPTX::Float64RegsRegClass);
585 addRegisterClass(MVT::f16, &NVPTX::Int16RegsRegClass);
586 addRegisterClass(MVT::v2f16, &NVPTX::Int32RegsRegClass);
587 addRegisterClass(MVT::bf16, &NVPTX::Int16RegsRegClass);
588 addRegisterClass(MVT::v2bf16, &NVPTX::Int32RegsRegClass);
589
590 // Conversion to/from FP16/FP16x2 is always legal.
595
597 if (STI.getSmVersion() >= 30 && STI.getPTXVersion() > 31)
599
600 setFP16OperationAction(ISD::SETCC, MVT::f16, Legal, Promote);
601 setFP16OperationAction(ISD::SETCC, MVT::v2f16, Legal, Expand);
602
603 // Conversion to/from BFP16/BFP16x2 is always legal.
608
609 setBF16OperationAction(ISD::SETCC, MVT::v2bf16, Legal, Expand);
610 setBF16OperationAction(ISD::SETCC, MVT::bf16, Legal, Promote);
611 if (getOperationAction(ISD::SETCC, MVT::bf16) == Promote)
612 AddPromotedToType(ISD::SETCC, MVT::bf16, MVT::f32);
613
614 // Conversion to/from i16/i16x2 is always legal.
619
624
625 // Custom conversions to/from v2i8.
627
628 // Only logical ops can be done on v4i8 directly, others must be done
629 // elementwise.
646 MVT::v4i8, Expand);
647
648 // Operations not directly supported by NVPTX.
649 for (MVT VT : {MVT::bf16, MVT::f16, MVT::v2bf16, MVT::v2f16, MVT::f32,
650 MVT::f64, MVT::i1, MVT::i8, MVT::i16, MVT::v2i16, MVT::v4i8,
651 MVT::i32, MVT::i64}) {
654 }
655
656 // Some SIGN_EXTEND_INREG can be done using cvt instruction.
657 // For others we will expand to a SHL/SRA pair.
664
671
674
676 {MVT::i8, MVT::i16, MVT::v2i16, MVT::i32, MVT::i64},
677 Expand);
678
679 if (STI.hasHWROT32())
681
683
686
689
690 // We want to legalize constant related memmove and memcopy
691 // intrinsics.
693
694 // Turn FP extload into load/fpextend
695 setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::f16, Expand);
696 setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f16, Expand);
697 setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::bf16, Expand);
698 setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::bf16, Expand);
699 setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f32, Expand);
700 setLoadExtAction(ISD::EXTLOAD, MVT::v2f32, MVT::v2f16, Expand);
701 setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2f16, Expand);
702 setLoadExtAction(ISD::EXTLOAD, MVT::v2f32, MVT::v2bf16, Expand);
703 setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2bf16, Expand);
704 setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2f32, Expand);
705 setLoadExtAction(ISD::EXTLOAD, MVT::v4f32, MVT::v4f16, Expand);
706 setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4f16, Expand);
707 setLoadExtAction(ISD::EXTLOAD, MVT::v4f32, MVT::v4bf16, Expand);
708 setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4bf16, Expand);
709 setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4f32, Expand);
710 setLoadExtAction(ISD::EXTLOAD, MVT::v8f32, MVT::v8f16, Expand);
711 setLoadExtAction(ISD::EXTLOAD, MVT::v8f64, MVT::v8f16, Expand);
712 setLoadExtAction(ISD::EXTLOAD, MVT::v8f32, MVT::v8bf16, Expand);
713 setLoadExtAction(ISD::EXTLOAD, MVT::v8f64, MVT::v8bf16, Expand);
714 // Turn FP truncstore into trunc + store.
715 // FIXME: vector types should also be expanded
716 setTruncStoreAction(MVT::f32, MVT::f16, Expand);
717 setTruncStoreAction(MVT::f64, MVT::f16, Expand);
718 setTruncStoreAction(MVT::f32, MVT::bf16, Expand);
719 setTruncStoreAction(MVT::f64, MVT::bf16, Expand);
720 setTruncStoreAction(MVT::f64, MVT::f32, Expand);
721
722 // PTX does not support load / store predicate registers
725
726 for (MVT VT : MVT::integer_valuetypes()) {
730 setTruncStoreAction(VT, MVT::i1, Expand);
731 }
732
736 MVT::i1, Expand);
737
738 // expand extload of vector of integers.
740 MVT::v2i8, Expand);
741 setTruncStoreAction(MVT::v2i16, MVT::v2i8, Expand);
742
743 // This is legal in NVPTX
748
749 setOperationAction(ISD::DYNAMIC_STACKALLOC, {MVT::i32, MVT::i64}, Custom);
751
752 // TRAP can be lowered to PTX trap
753 setOperationAction(ISD::TRAP, MVT::Other, Legal);
754 // DEBUGTRAP can be lowered to PTX brkpt
756
757 // Register custom handling for vector loads/stores
759 if (IsPTXVectorType(VT)) {
763 }
764 }
765
766 // Support varargs.
771
772 // Custom handling for i8 intrinsics
774
775 for (const auto& Ty : {MVT::i16, MVT::i32, MVT::i64}) {
781
784 }
785
786 setI16x2OperationAction(ISD::ABS, MVT::v2i16, Legal, Custom);
787 setI16x2OperationAction(ISD::SMIN, MVT::v2i16, Legal, Custom);
788 setI16x2OperationAction(ISD::SMAX, MVT::v2i16, Legal, Custom);
789 setI16x2OperationAction(ISD::UMIN, MVT::v2i16, Legal, Custom);
790 setI16x2OperationAction(ISD::UMAX, MVT::v2i16, Legal, Custom);
791 setI16x2OperationAction(ISD::CTPOP, MVT::v2i16, Legal, Expand);
792 setI16x2OperationAction(ISD::CTLZ, MVT::v2i16, Legal, Expand);
793
794 setI16x2OperationAction(ISD::ADD, MVT::v2i16, Legal, Custom);
795 setI16x2OperationAction(ISD::SUB, MVT::v2i16, Legal, Custom);
796 setI16x2OperationAction(ISD::MUL, MVT::v2i16, Legal, Custom);
797 setI16x2OperationAction(ISD::SHL, MVT::v2i16, Legal, Custom);
798 setI16x2OperationAction(ISD::SREM, MVT::v2i16, Legal, Custom);
799 setI16x2OperationAction(ISD::UREM, MVT::v2i16, Legal, Custom);
800
801 // Other arithmetic and logic ops are unsupported.
805 MVT::v2i16, Expand);
806
811 if (STI.getPTXVersion() >= 43) {
816 }
817
819 setOperationAction(ISD::CTTZ, MVT::v2i16, Expand);
822
823 // PTX does not directly support SELP of i1, so promote to i32 first
825
826 // PTX cannot multiply two i64s in a single instruction.
829
830 // We have some custom DAG combine patterns for these nodes
834
835 // setcc for f16x2 and bf16x2 needs special handling to prevent
836 // legalizer's attempt to scalarize it due to v2i1 not being legal.
837 if (STI.allowFP16Math() || STI.hasBF16Math())
839
840 // Promote fp16 arithmetic if fp16 hardware isn't available or the
841 // user passed --nvptx-no-fp16-math. The flag is useful because,
842 // although sm_53+ GPUs have some sort of FP16 support in
843 // hardware, only sm_53 and sm_60 have full implementation. Others
844 // only have token amount of hardware and are likely to run faster
845 // by using fp32 units instead.
846 for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB, ISD::FMA}) {
847 setFP16OperationAction(Op, MVT::f16, Legal, Promote);
848 setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
849 setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
850 // bf16 must be promoted to f32.
851 setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
852 if (getOperationAction(Op, MVT::bf16) == Promote)
853 AddPromotedToType(Op, MVT::bf16, MVT::f32);
854 }
855
856 // f16/f16x2 neg was introduced in PTX 60, SM_53.
857 const bool IsFP16FP16x2NegAvailable = STI.getSmVersion() >= 53 &&
858 STI.getPTXVersion() >= 60 &&
859 STI.allowFP16Math();
860 for (const auto &VT : {MVT::f16, MVT::v2f16})
862 IsFP16FP16x2NegAvailable ? Legal : Expand);
863
864 setBF16OperationAction(ISD::FNEG, MVT::bf16, Legal, Expand);
865 setBF16OperationAction(ISD::FNEG, MVT::v2bf16, Legal, Expand);
866 // (would be) Library functions.
867
868 // These map to conversion instructions for scalar FP types.
869 for (const auto &Op : {ISD::FCEIL, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT,
871 setOperationAction(Op, MVT::f16, Legal);
872 setOperationAction(Op, MVT::f32, Legal);
873 setOperationAction(Op, MVT::f64, Legal);
874 setOperationAction(Op, MVT::v2f16, Expand);
875 setOperationAction(Op, MVT::v2bf16, Expand);
876 setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
877 if (getOperationAction(Op, MVT::bf16) == Promote)
878 AddPromotedToType(Op, MVT::bf16, MVT::f32);
879 }
880
881 if (STI.getSmVersion() < 80 || STI.getPTXVersion() < 71) {
883 }
884 if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
885 for (MVT VT : {MVT::bf16, MVT::f32, MVT::f64}) {
888 }
889 }
890
891 // sm_80 only has conversions between f32 and bf16. Custom lower all other
892 // bf16 conversions.
893 if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
894 for (MVT VT : {MVT::i1, MVT::i16, MVT::i32, MVT::i64}) {
897 VT, Custom);
898 }
901 MVT::bf16, Custom);
902 }
903
910 AddPromotedToType(ISD::FROUND, MVT::bf16, MVT::f32);
911
912 // 'Expand' implements FCOPYSIGN without calling an external library.
919
920 // These map to corresponding instructions for f32/f64. f16 must be
921 // promoted to f32. v2f16 is expanded to f16, which is then promoted
922 // to f32.
923 for (const auto &Op :
925 setOperationAction(Op, MVT::f16, Promote);
926 setOperationAction(Op, MVT::f32, Legal);
927 setOperationAction(Op, MVT::f64, Legal);
928 setOperationAction(Op, MVT::v2f16, Expand);
929 setOperationAction(Op, MVT::v2bf16, Expand);
930 setOperationAction(Op, MVT::bf16, Promote);
931 AddPromotedToType(Op, MVT::bf16, MVT::f32);
932 }
933
934 setOperationAction(ISD::FABS, {MVT::f32, MVT::f64}, Legal);
935 if (STI.getPTXVersion() >= 65) {
936 setFP16OperationAction(ISD::FABS, MVT::f16, Legal, Promote);
937 setFP16OperationAction(ISD::FABS, MVT::v2f16, Legal, Expand);
938 } else {
940 setOperationAction(ISD::FABS, MVT::v2f16, Expand);
941 }
942 setBF16OperationAction(ISD::FABS, MVT::v2bf16, Legal, Expand);
943 setBF16OperationAction(ISD::FABS, MVT::bf16, Legal, Promote);
944 if (getOperationAction(ISD::FABS, MVT::bf16) == Promote)
945 AddPromotedToType(ISD::FABS, MVT::bf16, MVT::f32);
946
947 for (const auto &Op : {ISD::FMINNUM, ISD::FMAXNUM}) {
948 setOperationAction(Op, MVT::f32, Legal);
949 setOperationAction(Op, MVT::f64, Legal);
950 setFP16OperationAction(Op, MVT::f16, Legal, Promote);
951 setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
952 setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
953 setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
954 if (getOperationAction(Op, MVT::bf16) == Promote)
955 AddPromotedToType(Op, MVT::bf16, MVT::f32);
956 }
957 bool SupportsF32MinMaxNaN =
958 STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70;
959 for (const auto &Op : {ISD::FMINIMUM, ISD::FMAXIMUM}) {
960 setOperationAction(Op, MVT::f32, SupportsF32MinMaxNaN ? Legal : Expand);
961 setFP16OperationAction(Op, MVT::f16, Legal, Expand);
962 setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
963 setBF16OperationAction(Op, MVT::bf16, Legal, Expand);
964 setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
965 }
966
967 // Custom lowering for inline asm with 128-bit operands
970
971 // No FEXP2, FLOG2. The PTX ex2 and log2 functions are always approximate.
972 // No FPOW or FREM in PTX.
973
974 // Now deduce the information based on the above mentioned
975 // actions
977
978 setMinCmpXchgSizeInBits(STI.hasAtomCas16() ? 16 : 32);
981}
982
983const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
984
985#define MAKE_CASE(V) \
986 case V: \
987 return #V;
988
989 switch ((NVPTXISD::NodeType)Opcode) {
991 break;
992
1056 }
1057 return nullptr;
1058
1059#undef MAKE_CASE
1060}
1061
1064 if (!VT.isScalableVector() && VT.getVectorNumElements() != 1 &&
1065 VT.getScalarType() == MVT::i1)
1066 return TypeSplitVector;
1068}
1069
1071 int Enabled, int &ExtraSteps,
1072 bool &UseOneConst,
1073 bool Reciprocal) const {
1076 return SDValue();
1077
1078 if (ExtraSteps == ReciprocalEstimate::Unspecified)
1079 ExtraSteps = 0;
1080
1081 SDLoc DL(Operand);
1082 EVT VT = Operand.getValueType();
1083 bool Ftz = useF32FTZ(DAG.getMachineFunction());
1084
1085 auto MakeIntrinsicCall = [&](Intrinsic::ID IID) {
1086 return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
1087 DAG.getConstant(IID, DL, MVT::i32), Operand);
1088 };
1089
1090 // The sqrt and rsqrt refinement processes assume we always start out with an
1091 // approximation of the rsqrt. Therefore, if we're going to do any refinement
1092 // (i.e. ExtraSteps > 0), we must return an rsqrt. But if we're *not* doing
1093 // any refinement, we must return a regular sqrt.
1094 if (Reciprocal || ExtraSteps > 0) {
1095 if (VT == MVT::f32)
1096 return MakeIntrinsicCall(Ftz ? Intrinsic::nvvm_rsqrt_approx_ftz_f
1097 : Intrinsic::nvvm_rsqrt_approx_f);
1098 else if (VT == MVT::f64)
1099 return MakeIntrinsicCall(Intrinsic::nvvm_rsqrt_approx_d);
1100 else
1101 return SDValue();
1102 } else {
1103 if (VT == MVT::f32)
1104 return MakeIntrinsicCall(Ftz ? Intrinsic::nvvm_sqrt_approx_ftz_f
1105 : Intrinsic::nvvm_sqrt_approx_f);
1106 else {
1107 // There's no sqrt.approx.f64 instruction, so we emit
1108 // reciprocal(rsqrt(x)). This is faster than
1109 // select(x == 0, 0, x * rsqrt(x)). (In fact, it's faster than plain
1110 // x * rsqrt(x).)
1111 return DAG.getNode(
1113 DAG.getConstant(Intrinsic::nvvm_rcp_approx_ftz_d, DL, MVT::i32),
1114 MakeIntrinsicCall(Intrinsic::nvvm_rsqrt_approx_d));
1115 }
1116 }
1117}
1118
1119SDValue
1121 SDLoc dl(Op);
1122 const GlobalAddressSDNode *GAN = cast<GlobalAddressSDNode>(Op);
1123 auto PtrVT = getPointerTy(DAG.getDataLayout(), GAN->getAddressSpace());
1124 Op = DAG.getTargetGlobalAddress(GAN->getGlobal(), dl, PtrVT);
1125 return DAG.getNode(NVPTXISD::Wrapper, dl, PtrVT, Op);
1126}
1127
1128static bool IsTypePassedAsArray(const Type *Ty) {
1129 return Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128) ||
1130 Ty->isHalfTy() || Ty->isBFloatTy();
1131}
1132
1134 const DataLayout &DL, Type *retTy, const ArgListTy &Args,
1135 const SmallVectorImpl<ISD::OutputArg> &Outs, MaybeAlign retAlignment,
1136 std::optional<std::pair<unsigned, const APInt &>> VAInfo,
1137 const CallBase &CB, unsigned UniqueCallSite) const {
1138 auto PtrVT = getPointerTy(DL);
1139
1140 bool isABI = (STI.getSmVersion() >= 20);
1141 assert(isABI && "Non-ABI compilation is not supported");
1142 if (!isABI)
1143 return "";
1144
1145 std::string Prototype;
1146 raw_string_ostream O(Prototype);
1147 O << "prototype_" << UniqueCallSite << " : .callprototype ";
1148
1149 if (retTy->getTypeID() == Type::VoidTyID) {
1150 O << "()";
1151 } else {
1152 O << "(";
1153 if ((retTy->isFloatingPointTy() || retTy->isIntegerTy()) &&
1154 !IsTypePassedAsArray(retTy)) {
1155 unsigned size = 0;
1156 if (auto *ITy = dyn_cast<IntegerType>(retTy)) {
1157 size = ITy->getBitWidth();
1158 } else {
1159 assert(retTy->isFloatingPointTy() &&
1160 "Floating point type expected here");
1161 size = retTy->getPrimitiveSizeInBits();
1162 }
1163 // PTX ABI requires all scalar return values to be at least 32
1164 // bits in size. fp16 normally uses .b16 as its storage type in
1165 // PTX, so its size must be adjusted here, too.
1167
1168 O << ".param .b" << size << " _";
1169 } else if (isa<PointerType>(retTy)) {
1170 O << ".param .b" << PtrVT.getSizeInBits() << " _";
1171 } else if (IsTypePassedAsArray(retTy)) {
1172 O << ".param .align " << (retAlignment ? retAlignment->value() : 0)
1173 << " .b8 _[" << DL.getTypeAllocSize(retTy) << "]";
1174 } else {
1175 llvm_unreachable("Unknown return type");
1176 }
1177 O << ") ";
1178 }
1179 O << "_ (";
1180
1181 bool first = true;
1182
1183 unsigned NumArgs = VAInfo ? VAInfo->first : Args.size();
1184 for (unsigned i = 0, OIdx = 0; i != NumArgs; ++i, ++OIdx) {
1185 Type *Ty = Args[i].Ty;
1186 if (!first) {
1187 O << ", ";
1188 }
1189 first = false;
1190
1191 if (!Outs[OIdx].Flags.isByVal()) {
1192 if (IsTypePassedAsArray(Ty)) {
1193 Align ParamAlign =
1194 getArgumentAlignment(&CB, Ty, i + AttributeList::FirstArgIndex, DL);
1195 O << ".param .align " << ParamAlign.value() << " .b8 ";
1196 O << "_";
1197 O << "[" << DL.getTypeAllocSize(Ty) << "]";
1198 // update the index for Outs
1199 SmallVector<EVT, 16> vtparts;
1200 ComputeValueVTs(*this, DL, Ty, vtparts);
1201 if (unsigned len = vtparts.size())
1202 OIdx += len - 1;
1203 continue;
1204 }
1205 // i8 types in IR will be i16 types in SDAG
1206 assert((getValueType(DL, Ty) == Outs[OIdx].VT ||
1207 (getValueType(DL, Ty) == MVT::i8 && Outs[OIdx].VT == MVT::i16)) &&
1208 "type mismatch between callee prototype and arguments");
1209 // scalar type
1210 unsigned sz = 0;
1211 if (isa<IntegerType>(Ty)) {
1212 sz = cast<IntegerType>(Ty)->getBitWidth();
1214 } else if (isa<PointerType>(Ty)) {
1215 sz = PtrVT.getSizeInBits();
1216 } else {
1217 sz = Ty->getPrimitiveSizeInBits();
1218 }
1219 O << ".param .b" << sz << " ";
1220 O << "_";
1221 continue;
1222 }
1223
1224 // Indirect calls need strict ABI alignment so we disable optimizations by
1225 // not providing a function to optimize.
1226 Type *ETy = Args[i].IndirectType;
1227 Align InitialAlign = Outs[OIdx].Flags.getNonZeroByValAlign();
1228 Align ParamByValAlign =
1229 getFunctionByValParamAlign(/*F=*/nullptr, ETy, InitialAlign, DL);
1230
1231 O << ".param .align " << ParamByValAlign.value() << " .b8 ";
1232 O << "_";
1233 O << "[" << Outs[OIdx].Flags.getByValSize() << "]";
1234 }
1235
1236 if (VAInfo)
1237 O << (first ? "" : ",") << " .param .align " << VAInfo->second
1238 << " .b8 _[]\n";
1239 O << ")";
1241 O << " .noreturn";
1242 O << ";";
1243
1244 return Prototype;
1245}
1246
1248 const Function *F, Type *Ty, unsigned Idx, const DataLayout &DL) const {
1249 return getAlign(*F, Idx).value_or(getFunctionParamOptimizedAlign(F, Ty, DL));
1250}
1251
1252Align NVPTXTargetLowering::getArgumentAlignment(const CallBase *CB, Type *Ty,
1253 unsigned Idx,
1254 const DataLayout &DL) const {
1255 if (!CB) {
1256 // CallSite is zero, fallback to ABI type alignment
1257 return DL.getABITypeAlign(Ty);
1258 }
1259
1260 const Function *DirectCallee = CB->getCalledFunction();
1261
1262 if (!DirectCallee) {
1263 // We don't have a direct function symbol, but that may be because of
1264 // constant cast instructions in the call.
1265
1266 // With bitcast'd call targets, the instruction will be the call
1267 if (const auto *CI = dyn_cast<CallInst>(CB)) {
1268 // Check if we have call alignment metadata
1269 if (MaybeAlign StackAlign = getAlign(*CI, Idx))
1270 return StackAlign.value();
1271 }
1272 DirectCallee = getMaybeBitcastedCallee(CB);
1273 }
1274
1275 // Check for function alignment information if we found that the
1276 // ultimate target is a Function
1277 if (DirectCallee)
1278 return getFunctionArgumentAlignment(DirectCallee, Ty, Idx, DL);
1279
1280 // Call is indirect, fall back to the ABI type alignment
1281 return DL.getABITypeAlign(Ty);
1282}
1283
1284static bool adjustElementType(EVT &ElementType) {
1285 switch (ElementType.getSimpleVT().SimpleTy) {
1286 default:
1287 return false;
1288 case MVT::f16:
1289 case MVT::bf16:
1290 ElementType = MVT::i16;
1291 return true;
1292 case MVT::f32:
1293 case MVT::v2f16:
1294 case MVT::v2bf16:
1295 ElementType = MVT::i32;
1296 return true;
1297 case MVT::f64:
1298 ElementType = MVT::i64;
1299 return true;
1300 }
1301}
1302
1303// Use byte-store when the param address of the argument value is unaligned.
1304// This may happen when the return value is a field of a packed structure.
1305//
1306// This is called in LowerCall() when passing the param values.
1308 uint64_t Offset, EVT ElementType,
1309 SDValue StVal, SDValue &InGlue,
1310 unsigned ArgID, const SDLoc &dl) {
1311 // Bit logic only works on integer types
1312 if (adjustElementType(ElementType))
1313 StVal = DAG.getNode(ISD::BITCAST, dl, ElementType, StVal);
1314
1315 // Store each byte
1316 SDVTList StoreVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1317 for (unsigned i = 0, n = ElementType.getSizeInBits() / 8; i < n; i++) {
1318 // Shift the byte to the last byte position
1319 SDValue ShiftVal = DAG.getNode(ISD::SRL, dl, ElementType, StVal,
1320 DAG.getConstant(i * 8, dl, MVT::i32));
1321 SDValue StoreOperands[] = {Chain, DAG.getConstant(ArgID, dl, MVT::i32),
1322 DAG.getConstant(Offset + i, dl, MVT::i32),
1323 ShiftVal, InGlue};
1324 // Trunc store only the last byte by using
1325 // st.param.b8
1326 // The register type can be larger than b8.
1327 Chain = DAG.getMemIntrinsicNode(
1328 NVPTXISD::StoreParam, dl, StoreVTs, StoreOperands, MVT::i8,
1330 InGlue = Chain.getValue(1);
1331 }
1332 return Chain;
1333}
1334
1335// Use byte-load when the param adress of the returned value is unaligned.
1336// This may happen when the returned value is a field of a packed structure.
1337static SDValue
1339 EVT ElementType, SDValue &InGlue,
1340 SmallVectorImpl<SDValue> &TempProxyRegOps,
1341 const SDLoc &dl) {
1342 // Bit logic only works on integer types
1343 EVT MergedType = ElementType;
1344 adjustElementType(MergedType);
1345
1346 // Load each byte and construct the whole value. Initial value to 0
1347 SDValue RetVal = DAG.getConstant(0, dl, MergedType);
1348 // LoadParamMemI8 loads into i16 register only
1349 SDVTList LoadVTs = DAG.getVTList(MVT::i16, MVT::Other, MVT::Glue);
1350 for (unsigned i = 0, n = ElementType.getSizeInBits() / 8; i < n; i++) {
1351 SDValue LoadOperands[] = {Chain, DAG.getConstant(1, dl, MVT::i32),
1352 DAG.getConstant(Offset + i, dl, MVT::i32),
1353 InGlue};
1354 // This will be selected to LoadParamMemI8
1355 SDValue LdVal =
1356 DAG.getMemIntrinsicNode(NVPTXISD::LoadParam, dl, LoadVTs, LoadOperands,
1357 MVT::i8, MachinePointerInfo(), Align(1));
1358 SDValue TmpLdVal = LdVal.getValue(0);
1359 Chain = LdVal.getValue(1);
1360 InGlue = LdVal.getValue(2);
1361
1362 TmpLdVal = DAG.getNode(NVPTXISD::ProxyReg, dl,
1363 TmpLdVal.getSimpleValueType(), TmpLdVal);
1364 TempProxyRegOps.push_back(TmpLdVal);
1365
1366 SDValue CMask = DAG.getConstant(255, dl, MergedType);
1367 SDValue CShift = DAG.getConstant(i * 8, dl, MVT::i32);
1368 // Need to extend the i16 register to the whole width.
1369 TmpLdVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MergedType, TmpLdVal);
1370 // Mask off the high bits. Leave only the lower 8bits.
1371 // Do this because we are using loadparam.b8.
1372 TmpLdVal = DAG.getNode(ISD::AND, dl, MergedType, TmpLdVal, CMask);
1373 // Shift and merge
1374 TmpLdVal = DAG.getNode(ISD::SHL, dl, MergedType, TmpLdVal, CShift);
1375 RetVal = DAG.getNode(ISD::OR, dl, MergedType, RetVal, TmpLdVal);
1376 }
1377 if (ElementType != MergedType)
1378 RetVal = DAG.getNode(ISD::BITCAST, dl, ElementType, RetVal);
1379
1380 return RetVal;
1381}
1382
1384 const GlobalAddressSDNode *Func) {
1385 if (!Func)
1386 return false;
1387 if (auto *CalleeFunc = dyn_cast<Function>(Func->getGlobal()))
1388 return CB->getFunctionType() != CalleeFunc->getFunctionType();
1389 return false;
1390}
1391
1393 SmallVectorImpl<SDValue> &InVals) const {
1394
1395 if (CLI.IsVarArg && (STI.getPTXVersion() < 60 || STI.getSmVersion() < 30))
1397 "Support for variadic functions (unsized array parameter) introduced "
1398 "in PTX ISA version 6.0 and requires target sm_30.");
1399
1400 SelectionDAG &DAG = CLI.DAG;
1401 SDLoc dl = CLI.DL;
1403 SmallVectorImpl<SDValue> &OutVals = CLI.OutVals;
1405 SDValue Chain = CLI.Chain;
1406 SDValue Callee = CLI.Callee;
1407 bool &isTailCall = CLI.IsTailCall;
1408 ArgListTy &Args = CLI.getArgs();
1409 Type *RetTy = CLI.RetTy;
1410 const CallBase *CB = CLI.CB;
1411 const DataLayout &DL = DAG.getDataLayout();
1412
1413 bool isABI = (STI.getSmVersion() >= 20);
1414 assert(isABI && "Non-ABI compilation is not supported");
1415 if (!isABI)
1416 return Chain;
1417
1418 // Variadic arguments.
1419 //
1420 // Normally, for each argument, we declare a param scalar or a param
1421 // byte array in the .param space, and store the argument value to that
1422 // param scalar or array starting at offset 0.
1423 //
1424 // In the case of the first variadic argument, we declare a vararg byte array
1425 // with size 0. The exact size of this array isn't known at this point, so
1426 // it'll be patched later. All the variadic arguments will be stored to this
1427 // array at a certain offset (which gets tracked by 'VAOffset'). The offset is
1428 // initially set to 0, so it can be used for non-variadic arguments (which use
1429 // 0 offset) to simplify the code.
1430 //
1431 // After all vararg is processed, 'VAOffset' holds the size of the
1432 // vararg byte array.
1433
1434 SDValue VADeclareParam; // vararg byte array
1435 unsigned FirstVAArg = CLI.NumFixedArgs; // position of the first variadic
1436 unsigned VAOffset = 0; // current offset in the param array
1437
1438 unsigned UniqueCallSite = GlobalUniqueCallSite.fetch_add(1);
1439 SDValue TempChain = Chain;
1440 Chain = DAG.getCALLSEQ_START(Chain, UniqueCallSite, 0, dl);
1441 SDValue InGlue = Chain.getValue(1);
1442
1443 unsigned ParamCount = 0;
1444 // Args.size() and Outs.size() need not match.
1445 // Outs.size() will be larger
1446 // * if there is an aggregate argument with multiple fields (each field
1447 // showing up separately in Outs)
1448 // * if there is a vector argument with more than typical vector-length
1449 // elements (generally if more than 4) where each vector element is
1450 // individually present in Outs.
1451 // So a different index should be used for indexing into Outs/OutVals.
1452 // See similar issue in LowerFormalArguments.
1453 unsigned OIdx = 0;
1454 // Declare the .params or .reg need to pass values
1455 // to the function
1456 for (unsigned i = 0, e = Args.size(); i != e; ++i, ++OIdx) {
1457 EVT VT = Outs[OIdx].VT;
1458 Type *Ty = Args[i].Ty;
1459 bool IsVAArg = (i >= CLI.NumFixedArgs);
1460 bool IsByVal = Outs[OIdx].Flags.isByVal();
1461
1464
1465 assert((!IsByVal || Args[i].IndirectType) &&
1466 "byval arg must have indirect type");
1467 Type *ETy = (IsByVal ? Args[i].IndirectType : Ty);
1468 ComputePTXValueVTs(*this, DL, ETy, VTs, &Offsets, IsByVal ? 0 : VAOffset);
1469
1470 Align ArgAlign;
1471 if (IsByVal) {
1472 // The ByValAlign in the Outs[OIdx].Flags is always set at this point,
1473 // so we don't need to worry whether it's naturally aligned or not.
1474 // See TargetLowering::LowerCallTo().
1475 Align InitialAlign = Outs[OIdx].Flags.getNonZeroByValAlign();
1476 ArgAlign = getFunctionByValParamAlign(CB->getCalledFunction(), ETy,
1477 InitialAlign, DL);
1478 if (IsVAArg)
1479 VAOffset = alignTo(VAOffset, ArgAlign);
1480 } else {
1481 ArgAlign = getArgumentAlignment(CB, Ty, ParamCount + 1, DL);
1482 }
1483
1484 unsigned TypeSize =
1485 (IsByVal ? Outs[OIdx].Flags.getByValSize() : DL.getTypeAllocSize(Ty));
1486 SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1487
1488 bool NeedAlign; // Does argument declaration specify alignment?
1489 bool PassAsArray = IsByVal || IsTypePassedAsArray(Ty);
1490 if (IsVAArg) {
1491 if (ParamCount == FirstVAArg) {
1492 SDValue DeclareParamOps[] = {
1493 Chain, DAG.getConstant(STI.getMaxRequiredAlignment(), dl, MVT::i32),
1494 DAG.getConstant(ParamCount, dl, MVT::i32),
1495 DAG.getConstant(1, dl, MVT::i32), InGlue};
1496 VADeclareParam = Chain = DAG.getNode(NVPTXISD::DeclareParam, dl,
1497 DeclareParamVTs, DeclareParamOps);
1498 }
1499 NeedAlign = PassAsArray;
1500 } else if (PassAsArray) {
1501 // declare .param .align <align> .b8 .param<n>[<size>];
1502 SDValue DeclareParamOps[] = {
1503 Chain, DAG.getConstant(ArgAlign.value(), dl, MVT::i32),
1504 DAG.getConstant(ParamCount, dl, MVT::i32),
1505 DAG.getConstant(TypeSize, dl, MVT::i32), InGlue};
1506 Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
1507 DeclareParamOps);
1508 NeedAlign = true;
1509 } else {
1510 // declare .param .b<size> .param<n>;
1511 if (VT.isInteger() || VT.isFloatingPoint()) {
1512 // PTX ABI requires integral types to be at least 32 bits in
1513 // size. FP16 is loaded/stored using i16, so it's handled
1514 // here as well.
1516 }
1517 SDValue DeclareScalarParamOps[] = {
1518 Chain, DAG.getConstant(ParamCount, dl, MVT::i32),
1519 DAG.getConstant(TypeSize * 8, dl, MVT::i32),
1520 DAG.getConstant(0, dl, MVT::i32), InGlue};
1521 Chain = DAG.getNode(NVPTXISD::DeclareScalarParam, dl, DeclareParamVTs,
1522 DeclareScalarParamOps);
1523 NeedAlign = false;
1524 }
1525 InGlue = Chain.getValue(1);
1526
1527 // PTX Interoperability Guide 3.3(A): [Integer] Values shorter
1528 // than 32-bits are sign extended or zero extended, depending on
1529 // whether they are signed or unsigned types. This case applies
1530 // only to scalar parameters and not to aggregate values.
1531 bool ExtendIntegerParam =
1532 Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Ty) < 32;
1533
1534 auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign, IsVAArg);
1535 SmallVector<SDValue, 6> StoreOperands;
1536 for (unsigned j = 0, je = VTs.size(); j != je; ++j) {
1537 EVT EltVT = VTs[j];
1538 int CurOffset = Offsets[j];
1539 MaybeAlign PartAlign;
1540 if (NeedAlign)
1541 PartAlign = commonAlignment(ArgAlign, CurOffset);
1542
1543 SDValue StVal = OutVals[OIdx];
1544
1545 MVT PromotedVT;
1546 if (PromoteScalarIntegerPTX(EltVT, &PromotedVT)) {
1547 EltVT = EVT(PromotedVT);
1548 }
1549 if (PromoteScalarIntegerPTX(StVal.getValueType(), &PromotedVT)) {
1551 Outs[OIdx].Flags.isSExt() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
1552 StVal = DAG.getNode(Ext, dl, PromotedVT, StVal);
1553 }
1554
1555 if (IsByVal) {
1556 auto PtrVT = getPointerTy(DL);
1557 SDValue srcAddr = DAG.getNode(ISD::ADD, dl, PtrVT, StVal,
1558 DAG.getConstant(CurOffset, dl, PtrVT));
1559 StVal = DAG.getLoad(EltVT, dl, TempChain, srcAddr, MachinePointerInfo(),
1560 PartAlign);
1561 } else if (ExtendIntegerParam) {
1562 assert(VTs.size() == 1 && "Scalar can't have multiple parts.");
1563 // zext/sext to i32
1564 StVal = DAG.getNode(Outs[OIdx].Flags.isSExt() ? ISD::SIGN_EXTEND
1566 dl, MVT::i32, StVal);
1567 }
1568
1569 if (!ExtendIntegerParam && EltVT.getSizeInBits() < 16) {
1570 // Use 16-bit registers for small stores as it's the
1571 // smallest general purpose register size supported by NVPTX.
1572 StVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, StVal);
1573 }
1574
1575 // If we have a PVF_SCALAR entry, it may not be sufficiently aligned for a
1576 // scalar store. In such cases, fall back to byte stores.
1577 if (VectorInfo[j] == PVF_SCALAR && !IsVAArg && PartAlign.has_value() &&
1578 PartAlign.value() <
1579 DL.getABITypeAlign(EltVT.getTypeForEVT(*DAG.getContext()))) {
1580 assert(StoreOperands.empty() && "Unfinished preceeding store.");
1582 DAG, Chain, IsByVal ? CurOffset + VAOffset : CurOffset, EltVT,
1583 StVal, InGlue, ParamCount, dl);
1584
1585 // LowerUnalignedStoreParam took care of inserting the necessary nodes
1586 // into the SDAG, so just move on to the next element.
1587 if (!IsByVal)
1588 ++OIdx;
1589 continue;
1590 }
1591
1592 // New store.
1593 if (VectorInfo[j] & PVF_FIRST) {
1594 assert(StoreOperands.empty() && "Unfinished preceding store.");
1595 StoreOperands.push_back(Chain);
1596 StoreOperands.push_back(
1597 DAG.getConstant(IsVAArg ? FirstVAArg : ParamCount, dl, MVT::i32));
1598
1599 StoreOperands.push_back(DAG.getConstant(
1600 IsByVal ? CurOffset + VAOffset : (IsVAArg ? VAOffset : CurOffset),
1601 dl, MVT::i32));
1602 }
1603
1604 // Record the value to store.
1605 StoreOperands.push_back(StVal);
1606
1607 if (VectorInfo[j] & PVF_LAST) {
1608 unsigned NumElts = StoreOperands.size() - 3;
1610 switch (NumElts) {
1611 case 1:
1613 break;
1614 case 2:
1616 break;
1617 case 4:
1619 break;
1620 default:
1621 llvm_unreachable("Invalid vector info.");
1622 }
1623
1624 StoreOperands.push_back(InGlue);
1625
1626 // Adjust type of the store op if we've extended the scalar
1627 // return value.
1628 EVT TheStoreType = ExtendIntegerParam ? MVT::i32 : EltVT;
1629
1630 Chain = DAG.getMemIntrinsicNode(
1631 Op, dl, DAG.getVTList(MVT::Other, MVT::Glue), StoreOperands,
1632 TheStoreType, MachinePointerInfo(), PartAlign,
1634 InGlue = Chain.getValue(1);
1635
1636 // Cleanup.
1637 StoreOperands.clear();
1638
1639 // TODO: We may need to support vector types that can be passed
1640 // as scalars in variadic arguments.
1641 if (!IsByVal && IsVAArg) {
1642 assert(NumElts == 1 &&
1643 "Vectorization is expected to be disabled for variadics.");
1644 VAOffset += DL.getTypeAllocSize(
1645 TheStoreType.getTypeForEVT(*DAG.getContext()));
1646 }
1647 }
1648 if (!IsByVal)
1649 ++OIdx;
1650 }
1651 assert(StoreOperands.empty() && "Unfinished parameter store.");
1652 if (!IsByVal && VTs.size() > 0)
1653 --OIdx;
1654 ++ParamCount;
1655 if (IsByVal && IsVAArg)
1656 VAOffset += TypeSize;
1657 }
1658
1659 GlobalAddressSDNode *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode());
1660 MaybeAlign retAlignment = std::nullopt;
1661
1662 // Handle Result
1663 if (Ins.size() > 0) {
1664 SmallVector<EVT, 16> resvtparts;
1665 ComputeValueVTs(*this, DL, RetTy, resvtparts);
1666
1667 // Declare
1668 // .param .align N .b8 retval0[<size-in-bytes>], or
1669 // .param .b<size-in-bits> retval0
1670 unsigned resultsz = DL.getTypeAllocSizeInBits(RetTy);
1671 if (!IsTypePassedAsArray(RetTy)) {
1672 resultsz = promoteScalarArgumentSize(resultsz);
1673 SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1674 SDValue DeclareRetOps[] = { Chain, DAG.getConstant(1, dl, MVT::i32),
1675 DAG.getConstant(resultsz, dl, MVT::i32),
1676 DAG.getConstant(0, dl, MVT::i32), InGlue };
1677 Chain = DAG.getNode(NVPTXISD::DeclareRet, dl, DeclareRetVTs,
1678 DeclareRetOps);
1679 InGlue = Chain.getValue(1);
1680 } else {
1681 retAlignment = getArgumentAlignment(CB, RetTy, 0, DL);
1682 assert(retAlignment && "retAlignment is guaranteed to be set");
1683 SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1684 SDValue DeclareRetOps[] = {
1685 Chain, DAG.getConstant(retAlignment->value(), dl, MVT::i32),
1686 DAG.getConstant(resultsz / 8, dl, MVT::i32),
1687 DAG.getConstant(0, dl, MVT::i32), InGlue};
1688 Chain = DAG.getNode(NVPTXISD::DeclareRetParam, dl, DeclareRetVTs,
1689 DeclareRetOps);
1690 InGlue = Chain.getValue(1);
1691 }
1692 }
1693
1694 bool HasVAArgs = CLI.IsVarArg && (CLI.Args.size() > CLI.NumFixedArgs);
1695 // Set the size of the vararg param byte array if the callee is a variadic
1696 // function and the variadic part is not empty.
1697 if (HasVAArgs) {
1698 SDValue DeclareParamOps[] = {
1699 VADeclareParam.getOperand(0), VADeclareParam.getOperand(1),
1700 VADeclareParam.getOperand(2), DAG.getConstant(VAOffset, dl, MVT::i32),
1701 VADeclareParam.getOperand(4)};
1702 DAG.MorphNodeTo(VADeclareParam.getNode(), VADeclareParam.getOpcode(),
1703 VADeclareParam->getVTList(), DeclareParamOps);
1704 }
1705
1706 // If the type of the callsite does not match that of the function, convert
1707 // the callsite to an indirect call.
1708 bool ConvertToIndirectCall = shouldConvertToIndirectCall(CB, Func);
1709
1710 // Both indirect calls and libcalls have nullptr Func. In order to distinguish
1711 // between them we must rely on the call site value which is valid for
1712 // indirect calls but is always null for libcalls.
1713 bool isIndirectCall = (!Func && CB) || ConvertToIndirectCall;
1714
1715 if (isa<ExternalSymbolSDNode>(Callee)) {
1716 Function* CalleeFunc = nullptr;
1717
1718 // Try to find the callee in the current module.
1719 Callee = DAG.getSymbolFunctionGlobalAddress(Callee, &CalleeFunc);
1720 assert(CalleeFunc != nullptr && "Libcall callee must be set.");
1721
1722 // Set the "libcall callee" attribute to indicate that the function
1723 // must always have a declaration.
1724 CalleeFunc->addFnAttr("nvptx-libcall-callee", "true");
1725 }
1726
1727 if (isIndirectCall) {
1728 // This is indirect function call case : PTX requires a prototype of the
1729 // form
1730 // proto_0 : .callprototype(.param .b32 _) _ (.param .b32 _);
1731 // to be emitted, and the label has to used as the last arg of call
1732 // instruction.
1733 // The prototype is embedded in a string and put as the operand for a
1734 // CallPrototype SDNode which will print out to the value of the string.
1735 SDVTList ProtoVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1736 std::string Proto = getPrototype(
1737 DL, RetTy, Args, Outs, retAlignment,
1738 HasVAArgs
1739 ? std::optional<std::pair<unsigned, const APInt &>>(std::make_pair(
1740 CLI.NumFixedArgs, VADeclareParam->getConstantOperandAPInt(1)))
1741 : std::nullopt,
1742 *CB, UniqueCallSite);
1743 const char *ProtoStr = nvTM->getStrPool().save(Proto).data();
1744 SDValue ProtoOps[] = {
1745 Chain,
1746 DAG.getTargetExternalSymbol(ProtoStr, MVT::i32),
1747 InGlue,
1748 };
1749 Chain = DAG.getNode(NVPTXISD::CallPrototype, dl, ProtoVTs, ProtoOps);
1750 InGlue = Chain.getValue(1);
1751 }
1752 // Op to just print "call"
1753 SDVTList PrintCallVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1754 SDValue PrintCallOps[] = {
1755 Chain, DAG.getConstant((Ins.size() == 0) ? 0 : 1, dl, MVT::i32), InGlue
1756 };
1757 // We model convergent calls as separate opcodes.
1759 if (CLI.IsConvergent)
1762 Chain = DAG.getNode(Opcode, dl, PrintCallVTs, PrintCallOps);
1763 InGlue = Chain.getValue(1);
1764
1765 if (ConvertToIndirectCall) {
1766 // Copy the function ptr to a ptx register and use the register to call the
1767 // function.
1768 EVT DestVT = Callee.getValueType();
1770 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
1771 unsigned DestReg =
1772 RegInfo.createVirtualRegister(TLI.getRegClassFor(DestVT.getSimpleVT()));
1773 auto RegCopy = DAG.getCopyToReg(DAG.getEntryNode(), dl, DestReg, Callee);
1774 Callee = DAG.getCopyFromReg(RegCopy, dl, DestReg, DestVT);
1775 }
1776
1777 // Ops to print out the function name
1778 SDVTList CallVoidVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1779 SDValue CallVoidOps[] = { Chain, Callee, InGlue };
1780 Chain = DAG.getNode(NVPTXISD::CallVoid, dl, CallVoidVTs, CallVoidOps);
1781 InGlue = Chain.getValue(1);
1782
1783 // Ops to print out the param list
1784 SDVTList CallArgBeginVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1785 SDValue CallArgBeginOps[] = { Chain, InGlue };
1786 Chain = DAG.getNode(NVPTXISD::CallArgBegin, dl, CallArgBeginVTs,
1787 CallArgBeginOps);
1788 InGlue = Chain.getValue(1);
1789
1790 for (unsigned i = 0, e = std::min(CLI.NumFixedArgs + 1, ParamCount); i != e;
1791 ++i) {
1792 unsigned opcode;
1793 if (i == (e - 1))
1794 opcode = NVPTXISD::LastCallArg;
1795 else
1796 opcode = NVPTXISD::CallArg;
1797 SDVTList CallArgVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1798 SDValue CallArgOps[] = { Chain, DAG.getConstant(1, dl, MVT::i32),
1799 DAG.getConstant(i, dl, MVT::i32), InGlue };
1800 Chain = DAG.getNode(opcode, dl, CallArgVTs, CallArgOps);
1801 InGlue = Chain.getValue(1);
1802 }
1803 SDVTList CallArgEndVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1804 SDValue CallArgEndOps[] = { Chain,
1805 DAG.getConstant(isIndirectCall ? 0 : 1, dl, MVT::i32),
1806 InGlue };
1807 Chain = DAG.getNode(NVPTXISD::CallArgEnd, dl, CallArgEndVTs, CallArgEndOps);
1808 InGlue = Chain.getValue(1);
1809
1810 if (isIndirectCall) {
1811 SDVTList PrototypeVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1812 SDValue PrototypeOps[] = {
1813 Chain, DAG.getConstant(UniqueCallSite, dl, MVT::i32), InGlue};
1814 Chain = DAG.getNode(NVPTXISD::Prototype, dl, PrototypeVTs, PrototypeOps);
1815 InGlue = Chain.getValue(1);
1816 }
1817
1818 SmallVector<SDValue, 16> ProxyRegOps;
1819 SmallVector<std::optional<MVT>, 16> ProxyRegTruncates;
1820 // An item of the vector is filled if the element does not need a ProxyReg
1821 // operation on it and should be added to InVals as is. ProxyRegOps and
1822 // ProxyRegTruncates contain empty/none items at the same index.
1824 // A temporary ProxyReg operations inserted in `LowerUnalignedLoadRetParam()`
1825 // to use the values of `LoadParam`s and to be replaced later then
1826 // `CALLSEQ_END` is added.
1827 SmallVector<SDValue, 16> TempProxyRegOps;
1828
1829 // Generate loads from param memory/moves from registers for result
1830 if (Ins.size() > 0) {
1833 ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets, 0);
1834 assert(VTs.size() == Ins.size() && "Bad value decomposition");
1835
1836 Align RetAlign = getArgumentAlignment(CB, RetTy, 0, DL);
1837 auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, RetAlign);
1838
1839 SmallVector<EVT, 6> LoadVTs;
1840 int VecIdx = -1; // Index of the first element of the vector.
1841
1842 // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
1843 // 32-bits are sign extended or zero extended, depending on whether
1844 // they are signed or unsigned types.
1845 bool ExtendIntegerRetVal =
1846 RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32;
1847
1848 for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
1849 bool needTruncate = false;
1850 EVT TheLoadType = VTs[i];
1851 EVT EltType = Ins[i].VT;
1852 Align EltAlign = commonAlignment(RetAlign, Offsets[i]);
1853 MVT PromotedVT;
1854
1855 if (PromoteScalarIntegerPTX(TheLoadType, &PromotedVT)) {
1856 TheLoadType = EVT(PromotedVT);
1857 EltType = EVT(PromotedVT);
1858 needTruncate = true;
1859 }
1860
1861 if (ExtendIntegerRetVal) {
1862 TheLoadType = MVT::i32;
1863 EltType = MVT::i32;
1864 needTruncate = true;
1865 } else if (TheLoadType.getSizeInBits() < 16) {
1866 if (VTs[i].isInteger())
1867 needTruncate = true;
1868 EltType = MVT::i16;
1869 }
1870
1871 // If we have a PVF_SCALAR entry, it may not be sufficiently aligned for a
1872 // scalar load. In such cases, fall back to byte loads.
1873 if (VectorInfo[i] == PVF_SCALAR && RetTy->isAggregateType() &&
1874 EltAlign < DL.getABITypeAlign(
1875 TheLoadType.getTypeForEVT(*DAG.getContext()))) {
1876 assert(VecIdx == -1 && LoadVTs.empty() && "Orphaned operand list.");
1878 DAG, Chain, Offsets[i], TheLoadType, InGlue, TempProxyRegOps, dl);
1879 ProxyRegOps.push_back(SDValue());
1880 ProxyRegTruncates.push_back(std::optional<MVT>());
1881 RetElts.resize(i);
1882 RetElts.push_back(Ret);
1883
1884 continue;
1885 }
1886
1887 // Record index of the very first element of the vector.
1888 if (VectorInfo[i] & PVF_FIRST) {
1889 assert(VecIdx == -1 && LoadVTs.empty() && "Orphaned operand list.");
1890 VecIdx = i;
1891 }
1892
1893 LoadVTs.push_back(EltType);
1894
1895 if (VectorInfo[i] & PVF_LAST) {
1896 unsigned NumElts = LoadVTs.size();
1897 LoadVTs.push_back(MVT::Other);
1898 LoadVTs.push_back(MVT::Glue);
1900 switch (NumElts) {
1901 case 1:
1903 break;
1904 case 2:
1906 break;
1907 case 4:
1909 break;
1910 default:
1911 llvm_unreachable("Invalid vector info.");
1912 }
1913
1914 SDValue LoadOperands[] = {
1915 Chain, DAG.getConstant(1, dl, MVT::i32),
1916 DAG.getConstant(Offsets[VecIdx], dl, MVT::i32), InGlue};
1917 SDValue RetVal = DAG.getMemIntrinsicNode(
1918 Op, dl, DAG.getVTList(LoadVTs), LoadOperands, TheLoadType,
1919 MachinePointerInfo(), EltAlign,
1921
1922 for (unsigned j = 0; j < NumElts; ++j) {
1923 ProxyRegOps.push_back(RetVal.getValue(j));
1924
1925 if (needTruncate)
1926 ProxyRegTruncates.push_back(std::optional<MVT>(Ins[VecIdx + j].VT));
1927 else
1928 ProxyRegTruncates.push_back(std::optional<MVT>());
1929 }
1930
1931 Chain = RetVal.getValue(NumElts);
1932 InGlue = RetVal.getValue(NumElts + 1);
1933
1934 // Cleanup
1935 VecIdx = -1;
1936 LoadVTs.clear();
1937 }
1938 }
1939 }
1940
1941 Chain =
1942 DAG.getCALLSEQ_END(Chain, UniqueCallSite, UniqueCallSite + 1, InGlue, dl);
1943 InGlue = Chain.getValue(1);
1944
1945 // Append ProxyReg instructions to the chain to make sure that `callseq_end`
1946 // will not get lost. Otherwise, during libcalls expansion, the nodes can become
1947 // dangling.
1948 for (unsigned i = 0; i < ProxyRegOps.size(); ++i) {
1949 if (i < RetElts.size() && RetElts[i]) {
1950 InVals.push_back(RetElts[i]);
1951 continue;
1952 }
1953
1954 SDValue Ret = DAG.getNode(
1956 DAG.getVTList(ProxyRegOps[i].getSimpleValueType(), MVT::Other, MVT::Glue),
1957 { Chain, ProxyRegOps[i], InGlue }
1958 );
1959
1960 Chain = Ret.getValue(1);
1961 InGlue = Ret.getValue(2);
1962
1963 if (ProxyRegTruncates[i]) {
1964 Ret = DAG.getNode(ISD::TRUNCATE, dl, *ProxyRegTruncates[i], Ret);
1965 }
1966
1967 InVals.push_back(Ret);
1968 }
1969
1970 for (SDValue &T : TempProxyRegOps) {
1971 SDValue Repl = DAG.getNode(
1973 DAG.getVTList(T.getSimpleValueType(), MVT::Other, MVT::Glue),
1974 {Chain, T.getOperand(0), InGlue});
1975 DAG.ReplaceAllUsesWith(T, Repl);
1976 DAG.RemoveDeadNode(T.getNode());
1977
1978 Chain = Repl.getValue(1);
1979 InGlue = Repl.getValue(2);
1980 }
1981
1982 // set isTailCall to false for now, until we figure out how to express
1983 // tail call optimization in PTX
1984 isTailCall = false;
1985 return Chain;
1986}
1987
1989 SelectionDAG &DAG) const {
1990
1991 if (STI.getPTXVersion() < 73 || STI.getSmVersion() < 52) {
1992 const Function &Fn = DAG.getMachineFunction().getFunction();
1993
1994 DiagnosticInfoUnsupported NoDynamicAlloca(
1995 Fn,
1996 "Support for dynamic alloca introduced in PTX ISA version 7.3 and "
1997 "requires target sm_52.",
1998 SDLoc(Op).getDebugLoc());
1999 DAG.getContext()->diagnose(NoDynamicAlloca);
2000 auto Ops = {DAG.getConstant(0, SDLoc(), Op.getValueType()),
2001 Op.getOperand(0)};
2002 return DAG.getMergeValues(Ops, SDLoc());
2003 }
2004
2005 SDValue Chain = Op.getOperand(0);
2006 SDValue Size = Op.getOperand(1);
2007 uint64_t Align = cast<ConstantSDNode>(Op.getOperand(2))->getZExtValue();
2008 SDLoc DL(Op.getNode());
2009
2010 // The size for ptx alloca instruction is 64-bit for m64 and 32-bit for m32.
2011 MVT ValueSizeTy = nvTM->is64Bit() ? MVT::i64 : MVT::i32;
2012
2013 SDValue AllocOps[] = {Chain, DAG.getZExtOrTrunc(Size, DL, ValueSizeTy),
2014 DAG.getTargetConstant(Align, DL, MVT::i32)};
2015 EVT RetTypes[] = {ValueSizeTy, MVT::Other};
2016 return DAG.getNode(NVPTXISD::DYNAMIC_STACKALLOC, DL, RetTypes, AllocOps);
2017}
2018
2020 SelectionDAG &DAG) const {
2021 SDLoc DL(Op.getNode());
2022 if (STI.getPTXVersion() < 73 || STI.getSmVersion() < 52) {
2023 const Function &Fn = DAG.getMachineFunction().getFunction();
2024
2025 DiagnosticInfoUnsupported NoStackRestore(
2026 Fn,
2027 "Support for stackrestore requires PTX ISA version >= 7.3 and target "
2028 ">= sm_52.",
2029 DL.getDebugLoc());
2030 DAG.getContext()->diagnose(NoStackRestore);
2031 return Op.getOperand(0);
2032 }
2033
2034 const MVT LocalVT = getPointerTy(DAG.getDataLayout(), ADDRESS_SPACE_LOCAL);
2035 SDValue Chain = Op.getOperand(0);
2036 SDValue Ptr = Op.getOperand(1);
2039 return DAG.getNode(NVPTXISD::STACKRESTORE, DL, MVT::Other, {Chain, ASC});
2040}
2041
2043 SelectionDAG &DAG) const {
2044 SDLoc DL(Op.getNode());
2045 if (STI.getPTXVersion() < 73 || STI.getSmVersion() < 52) {
2046 const Function &Fn = DAG.getMachineFunction().getFunction();
2047
2048 DiagnosticInfoUnsupported NoStackSave(
2049 Fn,
2050 "Support for stacksave requires PTX ISA version >= 7.3 and target >= "
2051 "sm_52.",
2052 DL.getDebugLoc());
2053 DAG.getContext()->diagnose(NoStackSave);
2054 auto Ops = {DAG.getConstant(0, DL, Op.getValueType()), Op.getOperand(0)};
2055 return DAG.getMergeValues(Ops, DL);
2056 }
2057
2058 const MVT LocalVT = getPointerTy(DAG.getDataLayout(), ADDRESS_SPACE_LOCAL);
2059 SDValue Chain = Op.getOperand(0);
2060 SDValue SS =
2061 DAG.getNode(NVPTXISD::STACKSAVE, DL, {LocalVT, MVT::Other}, Chain);
2062 SDValue ASC = DAG.getAddrSpaceCast(
2063 DL, Op.getValueType(), SS, ADDRESS_SPACE_LOCAL, ADDRESS_SPACE_GENERIC);
2064 return DAG.getMergeValues({ASC, SDValue(SS.getNode(), 1)}, DL);
2065}
2066
2067// By default CONCAT_VECTORS is lowered by ExpandVectorBuildThroughStack()
2068// (see LegalizeDAG.cpp). This is slow and uses local memory.
2069// We use extract/insert/build vector just as what LegalizeOp() does in llvm 2.5
2070SDValue
2071NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
2072 SDNode *Node = Op.getNode();
2073 SDLoc dl(Node);
2075 unsigned NumOperands = Node->getNumOperands();
2076 for (unsigned i = 0; i < NumOperands; ++i) {
2077 SDValue SubOp = Node->getOperand(i);
2078 EVT VVT = SubOp.getNode()->getValueType(0);
2079 EVT EltVT = VVT.getVectorElementType();
2080 unsigned NumSubElem = VVT.getVectorNumElements();
2081 for (unsigned j = 0; j < NumSubElem; ++j) {
2082 Ops.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, SubOp,
2083 DAG.getIntPtrConstant(j, dl)));
2084 }
2085 }
2086 return DAG.getBuildVector(Node->getValueType(0), dl, Ops);
2087}
2088
2089SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
2090 // Handle bitcasting from v2i8 without hitting the default promotion
2091 // strategy which goes through stack memory.
2092 EVT FromVT = Op->getOperand(0)->getValueType(0);
2093 if (FromVT != MVT::v2i8) {
2094 return Op;
2095 }
2096
2097 // Pack vector elements into i16 and bitcast to final type
2098 SDLoc DL(Op);
2099 SDValue Vec0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8,
2100 Op->getOperand(0), DAG.getIntPtrConstant(0, DL));
2101 SDValue Vec1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8,
2102 Op->getOperand(0), DAG.getIntPtrConstant(1, DL));
2103 SDValue Extend0 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i16, Vec0);
2104 SDValue Extend1 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i16, Vec1);
2105 SDValue Const8 = DAG.getConstant(8, DL, MVT::i16);
2106 SDValue AsInt = DAG.getNode(
2107 ISD::OR, DL, MVT::i16,
2108 {Extend0, DAG.getNode(ISD::SHL, DL, MVT::i16, {Extend1, Const8})});
2109 EVT ToVT = Op->getValueType(0);
2110 return MaybeBitcast(DAG, DL, ToVT, AsInt);
2111}
2112
2113// We can init constant f16x2/v2i16/v4i8 with a single .b32 move. Normally it
2114// would get lowered as two constant loads and vector-packing move.
2115// Instead we want just a constant move:
2116// mov.b32 %r2, 0x40003C00
2117SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op,
2118 SelectionDAG &DAG) const {
2119 EVT VT = Op->getValueType(0);
2120 if (!(Isv2x16VT(VT) || VT == MVT::v4i8))
2121 return Op;
2122 SDLoc DL(Op);
2123
2124 if (!llvm::all_of(Op->ops(), [](SDValue Operand) {
2125 return Operand->isUndef() || isa<ConstantSDNode>(Operand) ||
2126 isa<ConstantFPSDNode>(Operand);
2127 })) {
2128 if (VT != MVT::v4i8)
2129 return Op;
2130 // Lower non-const v4i8 vector as byte-wise constructed i32, which allows us
2131 // to optimize calculation of constant parts.
2132 auto GetPRMT = [&](const SDValue Left, const SDValue Right, bool Cast,
2133 uint64_t SelectionValue) -> SDValue {
2134 SDValue L = Left;
2135 SDValue R = Right;
2136 if (Cast) {
2137 L = DAG.getAnyExtOrTrunc(L, DL, MVT::i32);
2138 R = DAG.getAnyExtOrTrunc(R, DL, MVT::i32);
2139 }
2140 return DAG.getNode(
2141 NVPTXISD::PRMT, DL, MVT::v4i8,
2142 {L, R, DAG.getConstant(SelectionValue, DL, MVT::i32),
2143 DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32)});
2144 };
2145 auto PRMT__10 = GetPRMT(Op->getOperand(0), Op->getOperand(1), true, 0x3340);
2146 auto PRMT__32 = GetPRMT(Op->getOperand(2), Op->getOperand(3), true, 0x3340);
2147 auto PRMT3210 = GetPRMT(PRMT__10, PRMT__32, false, 0x5410);
2148 return DAG.getNode(ISD::BITCAST, DL, VT, PRMT3210);
2149 }
2150
2151 // Get value or the Nth operand as an APInt(32). Undef values treated as 0.
2152 auto GetOperand = [](SDValue Op, int N) -> APInt {
2153 const SDValue &Operand = Op->getOperand(N);
2154 EVT VT = Op->getValueType(0);
2155 if (Operand->isUndef())
2156 return APInt(32, 0);
2157 APInt Value;
2158 if (VT == MVT::v2f16 || VT == MVT::v2bf16)
2159 Value = cast<ConstantFPSDNode>(Operand)->getValueAPF().bitcastToAPInt();
2160 else if (VT == MVT::v2i16 || VT == MVT::v4i8)
2161 Value = Operand->getAsAPIntVal();
2162 else
2163 llvm_unreachable("Unsupported type");
2164 // i8 values are carried around as i16, so we need to zero out upper bits,
2165 // so they do not get in the way of combining individual byte values
2166 if (VT == MVT::v4i8)
2167 Value = Value.trunc(8);
2168 return Value.zext(32);
2169 };
2170 APInt Value;
2171 if (Isv2x16VT(VT)) {
2172 Value = GetOperand(Op, 0) | GetOperand(Op, 1).shl(16);
2173 } else if (VT == MVT::v4i8) {
2174 Value = GetOperand(Op, 0) | GetOperand(Op, 1).shl(8) |
2175 GetOperand(Op, 2).shl(16) | GetOperand(Op, 3).shl(24);
2176 } else {
2177 llvm_unreachable("Unsupported type");
2178 }
2179 SDValue Const = DAG.getConstant(Value, DL, MVT::i32);
2180 return DAG.getNode(ISD::BITCAST, DL, Op->getValueType(0), Const);
2181}
2182
2183SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
2184 SelectionDAG &DAG) const {
2185 SDValue Index = Op->getOperand(1);
2186 SDValue Vector = Op->getOperand(0);
2187 SDLoc DL(Op);
2188 EVT VectorVT = Vector.getValueType();
2189
2190 if (VectorVT == MVT::v4i8) {
2191 SDValue BFE =
2192 DAG.getNode(NVPTXISD::BFE, DL, MVT::i32,
2193 {Vector,
2194 DAG.getNode(ISD::MUL, DL, MVT::i32,
2195 DAG.getZExtOrTrunc(Index, DL, MVT::i32),
2196 DAG.getConstant(8, DL, MVT::i32)),
2197 DAG.getConstant(8, DL, MVT::i32)});
2198 return DAG.getAnyExtOrTrunc(BFE, DL, Op->getValueType(0));
2199 }
2200
2201 // Constant index will be matched by tablegen.
2202 if (isa<ConstantSDNode>(Index.getNode()))
2203 return Op;
2204
2205 // Extract individual elements and select one of them.
2206 assert(Isv2x16VT(VectorVT) && "Unexpected vector type.");
2207 EVT EltVT = VectorVT.getVectorElementType();
2208
2209 SDLoc dl(Op.getNode());
2210 SDValue E0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, Vector,
2211 DAG.getIntPtrConstant(0, dl));
2212 SDValue E1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, Vector,
2213 DAG.getIntPtrConstant(1, dl));
2214 return DAG.getSelectCC(dl, Index, DAG.getIntPtrConstant(0, dl), E0, E1,
2216}
2217
2218SDValue NVPTXTargetLowering::LowerINSERT_VECTOR_ELT(SDValue Op,
2219 SelectionDAG &DAG) const {
2220 SDValue Vector = Op->getOperand(0);
2221 EVT VectorVT = Vector.getValueType();
2222
2223 if (VectorVT != MVT::v4i8)
2224 return Op;
2225 SDLoc DL(Op);
2226 SDValue Value = Op->getOperand(1);
2227 if (Value->isUndef())
2228 return Vector;
2229
2230 SDValue Index = Op->getOperand(2);
2231
2232 SDValue BFI =
2233 DAG.getNode(NVPTXISD::BFI, DL, MVT::i32,
2234 {DAG.getZExtOrTrunc(Value, DL, MVT::i32), Vector,
2235 DAG.getNode(ISD::MUL, DL, MVT::i32,
2236 DAG.getZExtOrTrunc(Index, DL, MVT::i32),
2237 DAG.getConstant(8, DL, MVT::i32)),
2238 DAG.getConstant(8, DL, MVT::i32)});
2239 return DAG.getNode(ISD::BITCAST, DL, Op->getValueType(0), BFI);
2240}
2241
2242SDValue NVPTXTargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
2243 SelectionDAG &DAG) const {
2244 SDValue V1 = Op.getOperand(0);
2245 EVT VectorVT = V1.getValueType();
2246 if (VectorVT != MVT::v4i8 || Op.getValueType() != MVT::v4i8)
2247 return Op;
2248
2249 // Lower shuffle to PRMT instruction.
2250 const ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Op.getNode());
2251 SDValue V2 = Op.getOperand(1);
2252 uint32_t Selector = 0;
2253 for (auto I : llvm::enumerate(SVN->getMask())) {
2254 if (I.value() != -1) // -1 is a placeholder for undef.
2255 Selector |= (I.value() << (I.index() * 4));
2256 }
2257
2258 SDLoc DL(Op);
2259 return DAG.getNode(NVPTXISD::PRMT, DL, MVT::v4i8, V1, V2,
2260 DAG.getConstant(Selector, DL, MVT::i32),
2261 DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32));
2262}
2263/// LowerShiftRightParts - Lower SRL_PARTS, SRA_PARTS, which
2264/// 1) returns two i32 values and take a 2 x i32 value to shift plus a shift
2265/// amount, or
2266/// 2) returns two i64 values and take a 2 x i64 value to shift plus a shift
2267/// amount.
2268SDValue NVPTXTargetLowering::LowerShiftRightParts(SDValue Op,
2269 SelectionDAG &DAG) const {
2270 assert(Op.getNumOperands() == 3 && "Not a double-shift!");
2271 assert(Op.getOpcode() == ISD::SRA_PARTS || Op.getOpcode() == ISD::SRL_PARTS);
2272
2273 EVT VT = Op.getValueType();
2274 unsigned VTBits = VT.getSizeInBits();
2275 SDLoc dl(Op);
2276 SDValue ShOpLo = Op.getOperand(0);
2277 SDValue ShOpHi = Op.getOperand(1);
2278 SDValue ShAmt = Op.getOperand(2);
2279 unsigned Opc = (Op.getOpcode() == ISD::SRA_PARTS) ? ISD::SRA : ISD::SRL;
2280
2281 if (VTBits == 32 && STI.getSmVersion() >= 35) {
2282 // For 32bit and sm35, we can use the funnel shift 'shf' instruction.
2283 // {dHi, dLo} = {aHi, aLo} >> Amt
2284 // dHi = aHi >> Amt
2285 // dLo = shf.r.clamp aLo, aHi, Amt
2286
2287 SDValue Hi = DAG.getNode(Opc, dl, VT, ShOpHi, ShAmt);
2288 SDValue Lo =
2289 DAG.getNode(NVPTXISD::FSHR_CLAMP, dl, VT, ShOpHi, ShOpLo, ShAmt);
2290
2291 SDValue Ops[2] = { Lo, Hi };
2292 return DAG.getMergeValues(Ops, dl);
2293 }
2294 else {
2295 // {dHi, dLo} = {aHi, aLo} >> Amt
2296 // - if (Amt>=size) then
2297 // dLo = aHi >> (Amt-size)
2298 // dHi = aHi >> Amt (this is either all 0 or all 1)
2299 // else
2300 // dLo = (aLo >>logic Amt) | (aHi << (size-Amt))
2301 // dHi = aHi >> Amt
2302
2303 SDValue RevShAmt = DAG.getNode(ISD::SUB, dl, MVT::i32,
2304 DAG.getConstant(VTBits, dl, MVT::i32),
2305 ShAmt);
2306 SDValue Tmp1 = DAG.getNode(ISD::SRL, dl, VT, ShOpLo, ShAmt);
2307 SDValue ExtraShAmt = DAG.getNode(ISD::SUB, dl, MVT::i32, ShAmt,
2308 DAG.getConstant(VTBits, dl, MVT::i32));
2309 SDValue Tmp2 = DAG.getNode(ISD::SHL, dl, VT, ShOpHi, RevShAmt);
2310 SDValue FalseVal = DAG.getNode(ISD::OR, dl, VT, Tmp1, Tmp2);
2311 SDValue TrueVal = DAG.getNode(Opc, dl, VT, ShOpHi, ExtraShAmt);
2312
2313 SDValue Cmp = DAG.getSetCC(dl, MVT::i1, ShAmt,
2314 DAG.getConstant(VTBits, dl, MVT::i32),
2315 ISD::SETGE);
2316 SDValue Hi = DAG.getNode(Opc, dl, VT, ShOpHi, ShAmt);
2317 SDValue Lo = DAG.getNode(ISD::SELECT, dl, VT, Cmp, TrueVal, FalseVal);
2318
2319 SDValue Ops[2] = { Lo, Hi };
2320 return DAG.getMergeValues(Ops, dl);
2321 }
2322}
2323
2324/// LowerShiftLeftParts - Lower SHL_PARTS, which
2325/// 1) returns two i32 values and take a 2 x i32 value to shift plus a shift
2326/// amount, or
2327/// 2) returns two i64 values and take a 2 x i64 value to shift plus a shift
2328/// amount.
2329SDValue NVPTXTargetLowering::LowerShiftLeftParts(SDValue Op,
2330 SelectionDAG &DAG) const {
2331 assert(Op.getNumOperands() == 3 && "Not a double-shift!");
2332 assert(Op.getOpcode() == ISD::SHL_PARTS);
2333
2334 EVT VT = Op.getValueType();
2335 unsigned VTBits = VT.getSizeInBits();
2336 SDLoc dl(Op);
2337 SDValue ShOpLo = Op.getOperand(0);
2338 SDValue ShOpHi = Op.getOperand(1);
2339 SDValue ShAmt = Op.getOperand(2);
2340
2341 if (VTBits == 32 && STI.getSmVersion() >= 35) {
2342 // For 32bit and sm35, we can use the funnel shift 'shf' instruction.
2343 // {dHi, dLo} = {aHi, aLo} << Amt
2344 // dHi = shf.l.clamp aLo, aHi, Amt
2345 // dLo = aLo << Amt
2346
2347 SDValue Hi =
2348 DAG.getNode(NVPTXISD::FSHL_CLAMP, dl, VT, ShOpHi, ShOpLo, ShAmt);
2349 SDValue Lo = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ShAmt);
2350
2351 SDValue Ops[2] = { Lo, Hi };
2352 return DAG.getMergeValues(Ops, dl);
2353 }
2354 else {
2355 // {dHi, dLo} = {aHi, aLo} << Amt
2356 // - if (Amt>=size) then
2357 // dLo = aLo << Amt (all 0)
2358 // dLo = aLo << (Amt-size)
2359 // else
2360 // dLo = aLo << Amt
2361 // dHi = (aHi << Amt) | (aLo >> (size-Amt))
2362
2363 SDValue RevShAmt = DAG.getNode(ISD::SUB, dl, MVT::i32,
2364 DAG.getConstant(VTBits, dl, MVT::i32),
2365 ShAmt);
2366 SDValue Tmp1 = DAG.getNode(ISD::SHL, dl, VT, ShOpHi, ShAmt);
2367 SDValue ExtraShAmt = DAG.getNode(ISD::SUB, dl, MVT::i32, ShAmt,
2368 DAG.getConstant(VTBits, dl, MVT::i32));
2369 SDValue Tmp2 = DAG.getNode(ISD::SRL, dl, VT, ShOpLo, RevShAmt);
2370 SDValue FalseVal = DAG.getNode(ISD::OR, dl, VT, Tmp1, Tmp2);
2371 SDValue TrueVal = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ExtraShAmt);
2372
2373 SDValue Cmp = DAG.getSetCC(dl, MVT::i1, ShAmt,
2374 DAG.getConstant(VTBits, dl, MVT::i32),
2375 ISD::SETGE);
2376 SDValue Lo = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ShAmt);
2377 SDValue Hi = DAG.getNode(ISD::SELECT, dl, VT, Cmp, TrueVal, FalseVal);
2378
2379 SDValue Ops[2] = { Lo, Hi };
2380 return DAG.getMergeValues(Ops, dl);
2381 }
2382}
2383
2384/// If the types match, convert the generic copysign to the NVPTXISD version,
2385/// otherwise bail ensuring that mismatched cases are properly expaned.
2386SDValue NVPTXTargetLowering::LowerFCOPYSIGN(SDValue Op,
2387 SelectionDAG &DAG) const {
2388 EVT VT = Op.getValueType();
2389 SDLoc DL(Op);
2390
2391 SDValue In1 = Op.getOperand(0);
2392 SDValue In2 = Op.getOperand(1);
2393 EVT SrcVT = In2.getValueType();
2394
2395 if (!SrcVT.bitsEq(VT))
2396 return SDValue();
2397
2398 return DAG.getNode(NVPTXISD::FCOPYSIGN, DL, VT, In1, In2);
2399}
2400
2401SDValue NVPTXTargetLowering::LowerFROUND(SDValue Op, SelectionDAG &DAG) const {
2402 EVT VT = Op.getValueType();
2403
2404 if (VT == MVT::f32)
2405 return LowerFROUND32(Op, DAG);
2406
2407 if (VT == MVT::f64)
2408 return LowerFROUND64(Op, DAG);
2409
2410 llvm_unreachable("unhandled type");
2411}
2412
2413// This is the the rounding method used in CUDA libdevice in C like code:
2414// float roundf(float A)
2415// {
2416// float RoundedA = (float) (int) ( A > 0 ? (A + 0.5f) : (A - 0.5f));
2417// RoundedA = abs(A) > 0x1.0p23 ? A : RoundedA;
2418// return abs(A) < 0.5 ? (float)(int)A : RoundedA;
2419// }
2420SDValue NVPTXTargetLowering::LowerFROUND32(SDValue Op,
2421 SelectionDAG &DAG) const {
2422 SDLoc SL(Op);
2423 SDValue A = Op.getOperand(0);
2424 EVT VT = Op.getValueType();
2425
2426 SDValue AbsA = DAG.getNode(ISD::FABS, SL, VT, A);
2427
2428 // RoundedA = (float) (int) ( A > 0 ? (A + 0.5f) : (A - 0.5f))
2429 SDValue Bitcast = DAG.getNode(ISD::BITCAST, SL, MVT::i32, A);
2430 const unsigned SignBitMask = 0x80000000;
2431 SDValue Sign = DAG.getNode(ISD::AND, SL, MVT::i32, Bitcast,
2432 DAG.getConstant(SignBitMask, SL, MVT::i32));
2433 const unsigned PointFiveInBits = 0x3F000000;
2434 SDValue PointFiveWithSignRaw =
2435 DAG.getNode(ISD::OR, SL, MVT::i32, Sign,
2436 DAG.getConstant(PointFiveInBits, SL, MVT::i32));
2437 SDValue PointFiveWithSign =
2438 DAG.getNode(ISD::BITCAST, SL, VT, PointFiveWithSignRaw);
2439 SDValue AdjustedA = DAG.getNode(ISD::FADD, SL, VT, A, PointFiveWithSign);
2440 SDValue RoundedA = DAG.getNode(ISD::FTRUNC, SL, VT, AdjustedA);
2441
2442 // RoundedA = abs(A) > 0x1.0p23 ? A : RoundedA;
2443 EVT SetCCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
2444 SDValue IsLarge =
2445 DAG.getSetCC(SL, SetCCVT, AbsA, DAG.getConstantFP(pow(2.0, 23.0), SL, VT),
2446 ISD::SETOGT);
2447 RoundedA = DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA);
2448
2449 // return abs(A) < 0.5 ? (float)(int)A : RoundedA;
2450 SDValue IsSmall =DAG.getSetCC(SL, SetCCVT, AbsA,
2451 DAG.getConstantFP(0.5, SL, VT), ISD::SETOLT);
2452 SDValue RoundedAForSmallA = DAG.getNode(ISD::FTRUNC, SL, VT, A);
2453 return DAG.getNode(ISD::SELECT, SL, VT, IsSmall, RoundedAForSmallA, RoundedA);
2454}
2455
2456// The implementation of round(double) is similar to that of round(float) in
2457// that they both separate the value range into three regions and use a method
2458// specific to the region to round the values. However, round(double) first
2459// calculates the round of the absolute value and then adds the sign back while
2460// round(float) directly rounds the value with sign.
2461SDValue NVPTXTargetLowering::LowerFROUND64(SDValue Op,
2462 SelectionDAG &DAG) const {
2463 SDLoc SL(Op);
2464 SDValue A = Op.getOperand(0);
2465 EVT VT = Op.getValueType();
2466
2467 SDValue AbsA = DAG.getNode(ISD::FABS, SL, VT, A);
2468
2469 // double RoundedA = (double) (int) (abs(A) + 0.5f);
2470 SDValue AdjustedA = DAG.getNode(ISD::FADD, SL, VT, AbsA,
2471 DAG.getConstantFP(0.5, SL, VT));
2472 SDValue RoundedA = DAG.getNode(ISD::FTRUNC, SL, VT, AdjustedA);
2473
2474 // RoundedA = abs(A) < 0.5 ? (double)0 : RoundedA;
2475 EVT SetCCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
2476 SDValue IsSmall =DAG.getSetCC(SL, SetCCVT, AbsA,
2477 DAG.getConstantFP(0.5, SL, VT), ISD::SETOLT);
2478 RoundedA = DAG.getNode(ISD::SELECT, SL, VT, IsSmall,
2479 DAG.getConstantFP(0, SL, VT),
2480 RoundedA);
2481
2482 // Add sign to rounded_A
2483 RoundedA = DAG.getNode(ISD::FCOPYSIGN, SL, VT, RoundedA, A);
2484 DAG.getNode(ISD::FTRUNC, SL, VT, A);
2485
2486 // RoundedA = abs(A) > 0x1.0p52 ? A : RoundedA;
2487 SDValue IsLarge =
2488 DAG.getSetCC(SL, SetCCVT, AbsA, DAG.getConstantFP(pow(2.0, 52.0), SL, VT),
2489 ISD::SETOGT);
2490 return DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA);
2491}
2492
2493SDValue NVPTXTargetLowering::LowerINT_TO_FP(SDValue Op,
2494 SelectionDAG &DAG) const {
2495 assert(STI.getSmVersion() < 90 || STI.getPTXVersion() < 78);
2496
2497 if (Op.getValueType() == MVT::bf16) {
2498 SDLoc Loc(Op);
2499 return DAG.getNode(
2500 ISD::FP_ROUND, Loc, MVT::bf16,
2501 DAG.getNode(Op.getOpcode(), Loc, MVT::f32, Op.getOperand(0)),
2502 DAG.getIntPtrConstant(0, Loc, /*isTarget=*/true));
2503 }
2504
2505 // Everything else is considered legal.
2506 return Op;
2507}
2508
2509SDValue NVPTXTargetLowering::LowerFP_TO_INT(SDValue Op,
2510 SelectionDAG &DAG) const {
2511 assert(STI.getSmVersion() < 90 || STI.getPTXVersion() < 78);
2512
2513 if (Op.getOperand(0).getValueType() == MVT::bf16) {
2514 SDLoc Loc(Op);
2515 return DAG.getNode(
2516 Op.getOpcode(), Loc, Op.getValueType(),
2517 DAG.getNode(ISD::FP_EXTEND, Loc, MVT::f32, Op.getOperand(0)));
2518 }
2519
2520 // Everything else is considered legal.
2521 return Op;
2522}
2523
2524SDValue NVPTXTargetLowering::LowerFP_ROUND(SDValue Op,
2525 SelectionDAG &DAG) const {
2526 EVT NarrowVT = Op.getValueType();
2527 SDValue Wide = Op.getOperand(0);
2528 EVT WideVT = Wide.getValueType();
2529 if (NarrowVT.getScalarType() == MVT::bf16) {
2530 const TargetLowering *TLI = STI.getTargetLowering();
2531 if (STI.getSmVersion() < 80 || STI.getPTXVersion() < 70) {
2532 return TLI->expandFP_ROUND(Op.getNode(), DAG);
2533 }
2534 if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
2535 // This combination was the first to support f32 -> bf16.
2536 if (STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70) {
2537 if (WideVT.getScalarType() == MVT::f32) {
2538 return Op;
2539 }
2540 if (WideVT.getScalarType() == MVT::f64) {
2541 SDLoc Loc(Op);
2542 // Round-inexact-to-odd f64 to f32, then do the final rounding using
2543 // the hardware f32 -> bf16 instruction.
2545 WideVT.isVector() ? WideVT.changeVectorElementType(MVT::f32)
2546 : MVT::f32,
2547 Wide, Loc, DAG);
2548 return DAG.getFPExtendOrRound(rod, Loc, NarrowVT);
2549 }
2550 }
2551 return TLI->expandFP_ROUND(Op.getNode(), DAG);
2552 }
2553 }
2554
2555 // Everything else is considered legal.
2556 return Op;
2557}
2558
2559SDValue NVPTXTargetLowering::LowerFP_EXTEND(SDValue Op,
2560 SelectionDAG &DAG) const {
2561 SDValue Narrow = Op.getOperand(0);
2562 EVT NarrowVT = Narrow.getValueType();
2563 EVT WideVT = Op.getValueType();
2564 if (NarrowVT.getScalarType() == MVT::bf16) {
2565 if (WideVT.getScalarType() == MVT::f32 &&
2566 (STI.getSmVersion() < 80 || STI.getPTXVersion() < 71)) {
2567 SDLoc Loc(Op);
2568 return DAG.getNode(ISD::BF16_TO_FP, Loc, WideVT, Narrow);
2569 }
2570 if (WideVT.getScalarType() == MVT::f64 &&
2571 (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78)) {
2572 EVT F32 = NarrowVT.isVector() ? NarrowVT.changeVectorElementType(MVT::f32)
2573 : MVT::f32;
2574 SDLoc Loc(Op);
2575 if (STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 71) {
2576 Op = DAG.getNode(ISD::FP_EXTEND, Loc, F32, Narrow);
2577 } else {
2578 Op = DAG.getNode(ISD::BF16_TO_FP, Loc, F32, Narrow);
2579 }
2580 return DAG.getNode(ISD::FP_EXTEND, Loc, WideVT, Op);
2581 }
2582 }
2583
2584 // Everything else is considered legal.
2585 return Op;
2586}
2587
2589 SDLoc DL(Op);
2590 if (Op.getValueType() != MVT::v2i16)
2591 return Op;
2592 EVT EltVT = Op.getValueType().getVectorElementType();
2593 SmallVector<SDValue> VecElements;
2594 for (int I = 0, E = Op.getValueType().getVectorNumElements(); I < E; I++) {
2595 SmallVector<SDValue> ScalarArgs;
2596 llvm::transform(Op->ops(), std::back_inserter(ScalarArgs),
2597 [&](const SDUse &O) {
2598 return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT,
2599 O.get(), DAG.getIntPtrConstant(I, DL));
2600 });
2601 VecElements.push_back(DAG.getNode(Op.getOpcode(), DL, EltVT, ScalarArgs));
2602 }
2603 SDValue V =
2604 DAG.getNode(ISD::BUILD_VECTOR, DL, Op.getValueType(), VecElements);
2605 return V;
2606}
2607
2608SDValue
2610 switch (Op.getOpcode()) {
2611 case ISD::RETURNADDR:
2612 return SDValue();
2613 case ISD::FRAMEADDR:
2614 return SDValue();
2615 case ISD::GlobalAddress:
2616 return LowerGlobalAddress(Op, DAG);
2618 return Op;
2619 case ISD::BUILD_VECTOR:
2620 return LowerBUILD_VECTOR(Op, DAG);
2621 case ISD::BITCAST:
2622 return LowerBITCAST(Op, DAG);
2624 return Op;
2626 return LowerEXTRACT_VECTOR_ELT(Op, DAG);
2628 return LowerINSERT_VECTOR_ELT(Op, DAG);
2630 return LowerVECTOR_SHUFFLE(Op, DAG);
2632 return LowerCONCAT_VECTORS(Op, DAG);
2633 case ISD::STORE:
2634 return LowerSTORE(Op, DAG);
2635 case ISD::LOAD:
2636 return LowerLOAD(Op, DAG);
2637 case ISD::SHL_PARTS:
2638 return LowerShiftLeftParts(Op, DAG);
2639 case ISD::SRA_PARTS:
2640 case ISD::SRL_PARTS:
2641 return LowerShiftRightParts(Op, DAG);
2642 case ISD::SELECT:
2643 return LowerSelect(Op, DAG);
2644 case ISD::FROUND:
2645 return LowerFROUND(Op, DAG);
2646 case ISD::FCOPYSIGN:
2647 return LowerFCOPYSIGN(Op, DAG);
2648 case ISD::SINT_TO_FP:
2649 case ISD::UINT_TO_FP:
2650 return LowerINT_TO_FP(Op, DAG);
2651 case ISD::FP_TO_SINT:
2652 case ISD::FP_TO_UINT:
2653 return LowerFP_TO_INT(Op, DAG);
2654 case ISD::FP_ROUND:
2655 return LowerFP_ROUND(Op, DAG);
2656 case ISD::FP_EXTEND:
2657 return LowerFP_EXTEND(Op, DAG);
2658 case ISD::BR_JT:
2659 return LowerBR_JT(Op, DAG);
2660 case ISD::VAARG:
2661 return LowerVAARG(Op, DAG);
2662 case ISD::VASTART:
2663 return LowerVASTART(Op, DAG);
2664 case ISD::ABS:
2665 case ISD::SMIN:
2666 case ISD::SMAX:
2667 case ISD::UMIN:
2668 case ISD::UMAX:
2669 case ISD::ADD:
2670 case ISD::SUB:
2671 case ISD::MUL:
2672 case ISD::SHL:
2673 case ISD::SREM:
2674 case ISD::UREM:
2675 return LowerVectorArith(Op, DAG);
2677 return LowerDYNAMIC_STACKALLOC(Op, DAG);
2678 case ISD::STACKRESTORE:
2679 return LowerSTACKRESTORE(Op, DAG);
2680 case ISD::STACKSAVE:
2681 return LowerSTACKSAVE(Op, DAG);
2682 case ISD::CopyToReg:
2683 return LowerCopyToReg_128(Op, DAG);
2684 default:
2685 llvm_unreachable("Custom lowering not defined for operation");
2686 }
2687}
2688
2689SDValue NVPTXTargetLowering::LowerBR_JT(SDValue Op, SelectionDAG &DAG) const {
2690 SDLoc DL(Op);
2691 SDValue Chain = Op.getOperand(0);
2692 const auto *JT = cast<JumpTableSDNode>(Op.getOperand(1));
2693 SDValue Index = Op.getOperand(2);
2694
2695 unsigned JId = JT->getIndex();
2697 ArrayRef<MachineBasicBlock *> MBBs = MJTI->getJumpTables()[JId].MBBs;
2698
2699 SDValue IdV = DAG.getConstant(JId, DL, MVT::i32);
2700
2701 // Generate BrxStart node
2702 SDVTList VTs = DAG.getVTList(MVT::Other, MVT::Glue);
2703 Chain = DAG.getNode(NVPTXISD::BrxStart, DL, VTs, Chain, IdV);
2704
2705 // Generate BrxItem nodes
2706 assert(!MBBs.empty());
2707 for (MachineBasicBlock *MBB : MBBs.drop_back())
2708 Chain = DAG.getNode(NVPTXISD::BrxItem, DL, VTs, Chain.getValue(0),
2709 DAG.getBasicBlock(MBB), Chain.getValue(1));
2710
2711 // Generate BrxEnd nodes
2712 SDValue EndOps[] = {Chain.getValue(0), DAG.getBasicBlock(MBBs.back()), Index,
2713 IdV, Chain.getValue(1)};
2714 SDValue BrxEnd = DAG.getNode(NVPTXISD::BrxEnd, DL, VTs, EndOps);
2715
2716 return BrxEnd;
2717}
2718
2719// This will prevent AsmPrinter from trying to print the jump tables itself.
2722}
2723
2724// This function is almost a copy of SelectionDAG::expandVAArg().
2725// The only diff is that this one produces loads from local address space.
2726SDValue NVPTXTargetLowering::LowerVAARG(SDValue Op, SelectionDAG &DAG) const {
2727 const TargetLowering *TLI = STI.getTargetLowering();
2728 SDLoc DL(Op);
2729
2730 SDNode *Node = Op.getNode();
2731 const Value *V = cast<SrcValueSDNode>(Node->getOperand(2))->getValue();
2732 EVT VT = Node->getValueType(0);
2733 auto *Ty = VT.getTypeForEVT(*DAG.getContext());
2734 SDValue Tmp1 = Node->getOperand(0);
2735 SDValue Tmp2 = Node->getOperand(1);
2736 const MaybeAlign MA(Node->getConstantOperandVal(3));
2737
2738 SDValue VAListLoad = DAG.getLoad(TLI->getPointerTy(DAG.getDataLayout()), DL,
2739 Tmp1, Tmp2, MachinePointerInfo(V));
2740 SDValue VAList = VAListLoad;
2741
2742 if (MA && *MA > TLI->getMinStackArgumentAlignment()) {
2743 VAList = DAG.getNode(
2744 ISD::ADD, DL, VAList.getValueType(), VAList,
2745 DAG.getConstant(MA->value() - 1, DL, VAList.getValueType()));
2746
2747 VAList = DAG.getNode(ISD::AND, DL, VAList.getValueType(), VAList,
2748 DAG.getSignedConstant(-(int64_t)MA->value(), DL,
2749 VAList.getValueType()));
2750 }
2751
2752 // Increment the pointer, VAList, to the next vaarg
2753 Tmp1 = DAG.getNode(ISD::ADD, DL, VAList.getValueType(), VAList,
2755 DL, VAList.getValueType()));
2756
2757 // Store the incremented VAList to the legalized pointer
2758 Tmp1 = DAG.getStore(VAListLoad.getValue(1), DL, Tmp1, Tmp2,
2760
2761 const Value *SrcV =
2763
2764 // Load the actual argument out of the pointer VAList
2765 return DAG.getLoad(VT, DL, Tmp1, VAList, MachinePointerInfo(SrcV));
2766}
2767
2768SDValue NVPTXTargetLowering::LowerVASTART(SDValue Op, SelectionDAG &DAG) const {
2769 const TargetLowering *TLI = STI.getTargetLowering();
2770 SDLoc DL(Op);
2771 EVT PtrVT = TLI->getPointerTy(DAG.getDataLayout());
2772
2773 // Store the address of unsized array <function>_vararg[] in the ap object.
2774 SDValue Arg = getParamSymbol(DAG, /* vararg */ -1, PtrVT);
2775 SDValue VAReg = DAG.getNode(NVPTXISD::Wrapper, DL, PtrVT, Arg);
2776
2777 const Value *SV = cast<SrcValueSDNode>(Op.getOperand(2))->getValue();
2778 return DAG.getStore(Op.getOperand(0), DL, VAReg, Op.getOperand(1),
2779 MachinePointerInfo(SV));
2780}
2781
2782SDValue NVPTXTargetLowering::LowerSelect(SDValue Op, SelectionDAG &DAG) const {
2783 SDValue Op0 = Op->getOperand(0);
2784 SDValue Op1 = Op->getOperand(1);
2785 SDValue Op2 = Op->getOperand(2);
2786 SDLoc DL(Op.getNode());
2787
2788 assert(Op.getValueType() == MVT::i1 && "Custom lowering enabled only for i1");
2789
2790 Op1 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Op1);
2791 Op2 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Op2);
2792 SDValue Select = DAG.getNode(ISD::SELECT, DL, MVT::i32, Op0, Op1, Op2);
2793 SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, Select);
2794
2795 return Trunc;
2796}
2797
2798SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
2799 if (Op.getValueType() == MVT::i1)
2800 return LowerLOADi1(Op, DAG);
2801
2802 // v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to handle
2803 // unaligned loads and have to handle it here.
2804 EVT VT = Op.getValueType();
2805 if (Isv2x16VT(VT) || VT == MVT::v4i8) {
2806 LoadSDNode *Load = cast<LoadSDNode>(Op);
2807 EVT MemVT = Load->getMemoryVT();
2809 MemVT, *Load->getMemOperand())) {
2810 SDValue Ops[2];
2811 std::tie(Ops[0], Ops[1]) = expandUnalignedLoad(Load, DAG);
2812 return DAG.getMergeValues(Ops, SDLoc(Op));
2813 }
2814 }
2815
2816 return SDValue();
2817}
2818
2819// v = ld i1* addr
2820// =>
2821// v1 = ld i8* addr (-> i16)
2822// v = trunc i16 to i1
2823SDValue NVPTXTargetLowering::LowerLOADi1(SDValue Op, SelectionDAG &DAG) const {
2824 SDNode *Node = Op.getNode();
2825 LoadSDNode *LD = cast<LoadSDNode>(Node);
2826 SDLoc dl(Node);
2827 assert(LD->getExtensionType() == ISD::NON_EXTLOAD);
2828 assert(Node->getValueType(0) == MVT::i1 &&
2829 "Custom lowering for i1 load only");
2830 SDValue newLD = DAG.getExtLoad(ISD::ZEXTLOAD, dl, MVT::i16, LD->getChain(),
2831 LD->getBasePtr(), LD->getPointerInfo(),
2832 MVT::i8, LD->getAlign(),
2833 LD->getMemOperand()->getFlags());
2834 SDValue result = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, newLD);
2835 // The legalizer (the caller) is expecting two values from the legalized
2836 // load, so we build a MergeValues node for it. See ExpandUnalignedLoad()
2837 // in LegalizeDAG.cpp which also uses MergeValues.
2838 SDValue Ops[] = { result, LD->getChain() };
2839 return DAG.getMergeValues(Ops, dl);
2840}
2841
2842SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
2843 StoreSDNode *Store = cast<StoreSDNode>(Op);
2844 EVT VT = Store->getMemoryVT();
2845
2846 if (VT == MVT::i1)
2847 return LowerSTOREi1(Op, DAG);
2848
2849 // v2f16 is legal, so we can't rely on legalizer to handle unaligned
2850 // stores and have to handle it here.
2851 if ((Isv2x16VT(VT) || VT == MVT::v4i8) &&
2853 VT, *Store->getMemOperand()))
2854 return expandUnalignedStore(Store, DAG);
2855
2856 // v2f16, v2bf16 and v2i16 don't need special handling.
2857 if (Isv2x16VT(VT) || VT == MVT::v4i8)
2858 return SDValue();
2859
2860 if (VT.isVector())
2861 return LowerSTOREVector(Op, DAG);
2862
2863 return SDValue();
2864}
2865
2866SDValue
2867NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
2868 SDNode *N = Op.getNode();
2869 SDValue Val = N->getOperand(1);
2870 SDLoc DL(N);
2871 EVT ValVT = Val.getValueType();
2872
2873 auto NumEltsAndEltVT = getVectorLoweringShape(ValVT);
2874 if (!NumEltsAndEltVT)
2875 return SDValue();
2876 auto [NumElts, EltVT] = NumEltsAndEltVT.value();
2877
2878 MemSDNode *MemSD = cast<MemSDNode>(N);
2879 const DataLayout &TD = DAG.getDataLayout();
2880
2881 Align Alignment = MemSD->getAlign();
2882 Align PrefAlign = TD.getPrefTypeAlign(ValVT.getTypeForEVT(*DAG.getContext()));
2883 if (Alignment < PrefAlign) {
2884 // This store is not sufficiently aligned, so bail out and let this vector
2885 // store be scalarized. Note that we may still be able to emit smaller
2886 // vector stores. For example, if we are storing a <4 x float> with an
2887 // alignment of 8, this check will fail but the legalizer will try again
2888 // with 2 x <2 x float>, which will succeed with an alignment of 8.
2889 return SDValue();
2890 }
2891
2892 // Since StoreV2 is a target node, we cannot rely on DAG type legalization.
2893 // Therefore, we must ensure the type is legal. For i1 and i8, we set the
2894 // stored type to i16 and propagate the "real" type as the memory type.
2895 bool NeedExt = false;
2896 if (EltVT.getSizeInBits() < 16)
2897 NeedExt = true;
2898
2899 unsigned Opcode = 0;
2900 switch (NumElts) {
2901 default:
2902 return SDValue();
2903 case 2:
2904 Opcode = NVPTXISD::StoreV2;
2905 break;
2906 case 4:
2907 Opcode = NVPTXISD::StoreV4;
2908 break;
2909 }
2910
2912
2913 // First is the chain
2914 Ops.push_back(N->getOperand(0));
2915
2916 // Then the split values
2917 assert(NumElts <= ValVT.getVectorNumElements() &&
2918 "NumElts should not increase, only decrease or stay the same.");
2919 if (NumElts < ValVT.getVectorNumElements()) {
2920 // If the number of elements has decreased, getVectorLoweringShape has
2921 // upsized the element types
2922 assert(EltVT.isVector() && EltVT.getSizeInBits() == 32 &&
2923 EltVT.getVectorNumElements() <= 4 && "Unexpected upsized type.");
2924 // Combine individual elements into v2[i,f,bf]16/v4i8 subvectors to be
2925 // stored as b32s
2926 unsigned NumEltsPerSubVector = EltVT.getVectorNumElements();
2927 for (unsigned i = 0; i < NumElts; ++i) {
2928 SmallVector<SDValue, 4> SubVectorElts;
2929 DAG.ExtractVectorElements(Val, SubVectorElts, i * NumEltsPerSubVector,
2930 NumEltsPerSubVector);
2931 SDValue SubVector = DAG.getBuildVector(EltVT, DL, SubVectorElts);
2932 Ops.push_back(SubVector);
2933 }
2934 } else {
2935 for (unsigned i = 0; i < NumElts; ++i) {
2936 SDValue ExtVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
2937 DAG.getIntPtrConstant(i, DL));
2938 if (NeedExt)
2939 ExtVal = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i16, ExtVal);
2940 Ops.push_back(ExtVal);
2941 }
2942 }
2943
2944 // Then any remaining arguments
2945 Ops.append(N->op_begin() + 2, N->op_end());
2946
2947 SDValue NewSt =
2948 DAG.getMemIntrinsicNode(Opcode, DL, DAG.getVTList(MVT::Other), Ops,
2949 MemSD->getMemoryVT(), MemSD->getMemOperand());
2950
2951 // return DCI.CombineTo(N, NewSt, true);
2952 return NewSt;
2953}
2954
2955// st i1 v, addr
2956// =>
2957// v1 = zxt v to i16
2958// st.u8 i16, addr
2959SDValue NVPTXTargetLowering::LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const {
2960 SDNode *Node = Op.getNode();
2961 SDLoc dl(Node);
2962 StoreSDNode *ST = cast<StoreSDNode>(Node);
2963 SDValue Tmp1 = ST->getChain();
2964 SDValue Tmp2 = ST->getBasePtr();
2965 SDValue Tmp3 = ST->getValue();
2966 assert(Tmp3.getValueType() == MVT::i1 && "Custom lowering for i1 store only");
2967 Tmp3 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Tmp3);
2968 SDValue Result =
2969 DAG.getTruncStore(Tmp1, dl, Tmp3, Tmp2, ST->getPointerInfo(), MVT::i8,
2970 ST->getAlign(), ST->getMemOperand()->getFlags());
2971 return Result;
2972}
2973
2974SDValue NVPTXTargetLowering::LowerCopyToReg_128(SDValue Op,
2975 SelectionDAG &DAG) const {
2976 // Change the CopyToReg to take in two 64-bit operands instead of a 128-bit
2977 // operand so that it can pass the legalization.
2978
2979 assert(Op.getOperand(1).getValueType() == MVT::i128 &&
2980 "Custom lowering for 128-bit CopyToReg only");
2981
2982 SDNode *Node = Op.getNode();
2983 SDLoc DL(Node);
2984
2985 SDValue Cast = DAG.getBitcast(MVT::v2i64, Op->getOperand(2));
2986 SDValue Lo = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i64, Cast,
2987 DAG.getIntPtrConstant(0, DL));
2988 SDValue Hi = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i64, Cast,
2989 DAG.getIntPtrConstant(1, DL));
2990
2992 SmallVector<EVT, 3> ResultsType(Node->values());
2993
2994 NewOps[0] = Op->getOperand(0); // Chain
2995 NewOps[1] = Op->getOperand(1); // Dst Reg
2996 NewOps[2] = Lo; // Lower 64-bit
2997 NewOps[3] = Hi; // Higher 64-bit
2998 if (Op.getNumOperands() == 4)
2999 NewOps[4] = Op->getOperand(3); // Glue if exists
3000
3001 return DAG.getNode(ISD::CopyToReg, DL, ResultsType, NewOps);
3002}
3003
3004unsigned NVPTXTargetLowering::getNumRegisters(
3005 LLVMContext &Context, EVT VT,
3006 std::optional<MVT> RegisterVT = std::nullopt) const {
3007 if (VT == MVT::i128 && RegisterVT == MVT::i128)
3008 return 1;
3009 return TargetLoweringBase::getNumRegisters(Context, VT, RegisterVT);
3010}
3011
3012bool NVPTXTargetLowering::splitValueIntoRegisterParts(
3013 SelectionDAG &DAG, const SDLoc &DL, SDValue Val, SDValue *Parts,
3014 unsigned NumParts, MVT PartVT, std::optional<CallingConv::ID> CC) const {
3015 if (Val.getValueType() == MVT::i128 && NumParts == 1) {
3016 Parts[0] = Val;
3017 return true;
3018 }
3019 return false;
3020}
3021
3022// This creates target external symbol for a function parameter.
3023// Name of the symbol is composed from its index and the function name.
3024// Negative index corresponds to special parameter (unsized array) used for
3025// passing variable arguments.
3026SDValue NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, int idx,
3027 EVT v) const {
3028 StringRef SavedStr = nvTM->getStrPool().save(
3030 return DAG.getTargetExternalSymbol(SavedStr.data(), v);
3031}
3032
3034 SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
3035 const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &dl,
3036 SelectionDAG &DAG, SmallVectorImpl<SDValue> &InVals) const {
3038 const DataLayout &DL = DAG.getDataLayout();
3039 auto PtrVT = getPointerTy(DAG.getDataLayout());
3040
3041 const Function *F = &MF.getFunction();
3042 const AttributeList &PAL = F->getAttributes();
3043 const TargetLowering *TLI = STI.getTargetLowering();
3044
3045 SDValue Root = DAG.getRoot();
3046 std::vector<SDValue> OutChains;
3047
3048 bool isABI = (STI.getSmVersion() >= 20);
3049 assert(isABI && "Non-ABI compilation is not supported");
3050 if (!isABI)
3051 return Chain;
3052
3053 std::vector<Type *> argTypes;
3054 std::vector<const Argument *> theArgs;
3055 for (const Argument &I : F->args()) {
3056 theArgs.push_back(&I);
3057 argTypes.push_back(I.getType());
3058 }
3059 // argTypes.size() (or theArgs.size()) and Ins.size() need not match.
3060 // Ins.size() will be larger
3061 // * if there is an aggregate argument with multiple fields (each field
3062 // showing up separately in Ins)
3063 // * if there is a vector argument with more than typical vector-length
3064 // elements (generally if more than 4) where each vector element is
3065 // individually present in Ins.
3066 // So a different index should be used for indexing into Ins.
3067 // See similar issue in LowerCall.
3068 unsigned InsIdx = 0;
3069
3070 for (unsigned i = 0, e = theArgs.size(); i != e; ++i, ++InsIdx) {
3071 Type *Ty = argTypes[i];
3072
3073 if (theArgs[i]->use_empty()) {
3074 // argument is dead
3075 if (IsTypePassedAsArray(Ty) && !Ty->isVectorTy()) {
3076 SmallVector<EVT, 16> vtparts;
3077
3078 ComputePTXValueVTs(*this, DAG.getDataLayout(), Ty, vtparts);
3079 if (vtparts.empty())
3080 report_fatal_error("Empty parameter types are not supported");
3081
3082 for (unsigned parti = 0, parte = vtparts.size(); parti != parte;
3083 ++parti) {
3084 InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
3085 ++InsIdx;
3086 }
3087 if (vtparts.size() > 0)
3088 --InsIdx;
3089 continue;
3090 }
3091 if (Ty->isVectorTy()) {
3092 EVT ObjectVT = getValueType(DL, Ty);
3093 unsigned NumRegs = TLI->getNumRegisters(F->getContext(), ObjectVT);
3094 for (unsigned parti = 0; parti < NumRegs; ++parti) {
3095 InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
3096 ++InsIdx;
3097 }
3098 if (NumRegs > 0)
3099 --InsIdx;
3100 continue;
3101 }
3102 InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
3103 continue;
3104 }
3105
3106 // In the following cases, assign a node order of "i+1"
3107 // to newly created nodes. The SDNodes for params have to
3108 // appear in the same order as their order of appearance
3109 // in the original function. "i+1" holds that order.
3110 if (!PAL.hasParamAttr(i, Attribute::ByVal)) {
3111 bool aggregateIsPacked = false;
3112 if (StructType *STy = dyn_cast<StructType>(Ty))
3113 aggregateIsPacked = STy->isPacked();
3114
3117 ComputePTXValueVTs(*this, DL, Ty, VTs, &Offsets, 0);
3118 if (VTs.empty())
3119 report_fatal_error("Empty parameter types are not supported");
3120
3123 auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign);
3124
3125 SDValue Arg = getParamSymbol(DAG, i, PtrVT);
3126 int VecIdx = -1; // Index of the first element of the current vector.
3127 for (unsigned parti = 0, parte = VTs.size(); parti != parte; ++parti) {
3128 if (VectorInfo[parti] & PVF_FIRST) {
3129 assert(VecIdx == -1 && "Orphaned vector.");
3130 VecIdx = parti;
3131 }
3132
3133 // That's the last element of this store op.
3134 if (VectorInfo[parti] & PVF_LAST) {
3135 unsigned NumElts = parti - VecIdx + 1;
3136 EVT EltVT = VTs[parti];
3137 // i1 is loaded/stored as i8.
3138 EVT LoadVT = EltVT;
3139 if (EltVT == MVT::i1)
3140 LoadVT = MVT::i8;
3141 else if (Isv2x16VT(EltVT) || EltVT == MVT::v4i8)
3142 // getLoad needs a vector type, but it can't handle
3143 // vectors which contain v2f16 or v2bf16 elements. So we must load
3144 // using i32 here and then bitcast back.
3145 LoadVT = MVT::i32;
3146
3147 EVT VecVT = EVT::getVectorVT(F->getContext(), LoadVT, NumElts);
3148 SDValue VecAddr =
3149 DAG.getNode(ISD::ADD, dl, PtrVT, Arg,
3150 DAG.getConstant(Offsets[VecIdx], dl, PtrVT));
3152 EltVT.getTypeForEVT(F->getContext()), ADDRESS_SPACE_PARAM));
3153
3154 const MaybeAlign PartAlign = [&]() -> MaybeAlign {
3155 if (aggregateIsPacked)
3156 return Align(1);
3157 if (NumElts != 1)
3158 return std::nullopt;
3159 Align PartAlign =
3160 DL.getABITypeAlign(EltVT.getTypeForEVT(F->getContext()));
3161 return commonAlignment(PartAlign, Offsets[parti]);
3162 }();
3163 SDValue P = DAG.getLoad(VecVT, dl, Root, VecAddr,
3164 MachinePointerInfo(srcValue), PartAlign,
3167 if (P.getNode())
3168 P.getNode()->setIROrder(i + 1);
3169 for (unsigned j = 0; j < NumElts; ++j) {
3170 SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, LoadVT, P,
3171 DAG.getIntPtrConstant(j, dl));
3172 // We've loaded i1 as an i8 and now must truncate it back to i1
3173 if (EltVT == MVT::i1)
3174 Elt = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Elt);
3175 // v2f16 was loaded as an i32. Now we must bitcast it back.
3176 else if (EltVT != LoadVT)
3177 Elt = DAG.getNode(ISD::BITCAST, dl, EltVT, Elt);
3178
3179 // If a promoted integer type is used, truncate down to the original
3180 MVT PromotedVT;
3181 if (PromoteScalarIntegerPTX(EltVT, &PromotedVT)) {
3182 Elt = DAG.getNode(ISD::TRUNCATE, dl, EltVT, Elt);
3183 }
3184
3185 // Extend the element if necessary (e.g. an i8 is loaded
3186 // into an i16 register)
3187 if (Ins[InsIdx].VT.isInteger() &&
3188 Ins[InsIdx].VT.getFixedSizeInBits() >
3189 LoadVT.getFixedSizeInBits()) {
3190 unsigned Extend = Ins[InsIdx].Flags.isSExt() ? ISD::SIGN_EXTEND
3192 Elt = DAG.getNode(Extend, dl, Ins[InsIdx].VT, Elt);
3193 }
3194 InVals.push_back(Elt);
3195 }
3196
3197 // Reset vector tracking state.
3198 VecIdx = -1;
3199 }
3200 ++InsIdx;
3201 }
3202 if (VTs.size() > 0)
3203 --InsIdx;
3204 continue;
3205 }
3206
3207 // Param has ByVal attribute
3208 // Return MoveParam(param symbol).
3209 // Ideally, the param symbol can be returned directly,
3210 // but when SDNode builder decides to use it in a CopyToReg(),
3211 // machine instruction fails because TargetExternalSymbol
3212 // (not lowered) is target dependent, and CopyToReg assumes
3213 // the source is lowered.
3214 EVT ObjectVT = getValueType(DL, Ty);
3215 assert(ObjectVT == Ins[InsIdx].VT &&
3216 "Ins type did not match function type");
3217 SDValue Arg = getParamSymbol(DAG, i, PtrVT);
3218 SDValue p = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg);
3219 if (p.getNode())
3220 p.getNode()->setIROrder(i + 1);
3221 InVals.push_back(p);
3222 }
3223
3224 if (!OutChains.empty())
3225 DAG.setRoot(DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OutChains));
3226
3227 return Chain;
3228}
3229
3230// Use byte-store when the param adress of the return value is unaligned.
3231// This may happen when the return value is a field of a packed structure.
3233 uint64_t Offset, EVT ElementType,
3234 SDValue RetVal, const SDLoc &dl) {
3235 // Bit logic only works on integer types
3236 if (adjustElementType(ElementType))
3237 RetVal = DAG.getNode(ISD::BITCAST, dl, ElementType, RetVal);
3238
3239 // Store each byte
3240 for (unsigned i = 0, n = ElementType.getSizeInBits() / 8; i < n; i++) {
3241 // Shift the byte to the last byte position
3242 SDValue ShiftVal = DAG.getNode(ISD::SRL, dl, ElementType, RetVal,
3243 DAG.getConstant(i * 8, dl, MVT::i32));
3244 SDValue StoreOperands[] = {Chain, DAG.getConstant(Offset + i, dl, MVT::i32),
3245 ShiftVal};
3246 // Trunc store only the last byte by using
3247 // st.param.b8
3248 // The register type can be larger than b8.
3250 DAG.getVTList(MVT::Other), StoreOperands,
3251 MVT::i8, MachinePointerInfo(), std::nullopt,
3253 }
3254 return Chain;
3255}
3256
3257SDValue
3259 bool isVarArg,
3261 const SmallVectorImpl<SDValue> &OutVals,
3262 const SDLoc &dl, SelectionDAG &DAG) const {
3263 const MachineFunction &MF = DAG.getMachineFunction();
3264 const Function &F = MF.getFunction();
3266
3267 bool isABI = (STI.getSmVersion() >= 20);
3268 assert(isABI && "Non-ABI compilation is not supported");
3269 if (!isABI)
3270 return Chain;
3271
3272 const DataLayout &DL = DAG.getDataLayout();
3273 SmallVector<SDValue, 16> PromotedOutVals;
3276 ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets);
3277 assert(VTs.size() == OutVals.size() && "Bad return value decomposition");
3278
3279 for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
3280 SDValue PromotedOutVal = OutVals[i];
3281 MVT PromotedVT;
3282 if (PromoteScalarIntegerPTX(VTs[i], &PromotedVT)) {
3283 VTs[i] = EVT(PromotedVT);
3284 }
3285 if (PromoteScalarIntegerPTX(PromotedOutVal.getValueType(), &PromotedVT)) {
3287 Outs[i].Flags.isSExt() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
3288 PromotedOutVal = DAG.getNode(Ext, dl, PromotedVT, PromotedOutVal);
3289 }
3290 PromotedOutVals.push_back(PromotedOutVal);
3291 }
3292
3293 auto VectorInfo = VectorizePTXValueVTs(
3294 VTs, Offsets,
3296 : Align(1));
3297
3298 // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
3299 // 32-bits are sign extended or zero extended, depending on whether
3300 // they are signed or unsigned types.
3301 bool ExtendIntegerRetVal =
3302 RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32;
3303
3304 SmallVector<SDValue, 6> StoreOperands;
3305 for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
3306 SDValue OutVal = OutVals[i];
3307 SDValue RetVal = PromotedOutVals[i];
3308
3309 if (ExtendIntegerRetVal) {
3310 RetVal = DAG.getNode(Outs[i].Flags.isSExt() ? ISD::SIGN_EXTEND
3312 dl, MVT::i32, RetVal);
3313 } else if (OutVal.getValueSizeInBits() < 16) {
3314 // Use 16-bit registers for small load-stores as it's the
3315 // smallest general purpose register size supported by NVPTX.
3316 RetVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, RetVal);
3317 }
3318
3319 // If we have a PVF_SCALAR entry, it may not even be sufficiently aligned
3320 // for a scalar store. In such cases, fall back to byte stores.
3321 if (VectorInfo[i] == PVF_SCALAR && RetTy->isAggregateType()) {
3322 EVT ElementType = ExtendIntegerRetVal ? MVT::i32 : VTs[i];
3323 Align ElementTypeAlign =
3324 DL.getABITypeAlign(ElementType.getTypeForEVT(RetTy->getContext()));
3325 Align ElementAlign =
3326 commonAlignment(DL.getABITypeAlign(RetTy), Offsets[i]);
3327 if (ElementAlign < ElementTypeAlign) {
3328 assert(StoreOperands.empty() && "Orphaned operand list.");
3329 Chain = LowerUnalignedStoreRet(DAG, Chain, Offsets[i], ElementType,
3330 RetVal, dl);
3331
3332 // The call to LowerUnalignedStoreRet inserted the necessary SDAG nodes
3333 // into the graph, so just move on to the next element.
3334 continue;
3335 }
3336 }
3337
3338 // New load/store. Record chain and offset operands.
3339 if (VectorInfo[i] & PVF_FIRST) {
3340 assert(StoreOperands.empty() && "Orphaned operand list.");
3341 StoreOperands.push_back(Chain);
3342 StoreOperands.push_back(DAG.getConstant(Offsets[i], dl, MVT::i32));
3343 }
3344
3345 // Record the value to return.
3346 StoreOperands.push_back(RetVal);
3347
3348 // That's the last element of this store op.
3349 if (VectorInfo[i] & PVF_LAST) {
3351 unsigned NumElts = StoreOperands.size() - 2;
3352 switch (NumElts) {
3353 case 1:
3355 break;
3356 case 2:
3358 break;
3359 case 4:
3361 break;
3362 default:
3363 llvm_unreachable("Invalid vector info.");
3364 }
3365
3366 // Adjust type of load/store op if we've extended the scalar
3367 // return value.
3368 EVT TheStoreType = ExtendIntegerRetVal ? MVT::i32 : VTs[i];
3369 Chain = DAG.getMemIntrinsicNode(
3370 Op, dl, DAG.getVTList(MVT::Other), StoreOperands, TheStoreType,
3372 // Cleanup vector state.
3373 StoreOperands.clear();
3374 }
3375 }
3376
3377 return DAG.getNode(NVPTXISD::RET_GLUE, dl, MVT::Other, Chain);
3378}
3379
3381 SDValue Op, StringRef Constraint, std::vector<SDValue> &Ops,
3382 SelectionDAG &DAG) const {
3383 if (Constraint.size() > 1)
3384 return;
3385 TargetLowering::LowerAsmOperandForConstraint(Op, Constraint, Ops, DAG);
3386}
3387
3388// llvm.ptx.memcpy.const and llvm.ptx.memmove.const need to be modeled as
3389// TgtMemIntrinsic
3390// because we need the information that is only available in the "Value" type
3391// of destination
3392// pointer. In particular, the address space information.
3394 IntrinsicInfo &Info, const CallInst &I,
3395 MachineFunction &MF, unsigned Intrinsic) const {
3396 switch (Intrinsic) {
3397 default:
3398 return false;
3399 case Intrinsic::nvvm_match_all_sync_i32p:
3400 case Intrinsic::nvvm_match_all_sync_i64p:
3402 // memVT is bogus. These intrinsics have IntrInaccessibleMemOnly attribute
3403 // in order to model data exchange with other threads, but perform no real
3404 // memory accesses.
3405 Info.memVT = MVT::i1;
3406
3407 // Our result depends on both our and other thread's arguments.
3409 return true;
3410 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_col:
3411 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_row:
3412 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_col_stride:
3413 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_row_stride:
3414 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_col:
3415 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row:
3416 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_col_stride:
3417 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row_stride:
3418 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_col:
3419 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_row:
3420 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_col_stride:
3421 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_row_stride:
3422 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_col:
3423 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_row:
3424 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_col_stride:
3425 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_row_stride:
3426 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_col:
3427 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_row:
3428 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_col_stride:
3429 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_row_stride:
3430 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_col:
3431 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_row:
3432 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_col_stride:
3433 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_row_stride: {
3435 Info.memVT = MVT::v8f16;
3436 Info.ptrVal = I.getArgOperand(0);
3437 Info.offset = 0;
3439 Info.align = Align(16);
3440 return true;
3441 }
3442 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_col:
3443 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_col_stride:
3444 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_col_stride:
3445 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_col:
3446 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_row:
3447 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_row_stride:
3448 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_row_stride:
3449 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_row:
3450 case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_col:
3451 case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_col_stride:
3452 case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_row:
3453 case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_row_stride:
3454 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_col:
3455 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_col_stride:
3456 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_col_stride:
3457 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_col:
3458 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_row:
3459 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_row_stride:
3460 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row_stride:
3461 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row:
3462 case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_col:
3463 case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_col_stride:
3464 case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_row:
3465 case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_row_stride: {
3467 Info.memVT = MVT::v2i32;
3468 Info.ptrVal = I.getArgOperand(0);
3469 Info.offset = 0;
3471 Info.align = Align(8);
3472 return true;
3473 }
3474
3475 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_col:
3476 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_col_stride:
3477 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_col_stride:
3478 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_col:
3479 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_row:
3480 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_row_stride:
3481 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_row_stride:
3482 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_row:
3483 case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_col:
3484 case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_col_stride:
3485 case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_row:
3486 case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_row_stride:
3487 case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_col:
3488 case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_col_stride:
3489 case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_row:
3490 case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_row_stride:
3491
3492 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_col:
3493 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_col_stride:
3494 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_col_stride:
3495 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_col:
3496 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_row:
3497 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_row_stride:
3498 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row_stride:
3499 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row:
3500 case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_col:
3501 case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_col_stride:
3502 case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_row:
3503 case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_row_stride:
3504 case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_col:
3505 case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_col_stride:
3506 case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row:
3507 case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row_stride:
3508 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16:
3509 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16: {
3511 Info.memVT = MVT::v4i32;
3512 Info.ptrVal = I.getArgOperand(0);
3513 Info.offset = 0;
3515 Info.align = Align(16);
3516 return true;
3517 }
3518
3519 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_col:
3520 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_col_stride:
3521 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_col_stride:
3522 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_col:
3523 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_row:
3524 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_row_stride:
3525 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_row_stride:
3526 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_row:
3527
3528 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_col:
3529 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_col_stride:
3530 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_col_stride:
3531 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_col:
3532 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_row:
3533 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_row_stride:
3534 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_row_stride:
3535 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_row:
3536 case Intrinsic::nvvm_wmma_m8n8k128_load_a_b1_row:
3537 case Intrinsic::nvvm_wmma_m8n8k128_load_a_b1_row_stride:
3538 case Intrinsic::nvvm_wmma_m8n8k128_load_b_b1_col:
3539 case Intrinsic::nvvm_wmma_m8n8k128_load_b_b1_col_stride:
3540 case Intrinsic::nvvm_wmma_m8n8k32_load_a_s4_row:
3541 case Intrinsic::nvvm_wmma_m8n8k32_load_a_s4_row_stride:
3542 case Intrinsic::nvvm_wmma_m8n8k32_load_a_u4_row_stride:
3543 case Intrinsic::nvvm_wmma_m8n8k32_load_a_u4_row:
3544 case Intrinsic::nvvm_wmma_m8n8k32_load_b_s4_col:
3545 case Intrinsic::nvvm_wmma_m8n8k32_load_b_s4_col_stride:
3546 case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col_stride:
3547 case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col:
3548 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16:
3549 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16: {
3551 Info.memVT = MVT::i32;
3552 Info.ptrVal = I.getArgOperand(0);
3553 Info.offset = 0;
3555 Info.align = Align(4);
3556 return true;
3557 }
3558
3559 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col:
3560 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row:
3561 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col_stride:
3562 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row_stride:
3563 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_col:
3564 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_row:
3565 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_col_stride:
3566 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_row_stride:
3567 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_col:
3568 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_row:
3569 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_col_stride:
3570 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_row_stride: {
3572 Info.memVT = MVT::v4f16;
3573 Info.ptrVal = I.getArgOperand(0);
3574 Info.offset = 0;
3576 Info.align = Align(16);
3577 return true;
3578 }
3579
3580 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_col:
3581 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row:
3582 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_col_stride:
3583 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row_stride:
3584 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_col:
3585 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_row:
3586 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_col_stride:
3587 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_row_stride:
3588 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col:
3589 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row:
3590 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col_stride:
3591 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row_stride:
3592 case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_col:
3593 case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_row:
3594 case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_col_stride:
3595 case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_row_stride: {
3597 Info.memVT = MVT::v8f32;
3598 Info.ptrVal = I.getArgOperand(0);
3599 Info.offset = 0;
3601 Info.align = Align(16);
3602 return true;
3603 }
3604
3605 case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_col:
3606 case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_col_stride:
3607 case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_row:
3608 case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_row_stride:
3609
3610 case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_col:
3611 case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_col_stride:
3612 case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_row:
3613 case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_row_stride:
3614
3615 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_col:
3616 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_col_stride:
3617 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_row:
3618 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_row_stride:
3619 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_col:
3620 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_col_stride:
3621 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_row:
3622 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_row_stride:
3623 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_col:
3624 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_col_stride:
3625 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_row:
3626 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_row_stride: {
3628 Info.memVT = MVT::v8i32;
3629 Info.ptrVal = I.getArgOperand(0);
3630 Info.offset = 0;
3632 Info.align = Align(16);
3633 return true;
3634 }
3635
3636 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_col:
3637 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_col_stride:
3638 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_row:
3639 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_row_stride:
3640 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_col:
3641 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_col_stride:
3642 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row:
3643 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row_stride:
3644 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16:
3645 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16: {
3647 Info.memVT = MVT::v2i32;
3648 Info.ptrVal = I.getArgOperand(0);
3649 Info.offset = 0;
3651 Info.align = Align(8);
3652 return true;
3653 }
3654
3655 case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_col:
3656 case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_col_stride:
3657 case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_row:
3658 case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_row_stride:
3659
3660 case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_col:
3661 case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_col_stride:
3662 case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_row:
3663 case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_row_stride: {
3665 Info.memVT = MVT::f64;
3666 Info.ptrVal = I.getArgOperand(0);
3667 Info.offset = 0;
3669 Info.align = Align(8);
3670 return true;
3671 }
3672
3673 case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_col:
3674 case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_col_stride:
3675 case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_row:
3676 case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_row_stride: {
3678 Info.memVT = MVT::v2f64;
3679 Info.ptrVal = I.getArgOperand(0);
3680 Info.offset = 0;
3682 Info.align = Align(16);
3683 return true;
3684 }
3685
3686 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col:
3687 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row:
3688 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col_stride:
3689 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row_stride:
3690 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_col:
3691 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_row:
3692 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_col_stride:
3693 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_row_stride:
3694 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_col:
3695 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_row:
3696 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_col_stride:
3697 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_row_stride: {
3699 Info.memVT = MVT::v4f16;
3700 Info.ptrVal = I.getArgOperand(0);
3701 Info.offset = 0;
3703 Info.align = Align(16);
3704 return true;
3705 }
3706
3707 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col:
3708 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row:
3709 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col_stride:
3710 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row_stride:
3711 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_col:
3712 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_row:
3713 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_col_stride:
3714 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_row_stride:
3715 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col:
3716 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row:
3717 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col_stride:
3718 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row_stride:
3719 case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_col:
3720 case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_row:
3721 case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_col_stride:
3722 case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_row_stride: {
3724 Info.memVT = MVT::v8f32;
3725 Info.ptrVal = I.getArgOperand(0);
3726 Info.offset = 0;
3728 Info.align = Align(16);
3729 return true;
3730 }
3731
3732 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_col:
3733 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_col_stride:
3734 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_row:
3735 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_row_stride:
3736 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_col:
3737 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_col_stride:
3738 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_row:
3739 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_row_stride:
3740 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_col:
3741 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_col_stride:
3742 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_row:
3743 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_row_stride: {
3745 Info.memVT = MVT::v8i32;
3746 Info.ptrVal = I.getArgOperand(0);
3747 Info.offset = 0;
3749 Info.align = Align(16);
3750 return true;
3751 }
3752
3753 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_col:
3754 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_col_stride:
3755 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_row:
3756 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_row_stride:
3757 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col:
3758 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col_stride:
3759 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row:
3760 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride: {
3762 Info.memVT = MVT::v2i32;
3763 Info.ptrVal = I.getArgOperand(0);
3764 Info.offset = 0;
3766 Info.align = Align(8);
3767 return true;
3768 }
3769
3770 case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_col:
3771 case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_col_stride:
3772 case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_row:
3773 case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_row_stride: {
3775 Info.memVT = MVT::v2f64;
3776 Info.ptrVal = I.getArgOperand(0);
3777 Info.offset = 0;
3779 Info.align = Align(16);
3780 return true;
3781 }
3782
3783 case Intrinsic::nvvm_atomic_load_inc_32:
3784 case Intrinsic::nvvm_atomic_load_dec_32:
3785
3786 case Intrinsic::nvvm_atomic_add_gen_f_cta:
3787 case Intrinsic::nvvm_atomic_add_gen_f_sys:
3788 case Intrinsic::nvvm_atomic_add_gen_i_cta:
3789 case Intrinsic::nvvm_atomic_add_gen_i_sys:
3790 case Intrinsic::nvvm_atomic_and_gen_i_cta:
3791 case Intrinsic::nvvm_atomic_and_gen_i_sys:
3792 case Intrinsic::nvvm_atomic_cas_gen_i_cta:
3793 case Intrinsic::nvvm_atomic_cas_gen_i_sys:
3794 case Intrinsic::nvvm_atomic_dec_gen_i_cta:
3795 case Intrinsic::nvvm_atomic_dec_gen_i_sys:
3796 case Intrinsic::nvvm_atomic_inc_gen_i_cta:
3797 case Intrinsic::nvvm_atomic_inc_gen_i_sys:
3798 case Intrinsic::nvvm_atomic_max_gen_i_cta:
3799 case Intrinsic::nvvm_atomic_max_gen_i_sys:
3800 case Intrinsic::nvvm_atomic_min_gen_i_cta:
3801 case Intrinsic::nvvm_atomic_min_gen_i_sys:
3802 case Intrinsic::nvvm_atomic_or_gen_i_cta:
3803 case Intrinsic::nvvm_atomic_or_gen_i_sys:
3804 case Intrinsic::nvvm_atomic_exch_gen_i_cta:
3805 case Intrinsic::nvvm_atomic_exch_gen_i_sys:
3806 case Intrinsic::nvvm_atomic_xor_gen_i_cta:
3807 case Intrinsic::nvvm_atomic_xor_gen_i_sys: {
3808 auto &DL = I.getDataLayout();
3810 Info.memVT = getValueType(DL, I.getType());
3811 Info.ptrVal = I.getArgOperand(0);
3812 Info.offset = 0;
3814 Info.align.reset();
3815 return true;
3816 }
3817
3818 case Intrinsic::nvvm_ldu_global_i:
3819 case Intrinsic::nvvm_ldu_global_f:
3820 case Intrinsic::nvvm_ldu_global_p: {
3821 auto &DL = I.getDataLayout();
3823 if (Intrinsic == Intrinsic::nvvm_ldu_global_i)
3824 Info.memVT = getValueType(DL, I.getType());
3825 else if(Intrinsic == Intrinsic::nvvm_ldu_global_p)
3826 Info.memVT = getPointerTy(DL);
3827 else
3828 Info.memVT = getValueType(DL, I.getType());
3829 Info.ptrVal = I.getArgOperand(0);
3830 Info.offset = 0;
3832 Info.align = cast<ConstantInt>(I.getArgOperand(1))->getMaybeAlignValue();
3833
3834 return true;
3835 }
3836 case Intrinsic::nvvm_tex_1d_v4f32_s32:
3837 case Intrinsic::nvvm_tex_1d_v4f32_f32:
3838 case Intrinsic::nvvm_tex_1d_level_v4f32_f32:
3839 case Intrinsic::nvvm_tex_1d_grad_v4f32_f32:
3840 case Intrinsic::nvvm_tex_1d_array_v4f32_s32:
3841 case Intrinsic::nvvm_tex_1d_array_v4f32_f32:
3842 case Intrinsic::nvvm_tex_1d_array_level_v4f32_f32:
3843 case Intrinsic::nvvm_tex_1d_array_grad_v4f32_f32:
3844 case Intrinsic::nvvm_tex_2d_v4f32_s32:
3845 case Intrinsic::nvvm_tex_2d_v4f32_f32:
3846 case Intrinsic::nvvm_tex_2d_level_v4f32_f32:
3847 case Intrinsic::nvvm_tex_2d_grad_v4f32_f32:
3848 case Intrinsic::nvvm_tex_2d_array_v4f32_s32:
3849 case Intrinsic::nvvm_tex_2d_array_v4f32_f32:
3850 case Intrinsic::nvvm_tex_2d_array_level_v4f32_f32:
3851 case Intrinsic::nvvm_tex_2d_array_grad_v4f32_f32:
3852 case Intrinsic::nvvm_tex_3d_v4f32_s32:
3853 case Intrinsic::nvvm_tex_3d_v4f32_f32:
3854 case Intrinsic::nvvm_tex_3d_level_v4f32_f32:
3855 case Intrinsic::nvvm_tex_3d_grad_v4f32_f32:
3856 case Intrinsic::nvvm_tex_cube_v4f32_f32:
3857 case Intrinsic::nvvm_tex_cube_level_v4f32_f32:
3858 case Intrinsic::nvvm_tex_cube_array_v4f32_f32:
3859 case Intrinsic::nvvm_tex_cube_array_level_v4f32_f32:
3860 case Intrinsic::nvvm_tld4_r_2d_v4f32_f32:
3861 case Intrinsic::nvvm_tld4_g_2d_v4f32_f32:
3862 case Intrinsic::nvvm_tld4_b_2d_v4f32_f32:
3863 case Intrinsic::nvvm_tld4_a_2d_v4f32_f32:
3864 case Intrinsic::nvvm_tex_unified_1d_v4f32_s32:
3865 case Intrinsic::nvvm_tex_unified_1d_v4f32_f32:
3866 case Intrinsic::nvvm_tex_unified_1d_level_v4f32_f32:
3867 case Intrinsic::nvvm_tex_unified_1d_grad_v4f32_f32:
3868 case Intrinsic::nvvm_tex_unified_1d_array_v4f32_s32:
3869 case Intrinsic::nvvm_tex_unified_1d_array_v4f32_f32:
3870 case Intrinsic::nvvm_tex_unified_1d_array_level_v4f32_f32:
3871 case Intrinsic::nvvm_tex_unified_1d_array_grad_v4f32_f32:
3872 case Intrinsic::nvvm_tex_unified_2d_v4f32_s32:
3873 case Intrinsic::nvvm_tex_unified_2d_v4f32_f32:
3874 case Intrinsic::nvvm_tex_unified_2d_level_v4f32_f32:
3875 case Intrinsic::nvvm_tex_unified_2d_grad_v4f32_f32:
3876 case Intrinsic::nvvm_tex_unified_2d_array_v4f32_s32:
3877 case Intrinsic::nvvm_tex_unified_2d_array_v4f32_f32:
3878 case Intrinsic::nvvm_tex_unified_2d_array_level_v4f32_f32:
3879 case Intrinsic::nvvm_tex_unified_2d_array_grad_v4f32_f32:
3880 case Intrinsic::nvvm_tex_unified_3d_v4f32_s32:
3881 case Intrinsic::nvvm_tex_unified_3d_v4f32_f32:
3882 case Intrinsic::nvvm_tex_unified_3d_level_v4f32_f32:
3883 case Intrinsic::nvvm_tex_unified_3d_grad_v4f32_f32:
3884 case Intrinsic::nvvm_tex_unified_cube_v4f32_f32:
3885 case Intrinsic::nvvm_tex_unified_cube_level_v4f32_f32:
3886 case Intrinsic::nvvm_tex_unified_cube_array_v4f32_f32:
3887 case Intrinsic::nvvm_tex_unified_cube_array_level_v4f32_f32:
3888 case Intrinsic::nvvm_tex_unified_cube_grad_v4f32_f32:
3889 case Intrinsic::nvvm_tex_unified_cube_array_grad_v4f32_f32:
3890 case Intrinsic::nvvm_tld4_unified_r_2d_v4f32_f32:
3891 case Intrinsic::nvvm_tld4_unified_g_2d_v4f32_f32:
3892 case Intrinsic::nvvm_tld4_unified_b_2d_v4f32_f32:
3893 case Intrinsic::nvvm_tld4_unified_a_2d_v4f32_f32:
3895 Info.memVT = MVT::v4f32;
3896 Info.ptrVal = nullptr;
3897 Info.offset = 0;
3899 Info.align = Align(16);
3900 return true;
3901
3902 case Intrinsic::nvvm_tex_1d_v4s32_s32:
3903 case Intrinsic::nvvm_tex_1d_v4s32_f32:
3904 case Intrinsic::nvvm_tex_1d_level_v4s32_f32:
3905 case Intrinsic::nvvm_tex_1d_grad_v4s32_f32:
3906 case Intrinsic::nvvm_tex_1d_array_v4s32_s32:
3907 case Intrinsic::nvvm_tex_1d_array_v4s32_f32:
3908 case Intrinsic::nvvm_tex_1d_array_level_v4s32_f32:
3909 case Intrinsic::nvvm_tex_1d_array_grad_v4s32_f32:
3910 case Intrinsic::nvvm_tex_2d_v4s32_s32:
3911 case Intrinsic::nvvm_tex_2d_v4s32_f32:
3912 case Intrinsic::nvvm_tex_2d_level_v4s32_f32:
3913 case Intrinsic::nvvm_tex_2d_grad_v4s32_f32:
3914 case Intrinsic::nvvm_tex_2d_array_v4s32_s32:
3915 case Intrinsic::nvvm_tex_2d_array_v4s32_f32:
3916 case Intrinsic::nvvm_tex_2d_array_level_v4s32_f32:
3917 case Intrinsic::nvvm_tex_2d_array_grad_v4s32_f32:
3918 case Intrinsic::nvvm_tex_3d_v4s32_s32:
3919 case Intrinsic::nvvm_tex_3d_v4s32_f32:
3920 case Intrinsic::nvvm_tex_3d_level_v4s32_f32:
3921 case Intrinsic::nvvm_tex_3d_grad_v4s32_f32:
3922 case Intrinsic::nvvm_tex_cube_v4s32_f32:
3923 case Intrinsic::nvvm_tex_cube_level_v4s32_f32:
3924 case Intrinsic::nvvm_tex_cube_array_v4s32_f32:
3925 case Intrinsic::nvvm_tex_cube_array_level_v4s32_f32:
3926 case Intrinsic::nvvm_tex_cube_v4u32_f32:
3927 case Intrinsic::nvvm_tex_cube_level_v4u32_f32:
3928 case Intrinsic::nvvm_tex_cube_array_v4u32_f32:
3929 case Intrinsic::nvvm_tex_cube_array_level_v4u32_f32:
3930 case Intrinsic::nvvm_tex_1d_v4u32_s32:
3931 case Intrinsic::nvvm_tex_1d_v4u32_f32:
3932 case Intrinsic::nvvm_tex_1d_level_v4u32_f32:
3933 case Intrinsic::nvvm_tex_1d_grad_v4u32_f32:
3934 case Intrinsic::nvvm_tex_1d_array_v4u32_s32:
3935 case Intrinsic::nvvm_tex_1d_array_v4u32_f32:
3936 case Intrinsic::nvvm_tex_1d_array_level_v4u32_f32:
3937 case Intrinsic::nvvm_tex_1d_array_grad_v4u32_f32:
3938 case Intrinsic::nvvm_tex_2d_v4u32_s32:
3939 case Intrinsic::nvvm_tex_2d_v4u32_f32:
3940 case Intrinsic::nvvm_tex_2d_level_v4u32_f32:
3941 case Intrinsic::nvvm_tex_2d_grad_v4u32_f32:
3942 case Intrinsic::nvvm_tex_2d_array_v4u32_s32:
3943 case Intrinsic::nvvm_tex_2d_array_v4u32_f32:
3944 case Intrinsic::nvvm_tex_2d_array_level_v4u32_f32:
3945 case Intrinsic::nvvm_tex_2d_array_grad_v4u32_f32:
3946 case Intrinsic::nvvm_tex_3d_v4u32_s32:
3947 case Intrinsic::nvvm_tex_3d_v4u32_f32:
3948 case Intrinsic::nvvm_tex_3d_level_v4u32_f32:
3949 case Intrinsic::nvvm_tex_3d_grad_v4u32_f32:
3950 case Intrinsic::nvvm_tld4_r_2d_v4s32_f32:
3951 case Intrinsic::nvvm_tld4_g_2d_v4s32_f32:
3952 case Intrinsic::nvvm_tld4_b_2d_v4s32_f32:
3953 case Intrinsic::nvvm_tld4_a_2d_v4s32_f32:
3954 case Intrinsic::nvvm_tld4_r_2d_v4u32_f32:
3955 case Intrinsic::nvvm_tld4_g_2d_v4u32_f32:
3956 case Intrinsic::nvvm_tld4_b_2d_v4u32_f32:
3957 case Intrinsic::nvvm_tld4_a_2d_v4u32_f32:
3958 case Intrinsic::nvvm_tex_unified_1d_v4s32_s32:
3959 case Intrinsic::nvvm_tex_unified_1d_v4s32_f32:
3960 case Intrinsic::nvvm_tex_unified_1d_level_v4s32_f32:
3961 case Intrinsic::nvvm_tex_unified_1d_grad_v4s32_f32:
3962 case Intrinsic::nvvm_tex_unified_1d_array_v4s32_s32:
3963 case Intrinsic::nvvm_tex_unified_1d_array_v4s32_f32:
3964 case Intrinsic::nvvm_tex_unified_1d_array_level_v4s32_f32:
3965 case Intrinsic::nvvm_tex_unified_1d_array_grad_v4s32_f32:
3966 case Intrinsic::nvvm_tex_unified_2d_v4s32_s32:
3967 case Intrinsic::nvvm_tex_unified_2d_v4s32_f32:
3968 case Intrinsic::nvvm_tex_unified_2d_level_v4s32_f32:
3969 case Intrinsic::nvvm_tex_unified_2d_grad_v4s32_f32:
3970 case Intrinsic::nvvm_tex_unified_2d_array_v4s32_s32:
3971 case Intrinsic::nvvm_tex_unified_2d_array_v4s32_f32:
3972 case Intrinsic::nvvm_tex_unified_2d_array_level_v4s32_f32:
3973 case Intrinsic::nvvm_tex_unified_2d_array_grad_v4s32_f32:
3974 case Intrinsic::nvvm_tex_unified_3d_v4s32_s32:
3975 case Intrinsic::nvvm_tex_unified_3d_v4s32_f32:
3976 case Intrinsic::nvvm_tex_unified_3d_level_v4s32_f32:
3977 case Intrinsic::nvvm_tex_unified_3d_grad_v4s32_f32:
3978 case Intrinsic::nvvm_tex_unified_1d_v4u32_s32:
3979 case Intrinsic::nvvm_tex_unified_1d_v4u32_f32:
3980 case Intrinsic::nvvm_tex_unified_1d_level_v4u32_f32:
3981 case Intrinsic::nvvm_tex_unified_1d_grad_v4u32_f32:
3982 case Intrinsic::nvvm_tex_unified_1d_array_v4u32_s32:
3983 case Intrinsic::nvvm_tex_unified_1d_array_v4u32_f32:
3984 case Intrinsic::nvvm_tex_unified_1d_array_level_v4u32_f32:
3985 case Intrinsic::nvvm_tex_unified_1d_array_grad_v4u32_f32:
3986 case Intrinsic::nvvm_tex_unified_2d_v4u32_s32:
3987 case Intrinsic::nvvm_tex_unified_2d_v4u32_f32:
3988 case Intrinsic::nvvm_tex_unified_2d_level_v4u32_f32:
3989 case Intrinsic::nvvm_tex_unified_2d_grad_v4u32_f32:
3990 case Intrinsic::nvvm_tex_unified_2d_array_v4u32_s32:
3991 case Intrinsic::nvvm_tex_unified_2d_array_v4u32_f32:
3992 case Intrinsic::nvvm_tex_unified_2d_array_level_v4u32_f32:
3993 case Intrinsic::nvvm_tex_unified_2d_array_grad_v4u32_f32:
3994 case Intrinsic::nvvm_tex_unified_3d_v4u32_s32:
3995 case Intrinsic::nvvm_tex_unified_3d_v4u32_f32:
3996 case Intrinsic::nvvm_tex_unified_3d_level_v4u32_f32:
3997 case Intrinsic::nvvm_tex_unified_3d_grad_v4u32_f32:
3998 case Intrinsic::nvvm_tex_unified_cube_v4s32_f32:
3999 case Intrinsic::nvvm_tex_unified_cube_level_v4s32_f32:
4000 case Intrinsic::nvvm_tex_unified_cube_array_v4s32_f32:
4001 case Intrinsic::nvvm_tex_unified_cube_array_level_v4s32_f32:
4002 case Intrinsic::nvvm_tex_unified_cube_v4u32_f32:
4003 case Intrinsic::nvvm_tex_unified_cube_level_v4u32_f32:
4004 case Intrinsic::nvvm_tex_unified_cube_array_v4u32_f32:
4005 case Intrinsic::nvvm_tex_unified_cube_array_level_v4u32_f32:
4006 case Intrinsic::nvvm_tex_unified_cube_grad_v4s32_f32:
4007 case Intrinsic::nvvm_tex_unified_cube_grad_v4u32_f32:
4008 case Intrinsic::nvvm_tex_unified_cube_array_grad_v4s32_f32:
4009 case Intrinsic::nvvm_tex_unified_cube_array_grad_v4u32_f32:
4010 case Intrinsic::nvvm_tld4_unified_r_2d_v4s32_f32:
4011 case Intrinsic::nvvm_tld4_unified_g_2d_v4s32_f32:
4012 case Intrinsic::nvvm_tld4_unified_b_2d_v4s32_f32:
4013 case Intrinsic::nvvm_tld4_unified_a_2d_v4s32_f32:
4014 case Intrinsic::nvvm_tld4_unified_r_2d_v4u32_f32:
4015 case Intrinsic::nvvm_tld4_unified_g_2d_v4u32_f32:
4016 case Intrinsic::nvvm_tld4_unified_b_2d_v4u32_f32:
4017 case Intrinsic::nvvm_tld4_unified_a_2d_v4u32_f32:
4019 Info.memVT = MVT::v4i32;
4020 Info.ptrVal = nullptr;
4021 Info.offset = 0;
4023 Info.align = Align(16);
4024 return true;
4025
4026 case Intrinsic::nvvm_suld_1d_i8_clamp:
4027 case Intrinsic::nvvm_suld_1d_v2i8_clamp:
4028 case Intrinsic::nvvm_suld_1d_v4i8_clamp:
4029 case Intrinsic::nvvm_suld_1d_array_i8_clamp:
4030 case Intrinsic::nvvm_suld_1d_array_v2i8_clamp:
4031 case Intrinsic::nvvm_suld_1d_array_v4i8_clamp:
4032 case Intrinsic::nvvm_suld_2d_i8_clamp:
4033 case Intrinsic::nvvm_suld_2d_v2i8_clamp:
4034 case Intrinsic::nvvm_suld_2d_v4i8_clamp:
4035 case Intrinsic::nvvm_suld_2d_array_i8_clamp:
4036 case Intrinsic::nvvm_suld_2d_array_v2i8_clamp:
4037 case Intrinsic::nvvm_suld_2d_array_v4i8_clamp:
4038 case Intrinsic::nvvm_suld_3d_i8_clamp:
4039 case Intrinsic::nvvm_suld_3d_v2i8_clamp:
4040 case Intrinsic::nvvm_suld_3d_v4i8_clamp:
4041 case Intrinsic::nvvm_suld_1d_i8_trap:
4042 case Intrinsic::nvvm_suld_1d_v2i8_trap:
4043 case Intrinsic::nvvm_suld_1d_v4i8_trap:
4044 case Intrinsic::nvvm_suld_1d_array_i8_trap:
4045 case Intrinsic::nvvm_suld_1d_array_v2i8_trap:
4046 case Intrinsic::nvvm_suld_1d_array_v4i8_trap:
4047 case Intrinsic::nvvm_suld_2d_i8_trap:
4048 case Intrinsic::nvvm_suld_2d_v2i8_trap:
4049 case Intrinsic::nvvm_suld_2d_v4i8_trap:
4050 case Intrinsic::nvvm_suld_2d_array_i8_trap:
4051 case Intrinsic::nvvm_suld_2d_array_v2i8_trap:
4052 case Intrinsic::nvvm_suld_2d_array_v4i8_trap:
4053 case Intrinsic::nvvm_suld_3d_i8_trap:
4054 case Intrinsic::nvvm_suld_3d_v2i8_trap:
4055 case Intrinsic::nvvm_suld_3d_v4i8_trap:
4056 case Intrinsic::nvvm_suld_1d_i8_zero:
4057 case Intrinsic::nvvm_suld_1d_v2i8_zero:
4058 case Intrinsic::nvvm_suld_1d_v4i8_zero:
4059 case Intrinsic::nvvm_suld_1d_array_i8_zero:
4060 case Intrinsic::nvvm_suld_1d_array_v2i8_zero:
4061 case Intrinsic::nvvm_suld_1d_array_v4i8_zero:
4062 case Intrinsic::nvvm_suld_2d_i8_zero:
4063 case Intrinsic::nvvm_suld_2d_v2i8_zero:
4064 case Intrinsic::nvvm_suld_2d_v4i8_zero:
4065 case Intrinsic::nvvm_suld_2d_array_i8_zero:
4066 case Intrinsic::nvvm_suld_2d_array_v2i8_zero:
4067 case Intrinsic::nvvm_suld_2d_array_v4i8_zero:
4068 case Intrinsic::nvvm_suld_3d_i8_zero:
4069 case Intrinsic::nvvm_suld_3d_v2i8_zero:
4070 case Intrinsic::nvvm_suld_3d_v4i8_zero:
4072 Info.memVT = MVT::i8;
4073 Info.ptrVal = nullptr;
4074 Info.offset = 0;
4076 Info.align = Align(16);
4077 return true;
4078
4079 case Intrinsic::nvvm_suld_1d_i16_clamp:
4080 case Intrinsic::nvvm_suld_1d_v2i16_clamp:
4081 case Intrinsic::nvvm_suld_1d_v4i16_clamp:
4082 case Intrinsic::nvvm_suld_1d_array_i16_clamp:
4083 case Intrinsic::nvvm_suld_1d_array_v2i16_clamp:
4084 case Intrinsic::nvvm_suld_1d_array_v4i16_clamp:
4085 case Intrinsic::nvvm_suld_2d_i16_clamp:
4086 case Intrinsic::nvvm_suld_2d_v2i16_clamp:
4087 case Intrinsic::nvvm_suld_2d_v4i16_clamp:
4088 case Intrinsic::nvvm_suld_2d_array_i16_clamp:
4089 case Intrinsic::nvvm_suld_2d_array_v2i16_clamp:
4090 case Intrinsic::nvvm_suld_2d_array_v4i16_clamp:
4091 case Intrinsic::nvvm_suld_3d_i16_clamp:
4092 case Intrinsic::nvvm_suld_3d_v2i16_clamp:
4093 case Intrinsic::nvvm_suld_3d_v4i16_clamp:
4094 case Intrinsic::nvvm_suld_1d_i16_trap:
4095 case Intrinsic::nvvm_suld_1d_v2i16_trap:
4096 case Intrinsic::nvvm_suld_1d_v4i16_trap:
4097 case Intrinsic::nvvm_suld_1d_array_i16_trap:
4098 case Intrinsic::nvvm_suld_1d_array_v2i16_trap:
4099 case Intrinsic::nvvm_suld_1d_array_v4i16_trap:
4100 case Intrinsic::nvvm_suld_2d_i16_trap:
4101 case Intrinsic::nvvm_suld_2d_v2i16_trap:
4102 case Intrinsic::nvvm_suld_2d_v4i16_trap:
4103 case Intrinsic::nvvm_suld_2d_array_i16_trap:
4104 case Intrinsic::nvvm_suld_2d_array_v2i16_trap:
4105 case Intrinsic::nvvm_suld_2d_array_v4i16_trap:
4106 case Intrinsic::nvvm_suld_3d_i16_trap:
4107 case Intrinsic::nvvm_suld_3d_v2i16_trap:
4108 case Intrinsic::nvvm_suld_3d_v4i16_trap:
4109 case Intrinsic::nvvm_suld_1d_i16_zero:
4110 case Intrinsic::nvvm_suld_1d_v2i16_zero:
4111 case Intrinsic::nvvm_suld_1d_v4i16_zero:
4112 case Intrinsic::nvvm_suld_1d_array_i16_zero:
4113 case Intrinsic::nvvm_suld_1d_array_v2i16_zero:
4114 case Intrinsic::nvvm_suld_1d_array_v4i16_zero:
4115 case Intrinsic::nvvm_suld_2d_i16_zero:
4116 case Intrinsic::nvvm_suld_2d_v2i16_zero:
4117 case Intrinsic::nvvm_suld_2d_v4i16_zero:
4118 case Intrinsic::nvvm_suld_2d_array_i16_zero:
4119 case Intrinsic::nvvm_suld_2d_array_v2i16_zero:
4120 case Intrinsic::nvvm_suld_2d_array_v4i16_zero:
4121 case Intrinsic::nvvm_suld_3d_i16_zero:
4122 case Intrinsic::nvvm_suld_3d_v2i16_zero:
4123 case Intrinsic::nvvm_suld_3d_v4i16_zero:
4125 Info.memVT = MVT::i16;
4126 Info.ptrVal = nullptr;
4127 Info.offset = 0;
4129 Info.align = Align(16);
4130 return true;
4131
4132 case Intrinsic::nvvm_suld_1d_i32_clamp:
4133 case Intrinsic::nvvm_suld_1d_v2i32_clamp:
4134 case Intrinsic::nvvm_suld_1d_v4i32_clamp:
4135 case Intrinsic::nvvm_suld_1d_array_i32_clamp:
4136 case Intrinsic::nvvm_suld_1d_array_v2i32_clamp:
4137 case Intrinsic::nvvm_suld_1d_array_v4i32_clamp:
4138 case Intrinsic::nvvm_suld_2d_i32_clamp:
4139 case Intrinsic::nvvm_suld_2d_v2i32_clamp:
4140 case Intrinsic::nvvm_suld_2d_v4i32_clamp:
4141 case Intrinsic::nvvm_suld_2d_array_i32_clamp:
4142 case Intrinsic::nvvm_suld_2d_array_v2i32_clamp:
4143 case Intrinsic::nvvm_suld_2d_array_v4i32_clamp:
4144 case Intrinsic::nvvm_suld_3d_i32_clamp:
4145 case Intrinsic::nvvm_suld_3d_v2i32_clamp:
4146 case Intrinsic::nvvm_suld_3d_v4i32_clamp:
4147 case Intrinsic::nvvm_suld_1d_i32_trap:
4148 case Intrinsic::nvvm_suld_1d_v2i32_trap:
4149 case Intrinsic::nvvm_suld_1d_v4i32_trap:
4150 case Intrinsic::nvvm_suld_1d_array_i32_trap:
4151 case Intrinsic::nvvm_suld_1d_array_v2i32_trap:
4152 case Intrinsic::nvvm_suld_1d_array_v4i32_trap:
4153 case Intrinsic::nvvm_suld_2d_i32_trap:
4154 case Intrinsic::nvvm_suld_2d_v2i32_trap:
4155 case Intrinsic::nvvm_suld_2d_v4i32_trap:
4156 case Intrinsic::nvvm_suld_2d_array_i32_trap:
4157 case Intrinsic::nvvm_suld_2d_array_v2i32_trap:
4158 case Intrinsic::nvvm_suld_2d_array_v4i32_trap:
4159 case Intrinsic::nvvm_suld_3d_i32_trap:
4160 case Intrinsic::nvvm_suld_3d_v2i32_trap:
4161 case Intrinsic::nvvm_suld_3d_v4i32_trap:
4162 case Intrinsic::nvvm_suld_1d_i32_zero:
4163 case Intrinsic::nvvm_suld_1d_v2i32_zero:
4164 case Intrinsic::nvvm_suld_1d_v4i32_zero:
4165 case Intrinsic::nvvm_suld_1d_array_i32_zero:
4166 case Intrinsic::nvvm_suld_1d_array_v2i32_zero:
4167 case Intrinsic::nvvm_suld_1d_array_v4i32_zero:
4168 case Intrinsic::nvvm_suld_2d_i32_zero:
4169 case Intrinsic::nvvm_suld_2d_v2i32_zero:
4170 case Intrinsic::nvvm_suld_2d_v4i32_zero:
4171 case Intrinsic::nvvm_suld_2d_array_i32_zero:
4172 case Intrinsic::nvvm_suld_2d_array_v2i32_zero:
4173 case Intrinsic::nvvm_suld_2d_array_v4i32_zero:
4174 case Intrinsic::nvvm_suld_3d_i32_zero:
4175 case Intrinsic::nvvm_suld_3d_v2i32_zero:
4176 case Intrinsic::nvvm_suld_3d_v4i32_zero:
4178 Info.memVT = MVT::i32;
4179 Info.ptrVal = nullptr;
4180 Info.offset = 0;
4182 Info.align = Align(16);
4183 return true;
4184
4185 case Intrinsic::nvvm_suld_1d_i64_clamp:
4186 case Intrinsic::nvvm_suld_1d_v2i64_clamp:
4187 case Intrinsic::nvvm_suld_1d_array_i64_clamp:
4188 case Intrinsic::nvvm_suld_1d_array_v2i64_clamp:
4189 case Intrinsic::nvvm_suld_2d_i64_clamp:
4190 case Intrinsic::nvvm_suld_2d_v2i64_clamp:
4191 case Intrinsic::nvvm_suld_2d_array_i64_clamp:
4192 case Intrinsic::nvvm_suld_2d_array_v2i64_clamp:
4193 case Intrinsic::nvvm_suld_3d_i64_clamp:
4194 case Intrinsic::nvvm_suld_3d_v2i64_clamp:
4195 case Intrinsic::nvvm_suld_1d_i64_trap:
4196 case Intrinsic::nvvm_suld_1d_v2i64_trap:
4197 case Intrinsic::nvvm_suld_1d_array_i64_trap:
4198 case Intrinsic::nvvm_suld_1d_array_v2i64_trap:
4199 case Intrinsic::nvvm_suld_2d_i64_trap:
4200 case Intrinsic::nvvm_suld_2d_v2i64_trap:
4201 case Intrinsic::nvvm_suld_2d_array_i64_trap:
4202 case Intrinsic::nvvm_suld_2d_array_v2i64_trap:
4203 case Intrinsic::nvvm_suld_3d_i64_trap:
4204 case Intrinsic::nvvm_suld_3d_v2i64_trap:
4205 case Intrinsic::nvvm_suld_1d_i64_zero:
4206 case Intrinsic::nvvm_suld_1d_v2i64_zero:
4207 case Intrinsic::nvvm_suld_1d_array_i64_zero:
4208 case Intrinsic::nvvm_suld_1d_array_v2i64_zero:
4209 case Intrinsic::nvvm_suld_2d_i64_zero:
4210 case Intrinsic::nvvm_suld_2d_v2i64_zero:
4211 case Intrinsic::nvvm_suld_2d_array_i64_zero:
4212 case Intrinsic::nvvm_suld_2d_array_v2i64_zero:
4213 case Intrinsic::nvvm_suld_3d_i64_zero:
4214 case Intrinsic::nvvm_suld_3d_v2i64_zero:
4216 Info.memVT = MVT::i64;
4217 Info.ptrVal = nullptr;
4218 Info.offset = 0;
4220 Info.align = Align(16);
4221 return true;
4222 }
4223 return false;
4224}
4225
4226/// getFunctionParamOptimizedAlign - since function arguments are passed via
4227/// .param space, we may want to increase their alignment in a way that
4228/// ensures that we can effectively vectorize their loads & stores. We can
4229/// increase alignment only if the function has internal or has private
4230/// linkage as for other linkage types callers may already rely on default
4231/// alignment. To allow using 128-bit vectorized loads/stores, this function
4232/// ensures that alignment is 16 or greater.
4234 const Function *F, Type *ArgTy, const DataLayout &DL) const {
4235 // Capping the alignment to 128 bytes as that is the maximum alignment
4236 // supported by PTX.
4237 const Align ABITypeAlign = std::min(Align(128), DL.getABITypeAlign(ArgTy));
4238
4239 // If a function has linkage different from internal or private, we
4240 // must use default ABI alignment as external users rely on it. Same
4241 // for a function that may be called from a function pointer.
4242 if (!F || !F->hasLocalLinkage() ||
4243 F->hasAddressTaken(/*Users=*/nullptr,
4244 /*IgnoreCallbackUses=*/false,
4245 /*IgnoreAssumeLikeCalls=*/true,
4246 /*IgnoreLLVMUsed=*/true))
4247 return ABITypeAlign;
4248
4249 assert(!isKernelFunction(*F) && "Expect kernels to have non-local linkage");
4250 return std::max(Align(16), ABITypeAlign);
4251}
4252
4253/// Helper for computing alignment of a device function byval parameter.
4255 const Function *F, Type *ArgTy, Align InitialAlign,
4256 const DataLayout &DL) const {
4257 Align ArgAlign = InitialAlign;
4258 // Try to increase alignment to enhance vectorization options.
4259 if (F)
4260 ArgAlign = std::max(ArgAlign, getFunctionParamOptimizedAlign(F, ArgTy, DL));
4261
4262 // Old ptx versions have a bug. When PTX code takes address of
4263 // byval parameter with alignment < 4, ptxas generates code to
4264 // spill argument into memory. Alas on sm_50+ ptxas generates
4265 // SASS code that fails with misaligned access. To work around
4266 // the problem, make sure that we align byval parameters by at
4267 // least 4. This bug seems to be fixed at least starting from
4268 // ptxas > 9.0.
4269 // TODO: remove this after verifying the bug is not reproduced
4270 // on non-deprecated ptxas versions.
4272 ArgAlign = std::max(ArgAlign, Align(4));
4273
4274 return ArgAlign;
4275}
4276
4277// Helper for getting a function parameter name. Name is composed from
4278// its index and the function name. Negative index corresponds to special
4279// parameter (unsized array) used for passing variable arguments.
4281 int Idx) const {
4282 std::string ParamName;
4283 raw_string_ostream ParamStr(ParamName);
4284
4285 ParamStr << getTargetMachine().getSymbol(F)->getName();
4286 if (Idx < 0)
4287 ParamStr << "_vararg";
4288 else
4289 ParamStr << "_param_" << Idx;
4290
4291 return ParamName;
4292}
4293
4294/// isLegalAddressingMode - Return true if the addressing mode represented
4295/// by AM is legal for this target, for a load/store of the specified type.
4296/// Used to guide target specific optimizations, like loop strength reduction
4297/// (LoopStrengthReduce.cpp) and memory optimization for address mode
4298/// (CodeGenPrepare.cpp)
4300 const AddrMode &AM, Type *Ty,
4301 unsigned AS, Instruction *I) const {
4302 // AddrMode - This represents an addressing mode of:
4303 // BaseGV + BaseOffs + BaseReg + Scale*ScaleReg
4304 //
4305 // The legal address modes are
4306 // - [avar]
4307 // - [areg]
4308 // - [areg+immoff]
4309 // - [immAddr]
4310
4311 // immoff must fit in a signed 32-bit int
4312 if (!APInt(64, AM.BaseOffs).isSignedIntN(32))
4313 return false;
4314
4315 if (AM.BaseGV)
4316 return !AM.BaseOffs && !AM.HasBaseReg && !AM.Scale;
4317
4318 switch (AM.Scale) {
4319 case 0: // "r", "r+i" or "i" is allowed
4320 break;
4321 case 1:
4322 if (AM.HasBaseReg) // "r+r+i" or "r+r" is not allowed.
4323 return false;
4324 // Otherwise we have r+i.
4325 break;
4326 default:
4327 // No scale > 1 is allowed
4328 return false;
4329 }
4330 return true;
4331}
4332
4333//===----------------------------------------------------------------------===//
4334// NVPTX Inline Assembly Support
4335//===----------------------------------------------------------------------===//
4336
4337/// getConstraintType - Given a constraint letter, return the type of
4338/// constraint it is for this target.
4341 if (Constraint.size() == 1) {
4342 switch (Constraint[0]) {
4343 default:
4344 break;
4345 case 'b':
4346 case 'r':
4347 case 'h':
4348 case 'c':
4349 case 'l':
4350 case 'f':
4351 case 'd':
4352 case 'q':
4353 case '0':
4354 case 'N':
4355 return C_RegisterClass;
4356 }
4357 }
4358 return TargetLowering::getConstraintType(Constraint);
4359}
4360
4361std::pair<unsigned, const TargetRegisterClass *>
4363 StringRef Constraint,
4364 MVT VT) const {
4365 if (Constraint.size() == 1) {
4366 switch (Constraint[0]) {
4367 case 'b':
4368 return std::make_pair(0U, &NVPTX::Int1RegsRegClass);
4369 case 'c':
4370 return std::make_pair(0U, &NVPTX::Int16RegsRegClass);
4371 case 'h':
4372 return std::make_pair(0U, &NVPTX::Int16RegsRegClass);
4373 case 'r':
4374 return std::make_pair(0U, &NVPTX::Int32RegsRegClass);
4375 case 'l':
4376 case 'N':
4377 return std::make_pair(0U, &NVPTX::Int64RegsRegClass);
4378 case 'q': {
4379 if (STI.getSmVersion() < 70)
4380 report_fatal_error("Inline asm with 128 bit operands is only "
4381 "supported for sm_70 and higher!");
4382 return std::make_pair(0U, &NVPTX::Int128RegsRegClass);
4383 }
4384 case 'f':
4385 return std::make_pair(0U, &NVPTX::Float32RegsRegClass);
4386 case 'd':
4387 return std::make_pair(0U, &NVPTX::Float64RegsRegClass);
4388 }
4389 }
4390 return TargetLowering::getRegForInlineAsmConstraint(TRI, Constraint, VT);
4391}
4392
4393//===----------------------------------------------------------------------===//
4394// NVPTX DAG Combining
4395//===----------------------------------------------------------------------===//
4396
4398 CodeGenOptLevel OptLevel) const {
4399 // Always honor command-line argument
4400 if (FMAContractLevelOpt.getNumOccurrences() > 0)
4401 return FMAContractLevelOpt > 0;
4402
4403 // Do not contract if we're not optimizing the code.
4404 if (OptLevel == CodeGenOptLevel::None)
4405 return false;