LLVM 22.0.0git
NVPTXISelLowering.cpp
Go to the documentation of this file.
1//===-- NVPTXISelLowering.cpp - NVPTX DAG Lowering Implementation ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the interfaces that NVPTX uses to lower LLVM code into a
10// selection DAG.
11//
12//===----------------------------------------------------------------------===//
13
14#include "NVPTXISelLowering.h"
16#include "NVPTX.h"
17#include "NVPTXISelDAGToDAG.h"
19#include "NVPTXSubtarget.h"
20#include "NVPTXTargetMachine.h"
22#include "NVPTXUtilities.h"
23#include "llvm/ADT/APFloat.h"
24#include "llvm/ADT/APInt.h"
25#include "llvm/ADT/STLExtras.h"
27#include "llvm/ADT/StringRef.h"
39#include "llvm/IR/Argument.h"
40#include "llvm/IR/Attributes.h"
41#include "llvm/IR/Constants.h"
42#include "llvm/IR/DataLayout.h"
45#include "llvm/IR/FPEnv.h"
46#include "llvm/IR/Function.h"
47#include "llvm/IR/GlobalValue.h"
48#include "llvm/IR/IRBuilder.h"
49#include "llvm/IR/Instruction.h"
51#include "llvm/IR/IntrinsicsNVPTX.h"
52#include "llvm/IR/Module.h"
53#include "llvm/IR/Type.h"
54#include "llvm/IR/Value.h"
66#include <algorithm>
67#include <cassert>
68#include <cmath>
69#include <cstdint>
70#include <iterator>
71#include <optional>
72#include <string>
73#include <tuple>
74#include <utility>
75#include <vector>
76
77#define DEBUG_TYPE "nvptx-lower"
78
79using namespace llvm;
80
82 "nvptx-sched4reg",
83 cl::desc("NVPTX Specific: schedule for register pressue"), cl::init(false));
84
86 "nvptx-fma-level", cl::Hidden,
87 cl::desc("NVPTX Specific: FMA contraction (0: don't do it"
88 " 1: do it 2: do it aggressively"),
89 cl::init(2));
90
92 "nvptx-prec-divf32", cl::Hidden,
94 "NVPTX Specific: Override the precision of the lowering for f32 fdiv"),
96 clEnumValN(NVPTX::DivPrecisionLevel::Approx, "0", "Use div.approx"),
97 clEnumValN(NVPTX::DivPrecisionLevel::Full, "1", "Use div.full"),
99 "Use IEEE Compliant F32 div.rnd if available (default)"),
101 "Use IEEE Compliant F32 div.rnd if available, no FTZ")),
103
105 "nvptx-prec-sqrtf32", cl::Hidden,
106 cl::desc("NVPTX Specific: 0 use sqrt.approx, 1 use sqrt.rn."),
107 cl::init(true));
108
109/// Whereas CUDA's implementation (see libdevice) uses ex2.approx for exp2(), it
110/// does NOT use lg2.approx for log2, so this is disabled by default.
112 "nvptx-approx-log2f32",
113 cl::desc("NVPTX Specific: whether to use lg2.approx for log2"),
114 cl::init(false));
115
117 "nvptx-force-min-byval-param-align", cl::Hidden,
118 cl::desc("NVPTX Specific: force 4-byte minimal alignment for byval"
119 " params of device functions."),
120 cl::init(false));
121
124 const SDNode &N) const {
125 // If nvptx-prec-div32=N is used on the command-line, always honor it
126 if (UsePrecDivF32.getNumOccurrences() > 0)
127 return UsePrecDivF32;
128
129 const SDNodeFlags Flags = N.getFlags();
130 if (Flags.hasApproximateFuncs())
132
134}
135
137 // If nvptx-prec-sqrtf32 is used on the command-line, always honor it
138 if (UsePrecSqrtF32.getNumOccurrences() > 0)
139 return UsePrecSqrtF32;
140
141 if (N) {
142 const SDNodeFlags Flags = N->getFlags();
143 if (Flags.hasApproximateFuncs())
144 return false;
145 }
146
147 return true;
148}
149
154
155static bool IsPTXVectorType(MVT VT) {
156 switch (VT.SimpleTy) {
157 default:
158 return false;
159 case MVT::v2i1:
160 case MVT::v4i1:
161 case MVT::v2i8:
162 case MVT::v4i8:
163 case MVT::v8i8: // <2 x i8x4>
164 case MVT::v16i8: // <4 x i8x4>
165 case MVT::v2i16:
166 case MVT::v4i16:
167 case MVT::v8i16: // <4 x i16x2>
168 case MVT::v2i32:
169 case MVT::v4i32:
170 case MVT::v2i64:
171 case MVT::v2f16:
172 case MVT::v4f16:
173 case MVT::v8f16: // <4 x f16x2>
174 case MVT::v2bf16:
175 case MVT::v4bf16:
176 case MVT::v8bf16: // <4 x bf16x2>
177 case MVT::v2f32:
178 case MVT::v4f32:
179 case MVT::v2f64:
180 case MVT::v4i64:
181 case MVT::v4f64:
182 case MVT::v8i32:
183 case MVT::v8f32:
184 case MVT::v16f16: // <8 x f16x2>
185 case MVT::v16bf16: // <8 x bf16x2>
186 case MVT::v16i16: // <8 x i16x2>
187 case MVT::v32i8: // <8 x i8x4>
188 return true;
189 }
190}
191
192// When legalizing vector loads/stores, this function is called, which does two
193// things:
194// 1. Determines Whether the vector is something we want to custom lower,
195// std::nullopt is returned if we do not want to custom lower it.
196// 2. If we do want to handle it, returns two parameters:
197// - unsigned int NumElts - The number of elements in the final vector
198// - EVT EltVT - The type of the elements in the final vector
199static std::optional<std::pair<unsigned int, MVT>>
201 unsigned AddressSpace) {
202 const bool CanLowerTo256Bit = STI.has256BitVectorLoadStore(AddressSpace);
203
204 if (CanLowerTo256Bit && VectorEVT.isScalarInteger() &&
205 VectorEVT.getSizeInBits() == 256)
206 return {{4, MVT::i64}};
207
208 if (!VectorEVT.isSimple())
209 return std::nullopt;
210 const MVT VectorVT = VectorEVT.getSimpleVT();
211
212 if (!VectorVT.isVector()) {
213 if (VectorVT == MVT::i128 || VectorVT == MVT::f128)
214 return {{2, MVT::i64}};
215 return std::nullopt;
216 }
217
218 const MVT EltVT = VectorVT.getVectorElementType();
219 const unsigned NumElts = VectorVT.getVectorNumElements();
220
221 // The size of the PTX virtual register that holds a packed type.
222 unsigned PackRegSize;
223
224 // We only handle "native" vector sizes for now, e.g. <4 x double> is not
225 // legal. We can (and should) split that into 2 stores of <2 x double> here
226 // but I'm leaving that as a TODO for now.
227 switch (VectorVT.SimpleTy) {
228 default:
229 return std::nullopt;
230
231 case MVT::v4i64:
232 case MVT::v4f64:
233 // This is a "native" vector type iff the address space is global and the
234 // target supports 256-bit loads/stores
235 if (!CanLowerTo256Bit)
236 return std::nullopt;
237 [[fallthrough]];
238 case MVT::v2i8:
239 case MVT::v2i64:
240 case MVT::v2f64:
241 // This is a "native" vector type
242 return std::pair(NumElts, EltVT);
243
244 case MVT::v16f16: // <8 x f16x2>
245 case MVT::v16bf16: // <8 x bf16x2>
246 case MVT::v16i16: // <8 x i16x2>
247 case MVT::v32i8: // <8 x i8x4>
248 // This can be upsized into a "native" vector type iff the address space is
249 // global and the target supports 256-bit loads/stores.
250 if (!CanLowerTo256Bit)
251 return std::nullopt;
252 [[fallthrough]];
253 case MVT::v2i16: // <1 x i16x2>
254 case MVT::v2f16: // <1 x f16x2>
255 case MVT::v2bf16: // <1 x bf16x2>
256 case MVT::v4i8: // <1 x i8x4>
257 case MVT::v4i16: // <2 x i16x2>
258 case MVT::v4f16: // <2 x f16x2>
259 case MVT::v4bf16: // <2 x bf16x2>
260 case MVT::v8i8: // <2 x i8x4>
261 case MVT::v8f16: // <4 x f16x2>
262 case MVT::v8bf16: // <4 x bf16x2>
263 case MVT::v8i16: // <4 x i16x2>
264 case MVT::v16i8: // <4 x i8x4>
265 PackRegSize = 32;
266 break;
267
268 case MVT::v8f32: // <4 x f32x2>
269 case MVT::v8i32: // <4 x i32x2>
270 // This is a "native" vector type iff the address space is global and the
271 // target supports 256-bit loads/stores
272 if (!CanLowerTo256Bit)
273 return std::nullopt;
274 [[fallthrough]];
275 case MVT::v2f32: // <1 x f32x2>
276 case MVT::v4f32: // <2 x f32x2>
277 case MVT::v2i32: // <1 x i32x2>
278 case MVT::v4i32: // <2 x i32x2>
279 if (!STI.hasF32x2Instructions())
280 return std::pair(NumElts, EltVT);
281 PackRegSize = 64;
282 break;
283 }
284
285 // If we reach here, then we can pack 2 or more elements into a single 32-bit
286 // or 64-bit PTX register and treat the vector as a new vector containing
287 // packed elements.
288
289 // Number of elements to pack in one word.
290 const unsigned NPerReg = PackRegSize / EltVT.getSizeInBits();
291
292 return std::pair(NumElts / NPerReg, MVT::getVectorVT(EltVT, NPerReg));
293}
294
295/// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive
296/// legal-ish MVTs that compose it. Unlike ComputeValueVTs, this will legalize
297/// the types as required by the calling convention (with special handling for
298/// i8s).
299/// NOTE: This is a band-aid for code that expects ComputeValueVTs to return the
300/// same number of types as the Ins/Outs arrays in LowerFormalArguments,
301/// LowerCall, and LowerReturn.
302static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
303 LLVMContext &Ctx, CallingConv::ID CallConv,
304 Type *Ty, SmallVectorImpl<EVT> &ValueVTs,
306 uint64_t StartingOffset = 0) {
307 SmallVector<EVT, 16> TempVTs;
308 SmallVector<uint64_t, 16> TempOffsets;
309 ComputeValueVTs(TLI, DL, Ty, TempVTs, /*MemVTs=*/nullptr, &TempOffsets,
310 StartingOffset);
311
312 for (const auto [VT, Off] : zip(TempVTs, TempOffsets)) {
313 MVT RegisterVT = TLI.getRegisterTypeForCallingConv(Ctx, CallConv, VT);
314 unsigned NumRegs = TLI.getNumRegistersForCallingConv(Ctx, CallConv, VT);
315
316 // Since we actually can load/store b8, we need to ensure that we'll use
317 // the original sized type for any i8s or i8 vectors.
318 if (VT.getScalarType() == MVT::i8) {
319 if (RegisterVT == MVT::i16)
320 RegisterVT = MVT::i8;
321 else if (RegisterVT == MVT::v2i16)
322 RegisterVT = MVT::v2i8;
323 else
324 assert(RegisterVT == MVT::v4i8 &&
325 "Expected v4i8, v2i16, or i16 for i8 RegisterVT");
326 }
327
328 // TODO: This is horribly incorrect for cases where the vector elements are
329 // not a multiple of bytes (ex i1) and legal or i8. However, this problem
330 // has existed for as long as NVPTX has and no one has complained, so we'll
331 // leave it for now.
332 for (unsigned I : seq(NumRegs)) {
333 ValueVTs.push_back(RegisterVT);
334 Offsets.push_back(Off + I * RegisterVT.getStoreSize());
335 }
336 }
337}
338
339// We return an EVT that can hold N VTs
340// If the VT is a vector, the resulting EVT is a flat vector with the same
341// element type as VT's element type.
342static EVT getVectorizedVT(EVT VT, unsigned N, LLVMContext &C) {
343 if (N == 1)
344 return VT;
345
346 return VT.isVector() ? EVT::getVectorVT(C, VT.getScalarType(),
347 VT.getVectorNumElements() * N)
348 : EVT::getVectorVT(C, VT, N);
349}
350
352 const SDLoc &dl, SelectionDAG &DAG) {
353 if (V.getValueType() == VT) {
354 assert(I == 0 && "Index must be 0 for scalar value");
355 return V;
356 }
357
358 if (!VT.isVector())
359 return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT, V,
360 DAG.getVectorIdxConstant(I, dl));
361
362 return DAG.getNode(
363 ISD::EXTRACT_SUBVECTOR, dl, VT, V,
365}
366
367template <typename T>
368static inline SDValue getBuildVectorizedValue(unsigned N, const SDLoc &dl,
369 SelectionDAG &DAG, T GetElement) {
370 if (N == 1)
371 return GetElement(0);
372
374 for (const unsigned I : llvm::seq(N)) {
375 SDValue Val = GetElement(I);
376 if (Val.getValueType().isVector())
377 DAG.ExtractVectorElements(Val, Values);
378 else
379 Values.push_back(Val);
380 }
381
382 EVT VT = EVT::getVectorVT(*DAG.getContext(), Values[0].getValueType(),
383 Values.size());
384 return DAG.getBuildVector(VT, dl, Values);
385}
386
387/// PromoteScalarIntegerPTX
388/// Used to make sure the arguments/returns are suitable for passing
389/// and promote them to a larger size if they're not.
390///
391/// The promoted type is placed in \p PromoteVT if the function returns true.
393 if (VT.isScalarInteger()) {
394 switch (PowerOf2Ceil(VT.getFixedSizeInBits())) {
395 default:
397 "Promotion is not suitable for scalars of size larger than 64-bits");
398 case 1:
399 return MVT::i1;
400 case 2:
401 case 4:
402 case 8:
403 return MVT::i8;
404 case 16:
405 return MVT::i16;
406 case 32:
407 return MVT::i32;
408 case 64:
409 return MVT::i64;
410 }
411 }
412 return VT;
413}
414
415// Check whether we can merge loads/stores of some of the pieces of a
416// flattened function parameter or return value into a single vector
417// load/store.
418//
419// The flattened parameter is represented as a list of EVTs and
420// offsets, and the whole structure is aligned to ParamAlignment. This
421// function determines whether we can load/store pieces of the
422// parameter starting at index Idx using a single vectorized op of
423// size AccessSize. If so, it returns the number of param pieces
424// covered by the vector op. Otherwise, it returns 1.
425template <typename T>
427 unsigned Idx, uint32_t AccessSize, const SmallVectorImpl<EVT> &ValueVTs,
428 const SmallVectorImpl<T> &Offsets, Align ParamAlignment) {
429
430 // Can't vectorize if param alignment is not sufficient.
431 if (ParamAlignment < AccessSize)
432 return 1;
433 // Can't vectorize if offset is not aligned.
434 if (Offsets[Idx] & (AccessSize - 1))
435 return 1;
436
437 EVT EltVT = ValueVTs[Idx];
438 unsigned EltSize = EltVT.getStoreSize();
439
440 // Element is too large to vectorize.
441 if (EltSize >= AccessSize)
442 return 1;
443
444 unsigned NumElts = AccessSize / EltSize;
445 // Can't vectorize if AccessBytes if not a multiple of EltSize.
446 if (AccessSize != EltSize * NumElts)
447 return 1;
448
449 // We don't have enough elements to vectorize.
450 if (Idx + NumElts > ValueVTs.size())
451 return 1;
452
453 // PTX ISA can only deal with 2- and 4-element vector ops.
454 if (NumElts != 4 && NumElts != 2)
455 return 1;
456
457 for (unsigned j = Idx + 1; j < Idx + NumElts; ++j) {
458 // Types do not match.
459 if (ValueVTs[j] != EltVT)
460 return 1;
461
462 // Elements are not contiguous.
463 if (Offsets[j] - Offsets[j - 1] != EltSize)
464 return 1;
465 }
466 // OK. We can vectorize ValueVTs[i..i+NumElts)
467 return NumElts;
468}
469
470// Computes whether and how we can vectorize the loads/stores of a
471// flattened function parameter or return value.
472//
473// The flattened parameter is represented as the list of ValueVTs and
474// Offsets, and is aligned to ParamAlignment bytes. We return a vector
475// of the same size as ValueVTs indicating how each piece should be
476// loaded/stored (i.e. as a scalar, or as part of a vector
477// load/store).
478template <typename T>
481 const SmallVectorImpl<T> &Offsets, Align ParamAlignment,
482 bool IsVAArg = false) {
483 // Set vector size to match ValueVTs and mark all elements as
484 // scalars by default.
485
486 if (IsVAArg)
487 return SmallVector<unsigned>(ValueVTs.size(), 1);
488
489 SmallVector<unsigned, 16> VectorInfo;
490
491 const auto GetNumElts = [&](unsigned I) -> unsigned {
492 for (const unsigned AccessSize : {16, 8, 4, 2}) {
493 const unsigned NumElts = canMergeParamLoadStoresStartingAt(
494 I, AccessSize, ValueVTs, Offsets, ParamAlignment);
495 assert((NumElts == 1 || NumElts == 2 || NumElts == 4) &&
496 "Unexpected vectorization size");
497 if (NumElts != 1)
498 return NumElts;
499 }
500 return 1;
501 };
502
503 // Check what we can vectorize using 128/64/32-bit accesses.
504 for (unsigned I = 0, E = ValueVTs.size(); I != E;) {
505 const unsigned NumElts = GetNumElts(I);
506 VectorInfo.push_back(NumElts);
507 I += NumElts;
508 }
509 assert(std::accumulate(VectorInfo.begin(), VectorInfo.end(), 0u) ==
510 ValueVTs.size());
511 return VectorInfo;
512}
513
514// NVPTXTargetLowering Constructor.
516 const NVPTXSubtarget &STI)
517 : TargetLowering(TM, STI), nvTM(&TM), STI(STI), GlobalUniqueCallSite(0) {
518 // always lower memset, memcpy, and memmove intrinsics to load/store
519 // instructions, rather
520 // then generating calls to memset, mempcy or memmove.
524
527
528 // Jump is Expensive. Don't create extra control flow for 'and', 'or'
529 // condition branches.
530 setJumpIsExpensive(true);
531
532 // Wide divides are _very_ slow. Try to reduce the width of the divide if
533 // possible.
534 addBypassSlowDiv(64, 32);
535
536 // By default, use the Source scheduling
537 if (sched4reg)
539 else
541
542 auto setFP16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
543 LegalizeAction NoF16Action) {
544 bool IsOpSupported = STI.allowFP16Math();
545 switch (Op) {
546 // Several FP16 instructions are available on sm_80 only.
547 case ISD::FMINNUM:
548 case ISD::FMAXNUM:
549 case ISD::FMAXNUM_IEEE:
550 case ISD::FMINNUM_IEEE:
551 case ISD::FMAXIMUM:
552 case ISD::FMINIMUM:
553 case ISD::FMAXIMUMNUM:
554 case ISD::FMINIMUMNUM:
555 IsOpSupported &= STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70;
556 break;
557 case ISD::FEXP2:
558 IsOpSupported &= STI.getSmVersion() >= 75 && STI.getPTXVersion() >= 70;
559 break;
560 }
561 setOperationAction(Op, VT, IsOpSupported ? Action : NoF16Action);
562 };
563
564 auto setBF16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
565 LegalizeAction NoBF16Action) {
566 bool IsOpSupported = STI.hasNativeBF16Support(Op);
568 Op, VT, IsOpSupported ? Action : NoBF16Action);
569 };
570
571 auto setI16x2OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
572 LegalizeAction NoI16x2Action) {
573 bool IsOpSupported = false;
574 // instructions are available on sm_90 only
575 switch (Op) {
576 case ISD::ADD:
577 case ISD::SMAX:
578 case ISD::SMIN:
579 case ISD::UMIN:
580 case ISD::UMAX:
581 IsOpSupported = STI.getSmVersion() >= 90 && STI.getPTXVersion() >= 80;
582 break;
583 }
584 setOperationAction(Op, VT, IsOpSupported ? Action : NoI16x2Action);
585 };
586
587 addRegisterClass(MVT::i1, &NVPTX::B1RegClass);
588 addRegisterClass(MVT::i16, &NVPTX::B16RegClass);
589 addRegisterClass(MVT::v2i16, &NVPTX::B32RegClass);
590 addRegisterClass(MVT::v4i8, &NVPTX::B32RegClass);
591 addRegisterClass(MVT::i32, &NVPTX::B32RegClass);
592 addRegisterClass(MVT::i64, &NVPTX::B64RegClass);
593 addRegisterClass(MVT::f32, &NVPTX::B32RegClass);
594 addRegisterClass(MVT::f64, &NVPTX::B64RegClass);
595 addRegisterClass(MVT::f16, &NVPTX::B16RegClass);
596 addRegisterClass(MVT::v2f16, &NVPTX::B32RegClass);
597 addRegisterClass(MVT::bf16, &NVPTX::B16RegClass);
598 addRegisterClass(MVT::v2bf16, &NVPTX::B32RegClass);
599
600 if (STI.hasF32x2Instructions()) {
601 addRegisterClass(MVT::v2f32, &NVPTX::B64RegClass);
602 addRegisterClass(MVT::v2i32, &NVPTX::B64RegClass);
603 }
604
605 // Conversion to/from FP16/FP16x2 is always legal.
610
611 setOperationAction(ISD::READCYCLECOUNTER, MVT::i64, Legal);
612 if (STI.getSmVersion() >= 30 && STI.getPTXVersion() > 31)
613 setOperationAction(ISD::READSTEADYCOUNTER, MVT::i64, Legal);
614
615 setFP16OperationAction(ISD::SETCC, MVT::f16, Legal, Promote);
616 setFP16OperationAction(ISD::SETCC, MVT::v2f16, Legal, Expand);
617
618 // Conversion to/from BFP16/BFP16x2 is always legal.
623
624 setBF16OperationAction(ISD::SETCC, MVT::v2bf16, Legal, Expand);
625 setBF16OperationAction(ISD::SETCC, MVT::bf16, Legal, Promote);
626 if (getOperationAction(ISD::SETCC, MVT::bf16) == Promote)
627 AddPromotedToType(ISD::SETCC, MVT::bf16, MVT::f32);
628
629 // Conversion to/from i16/i16x2 is always legal.
634
639
640 // No support for these operations with v2f32/v2i32
641 setOperationAction(ISD::INSERT_VECTOR_ELT, {MVT::v2f32, MVT::v2i32}, Expand);
642 setOperationAction(ISD::VECTOR_SHUFFLE, {MVT::v2f32, MVT::v2i32}, Expand);
643
646 MVT::v2i32, Expand);
647
648 // Need custom lowering in case the index is dynamic.
649 if (STI.hasF32x2Instructions())
650 setOperationAction(ISD::EXTRACT_VECTOR_ELT, {MVT::v2f32, MVT::v2i32},
651 Custom);
652
653 // Custom conversions to/from v2i8.
654 setOperationAction(ISD::BITCAST, MVT::v2i8, Custom);
655
656 // Only logical ops can be done on v4i8/v2i32 directly, others must be done
657 // elementwise.
674 {MVT::v4i8, MVT::v2i32}, Expand);
675
676 // Operations not directly supported by NVPTX.
677 for (MVT VT : {MVT::bf16, MVT::f16, MVT::v2bf16, MVT::v2f16, MVT::f32,
678 MVT::v2f32, MVT::f64, MVT::i1, MVT::i8, MVT::i16, MVT::v2i16,
679 MVT::v4i8, MVT::i32, MVT::v2i32, MVT::i64}) {
681 setOperationAction(ISD::BR_CC, VT, Expand);
682 }
683
684 // We don't want ops like FMINIMUM or UMAX to be lowered to SETCC+VSELECT.
685 setOperationAction(ISD::VSELECT, {MVT::v2f32, MVT::v2i32}, Expand);
686
687 // Some SIGN_EXTEND_INREG can be done using cvt instruction.
688 // For others we will expand to a SHL/SRA pair.
694 setOperationAction(ISD::SIGN_EXTEND_INREG, {MVT::v2i16, MVT::v2i32}, Expand);
695
702
705
707 {MVT::i8, MVT::i16, MVT::v2i16, MVT::i32, MVT::i64},
708 Expand);
709
710 if (STI.hasHWROT32()) {
713 Custom);
714 }
715
716 setOperationAction(ISD::BR_JT, MVT::Other, Custom);
717 setOperationAction(ISD::BRIND, MVT::Other, Expand);
718
719 // We want to legalize constant related memmove and memcopy
720 // intrinsics.
722
723 // FP extload/truncstore is not legal in PTX. We need to expand all these.
724 for (auto FloatVTs :
726 for (MVT ValVT : FloatVTs) {
727 for (MVT MemVT : FloatVTs) {
728 setLoadExtAction(ISD::EXTLOAD, ValVT, MemVT, Expand);
729 setTruncStoreAction(ValVT, MemVT, Expand);
730 }
731 }
732 }
733
734 // To improve CodeGen we'll legalize any-extend loads to zext loads. This is
735 // how they'll be lowered in ISel anyway, and by doing this a little earlier
736 // we allow for more DAG combine opportunities.
737 for (auto IntVTs :
739 for (MVT ValVT : IntVTs)
740 for (MVT MemVT : IntVTs)
741 if (isTypeLegal(ValVT))
742 setLoadExtAction(ISD::EXTLOAD, ValVT, MemVT, Custom);
743
744 // PTX does not support load / store predicate registers
745 setOperationAction({ISD::LOAD, ISD::STORE}, MVT::i1, Custom);
746 for (MVT VT : MVT::integer_valuetypes()) {
748 Promote);
749 setTruncStoreAction(VT, MVT::i1, Expand);
750 }
751
752 // Disable generations of extload/truncstore for v2i32/v2i16/v2i8. The generic
753 // expansion for these nodes when they are unaligned is incorrect if the
754 // type is a vector.
755 //
756 // TODO: Fix the generic expansion for these nodes found in
757 // TargetLowering::expandUnalignedLoad/Store.
759 MVT::v2i8, Expand);
761 {MVT::v2i8, MVT::v2i16}, Expand);
762 setTruncStoreAction(MVT::v2i16, MVT::v2i8, Expand);
763 setTruncStoreAction(MVT::v2i32, MVT::v2i16, Expand);
764 setTruncStoreAction(MVT::v2i32, MVT::v2i8, Expand);
765
766 // Register custom handling for illegal type loads/stores. We'll try to custom
767 // lower almost all illegal types and logic in the lowering will discard cases
768 // we can't handle.
769 setOperationAction({ISD::LOAD, ISD::STORE}, {MVT::i128, MVT::i256, MVT::f128},
770 Custom);
772 if (!isTypeLegal(VT) && VT.getStoreSizeInBits() <= 256)
773 setOperationAction({ISD::STORE, ISD::LOAD, ISD::MSTORE, ISD::MLOAD}, VT,
774 Custom);
775
776 // Custom legalization for LDU intrinsics.
777 // TODO: The logic to lower these is not very robust and we should rewrite it.
778 // Perhaps LDU should not be represented as an intrinsic at all.
781 if (IsPTXVectorType(VT))
783
787 MVT::i1, Expand);
788
789 // This is legal in NVPTX
794
795 setOperationAction(ISD::DYNAMIC_STACKALLOC, {MVT::i32, MVT::i64}, Custom);
796 setOperationAction({ISD::STACKRESTORE, ISD::STACKSAVE}, MVT::Other, Custom);
797
798 // TRAP can be lowered to PTX trap
799 setOperationAction(ISD::TRAP, MVT::Other, Legal);
800 // DEBUGTRAP can be lowered to PTX brkpt
801 setOperationAction(ISD::DEBUGTRAP, MVT::Other, Legal);
802
803 // Support varargs.
804 setOperationAction(ISD::VASTART, MVT::Other, Custom);
805 setOperationAction(ISD::VAARG, MVT::Other, Custom);
806 setOperationAction(ISD::VACOPY, MVT::Other, Expand);
807 setOperationAction(ISD::VAEND, MVT::Other, Expand);
808
810 {MVT::i16, MVT::i32, MVT::i64}, Legal);
811
813 Promote);
816
817 setI16x2OperationAction(ISD::ABS, MVT::v2i16, Legal, Custom);
818 setI16x2OperationAction(ISD::SMIN, MVT::v2i16, Legal, Custom);
819 setI16x2OperationAction(ISD::SMAX, MVT::v2i16, Legal, Custom);
820 setI16x2OperationAction(ISD::UMIN, MVT::v2i16, Legal, Custom);
821 setI16x2OperationAction(ISD::UMAX, MVT::v2i16, Legal, Custom);
822 setI16x2OperationAction(ISD::CTPOP, MVT::v2i16, Legal, Expand);
823 setI16x2OperationAction(ISD::CTLZ, MVT::v2i16, Legal, Expand);
824
825 setI16x2OperationAction(ISD::ADD, MVT::v2i16, Legal, Custom);
826 setI16x2OperationAction(ISD::SUB, MVT::v2i16, Legal, Custom);
827 setI16x2OperationAction(ISD::MUL, MVT::v2i16, Legal, Custom);
828 setI16x2OperationAction(ISD::SHL, MVT::v2i16, Legal, Custom);
829 setI16x2OperationAction(ISD::SREM, MVT::v2i16, Legal, Custom);
830 setI16x2OperationAction(ISD::UREM, MVT::v2i16, Legal, Custom);
831
832 // Other arithmetic and logic ops are unsupported.
836 {MVT::v2i16, MVT::v2i32}, Expand);
837
838 // v2i32 is not supported for any arithmetic operations
843 MVT::v2i32, Expand);
844
849 if (STI.getPTXVersion() >= 43) {
854 }
855
857 setOperationAction(ISD::CTTZ, {MVT::v2i16, MVT::v2i32}, Expand);
860
861 // PTX does not directly support SELP of i1, so promote to i32 first
863
864 // PTX cannot multiply two i64s in a single instruction.
867
868 // We have some custom DAG combine patterns for these nodes
871 ISD::FADD, ISD::FMAXNUM, ISD::FMINNUM,
872 ISD::FMAXIMUM, ISD::FMINIMUM, ISD::FMAXIMUMNUM,
873 ISD::FMINIMUMNUM, ISD::MUL, ISD::SHL,
875 ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::LOAD,
876 ISD::STORE, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND});
877
878 // setcc for f16x2 and bf16x2 needs special handling to prevent
879 // legalizer's attempt to scalarize it due to v2i1 not being legal.
880 if (STI.allowFP16Math() || STI.hasBF16Math())
882
883 // Vector reduction operations. These may be turned into shuffle or tree
884 // reductions depending on what instructions are available for each type.
886 MVT EltVT = VT.getVectorElementType();
887 if (EltVT == MVT::f32 || EltVT == MVT::f64) {
888 setOperationAction({ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
889 ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
890 VT, Custom);
891 }
892 }
893
894 // Promote fp16 arithmetic if fp16 hardware isn't available or the
895 // user passed --nvptx-no-fp16-math. The flag is useful because,
896 // although sm_53+ GPUs have some sort of FP16 support in
897 // hardware, only sm_53 and sm_60 have full implementation. Others
898 // only have token amount of hardware and are likely to run faster
899 // by using fp32 units instead.
900 for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB, ISD::FMA}) {
901 setFP16OperationAction(Op, MVT::f16, Legal, Promote);
902 setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
903 setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
904 // bf16 must be promoted to f32.
905 setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
906 if (getOperationAction(Op, MVT::bf16) == Promote)
907 AddPromotedToType(Op, MVT::bf16, MVT::f32);
908 setOperationAction(Op, MVT::v2f32,
909 STI.hasF32x2Instructions() ? Legal : Expand);
910 }
911
912 // On SM80, we select add/mul/sub as fma to avoid promotion to float
913 for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB}) {
914 for (const auto &VT : {MVT::bf16, MVT::v2bf16}) {
915 if (!STI.hasNativeBF16Support(Op) && STI.hasNativeBF16Support(ISD::FMA)) {
917 }
918 }
919 }
920
921 // f16/f16x2 neg was introduced in PTX 60, SM_53.
922 const bool IsFP16FP16x2NegAvailable = STI.getSmVersion() >= 53 &&
923 STI.getPTXVersion() >= 60 &&
924 STI.allowFP16Math();
925 for (const auto &VT : {MVT::f16, MVT::v2f16})
926 setOperationAction(ISD::FNEG, VT,
927 IsFP16FP16x2NegAvailable ? Legal : Expand);
928
929 setBF16OperationAction(ISD::FNEG, MVT::bf16, Legal, Expand);
930 setBF16OperationAction(ISD::FNEG, MVT::v2bf16, Legal, Expand);
931 setOperationAction(ISD::FNEG, MVT::v2f32, Expand);
932 // (would be) Library functions.
933
934 // These map to conversion instructions for scalar FP types.
935 for (const auto &Op : {ISD::FCEIL, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT,
936 ISD::FROUNDEVEN, ISD::FTRUNC}) {
937 setOperationAction(Op, MVT::f16, Legal);
938 setOperationAction(Op, MVT::f32, Legal);
939 setOperationAction(Op, MVT::f64, Legal);
940 setOperationAction(Op, MVT::v2f16, Expand);
941 setOperationAction(Op, MVT::v2bf16, Expand);
942 setOperationAction(Op, MVT::v2f32, Expand);
943 setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
944 if (getOperationAction(Op, MVT::bf16) == Promote)
945 AddPromotedToType(Op, MVT::bf16, MVT::f32);
946 }
947
948 if (STI.getSmVersion() < 80 || STI.getPTXVersion() < 71) {
949 setOperationAction(ISD::BF16_TO_FP, MVT::f32, Expand);
950 }
951 if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
952 for (MVT VT : {MVT::bf16, MVT::f32, MVT::f64}) {
953 setOperationAction(ISD::FP_EXTEND, VT, Custom);
955 }
956 }
957
958 // Expand v2f32 = fp_extend
959 setOperationAction(ISD::FP_EXTEND, MVT::v2f32, Expand);
960 // Expand v2[b]f16 = fp_round v2f32
961 setOperationAction(ISD::FP_ROUND, {MVT::v2bf16, MVT::v2f16}, Expand);
962
963 // sm_80 only has conversions between f32 and bf16. Custom lower all other
964 // bf16 conversions.
965 if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
966 for (MVT VT : {MVT::i1, MVT::i16, MVT::i32, MVT::i64}) {
969 VT, Custom);
970 }
973 MVT::bf16, Custom);
974 }
975
976 setOperationAction(ISD::FROUND, MVT::f16, Promote);
977 setOperationAction(ISD::FROUND, MVT::v2f16, Expand);
978 setOperationAction(ISD::FROUND, MVT::v2bf16, Expand);
979 setOperationAction(ISD::FROUND, MVT::f32, Custom);
980 setOperationAction(ISD::FROUND, MVT::f64, Custom);
981 setOperationAction(ISD::FROUND, MVT::bf16, Promote);
982 AddPromotedToType(ISD::FROUND, MVT::bf16, MVT::f32);
983
984 // 'Expand' implements FCOPYSIGN without calling an external library.
991
992 // These map to corresponding instructions for f32/f64. f16 must be
993 // promoted to f32. v2f16 is expanded to f16, which is then promoted
994 // to f32.
995 for (const auto &Op :
996 {ISD::FDIV, ISD::FREM, ISD::FSQRT, ISD::FSIN, ISD::FCOS, ISD::FTANH}) {
997 setOperationAction(Op, MVT::f16, Promote);
998 setOperationAction(Op, MVT::f32, Legal);
999 // only div/rem/sqrt are legal for f64
1000 if (Op == ISD::FDIV || Op == ISD::FREM || Op == ISD::FSQRT) {
1001 setOperationAction(Op, MVT::f64, Legal);
1002 }
1003 setOperationAction(Op, {MVT::v2f16, MVT::v2bf16, MVT::v2f32}, Expand);
1004 setOperationAction(Op, MVT::bf16, Promote);
1005 AddPromotedToType(Op, MVT::bf16, MVT::f32);
1006 }
1007 setOperationAction(ISD::FREM, {MVT::f32, MVT::f64}, Custom);
1008
1009 setOperationAction(ISD::FABS, {MVT::f32, MVT::f64}, Legal);
1010 setOperationAction(ISD::FABS, MVT::v2f32, Expand);
1011 if (STI.getPTXVersion() >= 65) {
1012 setFP16OperationAction(ISD::FABS, MVT::f16, Legal, Promote);
1013 setFP16OperationAction(ISD::FABS, MVT::v2f16, Legal, Expand);
1014 } else {
1015 setOperationAction(ISD::FABS, MVT::f16, Promote);
1016 setOperationAction(ISD::FABS, MVT::v2f16, Expand);
1017 }
1018 setBF16OperationAction(ISD::FABS, MVT::v2bf16, Legal, Expand);
1019 setBF16OperationAction(ISD::FABS, MVT::bf16, Legal, Promote);
1020 if (getOperationAction(ISD::FABS, MVT::bf16) == Promote)
1021 AddPromotedToType(ISD::FABS, MVT::bf16, MVT::f32);
1022
1023 for (const auto &Op :
1024 {ISD::FMINNUM, ISD::FMAXNUM, ISD::FMINIMUMNUM, ISD::FMAXIMUMNUM}) {
1025 setOperationAction(Op, MVT::f32, Legal);
1026 setOperationAction(Op, MVT::f64, Legal);
1027 setFP16OperationAction(Op, MVT::f16, Legal, Promote);
1028 setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
1029 setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
1030 setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
1031 if (getOperationAction(Op, MVT::bf16) == Promote)
1032 AddPromotedToType(Op, MVT::bf16, MVT::f32);
1033 setOperationAction(Op, MVT::v2f32, Expand);
1034 }
1035 bool SupportsF32MinMaxNaN =
1036 STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70;
1037 for (const auto &Op : {ISD::FMINIMUM, ISD::FMAXIMUM}) {
1038 setOperationAction(Op, MVT::f32, SupportsF32MinMaxNaN ? Legal : Expand);
1039 setFP16OperationAction(Op, MVT::f16, Legal, Expand);
1040 setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
1041 setBF16OperationAction(Op, MVT::bf16, Legal, Expand);
1042 setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
1043 setOperationAction(Op, MVT::v2f32, Expand);
1044 }
1045
1046 // Custom lowering for inline asm with 128-bit operands
1049
1050 // FEXP2 support:
1051 // - f32
1052 // - f16/f16x2 (sm_70+, PTX 7.0+)
1053 // - bf16/bf16x2 (sm_90+, PTX 7.8+)
1054 // When f16/bf16 types aren't supported, they are promoted/expanded to f32.
1055 setOperationAction(ISD::FEXP2, MVT::f32, Legal);
1056 setOperationAction(ISD::FEXP2, MVT::v2f32, Expand);
1057 setFP16OperationAction(ISD::FEXP2, MVT::f16, Legal, Promote);
1058 setFP16OperationAction(ISD::FEXP2, MVT::v2f16, Legal, Expand);
1059 setBF16OperationAction(ISD::FEXP2, MVT::bf16, Legal, Promote);
1060 setBF16OperationAction(ISD::FEXP2, MVT::v2bf16, Legal, Expand);
1061
1062 // FLOG2 supports f32 only
1063 // f16/bf16 types aren't supported, but they are promoted/expanded to f32.
1064 if (UseApproxLog2F32) {
1065 setOperationAction(ISD::FLOG2, MVT::f32, Legal);
1066 setOperationPromotedToType(ISD::FLOG2, MVT::f16, MVT::f32);
1067 setOperationPromotedToType(ISD::FLOG2, MVT::bf16, MVT::f32);
1068 setOperationAction(ISD::FLOG2, {MVT::v2f16, MVT::v2bf16, MVT::v2f32},
1069 Expand);
1070 }
1071
1072 setOperationAction(ISD::ADDRSPACECAST, {MVT::i32, MVT::i64}, Custom);
1073
1074 setOperationAction(ISD::ATOMIC_LOAD_SUB, {MVT::i32, MVT::i64}, Expand);
1075
1076 // atom.b128 is legal in PTX but since we don't represent i128 as a legal
1077 // type, we need to custom lower it.
1078 setOperationAction({ISD::ATOMIC_CMP_SWAP, ISD::ATOMIC_SWAP}, MVT::i128,
1079 Custom);
1080
1081 // Now deduce the information based on the above mentioned
1082 // actions
1083 computeRegisterProperties(STI.getRegisterInfo());
1084
1085 // PTX support for 16-bit CAS is emulated. Only use 32+
1086 setMinCmpXchgSizeInBits(STI.getMinCmpXchgSizeInBits());
1087 setMaxAtomicSizeInBitsSupported(STI.hasAtomSwap128() ? 128 : 64);
1089
1090 // Custom lowering for tcgen05.ld vector operands
1092 {MVT::v2i32, MVT::v4i32, MVT::v8i32, MVT::v16i32,
1093 MVT::v32i32, MVT::v64i32, MVT::v128i32},
1094 Custom);
1095
1096 // Custom lowering for tcgen05.st vector operands
1098 {MVT::v2i32, MVT::v4i32, MVT::v8i32, MVT::v16i32,
1099 MVT::v32i32, MVT::v64i32, MVT::v128i32, MVT::Other},
1100 Custom);
1101
1102 // Enable custom lowering for the following:
1103 // * MVT::i128 - clusterlaunchcontrol
1104 // * MVT::i32 - prmt
1105 // * MVT::v4f32 - cvt_rs fp{4/6/8}x4 intrinsics
1106 // * MVT::Other - internal.addrspace.wrap
1108 {MVT::i32, MVT::i128, MVT::v4f32, MVT::Other}, Custom);
1109
1110 // Custom lowering for bswap
1111 setOperationAction(ISD::BSWAP, {MVT::i16, MVT::i32, MVT::i64, MVT::v2i16},
1112 Custom);
1113}
1114
1117 if (!VT.isScalableVector() && VT.getVectorNumElements() != 1 &&
1118 VT.getScalarType() == MVT::i1)
1119 return TypeSplitVector;
1121}
1122
1124 int Enabled, int &ExtraSteps,
1125 bool &UseOneConst,
1126 bool Reciprocal) const {
1129 return SDValue();
1130
1131 if (ExtraSteps == ReciprocalEstimate::Unspecified)
1132 ExtraSteps = 0;
1133
1134 SDLoc DL(Operand);
1135 EVT VT = Operand.getValueType();
1136 bool Ftz = useF32FTZ(DAG.getMachineFunction());
1137
1138 auto MakeIntrinsicCall = [&](Intrinsic::ID IID) {
1139 return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
1140 DAG.getConstant(IID, DL, MVT::i32), Operand);
1141 };
1142
1143 // The sqrt and rsqrt refinement processes assume we always start out with an
1144 // approximation of the rsqrt. Therefore, if we're going to do any refinement
1145 // (i.e. ExtraSteps > 0), we must return an rsqrt. But if we're *not* doing
1146 // any refinement, we must return a regular sqrt.
1147 if (Reciprocal || ExtraSteps > 0) {
1148 if (VT == MVT::f32)
1149 return MakeIntrinsicCall(Ftz ? Intrinsic::nvvm_rsqrt_approx_ftz_f
1150 : Intrinsic::nvvm_rsqrt_approx_f);
1151 else if (VT == MVT::f64)
1152 return MakeIntrinsicCall(Intrinsic::nvvm_rsqrt_approx_d);
1153 else
1154 return SDValue();
1155 } else {
1156 if (VT == MVT::f32)
1157 return MakeIntrinsicCall(Ftz ? Intrinsic::nvvm_sqrt_approx_ftz_f
1158 : Intrinsic::nvvm_sqrt_approx_f);
1159 else {
1160 // There's no sqrt.approx.f64 instruction, so we emit
1161 // reciprocal(rsqrt(x)). This is faster than
1162 // select(x == 0, 0, x * rsqrt(x)). (In fact, it's faster than plain
1163 // x * rsqrt(x).)
1164 return DAG.getNode(
1166 DAG.getConstant(Intrinsic::nvvm_rcp_approx_ftz_d, DL, MVT::i32),
1167 MakeIntrinsicCall(Intrinsic::nvvm_rsqrt_approx_d));
1168 }
1169 }
1170}
1171
1173 const DataLayout &DL, Type *RetTy, const ArgListTy &Args,
1175 std::optional<unsigned> FirstVAArg, const CallBase &CB,
1176 unsigned UniqueCallSite) const {
1177 auto PtrVT = getPointerTy(DL);
1178
1179 std::string Prototype;
1180 raw_string_ostream O(Prototype);
1181 O << "prototype_" << UniqueCallSite << " : .callprototype ";
1182
1183 if (RetTy->isVoidTy()) {
1184 O << "()";
1185 } else {
1186 O << "(";
1187 if (shouldPassAsArray(RetTy)) {
1188 const Align RetAlign = getArgumentAlignment(&CB, RetTy, 0, DL);
1189 O << ".param .align " << RetAlign.value() << " .b8 _["
1190 << DL.getTypeAllocSize(RetTy) << "]";
1191 } else if (RetTy->isFloatingPointTy() || RetTy->isIntegerTy()) {
1192 unsigned size = 0;
1193 if (auto *ITy = dyn_cast<IntegerType>(RetTy)) {
1194 size = ITy->getBitWidth();
1195 } else {
1196 assert(RetTy->isFloatingPointTy() &&
1197 "Floating point type expected here");
1198 size = RetTy->getPrimitiveSizeInBits();
1199 }
1200 // PTX ABI requires all scalar return values to be at least 32
1201 // bits in size. fp16 normally uses .b16 as its storage type in
1202 // PTX, so its size must be adjusted here, too.
1204
1205 O << ".param .b" << size << " _";
1206 } else if (isa<PointerType>(RetTy)) {
1207 O << ".param .b" << PtrVT.getSizeInBits() << " _";
1208 } else {
1209 llvm_unreachable("Unknown return type");
1210 }
1211 O << ") ";
1212 }
1213 O << "_ (";
1214
1215 bool first = true;
1216
1217 const unsigned NumArgs = FirstVAArg.value_or(Args.size());
1218 auto AllOuts = ArrayRef(Outs);
1219 for (const unsigned I : llvm::seq(NumArgs)) {
1220 const auto ArgOuts =
1221 AllOuts.take_while([I](auto O) { return O.OrigArgIndex == I; });
1222 AllOuts = AllOuts.drop_front(ArgOuts.size());
1223
1224 Type *Ty = Args[I].Ty;
1225 if (!first) {
1226 O << ", ";
1227 }
1228 first = false;
1229
1230 if (ArgOuts[0].Flags.isByVal()) {
1231 // Indirect calls need strict ABI alignment so we disable optimizations by
1232 // not providing a function to optimize.
1233 Type *ETy = Args[I].IndirectType;
1234 Align InitialAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
1235 Align ParamByValAlign =
1236 getFunctionByValParamAlign(/*F=*/nullptr, ETy, InitialAlign, DL);
1237
1238 O << ".param .align " << ParamByValAlign.value() << " .b8 _["
1239 << ArgOuts[0].Flags.getByValSize() << "]";
1240 } else {
1241 if (shouldPassAsArray(Ty)) {
1242 Align ParamAlign =
1243 getArgumentAlignment(&CB, Ty, I + AttributeList::FirstArgIndex, DL);
1244 O << ".param .align " << ParamAlign.value() << " .b8 _["
1245 << DL.getTypeAllocSize(Ty) << "]";
1246 continue;
1247 }
1248 // i8 types in IR will be i16 types in SDAG
1249 assert((getValueType(DL, Ty) == ArgOuts[0].VT ||
1250 (getValueType(DL, Ty) == MVT::i8 && ArgOuts[0].VT == MVT::i16)) &&
1251 "type mismatch between callee prototype and arguments");
1252 // scalar type
1253 unsigned sz = 0;
1254 if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
1255 sz = promoteScalarArgumentSize(ITy->getBitWidth());
1256 } else if (isa<PointerType>(Ty)) {
1257 sz = PtrVT.getSizeInBits();
1258 } else {
1259 sz = Ty->getPrimitiveSizeInBits();
1260 }
1261 O << ".param .b" << sz << " _";
1262 }
1263 }
1264
1265 if (FirstVAArg)
1266 O << (first ? "" : ",") << " .param .align "
1267 << STI.getMaxRequiredAlignment() << " .b8 _[]";
1268 O << ")";
1269 if (shouldEmitPTXNoReturn(&CB, *nvTM))
1270 O << " .noreturn";
1271 O << ";";
1272
1273 return Prototype;
1274}
1275
1277 const Function *F, Type *Ty, unsigned Idx, const DataLayout &DL) const {
1278 return getAlign(*F, Idx).value_or(getFunctionParamOptimizedAlign(F, Ty, DL));
1279}
1280
1281Align NVPTXTargetLowering::getArgumentAlignment(const CallBase *CB, Type *Ty,
1282 unsigned Idx,
1283 const DataLayout &DL) const {
1284 if (!CB) {
1285 // CallSite is zero, fallback to ABI type alignment
1286 return DL.getABITypeAlign(Ty);
1287 }
1288
1289 const Function *DirectCallee = CB->getCalledFunction();
1290
1291 if (!DirectCallee) {
1292 // We don't have a direct function symbol, but that may be because of
1293 // constant cast instructions in the call.
1294
1295 // With bitcast'd call targets, the instruction will be the call
1296 if (const auto *CI = dyn_cast<CallInst>(CB)) {
1297 // Check if we have call alignment metadata
1298 if (MaybeAlign StackAlign = getAlign(*CI, Idx))
1299 return StackAlign.value();
1300 }
1301 DirectCallee = getMaybeBitcastedCallee(CB);
1302 }
1303
1304 // Check for function alignment information if we found that the
1305 // ultimate target is a Function
1306 if (DirectCallee)
1307 return getFunctionArgumentAlignment(DirectCallee, Ty, Idx, DL);
1308
1309 // Call is indirect, fall back to the ABI type alignment
1310 return DL.getABITypeAlign(Ty);
1311}
1312
1314 const GlobalAddressSDNode *Func) {
1315 if (!Func)
1316 return false;
1317 if (auto *CalleeFunc = dyn_cast<Function>(Func->getGlobal()))
1318 return CB->getFunctionType() != CalleeFunc->getFunctionType();
1319 return false;
1320}
1321
1323 const DataLayout &DL,
1324 const TargetLowering &TL) {
1325 if (Ptr->getOpcode() == ISD::FrameIndex) {
1326 auto Ty = TL.getPointerTy(DL, ADDRESS_SPACE_LOCAL);
1327 Ptr = DAG.getAddrSpaceCast(SDLoc(), Ty, Ptr, ADDRESS_SPACE_GENERIC,
1329
1331 }
1332
1333 // Peel of an addrspacecast to generic and load directly from the specific
1334 // address space.
1335 if (Ptr->getOpcode() == ISD::ADDRSPACECAST) {
1336 const auto *ASC = cast<AddrSpaceCastSDNode>(Ptr);
1337 if (ASC->getDestAddressSpace() == ADDRESS_SPACE_GENERIC) {
1338 Ptr = ASC->getOperand(0);
1339 return MachinePointerInfo(ASC->getSrcAddressSpace());
1340 }
1341 }
1342
1343 return MachinePointerInfo();
1344}
1345
1347 if (Flags.isSExt())
1348 return ISD::SIGN_EXTEND;
1349 if (Flags.isZExt())
1350 return ISD::ZERO_EXTEND;
1351 return ISD::ANY_EXTEND;
1352}
1353
1355 ISD::ArgFlagsTy Flags, SelectionDAG &DAG,
1356 SDLoc dl) {
1357 const EVT ActualVT = V.getValueType();
1358 assert((ActualVT == ExpectedVT ||
1359 (ExpectedVT.isInteger() && ActualVT.isInteger())) &&
1360 "Non-integer argument type size mismatch");
1361 if (ExpectedVT.bitsGT(ActualVT))
1362 return DAG.getNode(getExtOpcode(Flags), dl, ExpectedVT, V);
1363 if (ExpectedVT.bitsLT(ActualVT))
1364 return DAG.getNode(ISD::TRUNCATE, dl, ExpectedVT, V);
1365
1366 return V;
1367}
1368
1370 SmallVectorImpl<SDValue> &InVals) const {
1371
1372 if (CLI.IsVarArg && (STI.getPTXVersion() < 60 || STI.getSmVersion() < 30))
1374 "Support for variadic functions (unsized array parameter) introduced "
1375 "in PTX ISA version 6.0 and requires target sm_30.");
1376
1377 SelectionDAG &DAG = CLI.DAG;
1378 SDLoc dl = CLI.DL;
1379 const SmallVectorImpl<ISD::InputArg> &Ins = CLI.Ins;
1380 SDValue Callee = CLI.Callee;
1381 ArgListTy &Args = CLI.getArgs();
1382 Type *RetTy = CLI.RetTy;
1383 const CallBase *CB = CLI.CB;
1384 const DataLayout &DL = DAG.getDataLayout();
1385 LLVMContext &Ctx = *DAG.getContext();
1386
1387 const auto GetI32 = [&](const unsigned I) {
1388 return DAG.getConstant(I, dl, MVT::i32);
1389 };
1390
1391 const unsigned UniqueCallSite = GlobalUniqueCallSite++;
1392 const SDValue CallChain = CLI.Chain;
1393 const SDValue StartChain =
1394 DAG.getCALLSEQ_START(CallChain, UniqueCallSite, 0, dl);
1395 SDValue DeclareGlue = StartChain.getValue(1);
1396
1397 SmallVector<SDValue, 16> CallPrereqs{StartChain};
1398
1399 const auto MakeDeclareScalarParam = [&](SDValue Symbol, unsigned Size) {
1400 // PTX ABI requires integral types to be at least 32 bits in size. FP16 is
1401 // loaded/stored using i16, so it's handled here as well.
1402 const unsigned SizeBits = promoteScalarArgumentSize(Size * 8);
1403 SDValue Declare =
1404 DAG.getNode(NVPTXISD::DeclareScalarParam, dl, {MVT::Other, MVT::Glue},
1405 {StartChain, Symbol, GetI32(SizeBits), DeclareGlue});
1406 CallPrereqs.push_back(Declare);
1407 DeclareGlue = Declare.getValue(1);
1408 return Declare;
1409 };
1410
1411 const auto MakeDeclareArrayParam = [&](SDValue Symbol, Align Align,
1412 unsigned Size) {
1413 SDValue Declare = DAG.getNode(
1414 NVPTXISD::DeclareArrayParam, dl, {MVT::Other, MVT::Glue},
1415 {StartChain, Symbol, GetI32(Align.value()), GetI32(Size), DeclareGlue});
1416 CallPrereqs.push_back(Declare);
1417 DeclareGlue = Declare.getValue(1);
1418 return Declare;
1419 };
1420
1421 // Variadic arguments.
1422 //
1423 // Normally, for each argument, we declare a param scalar or a param
1424 // byte array in the .param space, and store the argument value to that
1425 // param scalar or array starting at offset 0.
1426 //
1427 // In the case of the first variadic argument, we declare a vararg byte array
1428 // with size 0. The exact size of this array isn't known at this point, so
1429 // it'll be patched later. All the variadic arguments will be stored to this
1430 // array at a certain offset (which gets tracked by 'VAOffset'). The offset is
1431 // initially set to 0, so it can be used for non-variadic arguments (which use
1432 // 0 offset) to simplify the code.
1433 //
1434 // After all vararg is processed, 'VAOffset' holds the size of the
1435 // vararg byte array.
1436 assert((CLI.IsVarArg || CLI.Args.size() == CLI.NumFixedArgs) &&
1437 "Non-VarArg function with extra arguments");
1438
1439 const unsigned FirstVAArg = CLI.NumFixedArgs; // position of first variadic
1440 unsigned VAOffset = 0; // current offset in the param array
1441
1442 const SDValue VADeclareParam =
1443 CLI.Args.size() > FirstVAArg
1444 ? MakeDeclareArrayParam(getCallParamSymbol(DAG, FirstVAArg, MVT::i32),
1445 Align(STI.getMaxRequiredAlignment()), 0)
1446 : SDValue();
1447
1448 // Args.size() and Outs.size() need not match.
1449 // Outs.size() will be larger
1450 // * if there is an aggregate argument with multiple fields (each field
1451 // showing up separately in Outs)
1452 // * if there is a vector argument with more than typical vector-length
1453 // elements (generally if more than 4) where each vector element is
1454 // individually present in Outs.
1455 // So a different index should be used for indexing into Outs/OutVals.
1456 // See similar issue in LowerFormalArguments.
1457 auto AllOuts = ArrayRef(CLI.Outs);
1458 auto AllOutVals = ArrayRef(CLI.OutVals);
1459 assert(AllOuts.size() == AllOutVals.size() &&
1460 "Outs and OutVals must be the same size");
1461 // Declare the .params or .reg need to pass values
1462 // to the function
1463 for (const auto E : llvm::enumerate(Args)) {
1464 const auto ArgI = E.index();
1465 const auto Arg = E.value();
1466 const auto ArgOuts =
1467 AllOuts.take_while([&](auto O) { return O.OrigArgIndex == ArgI; });
1468 const auto ArgOutVals = AllOutVals.take_front(ArgOuts.size());
1469 AllOuts = AllOuts.drop_front(ArgOuts.size());
1470 AllOutVals = AllOutVals.drop_front(ArgOuts.size());
1471
1472 const bool IsVAArg = (ArgI >= FirstVAArg);
1473 const bool IsByVal = Arg.IsByVal;
1474
1475 const SDValue ParamSymbol =
1476 getCallParamSymbol(DAG, IsVAArg ? FirstVAArg : ArgI, MVT::i32);
1477
1478 assert((!IsByVal || Arg.IndirectType) &&
1479 "byval arg must have indirect type");
1480 Type *ETy = (IsByVal ? Arg.IndirectType : Arg.Ty);
1481
1482 const Align ArgAlign = [&]() {
1483 if (IsByVal) {
1484 // The ByValAlign in the Outs[OIdx].Flags is always set at this point,
1485 // so we don't need to worry whether it's naturally aligned or not.
1486 // See TargetLowering::LowerCallTo().
1487 const Align InitialAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
1489 InitialAlign, DL);
1490 }
1491 return getArgumentAlignment(CB, Arg.Ty, ArgI + 1, DL);
1492 }();
1493
1494 const unsigned TySize = DL.getTypeAllocSize(ETy);
1495 assert((!IsByVal || TySize == ArgOuts[0].Flags.getByValSize()) &&
1496 "type size mismatch");
1497
1498 const SDValue ArgDeclare = [&]() {
1499 if (IsVAArg)
1500 return VADeclareParam;
1501
1502 if (IsByVal || shouldPassAsArray(Arg.Ty))
1503 return MakeDeclareArrayParam(ParamSymbol, ArgAlign, TySize);
1504
1505 assert(ArgOuts.size() == 1 && "We must pass only one value as non-array");
1506 assert((ArgOuts[0].VT.isInteger() || ArgOuts[0].VT.isFloatingPoint()) &&
1507 "Only int and float types are supported as non-array arguments");
1508
1509 return MakeDeclareScalarParam(ParamSymbol, TySize);
1510 }();
1511
1512 if (IsByVal) {
1513 assert(ArgOutVals.size() == 1 && "We must pass only one value as byval");
1514 SDValue SrcPtr = ArgOutVals[0];
1515 const auto PointerInfo = refinePtrAS(SrcPtr, DAG, DL, *this);
1516 const Align BaseSrcAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
1517
1518 if (IsVAArg)
1519 VAOffset = alignTo(VAOffset, ArgAlign);
1520
1521 SmallVector<EVT, 4> ValueVTs, MemVTs;
1523 ComputeValueVTs(*this, DL, ETy, ValueVTs, &MemVTs, &Offsets);
1524
1525 unsigned J = 0;
1526 const auto VI = VectorizePTXValueVTs(MemVTs, Offsets, ArgAlign, IsVAArg);
1527 for (const unsigned NumElts : VI) {
1528 EVT LoadVT = getVectorizedVT(MemVTs[J], NumElts, Ctx);
1529 Align SrcAlign = commonAlignment(BaseSrcAlign, Offsets[J]);
1530 SDValue SrcAddr = DAG.getObjectPtrOffset(dl, SrcPtr, Offsets[J]);
1531 SDValue SrcLoad =
1532 DAG.getLoad(LoadVT, dl, CallChain, SrcAddr, PointerInfo, SrcAlign);
1533
1534 TypeSize ParamOffset = Offsets[J].getWithIncrement(VAOffset);
1535 Align ParamAlign = commonAlignment(ArgAlign, ParamOffset);
1536 SDValue ParamAddr =
1537 DAG.getObjectPtrOffset(dl, ParamSymbol, ParamOffset);
1538 SDValue StoreParam =
1539 DAG.getStore(ArgDeclare, dl, SrcLoad, ParamAddr,
1541 CallPrereqs.push_back(StoreParam);
1542
1543 J += NumElts;
1544 }
1545 if (IsVAArg)
1546 VAOffset += TySize;
1547 } else {
1550 ComputePTXValueVTs(*this, DL, Ctx, CLI.CallConv, Arg.Ty, VTs, Offsets,
1551 VAOffset);
1552 assert(VTs.size() == Offsets.size() && "Size mismatch");
1553 assert(VTs.size() == ArgOuts.size() && "Size mismatch");
1554
1555 // PTX Interoperability Guide 3.3(A): [Integer] Values shorter
1556 // than 32-bits are sign extended or zero extended, depending on
1557 // whether they are signed or unsigned types. This case applies
1558 // only to scalar parameters and not to aggregate values.
1559 const bool ExtendIntegerParam =
1560 Arg.Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Arg.Ty) < 32;
1561
1562 const auto GetStoredValue = [&](const unsigned I) {
1563 SDValue StVal = ArgOutVals[I];
1565 StVal.getValueType() &&
1566 "OutVal type should always be legal");
1567
1568 const EVT VTI = promoteScalarIntegerPTX(VTs[I]);
1569 const EVT StoreVT =
1570 ExtendIntegerParam ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);
1571
1572 return correctParamType(StVal, StoreVT, ArgOuts[I].Flags, DAG, dl);
1573 };
1574
1575 unsigned J = 0;
1576 const auto VI = VectorizePTXValueVTs(VTs, Offsets, ArgAlign, IsVAArg);
1577 for (const unsigned NumElts : VI) {
1578 const EVT EltVT = promoteScalarIntegerPTX(VTs[J]);
1579
1580 unsigned Offset;
1581 if (IsVAArg) {
1582 // TODO: We may need to support vector types that can be passed
1583 // as scalars in variadic arguments.
1584 assert(NumElts == 1 &&
1585 "Vectorization should be disabled for vaargs.");
1586
1587 // Align each part of the variadic argument to their type.
1588 VAOffset = alignTo(VAOffset, DAG.getEVTAlign(EltVT));
1589 Offset = VAOffset;
1590
1591 const EVT TheStoreType = ExtendIntegerParam ? MVT::i32 : EltVT;
1592 VAOffset += DL.getTypeAllocSize(TheStoreType.getTypeForEVT(Ctx));
1593 } else {
1594 assert(VAOffset == 0 && "VAOffset must be 0 for non-VA args");
1595 Offset = Offsets[J];
1596 }
1597
1598 SDValue Ptr =
1599 DAG.getObjectPtrOffset(dl, ParamSymbol, TypeSize::getFixed(Offset));
1600
1601 const MaybeAlign CurrentAlign = ExtendIntegerParam
1602 ? MaybeAlign(std::nullopt)
1603 : commonAlignment(ArgAlign, Offset);
1604
1605 SDValue Val =
1606 getBuildVectorizedValue(NumElts, dl, DAG, [&](unsigned K) {
1607 return GetStoredValue(J + K);
1608 });
1609
1610 SDValue StoreParam =
1611 DAG.getStore(ArgDeclare, dl, Val, Ptr,
1613 CallPrereqs.push_back(StoreParam);
1614
1615 J += NumElts;
1616 }
1617 }
1618 }
1619
1620 // Handle Result
1621 if (!Ins.empty()) {
1622 const SDValue RetSymbol = DAG.getExternalSymbol("retval0", MVT::i32);
1623 const unsigned ResultSize = DL.getTypeAllocSize(RetTy);
1624 if (shouldPassAsArray(RetTy)) {
1625 const Align RetAlign = getArgumentAlignment(CB, RetTy, 0, DL);
1626 MakeDeclareArrayParam(RetSymbol, RetAlign, ResultSize);
1627 } else {
1628 MakeDeclareScalarParam(RetSymbol, ResultSize);
1629 }
1630 }
1631
1632 // Set the size of the vararg param byte array if the callee is a variadic
1633 // function and the variadic part is not empty.
1634 if (VADeclareParam) {
1635 SDValue DeclareParamOps[] = {VADeclareParam.getOperand(0),
1636 VADeclareParam.getOperand(1),
1637 VADeclareParam.getOperand(2), GetI32(VAOffset),
1638 VADeclareParam.getOperand(4)};
1639 DAG.MorphNodeTo(VADeclareParam.getNode(), VADeclareParam.getOpcode(),
1640 VADeclareParam->getVTList(), DeclareParamOps);
1641 }
1642
1643 const auto *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode());
1644 // If the type of the callsite does not match that of the function, convert
1645 // the callsite to an indirect call.
1646 const bool ConvertToIndirectCall = shouldConvertToIndirectCall(CB, Func);
1647
1648 // Both indirect calls and libcalls have nullptr Func. In order to distinguish
1649 // between them we must rely on the call site value which is valid for
1650 // indirect calls but is always null for libcalls.
1651 const bool IsIndirectCall = (!Func && CB) || ConvertToIndirectCall;
1652
1653 if (isa<ExternalSymbolSDNode>(Callee)) {
1654 Function* CalleeFunc = nullptr;
1655
1656 // Try to find the callee in the current module.
1657 Callee = DAG.getSymbolFunctionGlobalAddress(Callee, &CalleeFunc);
1658 assert(CalleeFunc != nullptr && "Libcall callee must be set.");
1659
1660 // Set the "libcall callee" attribute to indicate that the function
1661 // must always have a declaration.
1662 CalleeFunc->addFnAttr("nvptx-libcall-callee", "true");
1663 }
1664
1665 if (IsIndirectCall) {
1666 // This is indirect function call case : PTX requires a prototype of the
1667 // form
1668 // proto_0 : .callprototype(.param .b32 _) _ (.param .b32 _);
1669 // to be emitted, and the label has to used as the last arg of call
1670 // instruction.
1671 // The prototype is embedded in a string and put as the operand for a
1672 // CallPrototype SDNode which will print out to the value of the string.
1673 const bool HasVAArgs = CLI.IsVarArg && (CLI.Args.size() > CLI.NumFixedArgs);
1674 std::string Proto =
1675 getPrototype(DL, RetTy, Args, CLI.Outs,
1676 HasVAArgs ? std::optional(FirstVAArg) : std::nullopt, *CB,
1677 UniqueCallSite);
1678 const char *ProtoStr = nvTM->getStrPool().save(Proto).data();
1679 const SDValue PrototypeDeclare = DAG.getNode(
1680 NVPTXISD::CallPrototype, dl, MVT::Other,
1681 {StartChain, DAG.getTargetExternalSymbol(ProtoStr, MVT::i32)});
1682 CallPrereqs.push_back(PrototypeDeclare);
1683 }
1684
1685 const unsigned Proto = IsIndirectCall ? UniqueCallSite : 0;
1686 const unsigned NumArgs =
1687 std::min<unsigned>(CLI.NumFixedArgs + 1, Args.size());
1688 /// CALL(Chain, IsConvergent, IsIndirectCall/IsUniform, NumReturns,
1689 /// NumParams, Callee, Proto)
1690 const SDValue CallToken = DAG.getTokenFactor(dl, CallPrereqs);
1691 const SDValue Call = DAG.getNode(
1692 NVPTXISD::CALL, dl, MVT::Other,
1693 {CallToken, GetI32(CLI.IsConvergent), GetI32(IsIndirectCall),
1694 GetI32(Ins.empty() ? 0 : 1), GetI32(NumArgs), Callee, GetI32(Proto)});
1695
1696 SmallVector<SDValue, 16> LoadChains{Call};
1697 SmallVector<SDValue, 16> ProxyRegOps;
1698 if (!Ins.empty()) {
1701 ComputePTXValueVTs(*this, DL, Ctx, CLI.CallConv, RetTy, VTs, Offsets);
1702 assert(VTs.size() == Ins.size() && "Bad value decomposition");
1703
1704 const Align RetAlign = getArgumentAlignment(CB, RetTy, 0, DL);
1705 const SDValue RetSymbol = DAG.getExternalSymbol("retval0", MVT::i32);
1706
1707 // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
1708 // 32-bits are sign extended or zero extended, depending on whether
1709 // they are signed or unsigned types.
1710 const bool ExtendIntegerRetVal =
1711 RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32;
1712
1713 unsigned I = 0;
1714 const auto VI = VectorizePTXValueVTs(VTs, Offsets, RetAlign);
1715 for (const unsigned NumElts : VI) {
1716 const MaybeAlign CurrentAlign =
1717 ExtendIntegerRetVal ? MaybeAlign(std::nullopt)
1718 : commonAlignment(RetAlign, Offsets[I]);
1719
1720 const EVT VTI = promoteScalarIntegerPTX(VTs[I]);
1721 const EVT LoadVT =
1722 ExtendIntegerRetVal ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);
1723 const EVT VecVT = getVectorizedVT(LoadVT, NumElts, Ctx);
1724 SDValue Ptr =
1725 DAG.getObjectPtrOffset(dl, RetSymbol, TypeSize::getFixed(Offsets[I]));
1726
1727 SDValue R =
1728 DAG.getLoad(VecVT, dl, Call, Ptr,
1730
1731 LoadChains.push_back(R.getValue(1));
1732 for (const unsigned J : llvm::seq(NumElts))
1733 ProxyRegOps.push_back(getExtractVectorizedValue(R, J, LoadVT, dl, DAG));
1734 I += NumElts;
1735 }
1736 }
1737
1738 const SDValue EndToken = DAG.getTokenFactor(dl, LoadChains);
1739 const SDValue CallEnd = DAG.getCALLSEQ_END(EndToken, UniqueCallSite,
1740 UniqueCallSite + 1, SDValue(), dl);
1741
1742 // Append ProxyReg instructions to the chain to make sure that `callseq_end`
1743 // will not get lost. Otherwise, during libcalls expansion, the nodes can become
1744 // dangling.
1745 for (const auto [I, Reg] : llvm::enumerate(ProxyRegOps)) {
1746 SDValue Proxy =
1747 DAG.getNode(NVPTXISD::ProxyReg, dl, Reg.getValueType(), {CallEnd, Reg});
1748 SDValue Ret = correctParamType(Proxy, Ins[I].VT, Ins[I].Flags, DAG, dl);
1749 InVals.push_back(Ret);
1750 }
1751
1752 // set IsTailCall to false for now, until we figure out how to express
1753 // tail call optimization in PTX
1754 CLI.IsTailCall = false;
1755 return CallEnd;
1756}
1757
1759 SelectionDAG &DAG) const {
1760
1761 if (STI.getPTXVersion() < 73 || STI.getSmVersion() < 52) {
1762 const Function &Fn = DAG.getMachineFunction().getFunction();
1763
1765 Fn,
1766 "Support for dynamic alloca introduced in PTX ISA version 7.3 and "
1767 "requires target sm_52.",
1768 SDLoc(Op).getDebugLoc()));
1769 auto Ops = {DAG.getConstant(0, SDLoc(), Op.getValueType()),
1770 Op.getOperand(0)};
1771 return DAG.getMergeValues(Ops, SDLoc());
1772 }
1773
1774 SDLoc DL(Op.getNode());
1775 SDValue Chain = Op.getOperand(0);
1776 SDValue Size = Op.getOperand(1);
1777 uint64_t Align = Op.getConstantOperandVal(2);
1778
1779 // The alignment on a ISD::DYNAMIC_STACKALLOC node may be 0 to indicate that
1780 // the default stack alignment should be used.
1781 if (Align == 0)
1783
1784 // The size for ptx alloca instruction is 64-bit for m64 and 32-bit for m32.
1785 const MVT LocalVT = getPointerTy(DAG.getDataLayout(), ADDRESS_SPACE_LOCAL);
1786
1787 SDValue Alloc =
1788 DAG.getNode(NVPTXISD::DYNAMIC_STACKALLOC, DL, {LocalVT, MVT::Other},
1789 {Chain, DAG.getZExtOrTrunc(Size, DL, LocalVT),
1790 DAG.getTargetConstant(Align, DL, MVT::i32)});
1791
1792 SDValue ASC = DAG.getAddrSpaceCast(
1794
1795 return DAG.getMergeValues({ASC, SDValue(Alloc.getNode(), 1)}, DL);
1796}
1797
1799 SelectionDAG &DAG) const {
1800 SDLoc DL(Op.getNode());
1801 if (STI.getPTXVersion() < 73 || STI.getSmVersion() < 52) {
1802 const Function &Fn = DAG.getMachineFunction().getFunction();
1803
1805 Fn,
1806 "Support for stackrestore requires PTX ISA version >= 7.3 and target "
1807 ">= sm_52.",
1808 DL.getDebugLoc()));
1809 return Op.getOperand(0);
1810 }
1811
1812 const MVT LocalVT = getPointerTy(DAG.getDataLayout(), ADDRESS_SPACE_LOCAL);
1813 SDValue Chain = Op.getOperand(0);
1814 SDValue Ptr = Op.getOperand(1);
1815 SDValue ASC = DAG.getAddrSpaceCast(DL, LocalVT, Ptr, ADDRESS_SPACE_GENERIC,
1817 return DAG.getNode(NVPTXISD::STACKRESTORE, DL, MVT::Other, {Chain, ASC});
1818}
1819
1821 SelectionDAG &DAG) const {
1822 SDLoc DL(Op.getNode());
1823 if (STI.getPTXVersion() < 73 || STI.getSmVersion() < 52) {
1824 const Function &Fn = DAG.getMachineFunction().getFunction();
1825
1827 Fn,
1828 "Support for stacksave requires PTX ISA version >= 7.3 and target >= "
1829 "sm_52.",
1830 DL.getDebugLoc()));
1831 auto Ops = {DAG.getConstant(0, DL, Op.getValueType()), Op.getOperand(0)};
1832 return DAG.getMergeValues(Ops, DL);
1833 }
1834
1835 const MVT LocalVT = getPointerTy(DAG.getDataLayout(), ADDRESS_SPACE_LOCAL);
1836 SDValue Chain = Op.getOperand(0);
1837 SDValue SS =
1838 DAG.getNode(NVPTXISD::STACKSAVE, DL, {LocalVT, MVT::Other}, Chain);
1839 SDValue ASC = DAG.getAddrSpaceCast(
1840 DL, Op.getValueType(), SS, ADDRESS_SPACE_LOCAL, ADDRESS_SPACE_GENERIC);
1841 return DAG.getMergeValues({ASC, SDValue(SS.getNode(), 1)}, DL);
1842}
1843
1844// By default CONCAT_VECTORS is lowered by ExpandVectorBuildThroughStack()
1845// (see LegalizeDAG.cpp). This is slow and uses local memory.
1846// We use extract/insert/build vector just as what LegalizeOp() does in llvm 2.5
1847SDValue
1848NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
1849 SDNode *Node = Op.getNode();
1850 SDLoc dl(Node);
1852 unsigned NumOperands = Node->getNumOperands();
1853 for (unsigned i = 0; i < NumOperands; ++i) {
1854 SDValue SubOp = Node->getOperand(i);
1855 EVT VVT = SubOp.getNode()->getValueType(0);
1856 EVT EltVT = VVT.getVectorElementType();
1857 unsigned NumSubElem = VVT.getVectorNumElements();
1858 for (unsigned j = 0; j < NumSubElem; ++j) {
1859 Ops.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, SubOp,
1860 DAG.getIntPtrConstant(j, dl)));
1861 }
1862 }
1863 return DAG.getBuildVector(Node->getValueType(0), dl, Ops);
1864}
1865
1867 SelectionDAG &DAG,
1868 unsigned Mode = NVPTX::PTXPrmtMode::NONE) {
1869 assert(A.getValueType() == MVT::i32 && B.getValueType() == MVT::i32 &&
1870 Selector.getValueType() == MVT::i32 && "PRMT must have i32 operands");
1871 return DAG.getNode(NVPTXISD::PRMT, DL, MVT::i32,
1872 {A, B, Selector, DAG.getConstant(Mode, DL, MVT::i32)});
1873}
1874
1876 SelectionDAG &DAG,
1877 unsigned Mode = NVPTX::PTXPrmtMode::NONE) {
1878 return getPRMT(A, B, DAG.getConstant(Selector, DL, MVT::i32), DL, DAG, Mode);
1879}
1880
1881/// Reduces the elements using the scalar operations provided. The operations
1882/// are sorted descending in number of inputs they take. The flags on the
1883/// original reduction operation will be propagated to each scalar operation.
1884/// Nearby elements are grouped in tree reduction, unlike the shuffle reduction
1885/// used in ExpandReductions and SelectionDAG.
1887 const SmallVector<SDValue> &Elements, EVT EltTy,
1888 ArrayRef<std::pair<unsigned /*NodeType*/, unsigned /*NumInputs*/>> Ops,
1889 const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
1890 // Build the reduction tree at each level, starting with all the elements.
1891 SmallVector<SDValue> Level = Elements;
1892
1893 unsigned OpIdx = 0;
1894 while (Level.size() > 1) {
1895 // Try to reduce this level using the current operator.
1896 const auto [Op, NumInputs] = Ops[OpIdx];
1897
1898 // Build the next level by partially reducing all elements.
1899 SmallVector<SDValue> ReducedLevel;
1900 unsigned I = 0, E = Level.size();
1901 for (; I + NumInputs <= E; I += NumInputs) {
1902 // Reduce elements in groups of [NumInputs], as much as possible.
1903 ReducedLevel.push_back(DAG.getNode(
1904 Op, DL, EltTy, ArrayRef<SDValue>(Level).slice(I, NumInputs), Flags));
1905 }
1906
1907 if (I < E) {
1908 // Handle leftover elements.
1909
1910 if (ReducedLevel.empty()) {
1911 // We didn't reduce anything at this level. We need to pick a smaller
1912 // operator.
1913 ++OpIdx;
1914 assert(OpIdx < Ops.size() && "no smaller operators for reduction");
1915 continue;
1916 }
1917
1918 // We reduced some things but there's still more left, meaning the
1919 // operator's number of inputs doesn't evenly divide this level size. Move
1920 // these elements to the next level.
1921 for (; I < E; ++I)
1922 ReducedLevel.push_back(Level[I]);
1923 }
1924
1925 // Process the next level.
1926 Level = ReducedLevel;
1927 }
1928
1929 return *Level.begin();
1930}
1931
1932// Get scalar reduction opcode
1933static ISD::NodeType getScalarOpcodeForReduction(unsigned ReductionOpcode) {
1934 switch (ReductionOpcode) {
1935 case ISD::VECREDUCE_FMAX:
1936 return ISD::FMAXNUM;
1937 case ISD::VECREDUCE_FMIN:
1938 return ISD::FMINNUM;
1939 case ISD::VECREDUCE_FMAXIMUM:
1940 return ISD::FMAXIMUM;
1941 case ISD::VECREDUCE_FMINIMUM:
1942 return ISD::FMINIMUM;
1943 default:
1944 llvm_unreachable("unhandled reduction opcode");
1945 }
1946}
1947
1948/// Get 3-input scalar reduction opcode
1949static std::optional<unsigned>
1950getScalar3OpcodeForReduction(unsigned ReductionOpcode) {
1951 switch (ReductionOpcode) {
1952 case ISD::VECREDUCE_FMAX:
1953 return NVPTXISD::FMAXNUM3;
1954 case ISD::VECREDUCE_FMIN:
1955 return NVPTXISD::FMINNUM3;
1956 case ISD::VECREDUCE_FMAXIMUM:
1957 return NVPTXISD::FMAXIMUM3;
1958 case ISD::VECREDUCE_FMINIMUM:
1959 return NVPTXISD::FMINIMUM3;
1960 default:
1961 return std::nullopt;
1962 }
1963}
1964
1965/// Lower reductions to either a sequence of operations or a tree if
1966/// reassociations are allowed. This method will use larger operations like
1967/// max3/min3 when the target supports them.
1968SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
1969 SelectionDAG &DAG) const {
1970 SDLoc DL(Op);
1971 const SDNodeFlags Flags = Op->getFlags();
1972 SDValue Vector = Op.getOperand(0);
1973
1974 const unsigned Opcode = Op->getOpcode();
1975 const EVT EltTy = Vector.getValueType().getVectorElementType();
1976
1977 // Whether we can use 3-input min/max when expanding the reduction.
1978 const bool CanUseMinMax3 =
1979 EltTy == MVT::f32 && STI.getSmVersion() >= 100 &&
1980 STI.getPTXVersion() >= 88 &&
1981 (Opcode == ISD::VECREDUCE_FMAX || Opcode == ISD::VECREDUCE_FMIN ||
1982 Opcode == ISD::VECREDUCE_FMAXIMUM || Opcode == ISD::VECREDUCE_FMINIMUM);
1983
1984 // A list of SDNode opcodes with equivalent semantics, sorted descending by
1985 // number of inputs they take.
1986 SmallVector<std::pair<unsigned /*Op*/, unsigned /*NumIn*/>, 2> ScalarOps;
1987
1988 if (auto Opcode3Elem = getScalar3OpcodeForReduction(Opcode);
1989 CanUseMinMax3 && Opcode3Elem)
1990 ScalarOps.push_back({*Opcode3Elem, 3});
1991 ScalarOps.push_back({getScalarOpcodeForReduction(Opcode), 2});
1992
1994 DAG.ExtractVectorElements(Vector, Elements);
1995
1996 return buildTreeReduction(Elements, EltTy, ScalarOps, DL, Flags, DAG);
1997}
1998
1999SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
2000 // Handle bitcasting from v2i8 without hitting the default promotion
2001 // strategy which goes through stack memory.
2002 EVT FromVT = Op->getOperand(0)->getValueType(0);
2003 if (FromVT != MVT::v2i8) {
2004 return Op;
2005 }
2006
2007 // Pack vector elements into i16 and bitcast to final type
2008 SDLoc DL(Op);
2009 SDValue Vec0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8,
2010 Op->getOperand(0), DAG.getIntPtrConstant(0, DL));
2011 SDValue Vec1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8,
2012 Op->getOperand(0), DAG.getIntPtrConstant(1, DL));
2013 SDValue Extend0 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i16, Vec0);
2014 SDValue Extend1 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i16, Vec1);
2015 SDValue Const8 = DAG.getConstant(8, DL, MVT::i16);
2016 SDValue AsInt = DAG.getNode(
2017 ISD::OR, DL, MVT::i16,
2018 {Extend0, DAG.getNode(ISD::SHL, DL, MVT::i16, {Extend1, Const8})});
2019 EVT ToVT = Op->getValueType(0);
2020 return DAG.getBitcast(ToVT, AsInt);
2021}
2022
2023// We can init constant f16x2/v2i16/v4i8 with a single .b32 move. Normally it
2024// would get lowered as two constant loads and vector-packing move.
2025// Instead we want just a constant move:
2026// mov.b32 %r2, 0x40003C00
2027SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op,
2028 SelectionDAG &DAG) const {
2029 EVT VT = Op->getValueType(0);
2030 if (!(NVPTX::isPackedVectorTy(VT) && VT.is32BitVector()))
2031 return Op;
2032 SDLoc DL(Op);
2033
2034 if (!llvm::all_of(Op->ops(), [](SDValue Operand) {
2035 return Operand->isUndef() || isa<ConstantSDNode>(Operand) ||
2036 isa<ConstantFPSDNode>(Operand);
2037 })) {
2038 if (VT != MVT::v4i8)
2039 return Op;
2040 // Lower non-const v4i8 vector as byte-wise constructed i32, which allows us
2041 // to optimize calculation of constant parts.
2042 auto GetPRMT = [&](const SDValue Left, const SDValue Right, bool Cast,
2043 uint64_t SelectionValue) -> SDValue {
2044 SDValue L = Left;
2045 SDValue R = Right;
2046 if (Cast) {
2047 L = DAG.getAnyExtOrTrunc(L, DL, MVT::i32);
2048 R = DAG.getAnyExtOrTrunc(R, DL, MVT::i32);
2049 }
2050 return getPRMT(L, R, SelectionValue, DL, DAG);
2051 };
2052 auto PRMT__10 = GetPRMT(Op->getOperand(0), Op->getOperand(1), true, 0x3340);
2053 auto PRMT__32 = GetPRMT(Op->getOperand(2), Op->getOperand(3), true, 0x3340);
2054 auto PRMT3210 = GetPRMT(PRMT__10, PRMT__32, false, 0x5410);
2055 return DAG.getBitcast(VT, PRMT3210);
2056 }
2057
2058 // Get value or the Nth operand as an APInt(32). Undef values treated as 0.
2059 auto GetOperand = [](SDValue Op, int N) -> APInt {
2060 const SDValue &Operand = Op->getOperand(N);
2061 EVT VT = Op->getValueType(0);
2062 if (Operand->isUndef())
2063 return APInt(32, 0);
2064 APInt Value;
2065 if (VT == MVT::v2f16 || VT == MVT::v2bf16)
2066 Value = cast<ConstantFPSDNode>(Operand)->getValueAPF().bitcastToAPInt();
2067 else if (VT == MVT::v2i16 || VT == MVT::v4i8)
2068 Value = Operand->getAsAPIntVal();
2069 else
2070 llvm_unreachable("Unsupported type");
2071 // i8 values are carried around as i16, so we need to zero out upper bits,
2072 // so they do not get in the way of combining individual byte values
2073 if (VT == MVT::v4i8)
2074 Value = Value.trunc(8);
2075 return Value.zext(32);
2076 };
2077
2078 // Construct a 32-bit constant by shifting into place smaller values
2079 // (elements of the vector type VT).
2080 // For example, if VT has 2 elements, then N == 2:
2081 // ShiftAmount = 32 / N = 16
2082 // Value |= Op0 (b16) << 0
2083 // Value |= Op1 (b16) << 16
2084 // If N == 4:
2085 // ShiftAmount = 32 / N = 8
2086 // Value |= Op0 (b8) << 0
2087 // Value |= Op1 (b8) << 8
2088 // Value |= Op2 (b8) << 16
2089 // Value |= Op3 (b8) << 24
2090 // ...etc
2091 APInt Value(32, 0);
2092 const unsigned NumElements = VT.getVectorNumElements();
2093 assert(32 % NumElements == 0 && "must evenly divide bit length");
2094 const unsigned ShiftAmount = 32 / NumElements;
2095 for (unsigned ElementNo : seq(NumElements))
2096 Value |= GetOperand(Op, ElementNo).shl(ElementNo * ShiftAmount);
2097 SDValue Const = DAG.getConstant(Value, DL, MVT::i32);
2098 return DAG.getNode(ISD::BITCAST, DL, Op->getValueType(0), Const);
2099}
2100
2101SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
2102 SelectionDAG &DAG) const {
2103 SDValue Index = Op->getOperand(1);
2104 SDValue Vector = Op->getOperand(0);
2105 SDLoc DL(Op);
2106 EVT VectorVT = Vector.getValueType();
2107
2108 if (VectorVT == MVT::v4i8) {
2109 SDValue Selector = DAG.getNode(ISD::OR, DL, MVT::i32,
2110 DAG.getZExtOrTrunc(Index, DL, MVT::i32),
2111 DAG.getConstant(0x7770, DL, MVT::i32));
2112 SDValue PRMT = getPRMT(DAG.getBitcast(MVT::i32, Vector),
2113 DAG.getConstant(0, DL, MVT::i32), Selector, DL, DAG);
2114 SDValue Ext = DAG.getAnyExtOrTrunc(PRMT, DL, Op->getValueType(0));
2115 SDNodeFlags Flags;
2116 Flags.setNoSignedWrap(Ext.getScalarValueSizeInBits() > 8);
2117 Flags.setNoUnsignedWrap(Ext.getScalarValueSizeInBits() >= 8);
2118 Ext->setFlags(Flags);
2119 return Ext;
2120 }
2121
2122 // Constant index will be matched by tablegen.
2123 if (isa<ConstantSDNode>(Index.getNode()))
2124 return Op;
2125
2126 // Extract individual elements and select one of them.
2127 assert(NVPTX::isPackedVectorTy(VectorVT) &&
2128 VectorVT.getVectorNumElements() == 2 && "Unexpected vector type.");
2129 EVT EltVT = VectorVT.getVectorElementType();
2130
2131 SDLoc dl(Op.getNode());
2132 SDValue E0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, Vector,
2133 DAG.getIntPtrConstant(0, dl));
2134 SDValue E1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, Vector,
2135 DAG.getIntPtrConstant(1, dl));
2136 return DAG.getSelectCC(dl, Index, DAG.getIntPtrConstant(0, dl), E0, E1,
2138}
2139
2140SDValue NVPTXTargetLowering::LowerINSERT_VECTOR_ELT(SDValue Op,
2141 SelectionDAG &DAG) const {
2142 SDValue Vector = Op->getOperand(0);
2143 EVT VectorVT = Vector.getValueType();
2144
2145 if (VectorVT != MVT::v4i8)
2146 return Op;
2147 SDLoc DL(Op);
2148 SDValue Value = Op->getOperand(1);
2149 if (Value->isUndef())
2150 return Vector;
2151
2152 SDValue Index = Op->getOperand(2);
2153
2154 SDValue BFI =
2155 DAG.getNode(NVPTXISD::BFI, DL, MVT::i32,
2156 {DAG.getZExtOrTrunc(Value, DL, MVT::i32), Vector,
2157 DAG.getNode(ISD::MUL, DL, MVT::i32,
2158 DAG.getZExtOrTrunc(Index, DL, MVT::i32),
2159 DAG.getConstant(8, DL, MVT::i32)),
2160 DAG.getConstant(8, DL, MVT::i32)});
2161 return DAG.getNode(ISD::BITCAST, DL, Op->getValueType(0), BFI);
2162}
2163
2164SDValue NVPTXTargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
2165 SelectionDAG &DAG) const {
2166 SDValue V1 = Op.getOperand(0);
2167 EVT VectorVT = V1.getValueType();
2168 if (VectorVT != MVT::v4i8 || Op.getValueType() != MVT::v4i8)
2169 return Op;
2170
2171 // Lower shuffle to PRMT instruction.
2172 const ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Op.getNode());
2173 SDValue V2 = Op.getOperand(1);
2174 uint32_t Selector = 0;
2175 for (auto I : llvm::enumerate(SVN->getMask())) {
2176 if (I.value() != -1) // -1 is a placeholder for undef.
2177 Selector |= (I.value() << (I.index() * 4));
2178 }
2179
2180 SDLoc DL(Op);
2181 SDValue PRMT = getPRMT(DAG.getBitcast(MVT::i32, V1),
2182 DAG.getBitcast(MVT::i32, V2), Selector, DL, DAG);
2183 return DAG.getBitcast(Op.getValueType(), PRMT);
2184}
2185/// LowerShiftRightParts - Lower SRL_PARTS, SRA_PARTS, which
2186/// 1) returns two i32 values and take a 2 x i32 value to shift plus a shift
2187/// amount, or
2188/// 2) returns two i64 values and take a 2 x i64 value to shift plus a shift
2189/// amount.
2190SDValue NVPTXTargetLowering::LowerShiftRightParts(SDValue Op,
2191 SelectionDAG &DAG) const {
2192 assert(Op.getNumOperands() == 3 && "Not a double-shift!");
2193 assert(Op.getOpcode() == ISD::SRA_PARTS || Op.getOpcode() == ISD::SRL_PARTS);
2194
2195 EVT VT = Op.getValueType();
2196 unsigned VTBits = VT.getSizeInBits();
2197 SDLoc dl(Op);
2198 SDValue ShOpLo = Op.getOperand(0);
2199 SDValue ShOpHi = Op.getOperand(1);
2200 SDValue ShAmt = Op.getOperand(2);
2201 unsigned Opc = (Op.getOpcode() == ISD::SRA_PARTS) ? ISD::SRA : ISD::SRL;
2202
2203 if (VTBits == 32 && STI.getSmVersion() >= 35) {
2204 // For 32bit and sm35, we can use the funnel shift 'shf' instruction.
2205 // {dHi, dLo} = {aHi, aLo} >> Amt
2206 // dHi = aHi >> Amt
2207 // dLo = shf.r.clamp aLo, aHi, Amt
2208
2209 SDValue Hi = DAG.getNode(Opc, dl, VT, ShOpHi, ShAmt);
2210 SDValue Lo =
2211 DAG.getNode(NVPTXISD::FSHR_CLAMP, dl, VT, ShOpHi, ShOpLo, ShAmt);
2212
2213 SDValue Ops[2] = { Lo, Hi };
2214 return DAG.getMergeValues(Ops, dl);
2215 }
2216 else {
2217 // {dHi, dLo} = {aHi, aLo} >> Amt
2218 // - if (Amt>=size) then
2219 // dLo = aHi >> (Amt-size)
2220 // dHi = aHi >> Amt (this is either all 0 or all 1)
2221 // else
2222 // dLo = (aLo >>logic Amt) | (aHi << (size-Amt))
2223 // dHi = aHi >> Amt
2224
2225 SDValue RevShAmt = DAG.getNode(ISD::SUB, dl, MVT::i32,
2226 DAG.getConstant(VTBits, dl, MVT::i32),
2227 ShAmt);
2228 SDValue Tmp1 = DAG.getNode(ISD::SRL, dl, VT, ShOpLo, ShAmt);
2229 SDValue ExtraShAmt = DAG.getNode(ISD::SUB, dl, MVT::i32, ShAmt,
2230 DAG.getConstant(VTBits, dl, MVT::i32));
2231 SDValue Tmp2 = DAG.getNode(ISD::SHL, dl, VT, ShOpHi, RevShAmt);
2232 SDValue FalseVal = DAG.getNode(ISD::OR, dl, VT, Tmp1, Tmp2);
2233 SDValue TrueVal = DAG.getNode(Opc, dl, VT, ShOpHi, ExtraShAmt);
2234
2235 SDValue Cmp = DAG.getSetCC(dl, MVT::i1, ShAmt,
2236 DAG.getConstant(VTBits, dl, MVT::i32),
2237 ISD::SETGE);
2238 SDValue Hi = DAG.getNode(Opc, dl, VT, ShOpHi, ShAmt);
2239 SDValue Lo = DAG.getNode(ISD::SELECT, dl, VT, Cmp, TrueVal, FalseVal);
2240
2241 SDValue Ops[2] = { Lo, Hi };
2242 return DAG.getMergeValues(Ops, dl);
2243 }
2244}
2245
2246/// LowerShiftLeftParts - Lower SHL_PARTS, which
2247/// 1) returns two i32 values and take a 2 x i32 value to shift plus a shift
2248/// amount, or
2249/// 2) returns two i64 values and take a 2 x i64 value to shift plus a shift
2250/// amount.
2251SDValue NVPTXTargetLowering::LowerShiftLeftParts(SDValue Op,
2252 SelectionDAG &DAG) const {
2253 assert(Op.getNumOperands() == 3 && "Not a double-shift!");
2254 assert(Op.getOpcode() == ISD::SHL_PARTS);
2255
2256 EVT VT = Op.getValueType();
2257 unsigned VTBits = VT.getSizeInBits();
2258 SDLoc dl(Op);
2259 SDValue ShOpLo = Op.getOperand(0);
2260 SDValue ShOpHi = Op.getOperand(1);
2261 SDValue ShAmt = Op.getOperand(2);
2262
2263 if (VTBits == 32 && STI.getSmVersion() >= 35) {
2264 // For 32bit and sm35, we can use the funnel shift 'shf' instruction.
2265 // {dHi, dLo} = {aHi, aLo} << Amt
2266 // dHi = shf.l.clamp aLo, aHi, Amt
2267 // dLo = aLo << Amt
2268
2269 SDValue Hi =
2270 DAG.getNode(NVPTXISD::FSHL_CLAMP, dl, VT, ShOpHi, ShOpLo, ShAmt);
2271 SDValue Lo = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ShAmt);
2272
2273 SDValue Ops[2] = { Lo, Hi };
2274 return DAG.getMergeValues(Ops, dl);
2275 }
2276 else {
2277 // {dHi, dLo} = {aHi, aLo} << Amt
2278 // - if (Amt>=size) then
2279 // dLo = aLo << Amt (all 0)
2280 // dLo = aLo << (Amt-size)
2281 // else
2282 // dLo = aLo << Amt
2283 // dHi = (aHi << Amt) | (aLo >> (size-Amt))
2284
2285 SDValue RevShAmt = DAG.getNode(ISD::SUB, dl, MVT::i32,
2286 DAG.getConstant(VTBits, dl, MVT::i32),
2287 ShAmt);
2288 SDValue Tmp1 = DAG.getNode(ISD::SHL, dl, VT, ShOpHi, ShAmt);
2289 SDValue ExtraShAmt = DAG.getNode(ISD::SUB, dl, MVT::i32, ShAmt,
2290 DAG.getConstant(VTBits, dl, MVT::i32));
2291 SDValue Tmp2 = DAG.getNode(ISD::SRL, dl, VT, ShOpLo, RevShAmt);
2292 SDValue FalseVal = DAG.getNode(ISD::OR, dl, VT, Tmp1, Tmp2);
2293 SDValue TrueVal = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ExtraShAmt);
2294
2295 SDValue Cmp = DAG.getSetCC(dl, MVT::i1, ShAmt,
2296 DAG.getConstant(VTBits, dl, MVT::i32),
2297 ISD::SETGE);
2298 SDValue Lo = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ShAmt);
2299 SDValue Hi = DAG.getNode(ISD::SELECT, dl, VT, Cmp, TrueVal, FalseVal);
2300
2301 SDValue Ops[2] = { Lo, Hi };
2302 return DAG.getMergeValues(Ops, dl);
2303 }
2304}
2305
2306/// If the types match, convert the generic copysign to the NVPTXISD version,
2307/// otherwise bail ensuring that mismatched cases are properly expaned.
2308SDValue NVPTXTargetLowering::LowerFCOPYSIGN(SDValue Op,
2309 SelectionDAG &DAG) const {
2310 EVT VT = Op.getValueType();
2311 SDLoc DL(Op);
2312
2313 SDValue In1 = Op.getOperand(0);
2314 SDValue In2 = Op.getOperand(1);
2315 EVT SrcVT = In2.getValueType();
2316
2317 if (!SrcVT.bitsEq(VT))
2318 return SDValue();
2319
2320 return DAG.getNode(NVPTXISD::FCOPYSIGN, DL, VT, In1, In2);
2321}
2322
2323SDValue NVPTXTargetLowering::LowerFROUND(SDValue Op, SelectionDAG &DAG) const {
2324 EVT VT = Op.getValueType();
2325
2326 if (VT == MVT::f32)
2327 return LowerFROUND32(Op, DAG);
2328
2329 if (VT == MVT::f64)
2330 return LowerFROUND64(Op, DAG);
2331
2332 llvm_unreachable("unhandled type");
2333}
2334
2335// This is the the rounding method used in CUDA libdevice in C like code:
2336// float roundf(float A)
2337// {
2338// float RoundedA = (float) (int) ( A > 0 ? (A + 0.5f) : (A - 0.5f));
2339// RoundedA = abs(A) > 0x1.0p23 ? A : RoundedA;
2340// return abs(A) < 0.5 ? (float)(int)A : RoundedA;
2341// }
2342SDValue NVPTXTargetLowering::LowerFROUND32(SDValue Op,
2343 SelectionDAG &DAG) const {
2344 SDLoc SL(Op);
2345 SDValue A = Op.getOperand(0);
2346 EVT VT = Op.getValueType();
2347
2348 SDValue AbsA = DAG.getNode(ISD::FABS, SL, VT, A);
2349
2350 // RoundedA = (float) (int) ( A > 0 ? (A + 0.5f) : (A - 0.5f))
2351 SDValue Bitcast = DAG.getNode(ISD::BITCAST, SL, MVT::i32, A);
2352 const unsigned SignBitMask = 0x80000000;
2353 SDValue Sign = DAG.getNode(ISD::AND, SL, MVT::i32, Bitcast,
2354 DAG.getConstant(SignBitMask, SL, MVT::i32));
2355 const unsigned PointFiveInBits = 0x3F000000;
2356 SDValue PointFiveWithSignRaw =
2357 DAG.getNode(ISD::OR, SL, MVT::i32, Sign,
2358 DAG.getConstant(PointFiveInBits, SL, MVT::i32));
2359 SDValue PointFiveWithSign =
2360 DAG.getNode(ISD::BITCAST, SL, VT, PointFiveWithSignRaw);
2361 SDValue AdjustedA = DAG.getNode(ISD::FADD, SL, VT, A, PointFiveWithSign);
2362 SDValue RoundedA = DAG.getNode(ISD::FTRUNC, SL, VT, AdjustedA);
2363
2364 // RoundedA = abs(A) > 0x1.0p23 ? A : RoundedA;
2365 EVT SetCCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
2366 SDValue IsLarge =
2367 DAG.getSetCC(SL, SetCCVT, AbsA, DAG.getConstantFP(pow(2.0, 23.0), SL, VT),
2368 ISD::SETOGT);
2369 RoundedA = DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA);
2370
2371 // return abs(A) < 0.5 ? (float)(int)A : RoundedA;
2372 SDValue IsSmall =DAG.getSetCC(SL, SetCCVT, AbsA,
2373 DAG.getConstantFP(0.5, SL, VT), ISD::SETOLT);
2374 SDValue RoundedAForSmallA = DAG.getNode(ISD::FTRUNC, SL, VT, A);
2375 return DAG.getNode(ISD::SELECT, SL, VT, IsSmall, RoundedAForSmallA, RoundedA);
2376}
2377
2378// The implementation of round(double) is similar to that of round(float) in
2379// that they both separate the value range into three regions and use a method
2380// specific to the region to round the values. However, round(double) first
2381// calculates the round of the absolute value and then adds the sign back while
2382// round(float) directly rounds the value with sign.
2383SDValue NVPTXTargetLowering::LowerFROUND64(SDValue Op,
2384 SelectionDAG &DAG) const {
2385 SDLoc SL(Op);
2386 SDValue A = Op.getOperand(0);
2387 EVT VT = Op.getValueType();
2388
2389 SDValue AbsA = DAG.getNode(ISD::FABS, SL, VT, A);
2390
2391 // double RoundedA = (double) (int) (abs(A) + 0.5f);
2392 SDValue AdjustedA = DAG.getNode(ISD::FADD, SL, VT, AbsA,
2393 DAG.getConstantFP(0.5, SL, VT));
2394 SDValue RoundedA = DAG.getNode(ISD::FTRUNC, SL, VT, AdjustedA);
2395
2396 // RoundedA = abs(A) < 0.5 ? (double)0 : RoundedA;
2397 EVT SetCCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
2398 SDValue IsSmall =DAG.getSetCC(SL, SetCCVT, AbsA,
2399 DAG.getConstantFP(0.5, SL, VT), ISD::SETOLT);
2400 RoundedA = DAG.getNode(ISD::SELECT, SL, VT, IsSmall,
2401 DAG.getConstantFP(0, SL, VT),
2402 RoundedA);
2403
2404 // Add sign to rounded_A
2405 RoundedA = DAG.getNode(ISD::FCOPYSIGN, SL, VT, RoundedA, A);
2406 DAG.getNode(ISD::FTRUNC, SL, VT, A);
2407
2408 // RoundedA = abs(A) > 0x1.0p52 ? A : RoundedA;
2409 SDValue IsLarge =
2410 DAG.getSetCC(SL, SetCCVT, AbsA, DAG.getConstantFP(pow(2.0, 52.0), SL, VT),
2411 ISD::SETOGT);
2412 return DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA);
2413}
2414
2416 EVT VT = N->getValueType(0);
2417 EVT NVT = MVT::f32;
2418 if (VT.isVector()) {
2419 NVT = EVT::getVectorVT(*DAG.getContext(), NVT, VT.getVectorElementCount());
2420 }
2421 SDLoc DL(N);
2422 SDValue Tmp0 = DAG.getFPExtendOrRound(N->getOperand(0), DL, NVT);
2423 SDValue Tmp1 = DAG.getFPExtendOrRound(N->getOperand(1), DL, NVT);
2424 SDValue Res = DAG.getNode(N->getOpcode(), DL, NVT, Tmp0, Tmp1, N->getFlags());
2425 return DAG.getFPExtendOrRound(Res, DL, VT);
2426}
2427
2428SDValue NVPTXTargetLowering::PromoteBinOpIfF32FTZ(SDValue Op,
2429 SelectionDAG &DAG) const {
2430 if (useF32FTZ(DAG.getMachineFunction())) {
2431 return PromoteBinOpToF32(Op.getNode(), DAG);
2432 }
2433 return Op;
2434}
2435
2436SDValue NVPTXTargetLowering::LowerINT_TO_FP(SDValue Op,
2437 SelectionDAG &DAG) const {
2438 assert(STI.getSmVersion() < 90 || STI.getPTXVersion() < 78);
2439
2440 if (Op.getValueType() == MVT::bf16) {
2441 SDLoc Loc(Op);
2442 return DAG.getNode(
2443 ISD::FP_ROUND, Loc, MVT::bf16,
2444 DAG.getNode(Op.getOpcode(), Loc, MVT::f32, Op.getOperand(0)),
2445 DAG.getIntPtrConstant(0, Loc, /*isTarget=*/true));
2446 }
2447
2448 // Everything else is considered legal.
2449 return Op;
2450}
2451
2452SDValue NVPTXTargetLowering::LowerFP_TO_INT(SDValue Op,
2453 SelectionDAG &DAG) const {
2454 assert(STI.getSmVersion() < 90 || STI.getPTXVersion() < 78);
2455
2456 if (Op.getOperand(0).getValueType() == MVT::bf16) {
2457 SDLoc Loc(Op);
2458 return DAG.getNode(
2459 Op.getOpcode(), Loc, Op.getValueType(),
2460 DAG.getNode(ISD::FP_EXTEND, Loc, MVT::f32, Op.getOperand(0)));
2461 }
2462
2463 // Everything else is considered legal.
2464 return Op;
2465}
2466
2467SDValue NVPTXTargetLowering::LowerFP_ROUND(SDValue Op,
2468 SelectionDAG &DAG) const {
2469 EVT NarrowVT = Op.getValueType();
2470 SDValue Wide = Op.getOperand(0);
2471 EVT WideVT = Wide.getValueType();
2472 if (NarrowVT.getScalarType() == MVT::bf16) {
2473 const TargetLowering *TLI = STI.getTargetLowering();
2474 if (STI.getSmVersion() < 80 || STI.getPTXVersion() < 70) {
2475 return TLI->expandFP_ROUND(Op.getNode(), DAG);
2476 }
2477 if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
2478 // This combination was the first to support f32 -> bf16.
2479 if (STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70) {
2480 if (WideVT.getScalarType() == MVT::f32) {
2481 return Op;
2482 }
2483 if (WideVT.getScalarType() == MVT::f64) {
2484 SDLoc Loc(Op);
2485 // Round-inexact-to-odd f64 to f32, then do the final rounding using
2486 // the hardware f32 -> bf16 instruction.
2488 WideVT.isVector() ? WideVT.changeVectorElementType(MVT::f32)
2489 : MVT::f32,
2490 Wide, Loc, DAG);
2491 return DAG.getFPExtendOrRound(rod, Loc, NarrowVT);
2492 }
2493 }
2494 return TLI->expandFP_ROUND(Op.getNode(), DAG);
2495 }
2496 }
2497
2498 // Everything else is considered legal.
2499 return Op;
2500}
2501
2502SDValue NVPTXTargetLowering::LowerFP_EXTEND(SDValue Op,
2503 SelectionDAG &DAG) const {
2504 SDValue Narrow = Op.getOperand(0);
2505 EVT NarrowVT = Narrow.getValueType();
2506 EVT WideVT = Op.getValueType();
2507 if (NarrowVT.getScalarType() == MVT::bf16) {
2508 if (WideVT.getScalarType() == MVT::f32 &&
2509 (STI.getSmVersion() < 80 || STI.getPTXVersion() < 71)) {
2510 SDLoc Loc(Op);
2511 return DAG.getNode(ISD::BF16_TO_FP, Loc, WideVT, Narrow);
2512 }
2513 if (WideVT.getScalarType() == MVT::f64 &&
2514 (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78)) {
2515 EVT F32 = NarrowVT.isVector() ? NarrowVT.changeVectorElementType(MVT::f32)
2516 : MVT::f32;
2517 SDLoc Loc(Op);
2518 if (STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 71) {
2519 Op = DAG.getNode(ISD::FP_EXTEND, Loc, F32, Narrow);
2520 } else {
2521 Op = DAG.getNode(ISD::BF16_TO_FP, Loc, F32, Narrow);
2522 }
2523 return DAG.getNode(ISD::FP_EXTEND, Loc, WideVT, Op);
2524 }
2525 }
2526
2527 // Everything else is considered legal.
2528 return Op;
2529}
2530
2532 SDLoc DL(Op);
2533 if (Op.getValueType() != MVT::v2i16)
2534 return Op;
2535 EVT EltVT = Op.getValueType().getVectorElementType();
2536 SmallVector<SDValue> VecElements;
2537 for (int I = 0, E = Op.getValueType().getVectorNumElements(); I < E; I++) {
2538 SmallVector<SDValue> ScalarArgs;
2539 llvm::transform(Op->ops(), std::back_inserter(ScalarArgs),
2540 [&](const SDUse &O) {
2541 return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT,
2542 O.get(), DAG.getIntPtrConstant(I, DL));
2543 });
2544 VecElements.push_back(DAG.getNode(Op.getOpcode(), DL, EltVT, ScalarArgs));
2545 }
2546 SDValue V =
2547 DAG.getNode(ISD::BUILD_VECTOR, DL, Op.getValueType(), VecElements);
2548 return V;
2549}
2550
2552 SDNode *N = Op.getNode();
2553 SDLoc DL(N);
2555
2556 // split the vector argument
2557 for (size_t I = 0; I < N->getNumOperands(); I++) {
2558 SDValue Val = N->getOperand(I);
2559 EVT ValVT = Val.getValueType();
2560 if (ValVT.isVector()) {
2561 EVT EltVT = ValVT.getVectorElementType();
2562 for (unsigned J = 0, NElts = ValVT.getVectorNumElements(); J < NElts; J++)
2563 Ops.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
2564 DAG.getIntPtrConstant(J, DL)));
2565 } else
2566 Ops.push_back(Val);
2567 }
2568
2570 SDValue Tcgen05StNode =
2571 DAG.getMemIntrinsicNode(ISD::INTRINSIC_VOID, DL, N->getVTList(), Ops,
2572 MemSD->getMemoryVT(), MemSD->getMemOperand());
2573
2574 return Tcgen05StNode;
2575}
2576
2578 SDLoc DL(Op);
2579 SDValue Src = Op.getOperand(0);
2580 EVT VT = Op.getValueType();
2581
2582 switch (VT.getSimpleVT().SimpleTy) {
2583 case MVT::i16: {
2584 SDValue Extended = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Src);
2585 SDValue Swapped =
2586 getPRMT(Extended, DAG.getConstant(0, DL, MVT::i32), 0x7701, DL, DAG);
2587 return DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, Swapped);
2588 }
2589 case MVT::i32: {
2590 return getPRMT(Src, DAG.getConstant(0, DL, MVT::i32), 0x0123, DL, DAG);
2591 }
2592 case MVT::v2i16: {
2593 SDValue Converted = DAG.getBitcast(MVT::i32, Src);
2594 SDValue Swapped =
2595 getPRMT(Converted, DAG.getConstant(0, DL, MVT::i32), 0x2301, DL, DAG);
2596 return DAG.getNode(ISD::BITCAST, DL, MVT::v2i16, Swapped);
2597 }
2598 case MVT::i64: {
2599 SDValue UnpackSrc =
2600 DAG.getNode(NVPTXISD::UNPACK_VECTOR, DL, {MVT::i32, MVT::i32}, Src);
2601 SDValue SwappedLow =
2602 getPRMT(UnpackSrc.getValue(0), DAG.getConstant(0, DL, MVT::i32), 0x0123,
2603 DL, DAG);
2604 SDValue SwappedHigh =
2605 getPRMT(UnpackSrc.getValue(1), DAG.getConstant(0, DL, MVT::i32), 0x0123,
2606 DL, DAG);
2607 return DAG.getNode(NVPTXISD::BUILD_VECTOR, DL, MVT::i64,
2608 {SwappedHigh, SwappedLow});
2609 }
2610 default:
2611 llvm_unreachable("unsupported type for bswap");
2612 }
2613}
2614
2615static unsigned getTcgen05MMADisableOutputLane(unsigned IID) {
2616 switch (IID) {
2617 case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1:
2618 return NVPTXISD::TCGEN05_MMA_SHARED_DISABLE_OUTPUT_LANE_CG1;
2619 case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2:
2620 return NVPTXISD::TCGEN05_MMA_SHARED_DISABLE_OUTPUT_LANE_CG2;
2621 case Intrinsic::nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1:
2622 return NVPTXISD::TCGEN05_MMA_SHARED_SCALE_D_DISABLE_OUTPUT_LANE_CG1;
2623 case Intrinsic::nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2:
2624 return NVPTXISD::TCGEN05_MMA_SHARED_SCALE_D_DISABLE_OUTPUT_LANE_CG2;
2625 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg1:
2626 return NVPTXISD::TCGEN05_MMA_TENSOR_DISABLE_OUTPUT_LANE_CG1;
2627 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg2:
2628 return NVPTXISD::TCGEN05_MMA_TENSOR_DISABLE_OUTPUT_LANE_CG2;
2629 case Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1:
2630 return NVPTXISD::TCGEN05_MMA_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG1;
2631 case Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2:
2632 return NVPTXISD::TCGEN05_MMA_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG2;
2633 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift:
2634 return NVPTXISD::TCGEN05_MMA_TENSOR_DISABLE_OUTPUT_LANE_CG1_ASHIFT;
2635 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift:
2636 return NVPTXISD::TCGEN05_MMA_TENSOR_DISABLE_OUTPUT_LANE_CG2_ASHIFT;
2637 case Intrinsic::
2638 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift:
2639 return NVPTXISD::TCGEN05_MMA_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG1_ASHIFT;
2640 case Intrinsic::
2641 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift:
2642 return NVPTXISD::TCGEN05_MMA_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG2_ASHIFT;
2643 case Intrinsic::nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg1:
2644 return NVPTXISD::TCGEN05_MMA_SP_SHARED_DISABLE_OUTPUT_LANE_CG1;
2645 case Intrinsic::nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg2:
2646 return NVPTXISD::TCGEN05_MMA_SP_SHARED_DISABLE_OUTPUT_LANE_CG2;
2647 case Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg1:
2648 return NVPTXISD::TCGEN05_MMA_SP_SHARED_SCALE_D_DISABLE_OUTPUT_LANE_CG1;
2649 case Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg2:
2650 return NVPTXISD::TCGEN05_MMA_SP_SHARED_SCALE_D_DISABLE_OUTPUT_LANE_CG2;
2651 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1:
2652 return NVPTXISD::TCGEN05_MMA_SP_TENSOR_DISABLE_OUTPUT_LANE_CG1;
2653 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2:
2654 return NVPTXISD::TCGEN05_MMA_SP_TENSOR_DISABLE_OUTPUT_LANE_CG2;
2655 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1_ashift:
2656 return NVPTXISD::TCGEN05_MMA_SP_TENSOR_DISABLE_OUTPUT_LANE_CG1_ASHIFT;
2657 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2_ashift:
2658 return NVPTXISD::TCGEN05_MMA_SP_TENSOR_DISABLE_OUTPUT_LANE_CG2_ASHIFT;
2659 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1:
2660 return NVPTXISD::TCGEN05_MMA_SP_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG1;
2661 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2:
2662 return NVPTXISD::TCGEN05_MMA_SP_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG2;
2663 case Intrinsic::
2664 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1_ashift:
2665 return NVPTXISD::
2666 TCGEN05_MMA_SP_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG1_ASHIFT;
2667 case Intrinsic::
2668 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2_ashift:
2669 return NVPTXISD::
2670 TCGEN05_MMA_SP_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG2_ASHIFT;
2671 };
2672 llvm_unreachable("unhandled tcgen05.mma.disable_output_lane intrinsic");
2673}
2674
2676 SDNode *N = Op.getNode();
2677 SDLoc DL(N);
2678 unsigned IID = cast<ConstantSDNode>(N->getOperand(1))->getZExtValue();
2679
2681 // split the vector argument
2682 for (size_t I = 0; I < N->getNumOperands(); I++) {
2683 if (I == 1)
2684 continue; // skip IID
2685 SDValue Val = N->getOperand(I);
2686 EVT ValVT = Val.getValueType();
2687 if (ValVT.isVector()) {
2688 EVT EltVT = ValVT.getVectorElementType();
2689 for (unsigned J = 0, NElts = ValVT.getVectorNumElements(); J < NElts; J++)
2690 Ops.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
2691 DAG.getIntPtrConstant(J, DL)));
2692 } else
2693 Ops.push_back(Val);
2694 }
2695
2697 SDValue Tcgen05MMANode = DAG.getMemIntrinsicNode(
2698 getTcgen05MMADisableOutputLane(IID), DL, N->getVTList(), Ops,
2699 MemSD->getMemoryVT(), MemSD->getMemOperand());
2700
2701 return Tcgen05MMANode;
2702}
2703
2704// Lower vector return type of tcgen05.ld intrinsics
2705static std::optional<std::pair<SDValue, SDValue>>
2706lowerTcgen05Ld(SDNode *N, SelectionDAG &DAG, bool HasOffset = false) {
2707 SDLoc DL(N);
2708 EVT ResVT = N->getValueType(0);
2709 if (!ResVT.isVector())
2710 return {}; // already legalized.
2711
2712 const unsigned NumElts = ResVT.getVectorNumElements();
2713
2714 // Create the return type of the instructions
2715 SmallVector<EVT, 5> ListVTs;
2716 for (unsigned i = 0; i < NumElts; ++i)
2717 ListVTs.push_back(MVT::i32);
2718
2719 ListVTs.push_back(N->getValueType(1)); // Chain
2720
2721 SDVTList ResVTs = DAG.getVTList(ListVTs);
2722
2723 SmallVector<SDValue, 8> Ops{N->getOperand(0), N->getOperand(1),
2724 N->getOperand(2)};
2725
2726 if (HasOffset) {
2727 Ops.push_back(N->getOperand(3)); // offset
2728 Ops.push_back(N->getOperand(4)); // Pack flag
2729 } else
2730 Ops.push_back(N->getOperand(3)); // Pack flag
2731
2733 SDValue NewNode =
2735 MemSD->getMemoryVT(), MemSD->getMemOperand());
2736
2737 // split the vector result
2738 SmallVector<SDValue, 4> ScalarRes;
2739 for (unsigned i = 0; i < NumElts; ++i) {
2740 SDValue Res = NewNode.getValue(i);
2741 ScalarRes.push_back(Res);
2742 }
2743
2744 SDValue Chain = NewNode.getValue(NumElts);
2745 SDValue BuildVector = DAG.getNode(ISD::BUILD_VECTOR, DL, ResVT, ScalarRes);
2746 return {{BuildVector, Chain}};
2747}
2748
2750 SDNode *N = Op.getNode();
2751 SDValue Intrin = N->getOperand(1);
2752
2753 // Get the intrinsic ID
2754 unsigned IntrinNo = cast<ConstantSDNode>(Intrin.getNode())->getZExtValue();
2755 switch (IntrinNo) {
2756 default:
2757 break;
2758 case Intrinsic::nvvm_tcgen05_st_16x64b_x1:
2759 case Intrinsic::nvvm_tcgen05_st_16x64b_x2:
2760 case Intrinsic::nvvm_tcgen05_st_16x64b_x4:
2761 case Intrinsic::nvvm_tcgen05_st_16x64b_x8:
2762 case Intrinsic::nvvm_tcgen05_st_16x64b_x16:
2763 case Intrinsic::nvvm_tcgen05_st_16x64b_x32:
2764 case Intrinsic::nvvm_tcgen05_st_16x64b_x128:
2765 case Intrinsic::nvvm_tcgen05_st_16x128b_x1:
2766 case Intrinsic::nvvm_tcgen05_st_16x128b_x2:
2767 case Intrinsic::nvvm_tcgen05_st_16x128b_x4:
2768 case Intrinsic::nvvm_tcgen05_st_16x128b_x8:
2769 case Intrinsic::nvvm_tcgen05_st_16x128b_x16:
2770 case Intrinsic::nvvm_tcgen05_st_16x128b_x32:
2771 case Intrinsic::nvvm_tcgen05_st_16x128b_x64:
2772 case Intrinsic::nvvm_tcgen05_st_16x256b_x1:
2773 case Intrinsic::nvvm_tcgen05_st_16x256b_x2:
2774 case Intrinsic::nvvm_tcgen05_st_16x256b_x4:
2775 case Intrinsic::nvvm_tcgen05_st_16x256b_x8:
2776 case Intrinsic::nvvm_tcgen05_st_16x256b_x16:
2777 case Intrinsic::nvvm_tcgen05_st_16x256b_x32:
2778 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x1:
2779 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x2:
2780 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x4:
2781 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x8:
2782 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x16:
2783 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x32:
2784 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x64:
2785 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x128:
2786 case Intrinsic::nvvm_tcgen05_st_32x32b_x1:
2787 case Intrinsic::nvvm_tcgen05_st_32x32b_x2:
2788 case Intrinsic::nvvm_tcgen05_st_32x32b_x4:
2789 case Intrinsic::nvvm_tcgen05_st_32x32b_x8:
2790 case Intrinsic::nvvm_tcgen05_st_32x32b_x16:
2791 case Intrinsic::nvvm_tcgen05_st_32x32b_x32:
2792 case Intrinsic::nvvm_tcgen05_st_16x64b_x64:
2793 case Intrinsic::nvvm_tcgen05_st_32x32b_x64:
2794 case Intrinsic::nvvm_tcgen05_st_32x32b_x128:
2795 return lowerTcgen05St(Op, DAG);
2796 case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1:
2797 case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2:
2798 case Intrinsic::nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1:
2799 case Intrinsic::nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2:
2800 case Intrinsic::nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg1:
2801 case Intrinsic::nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg2:
2802 case Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg1:
2803 case Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg2:
2804 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg1:
2805 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg2:
2806 case Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1:
2807 case Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2:
2808 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1:
2809 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2:
2810 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1:
2811 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2:
2812 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift:
2813 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift:
2814 case Intrinsic::
2815 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift:
2816 case Intrinsic::
2817 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift:
2818 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1_ashift:
2819 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2_ashift:
2820 case Intrinsic::
2821 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1_ashift:
2822 case Intrinsic::
2823 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2_ashift:
2825 }
2826 return Op;
2827}
2828
2830 SelectionDAG &DAG) {
2831
2832 SDNode *N = Op.getNode();
2833 if (N->getOperand(1).getValueType() != MVT::i128) {
2834 // return, if the operand is already lowered
2835 return SDValue();
2836 }
2837
2838 unsigned IID =
2839 cast<ConstantSDNode>(N->getOperand(0).getNode())->getZExtValue();
2840 auto Opcode = [&]() {
2841 switch (IID) {
2842 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled:
2843 return NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_IS_CANCELED;
2844 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_x:
2845 return NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_X;
2846 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_y:
2847 return NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_Y;
2848 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_z:
2849 return NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_Z;
2850 default:
2851 llvm_unreachable("unsupported/unhandled intrinsic");
2852 }
2853 }();
2854
2855 SDLoc DL(N);
2856 SDValue TryCancelResponse = N->getOperand(1);
2857 SDValue Cast = DAG.getNode(ISD::BITCAST, DL, MVT::v2i64, TryCancelResponse);
2858 SDValue TryCancelResponse0 =
2859 DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i64, Cast,
2860 DAG.getIntPtrConstant(0, DL));
2861 SDValue TryCancelResponse1 =
2862 DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i64, Cast,
2863 DAG.getIntPtrConstant(1, DL));
2864
2865 return DAG.getNode(Opcode, DL, N->getVTList(),
2866 {TryCancelResponse0, TryCancelResponse1});
2867}
2868
2870 SDNode *N = Op.getNode();
2871 SDLoc DL(N);
2872 SDValue F32Vec = N->getOperand(1);
2873 SDValue RBits = N->getOperand(2);
2874
2875 unsigned IntrinsicID = N->getConstantOperandVal(0);
2876
2877 // Extract the 4 float elements from the vector
2879 for (unsigned i = 0; i < 4; ++i)
2880 Ops.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32, F32Vec,
2881 DAG.getIntPtrConstant(i, DL)));
2882
2884
2885 auto [OpCode, RetTy, CvtModeFlag] =
2886 [&]() -> std::tuple<unsigned, MVT::SimpleValueType, uint32_t> {
2887 switch (IntrinsicID) {
2888 case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite:
2889 return {NVPTXISD::CVT_E4M3X4_F32X4_RS_SF, MVT::v4i8,
2890 CvtMode::RS | CvtMode::RELU_FLAG};
2891 case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite:
2892 return {NVPTXISD::CVT_E4M3X4_F32X4_RS_SF, MVT::v4i8, CvtMode::RS};
2893 case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite:
2894 return {NVPTXISD::CVT_E5M2X4_F32X4_RS_SF, MVT::v4i8,
2895 CvtMode::RS | CvtMode::RELU_FLAG};
2896 case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite:
2897 return {NVPTXISD::CVT_E5M2X4_F32X4_RS_SF, MVT::v4i8, CvtMode::RS};
2898 case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite:
2899 return {NVPTXISD::CVT_E2M3X4_F32X4_RS_SF, MVT::v4i8,
2900 CvtMode::RS | CvtMode::RELU_FLAG};
2901 case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite:
2902 return {NVPTXISD::CVT_E2M3X4_F32X4_RS_SF, MVT::v4i8, CvtMode::RS};
2903 case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite:
2904 return {NVPTXISD::CVT_E3M2X4_F32X4_RS_SF, MVT::v4i8,
2905 CvtMode::RS | CvtMode::RELU_FLAG};
2906 case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite:
2907 return {NVPTXISD::CVT_E3M2X4_F32X4_RS_SF, MVT::v4i8, CvtMode::RS};
2908 case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite:
2909 return {NVPTXISD::CVT_E2M1X4_F32X4_RS_SF, MVT::i16,
2910 CvtMode::RS | CvtMode::RELU_FLAG};
2911 case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite:
2912 return {NVPTXISD::CVT_E2M1X4_F32X4_RS_SF, MVT::i16, CvtMode::RS};
2913 default:
2914 llvm_unreachable("unsupported/unhandled intrinsic");
2915 }
2916 }();
2917
2918 Ops.push_back(RBits);
2919 Ops.push_back(DAG.getConstant(CvtModeFlag, DL, MVT::i32));
2920
2921 return DAG.getNode(OpCode, DL, RetTy, Ops);
2922}
2923
2925 const unsigned Mode = [&]() {
2926 switch (Op->getConstantOperandVal(0)) {
2927 case Intrinsic::nvvm_prmt:
2929 case Intrinsic::nvvm_prmt_b4e:
2931 case Intrinsic::nvvm_prmt_ecl:
2933 case Intrinsic::nvvm_prmt_ecr:
2935 case Intrinsic::nvvm_prmt_f4e:
2937 case Intrinsic::nvvm_prmt_rc16:
2939 case Intrinsic::nvvm_prmt_rc8:
2941 default:
2942 llvm_unreachable("unsupported/unhandled intrinsic");
2943 }
2944 }();
2945 SDLoc DL(Op);
2946 SDValue A = Op->getOperand(1);
2947 SDValue B = Op.getNumOperands() == 4 ? Op.getOperand(2)
2948 : DAG.getConstant(0, DL, MVT::i32);
2949 SDValue Selector = (Op->op_end() - 1)->get();
2950 return getPRMT(A, B, Selector, DL, DAG, Mode);
2951}
2952
2954 switch (Op->getConstantOperandVal(1)) {
2955 default:
2956 return Op;
2957
2958 // These tcgen05 intrinsics return a v2i32, which is legal, so we have to
2959 // lower them through LowerOperation() instead of ReplaceNodeResults().
2960 case Intrinsic::nvvm_tcgen05_ld_16x64b_x2:
2961 case Intrinsic::nvvm_tcgen05_ld_16x128b_x1:
2962 case Intrinsic::nvvm_tcgen05_ld_32x32b_x2:
2963 if (auto Res = lowerTcgen05Ld(Op.getNode(), DAG))
2964 return DAG.getMergeValues({Res->first, Res->second}, SDLoc(Op));
2965 return SDValue();
2966
2967 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x2:
2968 if (auto Res = lowerTcgen05Ld(Op.getNode(), DAG, /*HasOffset=*/true))
2969 return DAG.getMergeValues({Res->first, Res->second}, SDLoc(Op));
2970 return SDValue();
2971 }
2972}
2973
2975 switch (Op->getConstantOperandVal(0)) {
2976 default:
2977 return Op;
2978 case Intrinsic::nvvm_prmt:
2979 case Intrinsic::nvvm_prmt_b4e:
2980 case Intrinsic::nvvm_prmt_ecl:
2981 case Intrinsic::nvvm_prmt_ecr:
2982 case Intrinsic::nvvm_prmt_f4e:
2983 case Intrinsic::nvvm_prmt_rc16:
2984 case Intrinsic::nvvm_prmt_rc8:
2985 return lowerPrmtIntrinsic(Op, DAG);
2986 case Intrinsic::nvvm_internal_addrspace_wrap:
2987 return Op.getOperand(1);
2988 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled:
2989 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_x:
2990 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_y:
2991 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_z:
2993 case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite:
2994 case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite:
2995 case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite:
2996 case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite:
2997 case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite:
2998 case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite:
2999 case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite:
3000 case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite:
3001 case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite:
3002 case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite:
3003 return lowerCvtRSIntrinsics(Op, DAG);
3004 }
3005}
3006
3007// In PTX 64-bit CTLZ and CTPOP are supported, but they return a 32-bit value.
3008// Lower these into a node returning the correct type which is zero-extended
3009// back to the correct size.
3011 SDValue V = Op->getOperand(0);
3012 assert(V.getValueType() == MVT::i64 &&
3013 "Unexpected CTLZ/CTPOP type to legalize");
3014
3015 SDLoc DL(Op);
3016 SDValue CT = DAG.getNode(Op->getOpcode(), DL, MVT::i32, V);
3017 return DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, CT, SDNodeFlags::NonNeg);
3018}
3019
3021 unsigned Opcode, SelectionDAG &DAG) {
3022 assert(A.getValueType() == MVT::i64 && B.getValueType() == MVT::i64);
3023
3024 const auto *AmtConst = dyn_cast<ConstantSDNode>(ShiftAmount);
3025 if (!AmtConst)
3026 return SDValue();
3027 const auto Amt = AmtConst->getZExtValue() & 63;
3028
3029 SDValue UnpackA =
3030 DAG.getNode(NVPTXISD::UNPACK_VECTOR, DL, {MVT::i32, MVT::i32}, A);
3031 SDValue UnpackB =
3032 DAG.getNode(NVPTXISD::UNPACK_VECTOR, DL, {MVT::i32, MVT::i32}, B);
3033
3034 // Arch is Little endiain: 0 = low bits, 1 = high bits
3035 SDValue ALo = UnpackA.getValue(0);
3036 SDValue AHi = UnpackA.getValue(1);
3037 SDValue BLo = UnpackB.getValue(0);
3038 SDValue BHi = UnpackB.getValue(1);
3039
3040 // The bitfeild consists of { AHi : ALo : BHi : BLo }
3041 //
3042 // * FSHL, Amt < 32 - The window will contain { AHi : ALo : BHi }
3043 // * FSHL, Amt >= 32 - The window will contain { ALo : BHi : BLo }
3044 // * FSHR, Amt < 32 - The window will contain { ALo : BHi : BLo }
3045 // * FSHR, Amt >= 32 - The window will contain { AHi : ALo : BHi }
3046 //
3047 // Note that Amt = 0 and Amt = 32 are special cases where 32-bit funnel shifts
3048 // are not needed at all. Amt = 0 is a no-op producing either A or B depending
3049 // on the direction. Amt = 32 can be implemented by a packing and unpacking
3050 // move to select and arrange the 32bit values. For simplicity, these cases
3051 // are not handled here explicitly and instead we rely on DAGCombiner to
3052 // remove the no-op funnel shifts we insert.
3053 auto [High, Mid, Low] = ((Opcode == ISD::FSHL) == (Amt < 32))
3054 ? std::make_tuple(AHi, ALo, BHi)
3055 : std::make_tuple(ALo, BHi, BLo);
3056
3057 SDValue NewAmt = DAG.getConstant(Amt & 31, DL, MVT::i32);
3058 SDValue RHi = DAG.getNode(Opcode, DL, MVT::i32, {High, Mid, NewAmt});
3059 SDValue RLo = DAG.getNode(Opcode, DL, MVT::i32, {Mid, Low, NewAmt});
3060
3061 return DAG.getNode(NVPTXISD::BUILD_VECTOR, DL, MVT::i64, {RLo, RHi});
3062}
3063
3065 return expandFSH64(Op->getOperand(0), Op->getOperand(1), Op->getOperand(2),
3066 SDLoc(Op), Op->getOpcode(), DAG);
3067}
3068
3070 unsigned Opcode = Op->getOpcode() == ISD::ROTL ? ISD::FSHL : ISD::FSHR;
3071 return expandFSH64(Op->getOperand(0), Op->getOperand(0), Op->getOperand(1),
3072 SDLoc(Op), Opcode, DAG);
3073}
3074
3076 // Lower (frem x, y) into (sub x, (mul (ftrunc (div x, y)) y)),
3077 // i.e. "poor man's fmod()". When y is infinite, x is returned. This matches
3078 // the semantics of LLVM's frem.
3079 SDLoc DL(Op);
3080 SDValue X = Op->getOperand(0);
3081 SDValue Y = Op->getOperand(1);
3082 EVT Ty = Op.getValueType();
3083 SDNodeFlags Flags = Op->getFlags();
3084
3085 SDValue Div = DAG.getNode(ISD::FDIV, DL, Ty, X, Y, Flags);
3086 SDValue Trunc = DAG.getNode(ISD::FTRUNC, DL, Ty, Div, Flags);
3087 SDValue Mul = DAG.getNode(ISD::FMUL, DL, Ty, Trunc, Y,
3089 SDValue Sub = DAG.getNode(ISD::FSUB, DL, Ty, X, Mul,
3091
3092 if (Flags.hasNoInfs())
3093 return Sub;
3094
3095 // If Y is infinite, return X
3096 SDValue AbsY = DAG.getNode(ISD::FABS, DL, Ty, Y);
3097 SDValue Inf =
3098 DAG.getConstantFP(APFloat::getInf(Ty.getFltSemantics()), DL, Ty);
3099 SDValue IsInf = DAG.getSetCC(DL, MVT::i1, AbsY, Inf, ISD::SETEQ);
3100 return DAG.getSelect(DL, Ty, IsInf, X, Sub);
3101}
3102
3104 assert(Op.getValueType() == MVT::i1 && "Custom lowering enabled only for i1");
3105
3106 SDValue Cond = Op->getOperand(0);
3107 SDValue TrueVal = Op->getOperand(1);
3108 SDValue FalseVal = Op->getOperand(2);
3109 SDLoc DL(Op);
3110
3111 // If both operands are truncated, we push the select through the truncates.
3112 if (TrueVal.getOpcode() == ISD::TRUNCATE &&
3113 FalseVal.getOpcode() == ISD::TRUNCATE) {
3114 TrueVal = TrueVal.getOperand(0);
3115 FalseVal = FalseVal.getOperand(0);
3116
3117 EVT VT = TrueVal.getSimpleValueType().bitsLE(FalseVal.getSimpleValueType())
3118 ? TrueVal.getValueType()
3119 : FalseVal.getValueType();
3120 TrueVal = DAG.getAnyExtOrTrunc(TrueVal, DL, VT);
3121 FalseVal = DAG.getAnyExtOrTrunc(FalseVal, DL, VT);
3122 SDValue Select = DAG.getSelect(DL, VT, Cond, TrueVal, FalseVal);
3123 return DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, Select);
3124 }
3125
3126 // Otherwise, expand the select into a series of logical operations. These
3127 // often can be folded into other operations either by us or ptxas.
3128 TrueVal = DAG.getFreeze(TrueVal);
3129 FalseVal = DAG.getFreeze(FalseVal);
3130 SDValue And1 = DAG.getNode(ISD::AND, DL, MVT::i1, Cond, TrueVal);
3131 SDValue NotCond = DAG.getNOT(DL, Cond, MVT::i1);
3132 SDValue And2 = DAG.getNode(ISD::AND, DL, MVT::i1, NotCond, FalseVal);
3133 SDValue Or = DAG.getNode(ISD::OR, DL, MVT::i1, And1, And2);
3134 return Or;
3135}
3136
3138 SDNode *N = Op.getNode();
3139
3140 SDValue Chain = N->getOperand(0);
3141 SDValue Val = N->getOperand(1);
3142 SDValue BasePtr = N->getOperand(2);
3143 SDValue Offset = N->getOperand(3);
3144 SDValue Mask = N->getOperand(4);
3145
3146 SDLoc DL(N);
3147 EVT ValVT = Val.getValueType();
3148 MemSDNode *MemSD = cast<MemSDNode>(N);
3149 assert(ValVT.isVector() && "Masked vector store must have vector type");
3150 assert(MemSD->getAlign() >= DAG.getEVTAlign(ValVT) &&
3151 "Unexpected alignment for masked store");
3152
3153 unsigned Opcode = 0;
3154 switch (ValVT.getSimpleVT().SimpleTy) {
3155 default:
3156 llvm_unreachable("Unexpected masked vector store type");
3157 case MVT::v4i64:
3158 case MVT::v4f64: {
3159 Opcode = NVPTXISD::StoreV4;
3160 break;
3161 }
3162 case MVT::v8i32:
3163 case MVT::v8f32: {
3164 Opcode = NVPTXISD::StoreV8;
3165 break;
3166 }
3167 }
3168
3170
3171 // Construct the new SDNode. First operand is the chain.
3172 Ops.push_back(Chain);
3173
3174 // The next N operands are the values to store. Encode the mask into the
3175 // values using the sentinel register 0 to represent a masked-off element.
3176 assert(Mask.getValueType().isVector() &&
3177 Mask.getValueType().getVectorElementType() == MVT::i1 &&
3178 "Mask must be a vector of i1");
3179 assert(Mask.getOpcode() == ISD::BUILD_VECTOR &&
3180 "Mask expected to be a BUILD_VECTOR");
3181 assert(Mask.getValueType().getVectorNumElements() ==
3182 ValVT.getVectorNumElements() &&
3183 "Mask size must be the same as the vector size");
3184 for (auto [I, Op] : enumerate(Mask->ops())) {
3185 // Mask elements must be constants.
3186 if (Op.getNode()->getAsZExtVal() == 0) {
3187 // Append a sentinel register 0 to the Ops vector to represent a masked
3188 // off element, this will be handled in tablegen
3190 ValVT.getVectorElementType()));
3191 } else {
3192 // Extract the element from the vector to store
3193 SDValue ExtVal =
3195 Val, DAG.getIntPtrConstant(I, DL));
3196 Ops.push_back(ExtVal);
3197 }
3198 }
3199
3200 // Next, the pointer operand.
3201 Ops.push_back(BasePtr);
3202
3203 // Finally, the offset operand. We expect this to always be undef, and it will
3204 // be ignored in lowering, but to mirror the handling of the other vector
3205 // store instructions we include it in the new SDNode.
3206 assert(Offset.getOpcode() == ISD::UNDEF &&
3207 "Offset operand expected to be undef");
3208 Ops.push_back(Offset);
3209
3210 SDValue NewSt =
3211 DAG.getMemIntrinsicNode(Opcode, DL, DAG.getVTList(MVT::Other), Ops,
3212 MemSD->getMemoryVT(), MemSD->getMemOperand());
3213
3214 return NewSt;
3215}
3216
3217SDValue
3219 switch (Op.getOpcode()) {
3220 case ISD::RETURNADDR:
3221 return SDValue();
3222 case ISD::FRAMEADDR:
3223 return SDValue();
3224 case ISD::ADDRSPACECAST:
3225 return LowerADDRSPACECAST(Op, DAG);
3227 return lowerIntrinsicWChain(Op, DAG);
3229 return lowerIntrinsicWOChain(Op, DAG);
3231 return lowerIntrinsicVoid(Op, DAG);
3232 case ISD::BUILD_VECTOR:
3233 return LowerBUILD_VECTOR(Op, DAG);
3234 case ISD::BITCAST:
3235 return LowerBITCAST(Op, DAG);
3237 return Op;
3239 return LowerEXTRACT_VECTOR_ELT(Op, DAG);
3241 return LowerINSERT_VECTOR_ELT(Op, DAG);
3243 return LowerVECTOR_SHUFFLE(Op, DAG);
3245 return LowerCONCAT_VECTORS(Op, DAG);
3246 case ISD::VECREDUCE_FMAX:
3247 case ISD::VECREDUCE_FMIN:
3248 case ISD::VECREDUCE_FMAXIMUM:
3249 case ISD::VECREDUCE_FMINIMUM:
3250 return LowerVECREDUCE(Op, DAG);
3251 case ISD::STORE:
3252 return LowerSTORE(Op, DAG);
3253 case ISD::MSTORE: {
3254 assert(STI.has256BitVectorLoadStore(
3255 cast<MemSDNode>(Op.getNode())->getAddressSpace()) &&
3256 "Masked store vector not supported on subtarget.");
3257 return lowerMSTORE(Op, DAG);
3258 }
3259 case ISD::LOAD:
3260 return LowerLOAD(Op, DAG);
3261 case ISD::MLOAD:
3262 return LowerMLOAD(Op, DAG);
3263 case ISD::SHL_PARTS:
3264 return LowerShiftLeftParts(Op, DAG);
3265 case ISD::SRA_PARTS:
3266 case ISD::SRL_PARTS:
3267 return LowerShiftRightParts(Op, DAG);
3268 case ISD::SELECT:
3269 return lowerSELECT(Op, DAG);
3270 case ISD::FROUND:
3271 return LowerFROUND(Op, DAG);
3272 case ISD::FCOPYSIGN:
3273 return LowerFCOPYSIGN(Op, DAG);
3274 case ISD::SINT_TO_FP:
3275 case ISD::UINT_TO_FP:
3276 return LowerINT_TO_FP(Op, DAG);
3277 case ISD::FP_TO_SINT:
3278 case ISD::FP_TO_UINT:
3279 return LowerFP_TO_INT(Op, DAG);
3280 case ISD::FP_ROUND:
3281 return LowerFP_ROUND(Op, DAG);
3282 case ISD::FP_EXTEND:
3283 return LowerFP_EXTEND(Op, DAG);
3284 case ISD::BR_JT:
3285 return LowerBR_JT(Op, DAG);
3286 case ISD::VAARG:
3287 return LowerVAARG(Op, DAG);
3288 case ISD::VASTART:
3289 return LowerVASTART(Op, DAG);
3290 case ISD::FSHL:
3291 case ISD::FSHR:
3292 return lowerFSH(Op, DAG);
3293 case ISD::ROTL:
3294 case ISD::ROTR:
3295 return lowerROT(Op, DAG);
3296 case ISD::ABS:
3297 case ISD::SMIN:
3298 case ISD::SMAX:
3299 case ISD::UMIN:
3300 case ISD::UMAX:
3301 case ISD::ADD:
3302 case ISD::SUB:
3303 case ISD::MUL:
3304 case ISD::SHL:
3305 case ISD::SREM:
3306 case ISD::UREM:
3307 return LowerVectorArith(Op, DAG);
3308 case ISD::DYNAMIC_STACKALLOC:
3309 return LowerDYNAMIC_STACKALLOC(Op, DAG);
3310 case ISD::STACKRESTORE:
3311 return LowerSTACKRESTORE(Op, DAG);
3312 case ISD::STACKSAVE:
3313 return LowerSTACKSAVE(Op, DAG);
3314 case ISD::CopyToReg:
3315 return LowerCopyToReg_128(Op, DAG);
3316 case ISD::FADD:
3317 case ISD::FSUB:
3318 case ISD::FMUL:
3319 // Used only for bf16 on SM80, where we select fma for non-ftz operation
3320 return PromoteBinOpIfF32FTZ(Op, DAG);
3321 case ISD::CTPOP:
3322 case ISD::CTLZ:
3323 return lowerCTLZCTPOP(Op, DAG);
3324 case ISD::FREM:
3325 return lowerFREM(Op, DAG);
3326 case ISD::BSWAP:
3327 return lowerBSWAP(Op, DAG);
3328 default:
3329 llvm_unreachable("Custom lowering not defined for operation");
3330 }
3331}
3332
3333SDValue NVPTXTargetLowering::LowerBR_JT(SDValue Op, SelectionDAG &DAG) const {
3334 SDLoc DL(Op);
3335 SDValue Chain = Op.getOperand(0);
3336 const auto *JT = cast<JumpTableSDNode>(Op.getOperand(1));
3337 SDValue Index = Op.getOperand(2);
3338
3339 unsigned JId = JT->getIndex();
3341 ArrayRef<MachineBasicBlock *> MBBs = MJTI->getJumpTables()[JId].MBBs;
3342
3343 SDValue IdV = DAG.getConstant(JId, DL, MVT::i32);
3344
3345 // Generate BrxStart node
3346 SDVTList VTs = DAG.getVTList(MVT::Other, MVT::Glue);
3347 Chain = DAG.getNode(NVPTXISD::BrxStart, DL, VTs, Chain, IdV);
3348
3349 // Generate BrxItem nodes
3350 assert(!MBBs.empty());
3351 for (MachineBasicBlock *MBB : MBBs.drop_back())
3352 Chain = DAG.getNode(NVPTXISD::BrxItem, DL, VTs, Chain.getValue(0),
3353 DAG.getBasicBlock(MBB), Chain.getValue(1));
3354
3355 // Generate BrxEnd nodes
3356 SDValue EndOps[] = {Chain.getValue(0), DAG.getBasicBlock(MBBs.back()), Index,
3357 IdV, Chain.getValue(1)};
3358 SDValue BrxEnd = DAG.getNode(NVPTXISD::BrxEnd, DL, MVT::Other, EndOps);
3359
3360 return BrxEnd;
3361}
3362
3363// This will prevent AsmPrinter from trying to print the jump tables itself.
3367
3368SDValue NVPTXTargetLowering::LowerADDRSPACECAST(SDValue Op,
3369 SelectionDAG &DAG) const {
3371 unsigned SrcAS = N->getSrcAddressSpace();
3372 unsigned DestAS = N->getDestAddressSpace();
3373 if (SrcAS != llvm::ADDRESS_SPACE_GENERIC &&
3374 DestAS != llvm::ADDRESS_SPACE_GENERIC) {
3375 // Shared and SharedCluster can be converted to each other through generic
3376 // space
3377 if ((SrcAS == llvm::ADDRESS_SPACE_SHARED &&
3380 DestAS == llvm::ADDRESS_SPACE_SHARED)) {
3381 SDLoc DL(Op.getNode());
3382 const MVT GenerictVT =
3384 SDValue GenericConversion = DAG.getAddrSpaceCast(
3385 DL, GenerictVT, Op.getOperand(0), SrcAS, ADDRESS_SPACE_GENERIC);
3386 SDValue SharedClusterConversion =
3387 DAG.getAddrSpaceCast(DL, Op.getValueType(), GenericConversion,
3388 ADDRESS_SPACE_GENERIC, DestAS);
3389 return SharedClusterConversion;
3390 }
3391
3392 return DAG.getUNDEF(Op.getValueType());
3393 }
3394
3395 return Op;
3396}
3397
3398// This function is almost a copy of SelectionDAG::expandVAArg().
3399// The only diff is that this one produces loads from local address space.
3400SDValue NVPTXTargetLowering::LowerVAARG(SDValue Op, SelectionDAG &DAG) const {
3401 const TargetLowering *TLI = STI.getTargetLowering();
3402 SDLoc DL(Op);
3403
3404 SDNode *Node = Op.getNode();
3405 const Value *V = cast<SrcValueSDNode>(Node->getOperand(2))->getValue();
3406 EVT VT = Node->getValueType(0);
3407 auto *Ty = VT.getTypeForEVT(*DAG.getContext());
3408 SDValue Tmp1 = Node->getOperand(0);
3409 SDValue Tmp2 = Node->getOperand(1);
3410 const MaybeAlign MA(Node->getConstantOperandVal(3));
3411
3412 SDValue VAListLoad = DAG.getLoad(TLI->getPointerTy(DAG.getDataLayout()), DL,
3413 Tmp1, Tmp2, MachinePointerInfo(V));
3414 SDValue VAList = VAListLoad;
3415
3416 if (MA && *MA > TLI->getMinStackArgumentAlignment()) {
3417 VAList = DAG.getNode(
3418 ISD::ADD, DL, VAList.getValueType(), VAList,
3419 DAG.getConstant(MA->value() - 1, DL, VAList.getValueType()));
3420
3421 VAList = DAG.getNode(ISD::AND, DL, VAList.getValueType(), VAList,
3422 DAG.getSignedConstant(-(int64_t)MA->value(), DL,
3423 VAList.getValueType()));
3424 }
3425
3426 // Increment the pointer, VAList, to the next vaarg
3427 Tmp1 = DAG.getNode(ISD::ADD, DL, VAList.getValueType(), VAList,
3429 DL, VAList.getValueType()));
3430
3431 // Store the incremented VAList to the legalized pointer
3432 Tmp1 = DAG.getStore(VAListLoad.getValue(1), DL, Tmp1, Tmp2,
3433 MachinePointerInfo(V));
3434
3435 const Value *SrcV = Constant::getNullValue(
3437
3438 // Load the actual argument out of the pointer VAList
3439 return DAG.getLoad(VT, DL, Tmp1, VAList, MachinePointerInfo(SrcV));
3440}
3441
3442SDValue NVPTXTargetLowering::LowerVASTART(SDValue Op, SelectionDAG &DAG) const {
3443 const TargetLowering *TLI = STI.getTargetLowering();
3444 SDLoc DL(Op);
3445 EVT PtrVT = TLI->getPointerTy(DAG.getDataLayout());
3446
3447 // Store the address of unsized array <function>_vararg[] in the ap object.
3448 SDValue VAReg = getParamSymbol(DAG, /* vararg */ -1, PtrVT);
3449
3450 const Value *SV = cast<SrcValueSDNode>(Op.getOperand(2))->getValue();
3451 return DAG.getStore(Op.getOperand(0), DL, VAReg, Op.getOperand(1),
3452 MachinePointerInfo(SV));
3453}
3454
3455static std::pair<MemSDNode *, uint32_t>
3457 const NVPTXSubtarget &STI) {
3458 SDValue Chain = N->getOperand(0);
3459 SDValue BasePtr = N->getOperand(1);
3460 SDValue Mask = N->getOperand(3);
3461 [[maybe_unused]] SDValue Passthru = N->getOperand(4);
3462
3463 SDLoc DL(N);
3464 EVT ResVT = N->getValueType(0);
3465 assert(ResVT.isVector() && "Masked vector load must have vector type");
3466 // While we only expect poison passthru vectors as an input to the backend,
3467 // when the legalization framework splits a poison vector in half, it creates
3468 // two undef vectors, so we can technically expect those too.
3469 assert((Passthru.getOpcode() == ISD::POISON ||
3470 Passthru.getOpcode() == ISD::UNDEF) &&
3471 "Passthru operand expected to be poison or undef");
3472
3473 // Extract the mask and convert it to a uint32_t representing the used bytes
3474 // of the entire vector load
3475 uint32_t UsedBytesMask = 0;
3476 uint32_t ElementSizeInBits = ResVT.getVectorElementType().getSizeInBits();
3477 assert(ElementSizeInBits % 8 == 0 && "Unexpected element size");
3478 uint32_t ElementSizeInBytes = ElementSizeInBits / 8;
3479 uint32_t ElementMask = (1u << ElementSizeInBytes) - 1u;
3480
3481 for (SDValue Op : reverse(Mask->ops())) {
3482 // We technically only want to do this shift for every
3483 // iteration *but* the first, but in the first iteration UsedBytesMask is 0,
3484 // so this shift is a no-op.
3485 UsedBytesMask <<= ElementSizeInBytes;
3486
3487 // Mask elements must be constants.
3488 if (Op->getAsZExtVal() != 0)
3489 UsedBytesMask |= ElementMask;
3490 }
3491
3492 assert(UsedBytesMask != 0 && UsedBytesMask != UINT32_MAX &&
3493 "Unexpected masked load with elements masked all on or all off");
3494
3495 // Create a new load sd node to be handled normally by ReplaceLoadVector.
3496 MemSDNode *NewLD = cast<MemSDNode>(
3497 DAG.getLoad(ResVT, DL, Chain, BasePtr, N->getMemOperand()).getNode());
3498
3499 // If our subtarget does not support the used bytes mask pragma, "drop" the
3500 // mask by setting it to UINT32_MAX
3501 if (!STI.hasUsedBytesMaskPragma())
3502 UsedBytesMask = UINT32_MAX;
3503
3504 return {NewLD, UsedBytesMask};
3505}
3506
3507/// replaceLoadVector - Convert vector loads into multi-output scalar loads.
3508static std::optional<std::pair<SDValue, SDValue>>
3511 const EVT ResVT = LD->getValueType(0);
3512 const EVT MemVT = LD->getMemoryVT();
3513
3514 // If we're doing sign/zero extension as part of the load, avoid lowering to
3515 // a LoadV node. TODO: consider relaxing this restriction.
3516 if (ResVT != MemVT)
3517 return std::nullopt;
3518
3519 const auto NumEltsAndEltVT =
3520 getVectorLoweringShape(ResVT, STI, LD->getAddressSpace());
3521 if (!NumEltsAndEltVT)
3522 return std::nullopt;
3523 const auto [NumElts, EltVT] = NumEltsAndEltVT.value();
3524
3525 Align Alignment = LD->getAlign();
3526 const auto &TD = DAG.getDataLayout();
3527 Align PrefAlign = TD.getPrefTypeAlign(MemVT.getTypeForEVT(*DAG.getContext()));
3528 if (Alignment < PrefAlign) {
3529 // This load is not sufficiently aligned, so bail out and let this vector
3530 // load be scalarized. Note that we may still be able to emit smaller
3531 // vector loads. For example, if we are loading a <4 x float> with an
3532 // alignment of 8, this check will fail but the legalizer will try again
3533 // with 2 x <2 x float>, which will succeed with an alignment of 8.
3534 return std::nullopt;
3535 }
3536
3537 // If we have a masked load, convert it to a normal load now
3538 std::optional<uint32_t> UsedBytesMask = std::nullopt;
3539 if (LD->getOpcode() == ISD::MLOAD)
3540 std::tie(LD, UsedBytesMask) =
3542
3543 // Since LoadV2 is a target node, we cannot rely on DAG type legalization.
3544 // Therefore, we must ensure the type is legal. For i1 and i8, we set the
3545 // loaded type to i16 and propagate the "real" type as the memory type.
3546 const MVT LoadEltVT = (EltVT.getSizeInBits() < 16) ? MVT::i16 : EltVT;
3547
3548 unsigned Opcode;
3549 switch (NumElts) {
3550 default:
3551 return std::nullopt;
3552 case 2:
3553 Opcode = NVPTXISD::LoadV2;
3554 break;
3555 case 4:
3556 Opcode = NVPTXISD::LoadV4;
3557 break;
3558 case 8:
3559 Opcode = NVPTXISD::LoadV8;
3560 break;
3561 }
3562 auto ListVTs = SmallVector<EVT, 9>(NumElts, LoadEltVT);
3563 ListVTs.push_back(MVT::Other);
3564 SDVTList LdResVTs = DAG.getVTList(ListVTs);
3565
3566 SDLoc DL(LD);
3567
3568 // Copy regular operands
3569 SmallVector<SDValue, 8> OtherOps(LD->ops());
3570
3571 OtherOps.push_back(
3572 DAG.getConstant(UsedBytesMask.value_or(UINT32_MAX), DL, MVT::i32));
3573
3574 // The select routine does not have access to the LoadSDNode instance, so
3575 // pass along the extension information
3576 OtherOps.push_back(
3577 DAG.getIntPtrConstant(cast<LoadSDNode>(LD)->getExtensionType(), DL));
3578
3579 SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, DL, LdResVTs, OtherOps, MemVT,
3580 LD->getMemOperand());
3581
3582 SmallVector<SDValue> ScalarRes;
3583 if (EltVT.isVector()) {
3585 assert(NumElts * EltVT.getVectorNumElements() ==
3586 ResVT.getVectorNumElements());
3587 // Generate EXTRACT_VECTOR_ELTs to split v2[i,f,bf]16/v4i8 subvectors back
3588 // into individual elements.
3589 for (const unsigned I : llvm::seq(NumElts)) {
3590 SDValue SubVector = NewLD.getValue(I);
3591 DAG.ExtractVectorElements(SubVector, ScalarRes);
3592 }
3593 } else {
3594 for (const unsigned I : llvm::seq(NumElts)) {
3595 SDValue Res = NewLD.getValue(I);
3596 if (LoadEltVT != EltVT)
3597 Res = DAG.getNode(ISD::TRUNCATE, DL, EltVT, Res);
3598 ScalarRes.push_back(Res);
3599 }
3600 }
3601
3602 SDValue LoadChain = NewLD.getValue(NumElts);
3603
3604 const MVT BuildVecVT =
3605 MVT::getVectorVT(EltVT.getScalarType(), ScalarRes.size());
3606 SDValue BuildVec = DAG.getBuildVector(BuildVecVT, DL, ScalarRes);
3607 SDValue LoadValue = DAG.getBitcast(ResVT, BuildVec);
3608
3609 return {{LoadValue, LoadChain}};
3610}
3611
3614 const NVPTXSubtarget &STI) {
3615 if (auto Res = replaceLoadVector(N, DAG, STI))
3616 Results.append({Res->first, Res->second});
3617}
3618
3620 const NVPTXSubtarget &STI) {
3621 if (auto Res = replaceLoadVector(N, DAG, STI))
3622 return DAG.getMergeValues({Res->first, Res->second}, SDLoc(N));
3623 return SDValue();
3624}
3625
3626// v = ld i1* addr
3627// =>
3628// v1 = ld i8* addr (-> i16)
3629// v = trunc i16 to i1
3631 SDLoc dl(LD);
3632 assert(LD->getExtensionType() == ISD::NON_EXTLOAD);
3633 assert(LD->getValueType(0) == MVT::i1 && "Custom lowering for i1 load only");
3634 SDValue newLD = DAG.getExtLoad(ISD::ZEXTLOAD, dl, MVT::i16, LD->getChain(),
3635 LD->getBasePtr(), LD->getPointerInfo(),
3636 MVT::i8, LD->getAlign(),
3637 LD->getMemOperand()->getFlags());
3638 SDValue result = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, newLD);
3639 // The legalizer (the caller) is expecting two values from the legalized
3640 // load, so we build a MergeValues node for it. See ExpandUnalignedLoad()
3641 // in LegalizeDAG.cpp which also uses MergeValues.
3642 return DAG.getMergeValues({result, LD->getChain()}, dl);
3643}
3644
3645SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
3646 LoadSDNode *LD = cast<LoadSDNode>(Op);
3647
3648 if (Op.getValueType() == MVT::i1)
3649 return lowerLOADi1(LD, DAG);
3650
3651 // To improve CodeGen we'll legalize any-extend loads to zext loads. This is
3652 // how they'll be lowered in ISel anyway, and by doing this a little earlier
3653 // we allow for more DAG combine opportunities.
3654 if (LD->getExtensionType() == ISD::EXTLOAD) {
3655 assert(LD->getValueType(0).isInteger() && LD->getMemoryVT().isInteger() &&
3656 "Unexpected fpext-load");
3657 return DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(Op), Op.getValueType(),
3658 LD->getChain(), LD->getBasePtr(), LD->getMemoryVT(),
3659 LD->getMemOperand());
3660 }
3661
3662 llvm_unreachable("Unexpected custom lowering for load");
3663}
3664
3665SDValue NVPTXTargetLowering::LowerMLOAD(SDValue Op, SelectionDAG &DAG) const {
3666 // v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to handle
3667 // masked loads of these types and have to handle them here.
3668 // v2f32 also needs to be handled here if the subtarget has f32x2
3669 // instructions, making it legal.
3670 //
3671 // Note: misaligned masked loads should never reach this point
3672 // because the override of isLegalMaskedLoad in NVPTXTargetTransformInfo.cpp
3673 // will validate alignment. Therefore, we do not need to special case handle
3674 // them here.
3675 EVT VT = Op.getValueType();
3676 if (NVPTX::isPackedVectorTy(VT)) {
3678 cast<MemSDNode>(Op.getNode()), DAG, STI);
3679 MemSDNode *LD = std::get<0>(Result);
3680 uint32_t UsedBytesMask = std::get<1>(Result);
3681
3682 SDLoc DL(LD);
3683
3684 // Copy regular operands
3685 SmallVector<SDValue, 8> OtherOps(LD->ops());
3686
3687 OtherOps.push_back(DAG.getConstant(UsedBytesMask, DL, MVT::i32));
3688
3689 // We currently are not lowering extending loads, but pass the extension
3690 // type anyway as later handling expects it.
3691 OtherOps.push_back(
3692 DAG.getIntPtrConstant(cast<LoadSDNode>(LD)->getExtensionType(), DL));
3693 SDValue NewLD =
3694 DAG.getMemIntrinsicNode(NVPTXISD::MLoad, DL, LD->getVTList(), OtherOps,
3695 LD->getMemoryVT(), LD->getMemOperand());
3696 return NewLD;
3697 }
3698 return SDValue();
3699}
3700
3702 const NVPTXSubtarget &STI) {
3703 MemSDNode *N = cast<MemSDNode>(Op.getNode());
3704 SDValue Val = N->getOperand(1);
3705 SDLoc DL(N);
3706 const EVT ValVT = Val.getValueType();
3707 const EVT MemVT = N->getMemoryVT();
3708
3709 // If we're truncating as part of the store, avoid lowering to a StoreV node.
3710 // TODO: consider relaxing this restriction.
3711 if (ValVT != MemVT)
3712 return SDValue();
3713
3714 const auto NumEltsAndEltVT =
3715 getVectorLoweringShape(ValVT, STI, N->getAddressSpace());
3716 if (!NumEltsAndEltVT)
3717 return SDValue();
3718 const auto [NumElts, EltVT] = NumEltsAndEltVT.value();
3719
3720 const DataLayout &TD = DAG.getDataLayout();
3721
3722 Align Alignment = N->getAlign();
3723 Align PrefAlign = TD.getPrefTypeAlign(ValVT.getTypeForEVT(*DAG.getContext()));
3724 if (Alignment < PrefAlign) {
3725 // This store is not sufficiently aligned, so bail out and let this vector
3726 // store be scalarized. Note that we may still be able to emit smaller
3727 // vector stores. For example, if we are storing a <4 x float> with an
3728 // alignment of 8, this check will fail but the legalizer will try again
3729 // with 2 x <2 x float>, which will succeed with an alignment of 8.
3730 return SDValue();
3731 }
3732
3733 unsigned Opcode;
3734 switch (NumElts) {
3735 default:
3736 return SDValue();
3737 case 2:
3738 Opcode = NVPTXISD::StoreV2;
3739 break;
3740 case 4:
3741 Opcode = NVPTXISD::StoreV4;
3742 break;
3743 case 8:
3744 Opcode = NVPTXISD::StoreV8;
3745 break;
3746 }
3747
3749
3750 // First is the chain
3751 Ops.push_back(N->getOperand(0));
3752
3753 // Then the split values
3754 if (EltVT.isVector()) {
3756 assert(NumElts * EltVT.getVectorNumElements() ==
3757 ValVT.getVectorNumElements());
3758 // Combine individual elements into v2[i,f,bf]16/v4i8 subvectors to be
3759 // stored as b32s
3760 const unsigned NumEltsPerSubVector = EltVT.getVectorNumElements();
3761 for (const unsigned I : llvm::seq(NumElts)) {
3762 SmallVector<SDValue, 4> SubVectorElts;
3763 DAG.ExtractVectorElements(Val, SubVectorElts, I * NumEltsPerSubVector,
3764 NumEltsPerSubVector);
3765 Ops.push_back(DAG.getBuildVector(EltVT, DL, SubVectorElts));
3766 }
3767 } else {
3768 SDValue V = DAG.getBitcast(MVT::getVectorVT(EltVT, NumElts), Val);
3769 for (const unsigned I : llvm::seq(NumElts)) {
3770 SDValue ExtVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, V,
3771 DAG.getIntPtrConstant(I, DL));
3772
3773 // Since StoreV2 is a target node, we cannot rely on DAG type
3774 // legalization. Therefore, we must ensure the type is legal. For i1 and
3775 // i8, we set the stored type to i16 and propagate the "real" type as the
3776 // memory type.
3777 if (EltVT.getSizeInBits() < 16)
3778 ExtVal = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i16, ExtVal);
3779 Ops.push_back(ExtVal);
3780 }
3781 }
3782
3783 // Then any remaining arguments
3784 Ops.append(N->op_begin() + 2, N->op_end());
3785
3786 SDValue NewSt =
3787 DAG.getMemIntrinsicNode(Opcode, DL, DAG.getVTList(MVT::Other), Ops,
3788 N->getMemoryVT(), N->getMemOperand());
3789
3790 // return DCI.CombineTo(N, NewSt, true);
3791 return NewSt;
3792}
3793
3794SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
3795 StoreSDNode *Store = cast<StoreSDNode>(Op);
3796 EVT VT = Store->getMemoryVT();
3797
3798 if (VT == MVT::i1)
3799 return LowerSTOREi1(Op, DAG);
3800
3801 // Lower store of any other vector type, including v2f32 as we want to break
3802 // it apart since this is not a widely-supported type.
3803 return lowerSTOREVector(Op, DAG, STI);
3804}
3805
3806// st i1 v, addr
3807// =>
3808// v1 = zxt v to i16
3809// st.u8 i16, addr
3810SDValue NVPTXTargetLowering::LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const {
3811 SDNode *Node = Op.getNode();
3812 SDLoc dl(Node);
3813 StoreSDNode *ST = cast<StoreSDNode>(Node);
3814 SDValue Tmp1 = ST->getChain();
3815 SDValue Tmp2 = ST->getBasePtr();
3816 SDValue Tmp3 = ST->getValue();
3817 assert(Tmp3.getValueType() == MVT::i1 && "Custom lowering for i1 store only");
3818 Tmp3 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Tmp3);
3819 SDValue Result =
3820 DAG.getTruncStore(Tmp1, dl, Tmp3, Tmp2, ST->getPointerInfo(), MVT::i8,
3821 ST->getAlign(), ST->getMemOperand()->getFlags());
3822 return Result;
3823}
3824
3825SDValue NVPTXTargetLowering::LowerCopyToReg_128(SDValue Op,
3826 SelectionDAG &DAG) const {
3827 // Change the CopyToReg to take in two 64-bit operands instead of a 128-bit
3828 // operand so that it can pass the legalization.
3829
3830 assert(Op.getOperand(1).getValueType() == MVT::i128 &&
3831 "Custom lowering for 128-bit CopyToReg only");
3832
3833 SDNode *Node = Op.getNode();
3834 SDLoc DL(Node);
3835
3836 SDValue Cast = DAG.getBitcast(MVT::v2i64, Op->getOperand(2));
3837 SDValue Lo = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i64, Cast,
3838 DAG.getIntPtrConstant(0, DL));
3839 SDValue Hi = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i64, Cast,
3840 DAG.getIntPtrConstant(1, DL));
3841
3843 SmallVector<EVT, 3> ResultsType(Node->values());
3844
3845 NewOps[0] = Op->getOperand(0); // Chain
3846 NewOps[1] = Op->getOperand(1); // Dst Reg
3847 NewOps[2] = Lo; // Lower 64-bit
3848 NewOps[3] = Hi; // Higher 64-bit
3849 if (Op.getNumOperands() == 4)
3850 NewOps[4] = Op->getOperand(3); // Glue if exists
3851
3852 return DAG.getNode(ISD::CopyToReg, DL, ResultsType, NewOps);
3853}
3854
3855unsigned NVPTXTargetLowering::getNumRegisters(
3856 LLVMContext &Context, EVT VT,
3857 std::optional<MVT> RegisterVT = std::nullopt) const {
3858 if (VT == MVT::i128 && RegisterVT == MVT::i128)
3859 return 1;
3860 return TargetLoweringBase::getNumRegisters(Context, VT, RegisterVT);
3861}
3862
3863bool NVPTXTargetLowering::splitValueIntoRegisterParts(
3864 SelectionDAG &DAG, const SDLoc &DL, SDValue Val, SDValue *Parts,
3865 unsigned NumParts, MVT PartVT, std::optional<CallingConv::ID> CC) const {
3866 if (Val.getValueType() == MVT::i128 && NumParts == 1) {
3867 Parts[0] = Val;
3868 return true;
3869 }
3870 return false;
3871}
3872
3873// This creates target external symbol for a function parameter.
3874// Name of the symbol is composed from its index and the function name.
3875// Negative index corresponds to special parameter (unsized array) used for
3876// passing variable arguments.
3877SDValue NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, int I,
3878 EVT T) const {
3879 StringRef SavedStr = nvTM->getStrPool().save(
3881 return DAG.getExternalSymbol(SavedStr.data(), T);
3882}
3883
3884SDValue NVPTXTargetLowering::getCallParamSymbol(SelectionDAG &DAG, int I,
3885 EVT T) const {
3886 const StringRef SavedStr = nvTM->getStrPool().save("param" + Twine(I));
3887 return DAG.getExternalSymbol(SavedStr.data(), T);
3888}
3889
3891 SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
3892 const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &dl,
3893 SelectionDAG &DAG, SmallVectorImpl<SDValue> &InVals) const {
3894 const DataLayout &DL = DAG.getDataLayout();
3895 LLVMContext &Ctx = *DAG.getContext();
3896 auto PtrVT = getPointerTy(DAG.getDataLayout());
3897
3898 const Function &F = DAG.getMachineFunction().getFunction();
3899
3900 SDValue Root = DAG.getRoot();
3901 SmallVector<SDValue, 16> OutChains;
3902
3903 // argTypes.size() (or theArgs.size()) and Ins.size() need not match.
3904 // Ins.size() will be larger
3905 // * if there is an aggregate argument with multiple fields (each field
3906 // showing up separately in Ins)
3907 // * if there is a vector argument with more than typical vector-length
3908 // elements (generally if more than 4) where each vector element is
3909 // individually present in Ins.
3910 // So a different index should be used for indexing into Ins.
3911 // See similar issue in LowerCall.
3912
3913 auto AllIns = ArrayRef(Ins);
3914 for (const auto &Arg : F.args()) {
3915 const auto ArgIns = AllIns.take_while(
3916 [&](auto I) { return I.OrigArgIndex == Arg.getArgNo(); });
3917 AllIns = AllIns.drop_front(ArgIns.size());
3918
3919 Type *Ty = Arg.getType();
3920
3921 if (ArgIns.empty())
3922 report_fatal_error("Empty parameter types are not supported");
3923
3924 if (Arg.use_empty()) {
3925 // argument is dead
3926 for (const auto &In : ArgIns) {
3927 assert(!In.Used && "Arg.use_empty() is true but Arg is used?");
3928 InVals.push_back(DAG.getUNDEF(In.VT));
3929 }
3930 continue;
3931 }
3932
3933 SDValue ArgSymbol = getParamSymbol(DAG, Arg.getArgNo(), PtrVT);
3934
3935 // In the following cases, assign a node order of "i+1"
3936 // to newly created nodes. The SDNodes for params have to
3937 // appear in the same order as their order of appearance
3938 // in the original function. "i+1" holds that order.
3939 if (Arg.hasByValAttr()) {
3940 // Param has ByVal attribute
3941 // Return MoveParam(param symbol).
3942 // Ideally, the param symbol can be returned directly,
3943 // but when SDNode builder decides to use it in a CopyToReg(),
3944 // machine instruction fails because TargetExternalSymbol
3945 // (not lowered) is target dependent, and CopyToReg assumes
3946 // the source is lowered.
3947 assert(ArgIns.size() == 1 && "ByVal argument must be a pointer");
3948 const auto &ByvalIn = ArgIns[0];
3949 assert(getValueType(DL, Ty) == ByvalIn.VT &&
3950 "Ins type did not match function type");
3951 assert(ByvalIn.VT == PtrVT && "ByVal argument must be a pointer");
3952
3953 SDValue P;
3954 if (isKernelFunction(F)) {
3955 P = ArgSymbol;
3956 P.getNode()->setIROrder(Arg.getArgNo() + 1);
3957 } else {
3958 P = DAG.getNode(NVPTXISD::MoveParam, dl, ByvalIn.VT, ArgSymbol);
3959 P.getNode()->setIROrder(Arg.getArgNo() + 1);
3960 P = DAG.getAddrSpaceCast(dl, ByvalIn.VT, P, ADDRESS_SPACE_LOCAL,
3962 }
3963 InVals.push_back(P);
3964 } else {
3967 ComputePTXValueVTs(*this, DL, Ctx, CallConv, Ty, VTs, Offsets);
3968 assert(VTs.size() == ArgIns.size() && "Size mismatch");
3969 assert(VTs.size() == Offsets.size() && "Size mismatch");
3970
3971 const Align ArgAlign = getFunctionArgumentAlignment(
3972 &F, Ty, Arg.getArgNo() + AttributeList::FirstArgIndex, DL);
3973
3974 unsigned I = 0;
3975 const auto VI = VectorizePTXValueVTs(VTs, Offsets, ArgAlign);
3976 for (const unsigned NumElts : VI) {
3977 // i1 is loaded/stored as i8
3978 const EVT LoadVT = VTs[I] == MVT::i1 ? MVT::i8 : VTs[I];
3979 const EVT VecVT = getVectorizedVT(LoadVT, NumElts, Ctx);
3980
3981 SDValue VecAddr = DAG.getObjectPtrOffset(
3982 dl, ArgSymbol, TypeSize::getFixed(Offsets[I]));
3983
3984 const Align PartAlign = commonAlignment(ArgAlign, Offsets[I]);
3985 SDValue P =
3986 DAG.getLoad(VecVT, dl, Root, VecAddr,
3990 P.getNode()->setIROrder(Arg.getArgNo() + 1);
3991 for (const unsigned J : llvm::seq(NumElts)) {
3992 SDValue Elt = getExtractVectorizedValue(P, J, LoadVT, dl, DAG);
3993
3994 Elt = correctParamType(Elt, ArgIns[I + J].VT, ArgIns[I + J].Flags,
3995 DAG, dl);
3996 InVals.push_back(Elt);
3997 }
3998 I += NumElts;
3999 }
4000 }
4001 }
4002
4003 if (!OutChains.empty())
4004 DAG.setRoot(DAG.getTokenFactor(dl, OutChains));
4005
4006 return Chain;
4007}
4008
4009SDValue
4011 bool isVarArg,
4013 const SmallVectorImpl<SDValue> &OutVals,
4014 const SDLoc &dl, SelectionDAG &DAG) const {
4015 const Function &F = DAG.getMachineFunction().getFunction();
4016 Type *RetTy = F.getReturnType();
4017
4018 if (RetTy->isVoidTy()) {
4019 assert(OutVals.empty() && Outs.empty() && "Return value expected for void");
4020 return DAG.getNode(NVPTXISD::RET_GLUE, dl, MVT::Other, Chain);
4021 }
4022
4023 const DataLayout &DL = DAG.getDataLayout();
4024 LLVMContext &Ctx = *DAG.getContext();
4025
4026 const SDValue RetSymbol = DAG.getExternalSymbol("func_retval0", MVT::i32);
4027 const auto RetAlign = getFunctionParamOptimizedAlign(&F, RetTy, DL);
4028
4029 // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
4030 // 32-bits are sign extended or zero extended, depending on whether
4031 // they are signed or unsigned types.
4032 const bool ExtendIntegerRetVal =
4033 RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32;
4034
4037 ComputePTXValueVTs(*this, DL, Ctx, CallConv, RetTy, VTs, Offsets);
4038 assert(VTs.size() == OutVals.size() && "Bad return value decomposition");
4039
4040 const auto GetRetVal = [&](unsigned I) -> SDValue {
4041 SDValue RetVal = OutVals[I];
4043 RetVal.getValueType() &&
4044 "OutVal type should always be legal");
4045
4046 const EVT VTI = promoteScalarIntegerPTX(VTs[I]);
4047 const EVT StoreVT =
4048 ExtendIntegerRetVal ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);
4049 return correctParamType(RetVal, StoreVT, Outs[I].Flags, DAG, dl);
4050 };
4051
4052 unsigned I = 0;
4053 const auto VI = VectorizePTXValueVTs(VTs, Offsets, RetAlign);
4054 for (const unsigned NumElts : VI) {
4055 const MaybeAlign CurrentAlign = ExtendIntegerRetVal
4056 ? MaybeAlign(std::nullopt)
4057 : commonAlignment(RetAlign, Offsets[I]);
4058
4060 NumElts, dl, DAG, [&](unsigned K) { return GetRetVal(I + K); });
4061
4062 SDValue Ptr =
4063 DAG.getObjectPtrOffset(dl, RetSymbol, TypeSize::getFixed(Offsets[I]));
4064
4065 Chain = DAG.getStore(Chain, dl, Val, Ptr,
4067
4068 I += NumElts;
4069 }
4070
4071 return DAG.getNode(NVPTXISD::RET_GLUE, dl, MVT::Other, Chain);
4072}
4073
4075 SDValue Op, StringRef Constraint, std::vector<SDValue> &Ops,
4076 SelectionDAG &DAG) const {
4077 if (Constraint.size() > 1)
4078 return;
4080}
4081
4082// llvm.ptx.memcpy.const and llvm.ptx.memmove.const need to be modeled as
4083// TgtMemIntrinsic
4084// because we need the information that is only available in the "Value" type
4085// of destination
4086// pointer. In particular, the address space information.
4088 const CallBase &I,
4089 MachineFunction &MF,
4090 unsigned Intrinsic) const {
4091 switch (Intrinsic) {
4092 default:
4093 return false;
4094 case Intrinsic::nvvm_match_all_sync_i32p:
4095 case Intrinsic::nvvm_match_all_sync_i64p:
4096 Info.opc = ISD::INTRINSIC_W_CHAIN;
4097 // memVT is bogus. These intrinsics have IntrInaccessibleMemOnly attribute
4098 // in order to model data exchange with other threads, but perform no real
4099 // memory accesses.
4100 Info.memVT = MVT::i1;
4101
4102 // Our result depends on both our and other thread's arguments.
4104 return true;
4105 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_col:
4106 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_row:
4107 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_col_stride:
4108 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_row_stride:
4109 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_col:
4110 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row:
4111 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_col_stride:
4112 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row_stride:
4113 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_col:
4114 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_row:
4115 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_col_stride:
4116 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_row_stride:
4117 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_col:
4118 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_row:
4119 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_col_stride:
4120 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_row_stride:
4121 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_col:
4122 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_row:
4123 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_col_stride:
4124 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_row_stride:
4125 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_col:
4126 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_row:
4127 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_col_stride:
4128 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_row_stride: {
4129 Info.opc = ISD::INTRINSIC_W_CHAIN;
4130 Info.memVT = MVT::v8f16;
4131 Info.ptrVal = I.getArgOperand(0);
4132 Info.offset = 0;
4133 Info.flags = MachineMemOperand::MOLoad;
4134 Info.align = Align(16);
4135 return true;
4136 }
4137 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_col:
4138 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_col_stride:
4139 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_col_stride:
4140 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_col:
4141 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_row:
4142 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_row_stride:
4143 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_row_stride:
4144 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_row:
4145 case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_col:
4146 case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_col_stride:
4147 case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_row:
4148 case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_row_stride:
4149 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_col:
4150 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_col_stride:
4151 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_col_stride:
4152 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_col:
4153 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_row:
4154 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_row_stride:
4155 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row_stride:
4156 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row:
4157 case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_col:
4158 case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_col_stride:
4159 case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_row:
4160 case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_row_stride: {
4161 Info.opc = ISD::INTRINSIC_W_CHAIN;
4162 Info.memVT = MVT::v2i32;
4163 Info.ptrVal = I.getArgOperand(0);
4164 Info.offset = 0;
4165 Info.flags = MachineMemOperand::MOLoad;
4166 Info.align = Align(8);
4167 return true;
4168 }
4169
4170 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_col:
4171 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_col_stride:
4172 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_col_stride:
4173 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_col:
4174 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_row:
4175 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_row_stride:
4176 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_row_stride:
4177 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_row:
4178 case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_col:
4179 case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_col_stride:
4180 case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_row:
4181 case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_row_stride:
4182 case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_col:
4183 case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_col_stride:
4184 case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_row:
4185 case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_row_stride:
4186
4187 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_col:
4188 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_col_stride:
4189 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_col_stride:
4190 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_col:
4191 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_row:
4192 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_row_stride:
4193 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row_stride:
4194 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row:
4195 case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_col:
4196 case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_col_stride:
4197 case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_row:
4198 case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_row_stride:
4199 case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_col:
4200 case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_col_stride:
4201 case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row:
4202 case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row_stride:
4203 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16:
4204 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16:
4205 case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8:
4206 case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b4x16_p64:
4207 case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32:
4208 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b4x16_p64:
4209 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32: {
4210 Info.opc = ISD::INTRINSIC_W_CHAIN;
4211 Info.memVT = MVT::v4i32;
4212 Info.ptrVal = I.getArgOperand(0);
4213 Info.offset = 0;
4214 Info.flags = MachineMemOperand::MOLoad;
4215 Info.align = Align(16);
4216 return true;
4217 }
4218
4219 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_col:
4220 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_col_stride:
4221 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_col_stride:
4222 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_col:
4223 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_row:
4224 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_row_stride:
4225 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_row_stride:
4226 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_row:
4227
4228 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_col:
4229 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_col_stride:
4230 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_col_stride:
4231 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_col:
4232 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_row:
4233 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_row_stride:
4234 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_row_stride:
4235 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_row:
4236 case Intrinsic::nvvm_wmma_m8n8k128_load_a_b1_row:
4237 case Intrinsic::nvvm_wmma_m8n8k128_load_a_b1_row_stride:
4238 case Intrinsic::nvvm_wmma_m8n8k128_load_b_b1_col:
4239 case Intrinsic::nvvm_wmma_m8n8k128_load_b_b1_col_stride:
4240 case Intrinsic::nvvm_wmma_m8n8k32_load_a_s4_row:
4241 case Intrinsic::nvvm_wmma_m8n8k32_load_a_s4_row_stride:
4242 case Intrinsic::nvvm_wmma_m8n8k32_load_a_u4_row_stride:
4243 case Intrinsic::nvvm_wmma_m8n8k32_load_a_u4_row:
4244 case Intrinsic::nvvm_wmma_m8n8k32_load_b_s4_col:
4245 case Intrinsic::nvvm_wmma_m8n8k32_load_b_s4_col_stride:
4246 case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col_stride:
4247 case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col:
4248 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16:
4249 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16:
4250 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b4x16_p64:
4251 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b6x16_p32: {
4252 Info.opc = ISD::INTRINSIC_W_CHAIN;
4253 Info.memVT = MVT::i32;
4254 Info.ptrVal = I.getArgOperand(0);
4255 Info.offset = 0;
4256 Info.flags = MachineMemOperand::MOLoad;
4257 Info.align = Align(4);
4258 return true;
4259 }
4260
4261 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col:
4262 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row:
4263 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col_stride:
4264 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row_stride:
4265 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_col:
4266 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_row:
4267 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_col_stride:
4268 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_row_stride:
4269 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_col:
4270 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_row:
4271 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_col_stride:
4272 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_row_stride: {
4273 Info.opc = ISD::INTRINSIC_W_CHAIN;
4274 Info.memVT = MVT::v4f16;
4275 Info.ptrVal = I.getArgOperand(0);
4276 Info.offset = 0;
4277 Info.flags = MachineMemOperand::MOLoad;
4278 Info.align = Align(16);
4279 return true;
4280 }
4281
4282 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_col:
4283 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row:
4284 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_col_stride:
4285 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row_stride:
4286 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_col:
4287 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_row:
4288 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_col_stride:
4289 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_row_stride:
4290 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col:
4291 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row:
4292 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col_stride:
4293 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row_stride:
4294 case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_col:
4295 case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_row:
4296 case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_col_stride:
4297 case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_row_stride: {
4298 Info.opc = ISD::INTRINSIC_W_CHAIN;
4299 Info.memVT = MVT::v8f32;
4300 Info.ptrVal = I.getArgOperand(0);
4301 Info.offset = 0;
4302 Info.flags = MachineMemOperand::MOLoad;
4303 Info.align = Align(16);
4304 return true;
4305 }
4306
4307 case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_col:
4308 case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_col_stride:
4309 case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_row:
4310 case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_row_stride:
4311
4312 case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_col:
4313 case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_col_stride:
4314 case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_row:
4315 case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_row_stride:
4316
4317 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_col:
4318 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_col_stride:
4319 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_row:
4320 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_row_stride:
4321 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_col:
4322 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_col_stride:
4323 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_row:
4324 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_row_stride:
4325 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_col:
4326 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_col_stride:
4327 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_row:
4328 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_row_stride: {
4329 Info.opc = ISD::INTRINSIC_W_CHAIN;
4330 Info.memVT = MVT::v8i32;
4331 Info.ptrVal = I.getArgOperand(0);
4332 Info.offset = 0;
4333 Info.flags = MachineMemOperand::MOLoad;
4334 Info.align = Align(16);
4335 return true;
4336 }
4337
4338 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_col:
4339 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_col_stride:
4340 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_row:
4341 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_row_stride:
4342 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_col:
4343 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_col_stride:
4344 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row:
4345 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row_stride:
4346 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16:
4347 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16:
4348 case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8:
4349 case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b4x16_p64:
4350 case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b6x16_p32:
4351 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b4x16_p64:
4352 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b6x16_p32: {
4353 Info.opc = ISD::INTRINSIC_W_CHAIN;
4354 Info.memVT = MVT::v2i32;
4355 Info.ptrVal = I.getArgOperand(0);
4356 Info.offset = 0;
4357 Info.flags = MachineMemOperand::MOLoad;
4358 Info.align = Align(8);
4359 return true;
4360 }
4361
4362 case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_col:
4363 case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_col_stride:
4364 case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_row:
4365 case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_row_stride:
4366
4367 case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_col:
4368 case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_col_stride:
4369 case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_row:
4370 case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_row_stride: {
4371 Info.opc = ISD::INTRINSIC_W_CHAIN;
4372 Info.memVT = MVT::f64;
4373 Info.ptrVal = I.getArgOperand(0);
4374 Info.offset = 0;
4375 Info.flags = MachineMemOperand::MOLoad;
4376 Info.align = Align(8);
4377 return true;
4378 }
4379
4380 case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_col:
4381 case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_col_stride:
4382 case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_row:
4383 case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_row_stride: {
4384 Info.opc = ISD::INTRINSIC_W_CHAIN;
4385 Info.memVT = MVT::v2f64;
4386 Info.ptrVal = I.getArgOperand(0);
4387 Info.offset = 0;
4388 Info.flags = MachineMemOperand::MOLoad;
4389 Info.align = Align(16);
4390 return true;
4391 }
4392
4393 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col:
4394 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row:
4395 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col_stride:
4396 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row_stride:
4397 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_col:
4398 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_row:
4399 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_col_stride:
4400 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_row_stride:
4401 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_col:
4402 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_row:
4403 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_col_stride:
4404 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_row_stride: {
4405 Info.opc = ISD::INTRINSIC_VOID;
4406 Info.memVT = MVT::v4f16;
4407 Info.ptrVal = I.getArgOperand(0);
4408 Info.offset = 0;
4409 Info.flags = MachineMemOperand::MOStore;
4410 Info.align = Align(16);
4411 return true;
4412 }
4413
4414 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col:
4415 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row:
4416 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col_stride:
4417 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row_stride:
4418 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_col:
4419 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_row:
4420 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_col_stride:
4421 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_row_stride:
4422 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col:
4423 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row:
4424 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col_stride:
4425 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row_stride:
4426 case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_col:
4427 case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_row:
4428 case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_col_stride:
4429 case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_row_stride: {
4430 Info.opc = ISD::INTRINSIC_VOID;
4431 Info.memVT = MVT::v8f32;
4432 Info.ptrVal = I.getArgOperand(0);
4433 Info.offset = 0;
4434 Info.flags = MachineMemOperand::MOStore;
4435 Info.align = Align(16);
4436 return true;
4437 }
4438
4439 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_col:
4440 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_col_stride:
4441 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_row:
4442 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_row_stride:
4443 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_col:
4444 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_col_stride:
4445 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_row:
4446 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_row_stride:
4447 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_col:
4448 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_col_stride:
4449 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_row:
4450 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_row_stride: {
4451 Info.opc = ISD::INTRINSIC_VOID;
4452 Info.memVT = MVT::v8i32;
4453 Info.ptrVal = I.getArgOperand(0);
4454 Info.offset = 0;
4455 Info.flags = MachineMemOperand::MOStore;
4456 Info.align = Align(16);
4457 return true;
4458 }
4459
4460 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_col:
4461 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_col_stride:
4462 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_row:
4463 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_row_stride:
4464 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col:
4465 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col_stride:
4466 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row:
4467 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride:
4468 case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16:
4469 case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16:
4470 case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8: {
4471 Info.opc = ISD::INTRINSIC_VOID;
4472 Info.memVT = MVT::v2i32;
4473 Info.ptrVal = I.getArgOperand(0);
4474 Info.offset = 0;
4475 Info.flags = MachineMemOperand::MOStore;
4476 Info.align = Align(8);
4477 return true;
4478 }
4479
4480 case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_col:
4481 case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_col_stride:
4482 case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_row:
4483 case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_row_stride: {
4484 Info.opc = ISD::INTRINSIC_VOID;
4485 Info.memVT = MVT::v2f64;
4486 Info.ptrVal = I.getArgOperand(0);
4487 Info.offset = 0;
4488 Info.flags = MachineMemOperand::MOStore;
4489 Info.align = Align(16);
4490 return true;
4491 }
4492
4493 case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16:
4494 case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16:
4495 case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8: {
4496 Info.opc = ISD::INTRINSIC_VOID;
4497 Info.memVT = MVT::i32;
4498 Info.ptrVal = I.getArgOperand(0);
4499 Info.offset = 0;
4500 Info.flags = MachineMemOperand::MOStore;
4501 Info.align = Align(4);
4502 return true;
4503 }
4504
4505 case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16:
4506 case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16:
4507 case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8: {
4508 Info.opc = ISD::INTRINSIC_VOID;
4509 Info.memVT = MVT::v4i32;
4510 Info.ptrVal = I.getArgOperand(0);
4511 Info.offset = 0;
4512 Info.flags = MachineMemOperand::MOStore;
4513 Info.align = Align(16);
4514 return true;
4515 }
4516
4517 case Intrinsic::nvvm_atomic_add_gen_f_cta:
4518 case Intrinsic::nvvm_atomic_add_gen_f_sys:
4519 case Intrinsic::nvvm_atomic_add_gen_i_cta:
4520 case Intrinsic::nvvm_atomic_add_gen_i_sys:
4521 case Intrinsic::nvvm_atomic_and_gen_i_cta:
4522 case Intrinsic::nvvm_atomic_and_gen_i_sys:
4523 case Intrinsic::nvvm_atomic_cas_gen_i_cta:
4524 case Intrinsic::nvvm_atomic_cas_gen_i_sys:
4525 case Intrinsic::nvvm_atomic_dec_gen_i_cta:
4526 case Intrinsic::nvvm_atomic_dec_gen_i_sys:
4527 case Intrinsic::nvvm_atomic_inc_gen_i_cta:
4528 case Intrinsic::nvvm_atomic_inc_gen_i_sys:
4529 case Intrinsic::nvvm_atomic_max_gen_i_cta:
4530 case Intrinsic::nvvm_atomic_max_gen_i_sys:
4531 case Intrinsic::nvvm_atomic_min_gen_i_cta:
4532 case Intrinsic::nvvm_atomic_min_gen_i_sys:
4533 case Intrinsic::nvvm_atomic_or_gen_i_cta:
4534 case Intrinsic::nvvm_atomic_or_gen_i_sys:
4535 case Intrinsic::nvvm_atomic_exch_gen_i_cta:
4536 case Intrinsic::nvvm_atomic_exch_gen_i_sys:
4537 case Intrinsic::nvvm_atomic_xor_gen_i_cta:
4538 case Intrinsic::nvvm_atomic_xor_gen_i_sys: {
4539 auto &DL = I.getDataLayout();
4540 Info.opc = ISD::INTRINSIC_W_CHAIN;
4541 Info.memVT = getValueType(DL, I.getType());
4542 Info.ptrVal = I.getArgOperand(0);
4543 Info.offset = 0;
4545 Info.align.reset();
4546 return true;
4547 }
4548
4549 case Intrinsic::nvvm_prefetch_tensormap: {
4550 auto &DL = I.getDataLayout();
4551 Info.opc = ISD::INTRINSIC_VOID;
4552 Info.memVT = getPointerTy(DL);
4553 Info.ptrVal = I.getArgOperand(0);
4554 Info.offset = 0;
4555 Info.flags =
4557 Info.align.reset();
4558 return true;
4559 }
4560
4561 case Intrinsic::nvvm_ldu_global_i:
4562 case Intrinsic::nvvm_ldu_global_f:
4563 case Intrinsic::nvvm_ldu_global_p: {
4564 Info.opc = ISD::INTRINSIC_W_CHAIN;
4565 Info.memVT = getValueType(I.getDataLayout(), I.getType());
4566 Info.ptrVal = I.getArgOperand(0);
4567 Info.offset = 0;
4568 Info.flags = MachineMemOperand::MOLoad;
4569 Info.align = cast<ConstantInt>(I.getArgOperand(1))->getMaybeAlignValue();
4570
4571 return true;
4572 }
4573 case Intrinsic::nvvm_tex_1d_v4f32_s32:
4574 case Intrinsic::nvvm_tex_1d_v4f32_f32:
4575 case Intrinsic::nvvm_tex_1d_level_v4f32_f32:
4576 case Intrinsic::nvvm_tex_1d_grad_v4f32_f32:
4577 case Intrinsic::nvvm_tex_1d_array_v4f32_s32:
4578 case Intrinsic::nvvm_tex_1d_array_v4f32_f32:
4579 case Intrinsic::nvvm_tex_1d_array_level_v4f32_f32:
4580 case Intrinsic::nvvm_tex_1d_array_grad_v4f32_f32:
4581 case Intrinsic::nvvm_tex_2d_v4f32_s32:
4582 case Intrinsic::nvvm_tex_2d_v4f32_f32:
4583 case Intrinsic::nvvm_tex_2d_level_v4f32_f32:
4584 case Intrinsic::nvvm_tex_2d_grad_v4f32_f32:
4585 case Intrinsic::nvvm_tex_2d_array_v4f32_s32:
4586 case Intrinsic::nvvm_tex_2d_array_v4f32_f32:
4587 case Intrinsic::nvvm_tex_2d_array_level_v4f32_f32:
4588 case Intrinsic::nvvm_tex_2d_array_grad_v4f32_f32:
4589 case Intrinsic::nvvm_tex_3d_v4f32_s32:
4590 case Intrinsic::nvvm_tex_3d_v4f32_f32:
4591 case Intrinsic::nvvm_tex_3d_level_v4f32_f32:
4592 case Intrinsic::nvvm_tex_3d_grad_v4f32_f32:
4593 case Intrinsic::nvvm_tex_cube_v4f32_f32:
4594 case Intrinsic::nvvm_tex_cube_level_v4f32_f32:
4595 case Intrinsic::nvvm_tex_cube_array_v4f32_f32:
4596 case Intrinsic::nvvm_tex_cube_array_level_v4f32_f32:
4597 case Intrinsic::nvvm_tld4_r_2d_v4f32_f32:
4598 case Intrinsic::nvvm_tld4_g_2d_v4f32_f32:
4599 case Intrinsic::nvvm_tld4_b_2d_v4f32_f32:
4600 case Intrinsic::nvvm_tld4_a_2d_v4f32_f32:
4601 case Intrinsic::nvvm_tex_unified_1d_v4f32_s32:
4602 case Intrinsic::nvvm_tex_unified_1d_v4f32_f32:
4603 case Intrinsic::nvvm_tex_unified_1d_level_v4f32_f32:
4604 case Intrinsic::nvvm_tex_unified_1d_grad_v4f32_f32:
4605 case Intrinsic::nvvm_tex_unified_1d_array_v4f32_s32:
4606 case Intrinsic::nvvm_tex_unified_1d_array_v4f32_f32:
4607 case Intrinsic::nvvm_tex_unified_1d_array_level_v4f32_f32:
4608 case Intrinsic::nvvm_tex_unified_1d_array_grad_v4f32_f32:
4609 case Intrinsic::nvvm_tex_unified_2d_v4f32_s32:
4610 case Intrinsic::nvvm_tex_unified_2d_v4f32_f32:
4611 case Intrinsic::nvvm_tex_unified_2d_level_v4f32_f32:
4612 case Intrinsic::nvvm_tex_unified_2d_grad_v4f32_f32:
4613 case Intrinsic::nvvm_tex_unified_2d_array_v4f32_s32:
4614 case Intrinsic::nvvm_tex_unified_2d_array_v4f32_f32:
4615 case Intrinsic::nvvm_tex_unified_2d_array_level_v4f32_f32:
4616 case Intrinsic::nvvm_tex_unified_2d_array_grad_v4f32_f32:
4617 case Intrinsic::nvvm_tex_unified_3d_v4f32_s32:
4618 case Intrinsic::nvvm_tex_unified_3d_v4f32_f32:
4619 case Intrinsic::nvvm_tex_unified_3d_level_v4f32_f32:
4620 case Intrinsic::nvvm_tex_unified_3d_grad_v4f32_f32:
4621 case Intrinsic::nvvm_tex_unified_cube_v4f32_f32:
4622 case Intrinsic::nvvm_tex_unified_cube_level_v4f32_f32:
4623 case Intrinsic::nvvm_tex_unified_cube_array_v4f32_f32:
4624 case Intrinsic::nvvm_tex_unified_cube_array_level_v4f32_f32:
4625 case Intrinsic::nvvm_tex_unified_cube_grad_v4f32_f32:
4626 case Intrinsic::nvvm_tex_unified_cube_array_grad_v4f32_f32:
4627 case Intrinsic::nvvm_tld4_unified_r_2d_v4f32_f32:
4628 case Intrinsic::nvvm_tld4_unified_g_2d_v4f32_f32:
4629 case Intrinsic::nvvm_tld4_unified_b_2d_v4f32_f32:
4630 case Intrinsic::nvvm_tld4_unified_a_2d_v4f32_f32:
4631 Info.opc = ISD::INTRINSIC_W_CHAIN;
4632 Info.memVT = MVT::v4f32;
4633 Info.ptrVal = nullptr;
4634 Info.offset = 0;
4635 Info.flags = MachineMemOperand::MOLoad;
4636 Info.align = Align(16);
4637 return true;
4638
4639 case Intrinsic::nvvm_tex_1d_v4s32_s32:
4640 case Intrinsic::nvvm_tex_1d_v4s32_f32:
4641 case Intrinsic::nvvm_tex_1d_level_v4s32_f32:
4642 case Intrinsic::nvvm_tex_1d_grad_v4s32_f32:
4643 case Intrinsic::nvvm_tex_1d_array_v4s32_s32:
4644 case Intrinsic::nvvm_tex_1d_array_v4s32_f32:
4645 case Intrinsic::nvvm_tex_1d_array_level_v4s32_f32:
4646 case Intrinsic::nvvm_tex_1d_array_grad_v4s32_f32:
4647 case Intrinsic::nvvm_tex_2d_v4s32_s32:
4648 case Intrinsic::nvvm_tex_2d_v4s32_f32:
4649 case Intrinsic::nvvm_tex_2d_level_v4s32_f32:
4650 case Intrinsic::nvvm_tex_2d_grad_v4s32_f32:
4651 case Intrinsic::nvvm_tex_2d_array_v4s32_s32:
4652 case Intrinsic::nvvm_tex_2d_array_v4s32_f32:
4653 case Intrinsic::nvvm_tex_2d_array_level_v4s32_f32:
4654 case Intrinsic::nvvm_tex_2d_array_grad_v4s32_f32:
4655 case Intrinsic::nvvm_tex_3d_v4s32_s32:
4656 case Intrinsic::nvvm_tex_3d_v4s32_f32:
4657 case Intrinsic::nvvm_tex_3d_level_v4s32_f32:
4658 case Intrinsic::nvvm_tex_3d_grad_v4s32_f32:
4659 case Intrinsic::nvvm_tex_cube_v4s32_f32:
4660 case Intrinsic::nvvm_tex_cube_level_v4s32_f32:
4661 case Intrinsic::nvvm_tex_cube_array_v4s32_f32:
4662 case Intrinsic::nvvm_tex_cube_array_level_v4s32_f32:
4663 case Intrinsic::nvvm_tex_cube_v4u32_f32:
4664 case Intrinsic::nvvm_tex_cube_level_v4u32_f32:
4665 case Intrinsic::nvvm_tex_cube_array_v4u32_f32:
4666 case Intrinsic::nvvm_tex_cube_array_level_v4u32_f32:
4667 case Intrinsic::nvvm_tex_1d_v4u32_s32:
4668 case Intrinsic::nvvm_tex_1d_v4u32_f32:
4669 case Intrinsic::nvvm_tex_1d_level_v4u32_f32:
4670 case Intrinsic::nvvm_tex_1d_grad_v4u32_f32:
4671 case Intrinsic::nvvm_tex_1d_array_v4u32_s32:
4672 case Intrinsic::nvvm_tex_1d_array_v4u32_f32:
4673 case Intrinsic::nvvm_tex_1d_array_level_v4u32_f32:
4674 case Intrinsic::nvvm_tex_1d_array_grad_v4u32_f32:
4675 case Intrinsic::nvvm_tex_2d_v4u32_s32:
4676 case Intrinsic::nvvm_tex_2d_v4u32_f32:
4677 case Intrinsic::nvvm_tex_2d_level_v4u32_f32:
4678 case Intrinsic::nvvm_tex_2d_grad_v4u32_f32:
4679 case Intrinsic::nvvm_tex_2d_array_v4u32_s32:
4680 case Intrinsic::nvvm_tex_2d_array_v4u32_f32:
4681 case Intrinsic::nvvm_tex_2d_array_level_v4u32_f32:
4682 case Intrinsic::nvvm_tex_2d_array_grad_v4u32_f32:
4683 case Intrinsic::nvvm_tex_3d_v4u32_s32:
4684 case Intrinsic::nvvm_tex_3d_v4u32_f32:
4685 case Intrinsic::nvvm_tex_3d_level_v4u32_f32:
4686 case Intrinsic::nvvm_tex_3d_grad_v4u32_f32:
4687 case Intrinsic::nvvm_tld4_r_2d_v4s32_f32:
4688 case Intrinsic::nvvm_tld4_g_2d_v4s32_f32:
4689 case Intrinsic::nvvm_tld4_b_2d_v4s32_f32:
4690 case Intrinsic::nvvm_tld4_a_2d_v4s32_f32:
4691 case Intrinsic::nvvm_tld4_r_2d_v4u32_f32:
4692 case Intrinsic::nvvm_tld4_g_2d_v4u32_f32:
4693 case Intrinsic::nvvm_tld4_b_2d_v4u32_f32:
4694 case Intrinsic::nvvm_tld4_a_2d_v4u32_f32:
4695 case Intrinsic::nvvm_tex_unified_1d_v4s32_s32:
4696 case Intrinsic::nvvm_tex_unified_1d_v4s32_f32:
4697 case Intrinsic::nvvm_tex_unified_1d_level_v4s32_f32:
4698 case Intrinsic::nvvm_tex_unified_1d_grad_v4s32_f32:
4699 case Intrinsic::nvvm_tex_unified_1d_array_v4s32_s32:
4700 case Intrinsic::nvvm_tex_unified_1d_array_v4s32_f32:
4701 case Intrinsic::nvvm_tex_unified_1d_array_level_v4s32_f32:
4702 case Intrinsic::nvvm_tex_unified_1d_array_grad_v4s32_f32:
4703 case Intrinsic::nvvm_tex_unified_2d_v4s32_s32:
4704 case Intrinsic::nvvm_tex_unified_2d_v4s32_f32:
4705 case Intrinsic::nvvm_tex_unified_2d_level_v4s32_f32:
4706 case Intrinsic::nvvm_tex_unified_2d_grad_v4s32_f32:
4707 case Intrinsic::nvvm_tex_unified_2d_array_v4s32_s32:
4708 case Intrinsic::nvvm_tex_unified_2d_array_v4s32_f32:
4709 case Intrinsic::nvvm_tex_unified_2d_array_level_v4s32_f32:
4710 case Intrinsic::nvvm_tex_unified_2d_array_grad_v4s32_f32:
4711 case Intrinsic::nvvm_tex_unified_3d_v4s32_s32:
4712 case Intrinsic::nvvm_tex_unified_3d_v4s32_f32:
4713 case Intrinsic::nvvm_tex_unified_3d_level_v4s32_f32:
4714 case Intrinsic::nvvm_tex_unified_3d_grad_v4s32_f32:
4715 case Intrinsic::nvvm_tex_unified_1d_v4u32_s32:
4716 case Intrinsic::nvvm_tex_unified_1d_v4u32_f32:
4717 case Intrinsic::nvvm_tex_unified_1d_level_v4u32_f32:
4718 case Intrinsic::nvvm_tex_unified_1d_grad_v4u32_f32:
4719 case Intrinsic::nvvm_tex_unified_1d_array_v4u32_s32:
4720 case Intrinsic::nvvm_tex_unified_1d_array_v4u32_f32:
4721 case Intrinsic::nvvm_tex_unified_1d_array_level_v4u32_f32:
4722 case Intrinsic::nvvm_tex_unified_1d_array_grad_v4u32_f32:
4723 case Intrinsic::nvvm_tex_unified_2d_v4u32_s32:
4724 case Intrinsic::nvvm_tex_unified_2d_v4u32_f32:
4725 case Intrinsic::nvvm_tex_unified_2d_level_v4u32_f32:
4726 case Intrinsic::nvvm_tex_unified_2d_grad_v4u32_f32:
4727 case Intrinsic::nvvm_tex_unified_2d_array_v4u32_s32:
4728 case Intrinsic::nvvm_tex_unified_2d_array_v4u32_f32:
4729 case Intrinsic::nvvm_tex_unified_2d_array_level_v4u32_f32:
4730 case Intrinsic::nvvm_tex_unified_2d_array_grad_v4u32_f32:
4731 case Intrinsic::nvvm_tex_unified_3d_v4u32_s32:
4732 case Intrinsic::nvvm_tex_unified_3d_v4u32_f32:
4733 case Intrinsic::nvvm_tex_unified_3d_level_v4u32_f32:
4734 case Intrinsic::nvvm_tex_unified_3d_grad_v4u32_f32:
4735 case Intrinsic::nvvm_tex_unified_cube_v4s32_f32:
4736 case Intrinsic::nvvm_tex_unified_cube_level_v4s32_f32:
4737 case Intrinsic::nvvm_tex_unified_cube_array_v4s32_f32:
4738 case Intrinsic::nvvm_tex_unified_cube_array_level_v4s32_f32:
4739 case Intrinsic::nvvm_tex_unified_cube_v4u32_f32:
4740 case Intrinsic::nvvm_tex_unified_cube_level_v4u32_f32:
4741 case Intrinsic::nvvm_tex_unified_cube_array_v4u32_f32:
4742 case Intrinsic::nvvm_tex_unified_cube_array_level_v4u32_f32:
4743 case Intrinsic::nvvm_tex_unified_cube_grad_v4s32_f32:
4744 case Intrinsic::nvvm_tex_unified_cube_grad_v4u32_f32:
4745 case Intrinsic::nvvm_tex_unified_cube_array_grad_v4s32_f32:
4746 case Intrinsic::nvvm_tex_unified_cube_array_grad_v4u32_f32:
4747 case Intrinsic::nvvm_tld4_unified_r_2d_v4s32_f32:
4748 case Intrinsic::nvvm_tld4_unified_g_2d_v4s32_f32:
4749 case Intrinsic::nvvm_tld4_unified_b_2d_v4s32_f32:
4750 case Intrinsic::nvvm_tld4_unified_a_2d_v4s32_f32:
4751 case Intrinsic::nvvm_tld4_unified_r_2d_v4u32_f32:
4752 case Intrinsic::nvvm_tld4_unified_g_2d_v4u32_f32:
4753 case Intrinsic::nvvm_tld4_unified_b_2d_v4u32_f32:
4754 case Intrinsic::nvvm_tld4_unified_a_2d_v4u32_f32:
4755 Info.opc = ISD::INTRINSIC_W_CHAIN;
4756 Info.memVT = MVT::v4i32;
4757 Info.ptrVal = nullptr;
4758 Info.offset = 0;
4759 Info.flags = MachineMemOperand::MOLoad;
4760 Info.align = Align(16);
4761 return true;
4762
4763 case Intrinsic::nvvm_suld_1d_i8_clamp:
4764 case Intrinsic::nvvm_suld_1d_v2i8_clamp:
4765 case Intrinsic::nvvm_suld_1d_v4i8_clamp:
4766 case Intrinsic::nvvm_suld_1d_array_i8_clamp:
4767 case Intrinsic::nvvm_suld_1d_array_v2i8_clamp:
4768 case Intrinsic::nvvm_suld_1d_array_v4i8_clamp:
4769 case Intrinsic::nvvm_suld_2d_i8_clamp:
4770 case Intrinsic::nvvm_suld_2d_v2i8_clamp:
4771 case Intrinsic::nvvm_suld_2d_v4i8_clamp:
4772 case Intrinsic::nvvm_suld_2d_array_i8_clamp:
4773 case Intrinsic::nvvm_suld_2d_array_v2i8_clamp:
4774 case Intrinsic::nvvm_suld_2d_array_v4i8_clamp:
4775 case Intrinsic::nvvm_suld_3d_i8_clamp:
4776 case Intrinsic::nvvm_suld_3d_v2i8_clamp:
4777 case Intrinsic::nvvm_suld_3d_v4i8_clamp:
4778 case Intrinsic::nvvm_suld_1d_i8_trap:
4779 case Intrinsic::nvvm_suld_1d_v2i8_trap:
4780 case Intrinsic::nvvm_suld_1d_v4i8_trap:
4781 case Intrinsic::nvvm_suld_1d_array_i8_trap:
4782 case Intrinsic::nvvm_suld_1d_array_v2i8_trap:
4783 case Intrinsic::nvvm_suld_1d_array_v4i8_trap:
4784 case Intrinsic::nvvm_suld_2d_i8_trap:
4785 case Intrinsic::nvvm_suld_2d_v2i8_trap:
4786 case Intrinsic::nvvm_suld_2d_v4i8_trap:
4787 case Intrinsic::nvvm_suld_2d_array_i8_trap:
4788 case Intrinsic::nvvm_suld_2d_array_v2i8_trap:
4789 case Intrinsic::nvvm_suld_2d_array_v4i8_trap:
4790 case Intrinsic::nvvm_suld_3d_i8_trap:
4791 case Intrinsic::nvvm_suld_3d_v2i8_trap:
4792 case Intrinsic::nvvm_suld_3d_v4i8_trap:
4793 case Intrinsic::nvvm_suld_1d_i8_zero:
4794 case Intrinsic::nvvm_suld_1d_v2i8_zero:
4795 case Intrinsic::nvvm_suld_1d_v4i8_zero:
4796 case Intrinsic::nvvm_suld_1d_array_i8_zero:
4797 case Intrinsic::nvvm_suld_1d_array_v2i8_zero:
4798 case Intrinsic::nvvm_suld_1d_array_v4i8_zero:
4799 case Intrinsic::nvvm_suld_2d_i8_zero:
4800 case Intrinsic::nvvm_suld_2d_v2i8_zero:
4801 case Intrinsic::nvvm_suld_2d_v4i8_zero:
4802 case Intrinsic::nvvm_suld_2d_array_i8_zero:
4803 case Intrinsic::nvvm_suld_2d_array_v2i8_zero:
4804 case Intrinsic::nvvm_suld_2d_array_v4i8_zero:
4805 case Intrinsic::nvvm_suld_3d_i8_zero:
4806 case Intrinsic::nvvm_suld_3d_v2i8_zero:
4807 case Intrinsic::nvvm_suld_3d_v4i8_zero:
4808 Info.opc = ISD::INTRINSIC_W_CHAIN;
4809 Info.memVT = MVT::i8;
4810 Info.ptrVal = nullptr;
4811 Info.offset = 0;
4812 Info.flags = MachineMemOperand::MOLoad;
4813 Info.align = Align(16);
4814 return true;
4815
4816 case Intrinsic::nvvm_suld_1d_i16_clamp:
4817 case Intrinsic::nvvm_suld_1d_v2i16_clamp:
4818 case Intrinsic::nvvm_suld_1d_v4i16_clamp:
4819 case Intrinsic::nvvm_suld_1d_array_i16_clamp:
4820 case Intrinsic::nvvm_suld_1d_array_v2i16_clamp:
4821 case Intrinsic::nvvm_suld_1d_array_v4i16_clamp:
4822 case Intrinsic::nvvm_suld_2d_i16_clamp:
4823 case Intrinsic::nvvm_suld_2d_v2i16_clamp:
4824 case Intrinsic::nvvm_suld_2d_v4i16_clamp:
4825 case Intrinsic::nvvm_suld_2d_array_i16_clamp:
4826 case Intrinsic::nvvm_suld_2d_array_v2i16_clamp:
4827 case Intrinsic::nvvm_suld_2d_array_v4i16_clamp:
4828 case Intrinsic::nvvm_suld_3d_i16_clamp:
4829 case Intrinsic::nvvm_suld_3d_v2i16_clamp:
4830 case Intrinsic::nvvm_suld_3d_v4i16_clamp:
4831 case Intrinsic::nvvm_suld_1d_i16_trap:
4832 case Intrinsic::nvvm_suld_1d_v2i16_trap:
4833 case Intrinsic::nvvm_suld_1d_v4i16_trap:
4834 case Intrinsic::nvvm_suld_1d_array_i16_trap:
4835 case Intrinsic::nvvm_suld_1d_array_v2i16_trap:
4836 case Intrinsic::nvvm_suld_1d_array_v4i16_trap:
4837 case Intrinsic::nvvm_suld_2d_i16_trap:
4838 case Intrinsic::nvvm_suld_2d_v2i16_trap:
4839 case Intrinsic::nvvm_suld_2d_v4i16_trap:
4840 case Intrinsic::nvvm_suld_2d_array_i16_trap:
4841 case Intrinsic::nvvm_suld_2d_array_v2i16_trap:
4842 case Intrinsic::nvvm_suld_2d_array_v4i16_trap:
4843 case Intrinsic::nvvm_suld_3d_i16_trap:
4844 case Intrinsic::nvvm_suld_3d_v2i16_trap:
4845 case Intrinsic::nvvm_suld_3d_v4i16_trap:
4846 case Intrinsic::nvvm_suld_1d_i16_zero:
4847 case Intrinsic::nvvm_suld_1d_v2i16_zero:
4848 case Intrinsic::nvvm_suld_1d_v4i16_zero:
4849 case Intrinsic::nvvm_suld_1d_array_i16_zero:
4850 case Intrinsic::nvvm_suld_1d_array_v2i16_zero:
4851 case Intrinsic::nvvm_suld_1d_array_v4i16_zero:
4852 case Intrinsic::nvvm_suld_2d_i16_zero:
4853 case Intrinsic::nvvm_suld_2d_v2i16_zero:
4854 case Intrinsic::nvvm_suld_2d_v4i16_zero:
4855 case Intrinsic::nvvm_suld_2d_array_i16_zero:
4856 case Intrinsic::nvvm_suld_2d_array_v2i16_zero:
4857 case Intrinsic::nvvm_suld_2d_array_v4i16_zero:
4858 case Intrinsic::nvvm_suld_3d_i16_zero:
4859 case Intrinsic::nvvm_suld_3d_v2i16_zero:
4860 case Intrinsic::nvvm_suld_3d_v4i16_zero:
4861 Info.opc = ISD::INTRINSIC_W_CHAIN;
4862 Info.memVT = MVT::i16;
4863 Info.ptrVal = nullptr;
4864 Info.offset = 0;
4865 Info.flags = MachineMemOperand::MOLoad;
4866 Info.align = Align(16);
4867 return true;
4868
4869 case Intrinsic::nvvm_suld_1d_i32_clamp:
4870 case Intrinsic::nvvm_suld_1d_v2i32_clamp:
4871 case Intrinsic::nvvm_suld_1d_v4i32_clamp:
4872 case Intrinsic::nvvm_suld_1d_array_i32_clamp:
4873 case Intrinsic::nvvm_suld_1d_array_v2i32_clamp:
4874 case Intrinsic::nvvm_suld_1d_array_v4i32_clamp:
4875 case Intrinsic::nvvm_suld_2d_i32_clamp:
4876 case Intrinsic::nvvm_suld_2d_v2i32_clamp:
4877 case Intrinsic::nvvm_suld_2d_v4i32_clamp:
4878 case Intrinsic::nvvm_suld_2d_array_i32_clamp:
4879 case Intrinsic::nvvm_suld_2d_array_v2i32_clamp:
4880 case Intrinsic::nvvm_suld_2d_array_v4i32_clamp:
4881 case Intrinsic::nvvm_suld_3d_i32_clamp:
4882 case Intrinsic::nvvm_suld_3d_v2i32_clamp:
4883 case Intrinsic::nvvm_suld_3d_v4i32_clamp:
4884 case Intrinsic::nvvm_suld_1d_i32_trap:
4885 case Intrinsic::nvvm_suld_1d_v2i32_trap:
4886 case Intrinsic::nvvm_suld_1d_v4i32_trap:
4887 case Intrinsic::nvvm_suld_1d_array_i32_trap:
4888 case Intrinsic::nvvm_suld_1d_array_v2i32_trap:
4889 case Intrinsic::nvvm_suld_1d_array_v4i32_trap:
4890 case Intrinsic::nvvm_suld_2d_i32_trap:
4891 case Intrinsic::nvvm_suld_2d_v2i32_trap:
4892 case Intrinsic::nvvm_suld_2d_v4i32_trap:
4893 case Intrinsic::nvvm_suld_2d_array_i32_trap:
4894 case Intrinsic::nvvm_suld_2d_array_v2i32_trap:
4895 case Intrinsic::nvvm_suld_2d_array_v4i32_trap:
4896 case Intrinsic::nvvm_suld_3d_i32_trap:
4897 case Intrinsic::nvvm_suld_3d_v2i32_trap:
4898 case Intrinsic::nvvm_suld_3d_v4i32_trap:
4899 case Intrinsic::nvvm_suld_1d_i32_zero:
4900 case Intrinsic::nvvm_suld_1d_v2i32_zero:
4901 case Intrinsic::nvvm_suld_1d_v4i32_zero:
4902 case Intrinsic::nvvm_suld_1d_array_i32_zero:
4903 case Intrinsic::nvvm_suld_1d_array_v2i32_zero:
4904 case Intrinsic::nvvm_suld_1d_array_v4i32_zero:
4905 case Intrinsic::nvvm_suld_2d_i32_zero:
4906 case Intrinsic::nvvm_suld_2d_v2i32_zero:
4907 case Intrinsic::nvvm_suld_2d_v4i32_zero:
4908 case Intrinsic::nvvm_suld_2d_array_i32_zero:
4909 case Intrinsic::nvvm_suld_2d_array_v2i32_zero:
4910 case Intrinsic::nvvm_suld_2d_array_v4i32_zero:
4911 case Intrinsic::nvvm_suld_3d_i32_zero:
4912 case Intrinsic::nvvm_suld_3d_v2i32_zero:
4913 case Intrinsic::nvvm_suld_3d_v4i32_zero:
4914 Info.opc = ISD::INTRINSIC_W_CHAIN;
4915 Info.memVT = MVT::i32;
4916 Info.ptrVal = nullptr;
4917 Info.offset = 0;
4918 Info.flags = MachineMemOperand::MOLoad;
4919 Info.align = Align(16);
4920 return true;
4921
4922 case Intrinsic::nvvm_suld_1d_i64_clamp:
4923 case Intrinsic::nvvm_suld_1d_v2i64_clamp:
4924 case Intrinsic::nvvm_suld_1d_array_i64_clamp:
4925 case Intrinsic::nvvm_suld_1d_array_v2i64_clamp:
4926 case Intrinsic::nvvm_suld_2d_i64_clamp:
4927 case Intrinsic::nvvm_suld_2d_v2i64_clamp:
4928 case Intrinsic::nvvm_suld_2d_array_i64_clamp:
4929 case Intrinsic::nvvm_suld_2d_array_v2i64_clamp:
4930 case Intrinsic::nvvm_suld_3d_i64_clamp:
4931 case Intrinsic::nvvm_suld_3d_v2i64_clamp:
4932 case Intrinsic::nvvm_suld_1d_i64_trap:
4933 case Intrinsic::nvvm_suld_1d_v2i64_trap:
4934 case Intrinsic::nvvm_suld_1d_array_i64_trap:
4935 case Intrinsic::nvvm_suld_1d_array_v2i64_trap:
4936 case Intrinsic::nvvm_suld_2d_i64_trap:
4937 case Intrinsic::nvvm_suld_2d_v2i64_trap:
4938 case Intrinsic::nvvm_suld_2d_array_i64_trap:
4939 case Intrinsic::nvvm_suld_2d_array_v2i64_trap:
4940 case Intrinsic::nvvm_suld_3d_i64_trap:
4941 case Intrinsic::nvvm_suld_3d_v2i64_trap:
4942 case Intrinsic::nvvm_suld_1d_i64_zero:
4943 case Intrinsic::nvvm_suld_1d_v2i64_zero:
4944 case Intrinsic::nvvm_suld_1d_array_i64_zero:
4945 case Intrinsic::nvvm_suld_1d_array_v2i64_zero:
4946 case Intrinsic::nvvm_suld_2d_i64_zero:
4947 case Intrinsic::nvvm_suld_2d_v2i64_zero:
4948 case Intrinsic::nvvm_suld_2d_array_i64_zero:
4949 case Intrinsic::nvvm_suld_2d_array_v2i64_zero:
4950 case Intrinsic::nvvm_suld_3d_i64_zero:
4951 case Intrinsic::nvvm_suld_3d_v2i64_zero:
4952 Info.opc = ISD::INTRINSIC_W_CHAIN;
4953 Info.memVT = MVT::i64;
4954 Info.ptrVal = nullptr;
4955 Info.offset = 0;
4956 Info.flags = MachineMemOperand::MOLoad;
4957 Info.align = Align(16);
4958 return true;
4959
4960 case Intrinsic::nvvm_tcgen05_ld_16x64b_x1:
4961 case Intrinsic::nvvm_tcgen05_ld_32x32b_x1:
4962 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x1: {
4963 Info.opc = ISD::INTRINSIC_W_CHAIN;
4964 Info.memVT = MVT::v1i32;
4965 Info.ptrVal = I.getArgOperand(0);
4966 Info.offset = 0;
4967 Info.flags = MachineMemOperand::MOLoad;
4968 Info.align.reset();
4969 return true;
4970 }
4971
4972 case Intrinsic::nvvm_tcgen05_ld_16x64b_x2:
4973 case Intrinsic::nvvm_tcgen05_ld_16x128b_x1:
4974 case Intrinsic::nvvm_tcgen05_ld_32x32b_x2:
4975 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x2: {
4976 Info.opc = ISD::INTRINSIC_W_CHAIN;
4977 Info.memVT = MVT::v2i32;
4978 Info.ptrVal = I.getArgOperand(0);
4979 Info.offset = 0;
4980 Info.flags = MachineMemOperand::MOLoad;
4981 Info.align.reset();
4982 return true;
4983 }
4984
4985 case Intrinsic::nvvm_tcgen05_ld_16x64b_x4:
4986 case Intrinsic::nvvm_tcgen05_ld_16x128b_x2:
4987 case Intrinsic::nvvm_tcgen05_ld_32x32b_x4:
4988 case Intrinsic::nvvm_tcgen05_ld_16x256b_x1:
4989 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x4: {
4990 Info.opc = ISD::INTRINSIC_W_CHAIN;
4991 Info.memVT = MVT::v4i32;
4992 Info.ptrVal = I.getArgOperand(0);
4993 Info.offset = 0;
4994 Info.flags = MachineMemOperand::MOLoad;
4995 Info.align.reset();
4996 return true;
4997 }
4998
4999 case Intrinsic::nvvm_tcgen05_ld_16x64b_x8:
5000 case Intrinsic::nvvm_tcgen05_ld_16x128b_x4:
5001 case Intrinsic::nvvm_tcgen05_ld_16x256b_x2:
5002 case Intrinsic::nvvm_tcgen05_ld_32x32b_x8:
5003 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x8: {
5004 Info.opc = ISD::INTRINSIC_W_CHAIN;
5005 Info.memVT = MVT::v8i32;
5006 Info.ptrVal = I.getArgOperand(0);
5007 Info.offset = 0;
5008 Info.flags = MachineMemOperand::MOLoad;
5009 Info.align.reset();
5010 return true;
5011 }
5012
5013 case Intrinsic::nvvm_tcgen05_ld_16x64b_x16:
5014 case Intrinsic::nvvm_tcgen05_ld_16x128b_x8:
5015 case Intrinsic::nvvm_tcgen05_ld_16x256b_x4:
5016 case Intrinsic::nvvm_tcgen05_ld_32x32b_x16:
5017 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x16: {
5018 Info.opc = ISD::INTRINSIC_W_CHAIN;
5019 Info.memVT = MVT::v16i32;
5020 Info.ptrVal = I.getArgOperand(0);
5021 Info.offset = 0;
5022 Info.flags = MachineMemOperand::MOLoad;
5023 Info.align.reset();
5024 return true;
5025 }
5026
5027 case Intrinsic::nvvm_tcgen05_ld_16x64b_x32:
5028 case Intrinsic::nvvm_tcgen05_ld_16x128b_x16:
5029 case Intrinsic::nvvm_tcgen05_ld_16x256b_x8:
5030 case Intrinsic::nvvm_tcgen05_ld_32x32b_x32:
5031 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x32: {
5032 Info.opc = ISD::INTRINSIC_W_CHAIN;
5033 Info.memVT = MVT::v32i32;
5034 Info.ptrVal = I.getArgOperand(0);
5035 Info.offset = 0;
5036 Info.flags = MachineMemOperand::MOLoad;
5037 Info.align.reset();
5038 return true;
5039 }
5040
5041 case Intrinsic::nvvm_tcgen05_ld_16x64b_x64:
5042 case Intrinsic::nvvm_tcgen05_ld_16x128b_x32:
5043 case Intrinsic::nvvm_tcgen05_ld_16x256b_x16:
5044 case Intrinsic::nvvm_tcgen05_ld_32x32b_x64:
5045 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x64: {
5046 Info.opc = ISD::INTRINSIC_W_CHAIN;
5047 Info.memVT = MVT::v64i32;
5048 Info.ptrVal = I.getArgOperand(0);
5049 Info.offset = 0;
5050 Info.flags = MachineMemOperand::MOLoad;
5051 Info.align.reset();
5052 return true;
5053 }
5054
5055 case Intrinsic::nvvm_tcgen05_ld_16x64b_x128:
5056 case Intrinsic::nvvm_tcgen05_ld_16x128b_x64:
5057 case Intrinsic::nvvm_tcgen05_ld_16x256b_x32:
5058 case Intrinsic::nvvm_tcgen05_ld_32x32b_x128:
5059 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x128: {
5060 Info.opc = ISD::INTRINSIC_W_CHAIN;
5061 Info.memVT = MVT::v128i32;
5062 Info.ptrVal = I.getArgOperand(0);
5063 Info.offset = 0;
5064 Info.flags = MachineMemOperand::MOLoad;
5065 Info.align.reset();
5066 return true;
5067 }
5068
5069 case Intrinsic::nvvm_tcgen05_st_16x64b_x1:
5070 case Intrinsic::nvvm_tcgen05_st_32x32b_x1:
5071 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x1: {
5072 Info.opc = ISD::INTRINSIC_VOID;
5073 Info.memVT = MVT::i32;
5074 Info.ptrVal = I.getArgOperand(0);
5075 Info.offset = 0;
5076 Info.flags = MachineMemOperand::MOStore;
5077 Info.align.reset();
5078 return true;
5079 }
5080
5081 case Intrinsic::nvvm_tcgen05_st_16x64b_x2:
5082 case Intrinsic::nvvm_tcgen05_st_16x128b_x1:
5083 case Intrinsic::nvvm_tcgen05_st_32x32b_x2:
5084 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x2: {
5085 Info.opc = ISD::INTRINSIC_VOID;
5086 Info.memVT = MVT::v2i32;
5087 Info.ptrVal = I.getArgOperand(0);
5088 Info.offset = 0;
5089 Info.flags = MachineMemOperand::MOStore;
5090 Info.align.reset();
5091 return true;
5092 }
5093
5094 case Intrinsic::nvvm_tcgen05_st_16x64b_x4:
5095 case Intrinsic::nvvm_tcgen05_st_16x128b_x2:
5096 case Intrinsic::nvvm_tcgen05_st_16x256b_x1:
5097 case Intrinsic::nvvm_tcgen05_st_32x32b_x4:
5098 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x4: {
5099 Info.opc = ISD::INTRINSIC_VOID;
5100 Info.memVT = MVT::v4i32;
5101 Info.ptrVal = I.getArgOperand(0);
5102 Info.offset = 0;
5103 Info.flags = MachineMemOperand::MOStore;
5104 Info.align.reset();
5105 return true;
5106 }
5107
5108 case Intrinsic::nvvm_tcgen05_st_16x64b_x8:
5109 case Intrinsic::nvvm_tcgen05_st_16x128b_x4:
5110 case Intrinsic::nvvm_tcgen05_st_16x256b_x2:
5111 case Intrinsic::nvvm_tcgen05_st_32x32b_x8:
5112 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x8: {
5113 Info.opc = ISD::INTRINSIC_VOID;
5114 Info.memVT = MVT::v8i32;
5115 Info.ptrVal = I.getArgOperand(0);
5116 Info.offset = 0;
5117 Info.flags = MachineMemOperand::MOStore;
5118 Info.align.reset();
5119 return true;
5120 }
5121
5122 case Intrinsic::nvvm_tcgen05_st_16x64b_x16:
5123 case Intrinsic::nvvm_tcgen05_st_16x128b_x8:
5124 case Intrinsic::nvvm_tcgen05_st_16x256b_x4:
5125 case Intrinsic::nvvm_tcgen05_st_32x32b_x16:
5126 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x16: {
5127 Info.opc = ISD::INTRINSIC_VOID;
5128 Info.memVT = MVT::v16i32;
5129 Info.ptrVal = I.getArgOperand(0);
5130 Info.offset = 0;
5131 Info.flags = MachineMemOperand::MOStore;
5132 Info.align.reset();
5133 return true;
5134 }
5135
5136 case Intrinsic::nvvm_tcgen05_st_16x64b_x32:
5137 case Intrinsic::nvvm_tcgen05_st_16x128b_x16:
5138 case Intrinsic::nvvm_tcgen05_st_16x256b_x8:
5139 case Intrinsic::nvvm_tcgen05_st_32x32b_x32:
5140 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x32: {
5141 Info.opc = ISD::INTRINSIC_VOID;
5142 Info.memVT = MVT::v32i32;
5143 Info.ptrVal = I.getArgOperand(0);
5144 Info.offset = 0;
5145 Info.flags = MachineMemOperand::MOStore;
5146 Info.align.reset();
5147 return true;
5148 }
5149
5150 case Intrinsic::nvvm_tcgen05_st_16x64b_x64:
5151 case Intrinsic::nvvm_tcgen05_st_16x128b_x32:
5152 case Intrinsic::nvvm_tcgen05_st_16x256b_x16:
5153 case Intrinsic::nvvm_tcgen05_st_32x32b_x64:
5154 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x64: {
5155 Info.opc = ISD::INTRINSIC_VOID;
5156 Info.memVT = MVT::v64i32;
5157 Info.ptrVal = I.getArgOperand(0);
5158 Info.offset = 0;
5159 Info.flags = MachineMemOperand::MOStore;
5160 Info.align.reset();
5161 return true;
5162 }
5163
5164 case Intrinsic::nvvm_tcgen05_st_16x64b_x128:
5165 case Intrinsic::nvvm_tcgen05_st_16x128b_x64:
5166 case Intrinsic::nvvm_tcgen05_st_16x256b_x32:
5167 case Intrinsic::nvvm_tcgen05_st_32x32b_x128:
5168 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x128: {
5169 Info.opc = ISD::INTRINSIC_VOID;
5170 Info.memVT = MVT::v128i32;
5171 Info.ptrVal = I.getArgOperand(0);
5172 Info.offset = 0;
5173 Info.flags = MachineMemOperand::MOStore;
5174 Info.align.reset();
5175 return true;
5176 }
5177 case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1:
5178 case Intrinsic::nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1:
5179 case Intrinsic::nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg1:
5180 case Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg1:
5181 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg1:
5182 case Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1:
5183 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift:
5184 case Intrinsic::
5185 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift:
5186 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1:
5187 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1:
5188 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1_ashift:
5189 case Intrinsic::
5190 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1_ashift: {
5191 // We are reading and writing back to TMem
5192 Info.opc = ISD::INTRINSIC_VOID;
5193 Info.memVT = MVT::v4i32;
5194 Info.ptrVal = I.getArgOperand(0);
5195 Info.offset = 0;
5197 Info.align = Align(16);
5198 return true;
5199 }
5200
5201 case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2:
5202 case Intrinsic::nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2:
5203 case Intrinsic::nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg2:
5204 case Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg2:
5205 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg2:
5206 case Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2:
5207 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2:
5208 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2:
5209 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift:
5210 case Intrinsic::
5211 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift:
5212 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2_ashift:
5213 case Intrinsic::
5214 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2_ashift: {
5215 // We are reading and writing back to TMem
5216 Info.opc = ISD::INTRINSIC_VOID;
5217 Info.memVT = MVT::v8i32;
5218 Info.ptrVal = I.getArgOperand(0);
5219 Info.offset = 0;
5221 Info.align = Align(16);
5222 return true;
5223 }
5224 }
5225 return false;
5226}
5227
5228/// getFunctionParamOptimizedAlign - since function arguments are passed via
5229/// .param space, we may want to increase their alignment in a way that
5230/// ensures that we can effectively vectorize their loads & stores. We can
5231/// increase alignment only if the function has internal or has private
5232/// linkage as for other linkage types callers may already rely on default
5233/// alignment. To allow using 128-bit vectorized loads/stores, this function
5234/// ensures that alignment is 16 or greater.
5236 const Function *F, Type *ArgTy, const DataLayout &DL) const {
5237 // Capping the alignment to 128 bytes as that is the maximum alignment
5238 // supported by PTX.
5239 const Align ABITypeAlign = std::min(Align(128), DL.getABITypeAlign(ArgTy));
5240
5241 // If a function has linkage different from internal or private, we
5242 // must use default ABI alignment as external users rely on it. Same
5243 // for a function that may be called from a function pointer.
5244 if (!F || !F->hasLocalLinkage() ||
5245 F->hasAddressTaken(/*Users=*/nullptr,
5246 /*IgnoreCallbackUses=*/false,
5247 /*IgnoreAssumeLikeCalls=*/true,
5248 /*IgnoreLLVMUsed=*/true))
5249 return ABITypeAlign;
5250
5251 assert(!isKernelFunction(*F) && "Expect kernels to have non-local linkage");
5252 return std::max(Align(16), ABITypeAlign);
5253}
5254
5255/// Helper for computing alignment of a device function byval parameter.
5257 const Function *F, Type *ArgTy, Align InitialAlign,
5258 const DataLayout &DL) const {
5259 Align ArgAlign = InitialAlign;
5260 // Try to increase alignment to enhance vectorization options.
5261 if (F)
5262 ArgAlign = std::max(ArgAlign, getFunctionParamOptimizedAlign(F, ArgTy, DL));
5263
5264 // Old ptx versions have a bug. When PTX code takes address of
5265 // byval parameter with alignment < 4, ptxas generates code to
5266 // spill argument into memory. Alas on sm_50+ ptxas generates
5267 // SASS code that fails with misaligned access. To work around
5268 // the problem, make sure that we align byval parameters by at
5269 // least 4. This bug seems to be fixed at least starting from
5270 // ptxas > 9.0.
5271 // TODO: remove this after verifying the bug is not reproduced
5272 // on non-deprecated ptxas versions.
5274 ArgAlign = std::max(ArgAlign, Align(4));
5275
5276 return ArgAlign;
5277}
5278
5279// Helper for getting a function parameter name. Name is composed from
5280// its index and the function name. Negative index corresponds to special
5281// parameter (unsized array) used for passing variable arguments.
5283 int Idx) const {
5284 std::string ParamName;
5285 raw_string_ostream ParamStr(ParamName);
5286
5287 ParamStr << getTargetMachine().getSymbol(F)->getName();
5288 if (Idx < 0)
5289 ParamStr << "_vararg";
5290 else
5291 ParamStr << "_param_" << Idx;
5292
5293 return ParamName;
5294}
5295
5296/// isLegalAddressingMode - Return true if the addressing mode represented
5297/// by AM is legal for this target, for a load/store of the specified type.
5298/// Used to guide target specific optimizations, like loop strength reduction
5299/// (LoopStrengthReduce.cpp) and memory optimization for address mode
5300/// (CodeGenPrepare.cpp)
5302 const AddrMode &AM, Type *Ty,
5303 unsigned AS, Instruction *I) const {
5304 // AddrMode - This represents an addressing mode of:
5305 // BaseGV + BaseOffs + BaseReg + Scale*ScaleReg
5306 //
5307 // The legal address modes are
5308 // - [avar]
5309 // - [areg]
5310 // - [areg+immoff]
5311 // - [immAddr]
5312
5313 // immoff must fit in a signed 32-bit int
5314 if (!APInt(64, AM.BaseOffs).isSignedIntN(32))
5315 return false;
5316
5317 if (AM.BaseGV)
5318 return !AM.BaseOffs && !AM.HasBaseReg && !AM.Scale;
5319
5320 switch (AM.Scale) {
5321 case 0: // "r", "r+i" or "i" is allowed
5322 break;
5323 case 1:
5324 if (AM.HasBaseReg) // "r+r+i" or "r+r" is not allowed.
5325 return false;
5326 // Otherwise we have r+i.
5327 break;
5328 default:
5329 // No scale > 1 is allowed
5330 return false;
5331 }
5332 return true;
5333}
5334
5335//===----------------------------------------------------------------------===//
5336// NVPTX Inline Assembly Support
5337//===----------------------------------------------------------------------===//
5338
5339/// getConstraintType - Given a constraint letter, return the type of
5340/// constraint it is for this target.
5343 if (Constraint.size() == 1) {
5344 switch (Constraint[0]) {
5345 default:
5346 break;
5347 case 'b':
5348 case 'r':
5349 case 'h':
5350 case 'c':
5351 case 'l':
5352 case 'f':
5353 case 'd':
5354 case 'q':
5355 case '0':
5356 case 'N':
5357 return C_RegisterClass;
5358 }
5359 }
5360 return TargetLowering::getConstraintType(Constraint);
5361}
5362
5363std::pair<unsigned, const TargetRegisterClass *>
5365 StringRef Constraint,
5366 MVT VT) const {
5367 if (Constraint.size() == 1) {
5368 switch (Constraint[0]) {
5369 case 'b':
5370 return std::make_pair(0U, &NVPTX::B1RegClass);
5371 case 'c':
5372 case 'h':
5373 return std::make_pair(0U, &NVPTX::B16RegClass);
5374 case 'r':
5375 case 'f':
5376 return std::make_pair(0U, &NVPTX::B32RegClass);
5377 case 'l':
5378 case 'N':
5379 case 'd':
5380 return std::make_pair(0U, &NVPTX::B64RegClass);
5381 case 'q': {
5382 if (STI.getSmVersion() < 70)
5383 report_fatal_error("Inline asm with 128 bit operands is only "
5384 "supported for sm_70 and higher!");
5385 return std::make_pair(0U, &NVPTX::B128RegClass);
5386 }
5387 }
5388 }
5389 return TargetLowering::getRegForInlineAsmConstraint(TRI, Constraint, VT);
5390}
5391
5392//===----------------------------------------------------------------------===//
5393// NVPTX DAG Combining
5394//===----------------------------------------------------------------------===//
5395
5397 CodeGenOptLevel OptLevel) const {
5398 // Always honor command-line argument
5399 if (FMAContractLevelOpt.getNumOccurrences() > 0)
5400 return FMAContractLevelOpt > 0;
5401
5402 // Do not contract if we're not optimizing the code.
5403 if (OptLevel == CodeGenOptLevel::None)
5404 return false;
5405
5406 // Honor TargetOptions flags that explicitly say fusion is okay.
5408 return true;
5409
5410 return false;
5411}
5412
5413static bool isConstZero(const SDValue &Operand) {
5414 const auto *Const = dyn_cast<ConstantSDNode>(Operand);
5415 return Const && Const->getZExtValue() == 0;
5416}
5417
5418/// PerformADDCombineWithOperands - Try DAG combinations for an ADD with
5419/// operands N0 and N1. This is a helper for PerformADDCombine that is
5420/// called with the default operands, and if that fails, with commuted
5421/// operands.
5422static SDValue
5425 EVT VT = N0.getValueType();
5426
5427 // Since integer multiply-add costs the same as integer multiply
5428 // but is more costly than integer add, do the fusion only when
5429 // the mul is only used in the add.
5430 // TODO: this may not be true for later architectures, consider relaxing this
5431 if (!N0.getNode()->hasOneUse())
5432 return SDValue();
5433
5434 // fold (add (select cond, 0, (mul a, b)), c)
5435 // -> (select cond, c, (add (mul a, b), c))
5436 //
5437 if (N0.getOpcode() == ISD::SELECT) {
5438 unsigned ZeroOpNum;
5439 if (isConstZero(N0->getOperand(1)))
5440 ZeroOpNum = 1;
5441 else if (isConstZero(N0->getOperand(2)))
5442 ZeroOpNum = 2;
5443 else
5444 return SDValue();
5445
5446 SDValue M = N0->getOperand((ZeroOpNum == 1) ? 2 : 1);
5447 if (M->getOpcode() != ISD::MUL || !M.getNode()->hasOneUse())
5448 return SDValue();
5449
5450 SDLoc DL(N);
5451 SDValue Mul =
5452 DCI.DAG.getNode(ISD::MUL, DL, VT, M->getOperand(0), M->getOperand(1));
5453 SDValue MAD = DCI.DAG.getNode(ISD::ADD, DL, VT, Mul, N1);
5454 return DCI.DAG.getSelect(SDLoc(N), VT, N0->getOperand(0),
5455 ((ZeroOpNum == 1) ? N1 : MAD),
5456 ((ZeroOpNum == 1) ? MAD : N1));
5457 }
5458
5459 return SDValue();
5460}
5461
5462static SDValue
5465 CodeGenOptLevel OptLevel) {
5466 EVT VT = N0.getValueType();
5467 if (N0.getOpcode() == ISD::FMUL) {
5468 const auto *TLI = static_cast<const NVPTXTargetLowering *>(
5469 &DCI.DAG.getTargetLoweringInfo());
5470 if (!(TLI->allowFMA(DCI.DAG.getMachineFunction(), OptLevel) ||
5471 (N->getFlags().hasAllowContract() &&
5472 N0->getFlags().hasAllowContract())))
5473 return SDValue();
5474
5475 // For floating point:
5476 // Do the fusion only when the mul has less than 5 uses and all
5477 // are add.
5478 // The heuristic is that if a use is not an add, then that use
5479 // cannot be fused into fma, therefore mul is still needed anyway.
5480 // If there are more than 4 uses, even if they are all add, fusing
5481 // them will increase register pressue.
5482 //
5483 int numUses = 0;
5484 int nonAddCount = 0;
5485 for (const SDNode *User : N0.getNode()->users()) {
5486 numUses++;
5487 if (User->getOpcode() != ISD::FADD)
5488 ++nonAddCount;
5489 if (numUses >= 5)
5490 return SDValue();
5491 }
5492 if (nonAddCount) {
5493 int orderNo = N->getIROrder();
5494 int orderNo2 = N0.getNode()->getIROrder();
5495 // simple heuristics here for considering potential register
5496 // pressure, the logics here is that the differnce are used
5497 // to measure the distance between def and use, the longer distance
5498 // more likely cause register pressure.
5499 if (orderNo - orderNo2 < 500)
5500 return SDValue();
5501
5502 // Now, check if at least one of the FMUL's operands is live beyond the
5503 // node N, which guarantees that the FMA will not increase register
5504 // pressure at node N.
5505 bool opIsLive = false;
5506 const SDNode *left = N0.getOperand(0).getNode();
5507 const SDNode *right = N0.getOperand(1).getNode();
5508
5509 if (isa<ConstantSDNode>(left) || isa<ConstantSDNode>(right))
5510 opIsLive = true;
5511
5512 if (!opIsLive)
5513 for (const SDNode *User : left->users()) {
5514 int orderNo3 = User->getIROrder();
5515 if (orderNo3 > orderNo) {
5516 opIsLive = true;
5517 break;
5518 }
5519 }
5520
5521 if (!opIsLive)
5522 for (const SDNode *User : right->users()) {
5523 int orderNo3 = User->getIROrder();
5524 if (orderNo3 > orderNo) {
5525 opIsLive = true;
5526 break;
5527 }
5528 }
5529
5530 if (!opIsLive)
5531 return SDValue();
5532 }
5533
5534 return DCI.DAG.getNode(ISD::FMA, SDLoc(N), VT, N0.getOperand(0),
5535 N0.getOperand(1), N1);
5536 }
5537
5538 return SDValue();
5539}
5540
5541/// Fold unpacking movs into a load by increasing the number of return values.
5542///
5543/// ex:
5544/// L: v2f16,ch = load <p>
5545/// a: f16 = extractelt L:0, 0
5546/// b: f16 = extractelt L:0, 1
5547/// use(a, b)
5548///
5549/// ...is turned into...
5550///
5551/// L: f16,f16,ch = LoadV2 <p>
5552/// use(L:0, L:1)
5553static SDValue
5555 // Don't run this optimization before the legalizer
5556 if (!DCI.isAfterLegalizeDAG())
5557 return SDValue();
5558
5559 EVT ElementVT = N->getValueType(0);
5560 // Avoid non-packed types and v4i8
5561 if (!NVPTX::isPackedVectorTy(ElementVT) || ElementVT == MVT::v4i8)
5562 return SDValue();
5563
5564 // Check whether all outputs are either used by an extractelt or are
5565 // glue/chain nodes
5566 if (!all_of(N->uses(), [&](SDUse &U) {
5567 // Skip glue, chain nodes
5568 if (U.getValueType() == MVT::Glue || U.getValueType() == MVT::Other)
5569 return true;
5570 if (U.getUser()->getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
5571 if (N->getOpcode() != ISD::LOAD)
5572 return true;
5573 // Since this is an ISD::LOAD, check all extractelts are used. If
5574 // any are not used, we don't want to defeat another optimization that
5575 // will narrow the load.
5576 //
5577 // For example:
5578 //
5579 // L: v2f16,ch = load <p>
5580 // e0: f16 = extractelt L:0, 0
5581 // e1: f16 = extractelt L:0, 1 <-- unused
5582 // store e0
5583 //
5584 // Can be optimized by DAGCombiner to:
5585 //
5586 // L: f16,ch = load <p>
5587 // store L:0
5588 return !U.getUser()->use_empty();
5589 }
5590
5591 // Otherwise, this use prevents us from splitting a value.
5592 return false;
5593 }))
5594 return SDValue();
5595
5596 auto *LD = cast<MemSDNode>(N);
5597 SDLoc DL(LD);
5598
5599 // the new opcode after we double the number of operands
5600 unsigned Opcode;
5601 SmallVector<SDValue> Operands(LD->ops());
5602 unsigned OldNumOutputs; // non-glue, non-chain outputs
5603 switch (LD->getOpcode()) {
5604 case ISD::LOAD:
5605 OldNumOutputs = 1;
5606 // Any packed type is legal, so the legalizer will not have lowered
5607 // ISD::LOAD -> NVPTXISD::Load (unless it's under-aligned). We have to do it
5608 // here.
5609 Opcode = NVPTXISD::LoadV2;
5610 // append a "full" used bytes mask operand right before the extension type
5611 // operand, signifying that all bytes are used.
5612 Operands.push_back(DCI.DAG.getConstant(UINT32_MAX, DL, MVT::i32));
5613 Operands.push_back(DCI.DAG.getIntPtrConstant(
5614 cast<LoadSDNode>(LD)->getExtensionType(), DL));
5615 break;
5616 case NVPTXISD::LoadV2:
5617 OldNumOutputs = 2;
5618 Opcode = NVPTXISD::LoadV4;
5619 break;
5620 case NVPTXISD::LoadV4:
5621 // V8 is only supported for f32/i32. Don't forget, we're not changing the
5622 // load size here. This is already a 256-bit load.
5623 if (ElementVT != MVT::v2f32 && ElementVT != MVT::v2i32)
5624 return SDValue();
5625 OldNumOutputs = 4;
5626 Opcode = NVPTXISD::LoadV8;
5627 break;
5628 case NVPTXISD::LoadV8:
5629 // PTX doesn't support the next doubling of outputs
5630 return SDValue();
5631 }
5632
5633 // the non-glue, non-chain outputs in the new load
5634 const unsigned NewNumOutputs = OldNumOutputs * 2;
5635 SmallVector<EVT> NewVTs(NewNumOutputs, ElementVT.getVectorElementType());
5636 // add remaining chain and glue values
5637 NewVTs.append(LD->value_begin() + OldNumOutputs, LD->value_end());
5638
5639 // Create the new load
5640 SDValue NewLoad = DCI.DAG.getMemIntrinsicNode(
5641 Opcode, DL, DCI.DAG.getVTList(NewVTs), Operands, LD->getMemoryVT(),
5642 LD->getMemOperand());
5643
5644 // Now we use a combination of BUILD_VECTORs and a MERGE_VALUES node to keep
5645 // the outputs the same. These nodes will be optimized away in later
5646 // DAGCombiner iterations.
5648 for (unsigned I : seq(OldNumOutputs))
5649 Results.push_back(DCI.DAG.getBuildVector(
5650 ElementVT, DL, {NewLoad.getValue(I * 2), NewLoad.getValue(I * 2 + 1)}));
5651 // Add remaining chain and glue nodes
5652 for (unsigned I : seq(NewLoad->getNumValues() - NewNumOutputs))
5653 Results.push_back(NewLoad.getValue(NewNumOutputs + I));
5654
5655 return DCI.DAG.getMergeValues(Results, DL);
5656}
5657
5658/// Fold packing movs into a store.
5659///
5660/// ex:
5661/// v1: v2f16 = BUILD_VECTOR a:f16, b:f16
5662/// v2: v2f16 = BUILD_VECTOR c:f16, d:f16
5663/// StoreV2 v1, v2
5664///
5665/// ...is turned into...
5666///
5667/// StoreV4 a, b, c, d
5670 unsigned Front, unsigned Back) {
5671 // We want to run this as late as possible since other optimizations may
5672 // eliminate the BUILD_VECTORs.
5673 if (!DCI.isAfterLegalizeDAG())
5674 return SDValue();
5675
5676 // Get the type of the operands being stored.
5677 EVT ElementVT = N->getOperand(Front).getValueType();
5678
5679 // Avoid non-packed types and v4i8
5680 if (!NVPTX::isPackedVectorTy(ElementVT) || ElementVT == MVT::v4i8)
5681 return SDValue();
5682
5683 auto *ST = cast<MemSDNode>(N);
5684
5685 // The new opcode after we double the number of operands.
5686 unsigned Opcode;
5687 switch (N->getOpcode()) {
5688 case ISD::STORE:
5689 // Any packed type is legal, so the legalizer will not have lowered
5690 // ISD::STORE -> NVPTXISD::Store (unless it's under-aligned). We have to do
5691 // it here.
5692 Opcode = NVPTXISD::StoreV2;
5693 break;
5694 case NVPTXISD::StoreV2:
5695 Opcode = NVPTXISD::StoreV4;
5696 break;
5697 case NVPTXISD::StoreV4:
5698 // V8 is only supported for f32/i32. Don't forget, we're not changing the
5699 // store size here. This is already a 256-bit store.
5700 if (ElementVT != MVT::v2f32 && ElementVT != MVT::v2i32)
5701 return SDValue();
5702 Opcode = NVPTXISD::StoreV8;
5703 break;
5704 case NVPTXISD::StoreV8:
5705 // PTX doesn't support the next doubling of operands
5706 return SDValue();
5707 default:
5708 llvm_unreachable("Unhandled store opcode");
5709 }
5710
5711 // Scan the operands and if they're all BUILD_VECTORs, we'll have gathered
5712 // their elements.
5713 SmallVector<SDValue, 4> Operands(N->ops().take_front(Front));
5714 for (SDValue BV : N->ops().drop_front(Front).drop_back(Back)) {
5715 if (BV.getOpcode() != ISD::BUILD_VECTOR)
5716 return SDValue();
5717
5718 // If the operand has multiple uses, this optimization can increase register
5719 // pressure.
5720 if (!BV.hasOneUse())
5721 return SDValue();
5722
5723 // DAGCombiner visits nodes bottom-up. Check the BUILD_VECTOR operands for
5724 // any signs they may be folded by some other pattern or rule.
5725 for (SDValue Op : BV->ops()) {
5726 // Peek through bitcasts
5727 if (Op.getOpcode() == ISD::BITCAST)
5728 Op = Op.getOperand(0);
5729
5730 // This may be folded into a PRMT.
5731 if (Op.getValueType() == MVT::i16 && Op.getOpcode() == ISD::TRUNCATE &&
5732 Op->getOperand(0).getValueType() == MVT::i32)
5733 return SDValue();
5734
5735 // This may be folded into cvt.bf16x2
5736 if (Op.getOpcode() == ISD::FP_ROUND)
5737 return SDValue();
5738 }
5739 Operands.append({BV.getOperand(0), BV.getOperand(1)});
5740 }
5741 Operands.append(N->op_end() - Back, N->op_end());
5742
5743 // Now we replace the store
5744 return DCI.DAG.getMemIntrinsicNode(Opcode, SDLoc(N), N->getVTList(), Operands,
5745 ST->getMemoryVT(), ST->getMemOperand());
5746}
5747
5749 const NVPTXSubtarget &STI) {
5750
5751 if (DCI.isBeforeLegalize() && N->getOpcode() == ISD::STORE) {
5752 // Here is our chance to custom lower a store with a non-simple type.
5753 // Unfortunately, we can't do this in the legalizer because there is no
5754 // way to setOperationAction for an non-simple type.
5756 if (!ST->getValue().getValueType().isSimple())
5757 return lowerSTOREVector(SDValue(ST, 0), DCI.DAG, STI);
5758 }
5759
5760 return combinePackingMovIntoStore(N, DCI, 1, 2);
5761}
5762
5764 const NVPTXSubtarget &STI) {
5765 if (DCI.isBeforeLegalize() && N->getOpcode() == ISD::LOAD) {
5766 // Here is our chance to custom lower a load with a non-simple type.
5767 // Unfortunately, we can't do this in the legalizer because there is no
5768 // way to setOperationAction for an non-simple type.
5769 if (!N->getValueType(0).isSimple())
5770 return lowerLoadVector(N, DCI.DAG, STI);
5771 }
5772
5773 return combineUnpackingMovIntoLoad(N, DCI);
5774}
5775
5776/// PerformADDCombine - Target-specific dag combine xforms for ISD::ADD.
5777///
5780 CodeGenOptLevel OptLevel) {
5781 if (OptLevel == CodeGenOptLevel::None)
5782 return SDValue();
5783
5784 SDValue N0 = N->getOperand(0);
5785 SDValue N1 = N->getOperand(1);
5786
5787 // Skip non-integer, non-scalar case
5788 EVT VT = N0.getValueType();
5789 if (VT.isVector() || VT != MVT::i32)
5790 return SDValue();
5791
5792 // First try with the default operand order.
5793 if (SDValue Result = PerformADDCombineWithOperands(N, N0, N1, DCI))
5794 return Result;
5795
5796 // If that didn't work, try again with the operands commuted.
5797 return PerformADDCombineWithOperands(N, N1, N0, DCI);
5798}
5799
5800/// PerformFADDCombine - Target-specific dag combine xforms for ISD::FADD.
5801///
5804 CodeGenOptLevel OptLevel) {
5805 SDValue N0 = N->getOperand(0);
5806 SDValue N1 = N->getOperand(1);
5807
5808 EVT VT = N0.getValueType();
5809 if (VT.isVector() || !(VT == MVT::f32 || VT == MVT::f64))
5810 return SDValue();
5811
5812 // First try with the default operand order.
5813 if (SDValue Result = PerformFADDCombineWithOperands(N, N0, N1, DCI, OptLevel))
5814 return Result;
5815
5816 // If that didn't work, try again with the operands commuted.
5817 return PerformFADDCombineWithOperands(N, N1, N0, DCI, OptLevel);
5818}
5819
5820/// Get 3-input version of a 2-input min/max opcode
5821static unsigned getMinMax3Opcode(unsigned MinMax2Opcode) {
5822 switch (MinMax2Opcode) {
5823 case ISD::FMAXNUM:
5824 case ISD::FMAXIMUMNUM:
5825 return NVPTXISD::FMAXNUM3;
5826 case ISD::FMINNUM:
5827 case ISD::FMINIMUMNUM:
5828 return NVPTXISD::FMINNUM3;
5829 case ISD::FMAXIMUM:
5830 return NVPTXISD::FMAXIMUM3;
5831 case ISD::FMINIMUM:
5832 return NVPTXISD::FMINIMUM3;
5833 default:
5834 llvm_unreachable("Invalid 2-input min/max opcode");
5835 }
5836}
5837
5838/// PerformFMinMaxCombine - Combine (fmaxnum (fmaxnum a, b), c) into
5839/// (fmaxnum3 a, b, c). Also covers other llvm min/max intrinsics.
5842 unsigned PTXVersion, unsigned SmVersion) {
5843
5844 // 3-input min/max requires PTX 8.8+ and SM_100+, and only supports f32s
5845 EVT VT = N->getValueType(0);
5846 if (VT != MVT::f32 || PTXVersion < 88 || SmVersion < 100)
5847 return SDValue();
5848
5849 SDValue Op0 = N->getOperand(0);
5850 SDValue Op1 = N->getOperand(1);
5851 unsigned MinMaxOp2 = N->getOpcode();
5852 unsigned MinMaxOp3 = getMinMax3Opcode(MinMaxOp2);
5853
5854 if (Op0.getOpcode() == MinMaxOp2 && Op0.hasOneUse()) {
5855 // (maxnum (maxnum a, b), c) -> (maxnum3 a, b, c)
5856 SDValue A = Op0.getOperand(0);
5857 SDValue B = Op0.getOperand(1);
5858 SDValue C = Op1;
5859 return DCI.DAG.getNode(MinMaxOp3, SDLoc(N), VT, A, B, C, N->getFlags());
5860 } else if (Op1.getOpcode() == MinMaxOp2 && Op1.hasOneUse()) {
5861 // (maxnum a, (maxnum b, c)) -> (maxnum3 a, b, c)
5862 SDValue A = Op0;
5863 SDValue B = Op1.getOperand(0);
5864 SDValue C = Op1.getOperand(1);
5865 return DCI.DAG.getNode(MinMaxOp3, SDLoc(N), VT, A, B, C, N->getFlags());
5866 }
5867 return SDValue();
5868}
5869
5872 CodeGenOptLevel OptLevel) {
5873 assert(N->getOpcode() == ISD::SREM || N->getOpcode() == ISD::UREM);
5874
5875 // Don't do anything at less than -O2.
5876 if (OptLevel < CodeGenOptLevel::Default)
5877 return SDValue();
5878
5879 SelectionDAG &DAG = DCI.DAG;
5880 SDLoc DL(N);
5881 EVT VT = N->getValueType(0);
5882 bool IsSigned = N->getOpcode() == ISD::SREM;
5883 unsigned DivOpc = IsSigned ? ISD::SDIV : ISD::UDIV;
5884
5885 const SDValue &Num = N->getOperand(0);
5886 const SDValue &Den = N->getOperand(1);
5887
5888 for (const SDNode *U : Num->users()) {
5889 if (U->getOpcode() == DivOpc && U->getOperand(0) == Num &&
5890 U->getOperand(1) == Den) {
5891 // Num % Den -> Num - (Num / Den) * Den
5892 return DAG.getNode(ISD::SUB, DL, VT, Num,
5893 DAG.getNode(ISD::MUL, DL, VT,
5894 DAG.getNode(DivOpc, DL, VT, Num, Den),
5895 Den));
5896 }
5897 }
5898 return SDValue();
5899}
5900
5901// (sign_extend|zero_extend (mul|shl) x, y) -> (mul.wide x, y)
5903 CodeGenOptLevel OptLevel) {
5904 if (OptLevel == CodeGenOptLevel::None)
5905 return SDValue();
5906
5907 SDValue Op = N->getOperand(0);
5908 if (!Op.hasOneUse())
5909 return SDValue();
5910 EVT ToVT = N->getValueType(0);
5911 EVT FromVT = Op.getValueType();
5912 if (!((ToVT == MVT::i32 && FromVT == MVT::i16) ||
5913 (ToVT == MVT::i64 && FromVT == MVT::i32)))
5914 return SDValue();
5915 if (!(Op.getOpcode() == ISD::MUL ||
5916 (Op.getOpcode() == ISD::SHL && isa<ConstantSDNode>(Op.getOperand(1)))))
5917 return SDValue();
5918
5919 SDLoc DL(N);
5920 unsigned ExtOpcode = N->getOpcode();
5921 unsigned Opcode = 0;
5922 if (ExtOpcode == ISD::SIGN_EXTEND && Op->getFlags().hasNoSignedWrap())
5923 Opcode = NVPTXISD::MUL_WIDE_SIGNED;
5924 else if (ExtOpcode == ISD::ZERO_EXTEND && Op->getFlags().hasNoUnsignedWrap())
5925 Opcode = NVPTXISD::MUL_WIDE_UNSIGNED;
5926 else
5927 return SDValue();
5928 SDValue RHS = Op.getOperand(1);
5929 if (Op.getOpcode() == ISD::SHL) {
5930 const auto ShiftAmt = Op.getConstantOperandVal(1);
5931 const auto MulVal = APInt(ToVT.getSizeInBits(), 1) << ShiftAmt;
5932 RHS = DCI.DAG.getConstant(MulVal, DL, ToVT);
5933 }
5934 return DCI.DAG.getNode(Opcode, DL, ToVT, Op.getOperand(0), RHS);
5935}
5936
5942
5943/// IsMulWideOperandDemotable - Checks if the provided DAG node is an operand
5944/// that can be demoted to \p OptSize bits without loss of information. The
5945/// signedness of the operand, if determinable, is placed in \p S.
5947 unsigned OptSize,
5948 OperandSignedness &S) {
5949 S = Unknown;
5950
5951 if (Op.getOpcode() == ISD::SIGN_EXTEND ||
5952 Op.getOpcode() == ISD::SIGN_EXTEND_INREG) {
5953 EVT OrigVT = Op.getOperand(0).getValueType();
5954 if (OrigVT.getFixedSizeInBits() <= OptSize) {
5955 S = Signed;
5956 return true;
5957 }
5958 } else if (Op.getOpcode() == ISD::ZERO_EXTEND) {
5959 EVT OrigVT = Op.getOperand(0).getValueType();
5960 if (OrigVT.getFixedSizeInBits() <= OptSize) {
5961 S = Unsigned;
5962 return true;
5963 }
5964 }
5965
5966 return false;
5967}
5968
5969/// AreMulWideOperandsDemotable - Checks if the given LHS and RHS operands can
5970/// be demoted to \p OptSize bits without loss of information. If the operands
5971/// contain a constant, it should appear as the RHS operand. The signedness of
5972/// the operands is placed in \p IsSigned.
5974 unsigned OptSize,
5975 bool &IsSigned) {
5976 OperandSignedness LHSSign;
5977
5978 // The LHS operand must be a demotable op
5979 if (!IsMulWideOperandDemotable(LHS, OptSize, LHSSign))
5980 return false;
5981
5982 // We should have been able to determine the signedness from the LHS
5983 if (LHSSign == Unknown)
5984 return false;
5985
5986 IsSigned = (LHSSign == Signed);
5987
5988 // The RHS can be a demotable op or a constant
5990 const APInt &Val = CI->getAPIntValue();
5991 if (LHSSign == Unsigned) {
5992 return Val.isIntN(OptSize);
5993 } else {
5994 return Val.isSignedIntN(OptSize);
5995 }
5996 } else {
5997 OperandSignedness RHSSign;
5998 if (!IsMulWideOperandDemotable(RHS, OptSize, RHSSign))
5999 return false;
6000
6001 return LHSSign == RHSSign;
6002 }
6003}
6004
6005/// TryMULWIDECombine - Attempt to replace a multiply of M bits with a multiply
6006/// of M/2 bits that produces an M-bit result (i.e. mul.wide). This transform
6007/// works on both multiply DAG nodes and SHL DAG nodes with a constant shift
6008/// amount.
6011 EVT MulType = N->getValueType(0);
6012 if (MulType != MVT::i32 && MulType != MVT::i64) {
6013 return SDValue();
6014 }
6015
6016 SDLoc DL(N);
6017 unsigned OptSize = MulType.getSizeInBits() >> 1;
6018 SDValue LHS = N->getOperand(0);
6019 SDValue RHS = N->getOperand(1);
6020
6021 // Canonicalize the multiply so the constant (if any) is on the right
6022 if (N->getOpcode() == ISD::MUL) {
6023 if (isa<ConstantSDNode>(LHS)) {
6024 std::swap(LHS, RHS);
6025 }
6026 }
6027
6028 // If we have a SHL, determine the actual multiply amount
6029 if (N->getOpcode() == ISD::SHL) {
6031 if (!ShlRHS) {
6032 return SDValue();
6033 }
6034
6035 APInt ShiftAmt = ShlRHS->getAPIntValue();
6036 unsigned BitWidth = MulType.getSizeInBits();
6037 if (ShiftAmt.sge(0) && ShiftAmt.slt(BitWidth)) {
6038 APInt MulVal = APInt(BitWidth, 1) << ShiftAmt;
6039 RHS = DCI.DAG.getConstant(MulVal, DL, MulType);
6040 } else {
6041 return SDValue();
6042 }
6043 }
6044
6045 bool Signed;
6046 // Verify that our operands are demotable
6047 if (!AreMulWideOperandsDemotable(LHS, RHS, OptSize, Signed)) {
6048 return SDValue();
6049 }
6050
6051 EVT DemotedVT;
6052 if (MulType == MVT::i32) {
6053 DemotedVT = MVT::i16;
6054 } else {
6055 DemotedVT = MVT::i32;
6056 }
6057
6058 // Truncate the operands to the correct size. Note that these are just for
6059 // type consistency and will (likely) be eliminated in later phases.
6060 SDValue TruncLHS =
6061 DCI.DAG.getNode(ISD::TRUNCATE, DL, DemotedVT, LHS);
6062 SDValue TruncRHS =
6063 DCI.DAG.getNode(ISD::TRUNCATE, DL, DemotedVT, RHS);
6064
6065 unsigned Opc;
6066 if (Signed) {
6067 Opc = NVPTXISD::MUL_WIDE_SIGNED;
6068 } else {
6069 Opc = NVPTXISD::MUL_WIDE_UNSIGNED;
6070 }
6071
6072 return DCI.DAG.getNode(Opc, DL, MulType, TruncLHS, TruncRHS);
6073}
6074
6075static bool isConstOne(const SDValue &Operand) {
6076 const auto *Const = dyn_cast<ConstantSDNode>(Operand);
6077 return Const && Const->getZExtValue() == 1;
6078}
6079
6081 if (Add->getOpcode() != ISD::ADD)
6082 return SDValue();
6083
6084 if (isConstOne(Add->getOperand(0)))
6085 return Add->getOperand(1);
6086
6087 if (isConstOne(Add->getOperand(1)))
6088 return Add->getOperand(0);
6089
6090 return SDValue();
6091}
6092
6095
6097 SDValue Mul = DCI.DAG.getNode(ISD::MUL, DL, VT, X, Y);
6098 return DCI.DAG.getNode(ISD::ADD, DL, VT, Mul, X);
6099 }
6100
6101 return SDValue();
6102}
6103
6105 SDLoc DL,
6107 if (Select->getOpcode() != ISD::SELECT)
6108 return SDValue();
6109
6110 SDValue Cond = Select->getOperand(0);
6111
6112 unsigned ConstOpNo;
6113 if (isConstOne(Select->getOperand(1)))
6114 ConstOpNo = 1;
6115 else if (isConstOne(Select->getOperand(2)))
6116 ConstOpNo = 2;
6117 else
6118 return SDValue();
6119
6120 SDValue Y = Select->getOperand((ConstOpNo == 1) ? 2 : 1);
6121
6122 // Do not combine if the resulting sequence is not obviously profitable.
6124 return SDValue();
6125
6126 SDValue NewMul = DCI.DAG.getNode(ISD::MUL, DL, VT, X, Y);
6127
6128 return DCI.DAG.getNode(ISD::SELECT, DL, VT, Cond,
6129 (ConstOpNo == 1) ? X : NewMul,
6130 (ConstOpNo == 1) ? NewMul : X);
6131}
6132
6133static SDValue
6136
6137 EVT VT = N0.getValueType();
6138 if (VT.isVector())
6139 return SDValue();
6140
6141 if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64)
6142 return SDValue();
6143
6144 SDLoc DL(N);
6145
6146 // (mul x, (add y, 1)) -> (add (mul x, y), x)
6147 if (SDValue Res = combineMADConstOne(N0, N1, VT, DL, DCI))
6148 return Res;
6149 if (SDValue Res = combineMADConstOne(N1, N0, VT, DL, DCI))
6150 return Res;
6151
6152 // (mul x, (select y, 1)) -> (select (mul x, y), x)
6153 if (SDValue Res = combineMulSelectConstOne(N0, N1, VT, DL, DCI))
6154 return Res;
6155 if (SDValue Res = combineMulSelectConstOne(N1, N0, VT, DL, DCI))
6156 return Res;
6157
6158 return SDValue();
6159}
6160
6161/// PerformMULCombine - Runs PTX-specific DAG combine patterns on MUL nodes.
6164 CodeGenOptLevel OptLevel) {
6165 if (OptLevel == CodeGenOptLevel::None)
6166 return SDValue();
6167
6168 if (SDValue Ret = TryMULWIDECombine(N, DCI))
6169 return Ret;
6170
6171 SDValue N0 = N->getOperand(0);
6172 SDValue N1 = N->getOperand(1);
6173 return PerformMULCombineWithOperands(N, N0, N1, DCI);
6174}
6175
6176/// PerformSHLCombine - Runs PTX-specific DAG combine patterns on SHL nodes.
6179 CodeGenOptLevel OptLevel) {
6180 if (OptLevel > CodeGenOptLevel::None) {
6181 // Try mul.wide combining at OptLevel > 0
6182 if (SDValue Ret = TryMULWIDECombine(N, DCI))
6183 return Ret;
6184 }
6185
6186 return SDValue();
6187}
6188
6191 unsigned int SmVersion) {
6192 EVT CCType = N->getValueType(0);
6193 SDValue A = N->getOperand(0);
6194 SDValue B = N->getOperand(1);
6195
6196 EVT AType = A.getValueType();
6197 if (!(CCType == MVT::v2i1 && (AType == MVT::v2f16 || AType == MVT::v2bf16)))
6198 return SDValue();
6199
6200 if (A.getValueType() == MVT::v2bf16 && SmVersion < 90)
6201 return SDValue();
6202
6203 SDLoc DL(N);
6204 // setp.f16x2 returns two scalar predicates, which we need to
6205 // convert back to v2i1. The returned result will be scalarized by
6206 // the legalizer, but the comparison will remain a single vector
6207 // instruction.
6208 SDValue CCNode = DCI.DAG.getNode(
6209 A.getValueType() == MVT::v2f16 ? NVPTXISD::SETP_F16X2
6211 DL, DCI.DAG.getVTList(MVT::i1, MVT::i1), {A, B, N->getOperand(2)});
6212 return DCI.DAG.getNode(ISD::BUILD_VECTOR, DL, CCType, CCNode.getValue(0),
6213 CCNode.getValue(1));
6214}
6215
6218 SDValue Vector = N->getOperand(0);
6219 if (Vector->getOpcode() == ISD::FREEZE)
6220 Vector = Vector->getOperand(0);
6221 SDLoc DL(N);
6222 EVT VectorVT = Vector.getValueType();
6223 if (Vector->getOpcode() == ISD::LOAD && VectorVT.isSimple() &&
6224 IsPTXVectorType(VectorVT.getSimpleVT()))
6225 return SDValue(); // Native vector loads already combine nicely w/
6226 // extract_vector_elt.
6227 // Don't mess with singletons or packed types (v2*32, v2*16, v4i8 and v8i8),
6228 // we already handle them OK.
6229 if (VectorVT.getVectorNumElements() == 1 ||
6230 NVPTX::isPackedVectorTy(VectorVT) || VectorVT == MVT::v8i8)
6231 return SDValue();
6232
6233 // Don't mess with undef values as sra may be simplified to 0, not undef.
6234 if (Vector->isUndef() || ISD::allOperandsUndef(Vector.getNode()))
6235 return SDValue();
6236
6237 uint64_t VectorBits = VectorVT.getSizeInBits();
6238 // We only handle the types we can extract in-register.
6239 if (!(VectorBits == 16 || VectorBits == 32 || VectorBits == 64))
6240 return SDValue();
6241
6242 ConstantSDNode *Index = dyn_cast<ConstantSDNode>(N->getOperand(1));
6243 // Index == 0 is handled by generic DAG combiner.
6244 if (!Index || Index->getZExtValue() == 0)
6245 return SDValue();
6246
6247 MVT IVT = MVT::getIntegerVT(VectorBits);
6248 EVT EltVT = VectorVT.getVectorElementType();
6249 EVT EltIVT = EltVT.changeTypeToInteger();
6250 uint64_t EltBits = EltVT.getScalarSizeInBits();
6251
6252 SDValue Result = DCI.DAG.getNode(
6253 ISD::TRUNCATE, DL, EltIVT,
6254 DCI.DAG.getNode(
6255 ISD::SRA, DL, IVT, DCI.DAG.getNode(ISD::BITCAST, DL, IVT, Vector),
6256 DCI.DAG.getConstant(Index->getZExtValue() * EltBits, DL, IVT)));
6257
6258 // If element has non-integer type, bitcast it back to the expected type.
6259 if (EltVT != EltIVT)
6260 Result = DCI.DAG.getNode(ISD::BITCAST, DL, EltVT, Result);
6261 // Past legalizer, we may need to extent i8 -> i16 to match the register type.
6262 if (EltVT != N->getValueType(0))
6263 Result = DCI.DAG.getNode(ISD::ANY_EXTEND, DL, N->getValueType(0), Result);
6264
6265 return Result;
6266}
6267
6270 SDValue VA = N->getOperand(1);
6271 EVT VectorVT = VA.getValueType();
6272 if (VectorVT != MVT::v4i8)
6273 return SDValue();
6274
6275 // We need to split vselect into individual per-element operations Because we
6276 // use BFE/BFI instruction for byte extraction/insertion, we do end up with
6277 // 32-bit values, so we may as well do comparison as i32 to avoid conversions
6278 // to/from i16 normally used for i8 values.
6280 SDLoc DL(N);
6281 SDValue VCond = N->getOperand(0);
6282 SDValue VB = N->getOperand(2);
6283 for (int I = 0; I < 4; ++I) {
6284 SDValue C = DCI.DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i1, VCond,
6285 DCI.DAG.getConstant(I, DL, MVT::i32));
6286 SDValue EA = DCI.DAG.getAnyExtOrTrunc(
6287 DCI.DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8, VA,
6288 DCI.DAG.getConstant(I, DL, MVT::i32)),
6289 DL, MVT::i32);
6290 SDValue EB = DCI.DAG.getAnyExtOrTrunc(
6291 DCI.DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8, VB,
6292 DCI.DAG.getConstant(I, DL, MVT::i32)),
6293 DL, MVT::i32);
6294 E.push_back(DCI.DAG.getAnyExtOrTrunc(
6295 DCI.DAG.getNode(ISD::SELECT, DL, MVT::i32, C, EA, EB), DL, MVT::i8));
6296 }
6297 return DCI.DAG.getNode(ISD::BUILD_VECTOR, DL, MVT::v4i8, E);
6298}
6299
6300static SDValue
6302 auto VT = N->getValueType(0);
6303 if (!DCI.isAfterLegalizeDAG() ||
6304 // only process v2*16 types
6305 !(NVPTX::isPackedVectorTy(VT) && VT.is32BitVector() &&
6306 VT.getVectorNumElements() == 2))
6307 return SDValue();
6308
6309 auto Op0 = N->getOperand(0);
6310 auto Op1 = N->getOperand(1);
6311
6312 // Start out by assuming we want to take the lower 2 bytes of each i32
6313 // operand.
6314 uint64_t Op0Bytes = 0x10;
6315 uint64_t Op1Bytes = 0x54;
6316
6317 std::pair<SDValue *, uint64_t *> OpData[2] = {{&Op0, &Op0Bytes},
6318 {&Op1, &Op1Bytes}};
6319
6320 // Check that each operand is an i16, truncated from an i32 operand. We'll
6321 // select individual bytes from those original operands. Optionally, fold in a
6322 // shift right of that original operand.
6323 for (auto &[Op, OpBytes] : OpData) {
6324 // Eat up any bitcast
6325 if (Op->getOpcode() == ISD::BITCAST)
6326 *Op = Op->getOperand(0);
6327
6328 if (!(Op->getValueType() == MVT::i16 && Op->getOpcode() == ISD::TRUNCATE &&
6329 Op->getOperand(0).getValueType() == MVT::i32))
6330 return SDValue();
6331
6332 // If the truncate has multiple uses, this optimization can increase
6333 // register pressure
6334 if (!Op->hasOneUse())
6335 return SDValue();
6336
6337 *Op = Op->getOperand(0);
6338
6339 // Optionally, fold in a shift-right of the original operand and let permute
6340 // pick the two higher bytes of the original value directly.
6341 if (Op->getOpcode() == ISD::SRL && isa<ConstantSDNode>(Op->getOperand(1))) {
6342 if (cast<ConstantSDNode>(Op->getOperand(1))->getZExtValue() == 16) {
6343 // Shift the PRMT byte selector to pick upper bytes from each respective
6344 // value, instead of the lower ones: 0x10 -> 0x32, 0x54 -> 0x76
6345 assert((*OpBytes == 0x10 || *OpBytes == 0x54) &&
6346 "PRMT selector values out of range");
6347 *OpBytes += 0x22;
6348 *Op = Op->getOperand(0);
6349 }
6350 }
6351 }
6352
6353 SDLoc DL(N);
6354 auto &DAG = DCI.DAG;
6355
6356 auto PRMT =
6357 getPRMT(DAG.getBitcast(MVT::i32, Op0), DAG.getBitcast(MVT::i32, Op1),
6358 (Op1Bytes << 8) | Op0Bytes, DL, DAG);
6359 return DAG.getBitcast(VT, PRMT);
6360}
6361
6364 auto *ASCN1 = cast<AddrSpaceCastSDNode>(N);
6365
6366 if (auto *ASCN2 = dyn_cast<AddrSpaceCastSDNode>(ASCN1->getOperand(0))) {
6367 assert(ASCN2->getDestAddressSpace() == ASCN1->getSrcAddressSpace());
6368
6369 // Fold asc[B -> A](asc[A -> B](x)) -> x
6370 if (ASCN1->getDestAddressSpace() == ASCN2->getSrcAddressSpace())
6371 return ASCN2->getOperand(0);
6372 }
6373
6374 return SDValue();
6375}
6376
6377// Given a constant selector value and a prmt mode, return the selector value
6378// normalized to the generic prmt mode. See the PTX ISA documentation for more
6379// details:
6380// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt
6381static APInt getPRMTSelector(const APInt &Selector, unsigned Mode) {
6382 assert(Selector.getBitWidth() == 32 && "PRMT must have i32 operands");
6383
6385 return Selector;
6386
6387 const unsigned V = Selector.trunc(2).getZExtValue();
6388
6389 const auto GetSelector = [](unsigned S0, unsigned S1, unsigned S2,
6390 unsigned S3) {
6391 return APInt(32, S0 | (S1 << 4) | (S2 << 8) | (S3 << 12));
6392 };
6393
6394 switch (Mode) {
6396 return GetSelector(V, V + 1, V + 2, V + 3);
6398 return GetSelector(V, (V - 1) & 7, (V - 2) & 7, (V - 3) & 7);
6400 return GetSelector(V, V, V, V);
6402 return GetSelector(V, std::max(V, 1U), std::max(V, 2U), 3U);
6404 return GetSelector(0, std::min(V, 1U), std::min(V, 2U), V);
6406 unsigned V1 = (V & 1) << 1;
6407 return GetSelector(V1, V1 + 1, V1, V1 + 1);
6408 }
6409 default:
6410 llvm_unreachable("Invalid PRMT mode");
6411 }
6412}
6413
6414static APInt computePRMT(APInt A, APInt B, APInt Selector, unsigned Mode) {
6415 assert(A.getBitWidth() == 32 && B.getBitWidth() == 32 &&
6416 Selector.getBitWidth() == 32 && "PRMT must have i32 operands");
6417 // {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}
6418 APInt BitField = B.concat(A);
6419 APInt SelectorVal = getPRMTSelector(Selector, Mode);
6420 APInt Result(32, 0);
6421 for (unsigned I : llvm::seq(4U)) {
6422 APInt Sel = SelectorVal.extractBits(4, I * 4);
6423 unsigned Idx = Sel.getLoBits(3).getZExtValue();
6424 unsigned Sign = Sel.getHiBits(1).getZExtValue();
6425 APInt Byte = BitField.extractBits(8, Idx * 8);
6426 if (Sign)
6427 Byte = Byte.ashr(8);
6428 Result.insertBits(Byte, I * 8);
6429 }
6430 return Result;
6431}
6432
6434 CodeGenOptLevel OptLevel) {
6435 if (OptLevel == CodeGenOptLevel::None)
6436 return SDValue();
6437
6438 // Constant fold PRMT
6439 if (isa<ConstantSDNode>(N->getOperand(0)) &&
6440 isa<ConstantSDNode>(N->getOperand(1)) &&
6441 isa<ConstantSDNode>(N->getOperand(2)))
6442 return DCI.DAG.getConstant(computePRMT(N->getConstantOperandAPInt(0),
6443 N->getConstantOperandAPInt(1),
6444 N->getConstantOperandAPInt(2),
6445 N->getConstantOperandVal(3)),
6446 SDLoc(N), N->getValueType(0));
6447 return SDValue();
6448}
6449
6450// During call lowering we wrap the return values in a ProxyReg node which
6451// depend on the chain value produced by the completed call. This ensures that
6452// the full call is emitted in cases where libcalls are used to legalize
6453// operations. To improve the functioning of other DAG combines we pull all
6454// operations we can through one of these nodes, ensuring that the ProxyReg
6455// directly wraps a load. That is:
6456//
6457// (ProxyReg (zext (load retval0))) => (zext (ProxyReg (load retval0)))
6458//
6461 switch (R.getOpcode()) {
6462 case ISD::TRUNCATE:
6463 case ISD::ANY_EXTEND:
6464 case ISD::SIGN_EXTEND:
6465 case ISD::ZERO_EXTEND:
6466 case ISD::BITCAST: {
6467 if (SDValue V = sinkProxyReg(R.getOperand(0), Chain, DCI))
6468 return DCI.DAG.getNode(R.getOpcode(), SDLoc(R), R.getValueType(), V);
6469 return SDValue();
6470 }
6471 case ISD::SHL:
6472 case ISD::SRL:
6473 case ISD::SRA:
6474 case ISD::OR: {
6475 if (SDValue A = sinkProxyReg(R.getOperand(0), Chain, DCI))
6476 if (SDValue B = sinkProxyReg(R.getOperand(1), Chain, DCI))
6477 return DCI.DAG.getNode(R.getOpcode(), SDLoc(R), R.getValueType(), A, B);
6478 return SDValue();
6479 }
6480 case ISD::Constant:
6481 return R;
6482 case ISD::LOAD:
6483 case NVPTXISD::LoadV2:
6484 case NVPTXISD::LoadV4: {
6485 return DCI.DAG.getNode(NVPTXISD::ProxyReg, SDLoc(R), R.getValueType(),
6486 {Chain, R});
6487 }
6488 case ISD::BUILD_VECTOR: {
6489 if (DCI.isBeforeLegalize())
6490 return SDValue();
6491
6493 for (auto &Op : R->ops()) {
6494 SDValue V = sinkProxyReg(Op, Chain, DCI);
6495 if (!V)
6496 return SDValue();
6497 Ops.push_back(V);
6498 }
6499 return DCI.DAG.getNode(ISD::BUILD_VECTOR, SDLoc(R), R.getValueType(), Ops);
6500 }
6502 if (DCI.isBeforeLegalize())
6503 return SDValue();
6504
6505 if (SDValue V = sinkProxyReg(R.getOperand(0), Chain, DCI))
6507 R.getValueType(), V, R.getOperand(1));
6508 return SDValue();
6509 }
6510 default:
6511 return SDValue();
6512 }
6513}
6514
6517
6518 SDValue Chain = N->getOperand(0);
6519 SDValue Reg = N->getOperand(1);
6520
6521 // If the ProxyReg is not wrapping a load, try to pull the operations through
6522 // the ProxyReg.
6523 if (Reg.getOpcode() != ISD::LOAD) {
6524 if (SDValue V = sinkProxyReg(Reg, Chain, DCI))
6525 return V;
6526 }
6527
6528 return SDValue();
6529}
6530
6531SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
6532 DAGCombinerInfo &DCI) const {
6534 switch (N->getOpcode()) {
6535 default:
6536 break;
6537 case ISD::ADD:
6538 return PerformADDCombine(N, DCI, OptLevel);
6539 case ISD::ADDRSPACECAST:
6540 return combineADDRSPACECAST(N, DCI);
6541 case ISD::SIGN_EXTEND:
6542 case ISD::ZERO_EXTEND:
6543 return combineMulWide(N, DCI, OptLevel);
6544 case ISD::BUILD_VECTOR:
6545 return PerformBUILD_VECTORCombine(N, DCI);
6547 return PerformEXTRACTCombine(N, DCI);
6548 case ISD::FADD:
6549 return PerformFADDCombine(N, DCI, OptLevel);
6550 case ISD::FMAXNUM:
6551 case ISD::FMINNUM:
6552 case ISD::FMAXIMUM:
6553 case ISD::FMINIMUM:
6554 case ISD::FMAXIMUMNUM:
6555 case ISD::FMINIMUMNUM:
6556 return PerformFMinMaxCombine(N, DCI, STI.getPTXVersion(),
6557 STI.getSmVersion());
6558 case ISD::LOAD:
6559 case NVPTXISD::LoadV2:
6560 case NVPTXISD::LoadV4:
6561 return combineLOAD(N, DCI, STI);
6562 case ISD::MUL:
6563 return PerformMULCombine(N, DCI, OptLevel);
6564 case NVPTXISD::PRMT:
6565 return combinePRMT(N, DCI, OptLevel);
6566 case NVPTXISD::ProxyReg:
6567 return combineProxyReg(N, DCI);
6568 case ISD::SETCC:
6569 return PerformSETCCCombine(N, DCI, STI.getSmVersion());
6570 case ISD::SHL:
6571 return PerformSHLCombine(N, DCI, OptLevel);
6572 case ISD::SREM:
6573 case ISD::UREM:
6574 return PerformREMCombine(N, DCI, OptLevel);
6575 case ISD::STORE:
6576 case NVPTXISD::StoreV2:
6577 case NVPTXISD::StoreV4:
6578 return combineSTORE(N, DCI, STI);
6579 case ISD::VSELECT:
6580 return PerformVSELECTCombine(N, DCI);
6581 }
6582 return SDValue();
6583}
6584
6587 // Handle bitcasting to v2i8 without hitting the default promotion
6588 // strategy which goes through stack memory.
6589 SDValue Op(Node, 0);
6590 EVT ToVT = Op->getValueType(0);
6591 if (ToVT != MVT::v2i8) {
6592 return;
6593 }
6594
6595 // Bitcast to i16 and unpack elements into a vector
6596 SDLoc DL(Node);
6597 SDValue AsInt = DAG.getBitcast(MVT::i16, Op->getOperand(0));
6598 SDValue Vec0 = DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, AsInt);
6599 SDValue Const8 = DAG.getConstant(8, DL, MVT::i16);
6600 SDValue Vec1 =
6601 DAG.getNode(ISD::TRUNCATE, DL, MVT::i8,
6602 DAG.getNode(ISD::SRL, DL, MVT::i16, {AsInt, Const8}));
6603 Results.push_back(
6604 DAG.getNode(ISD::BUILD_VECTOR, DL, MVT::v2i8, {Vec0, Vec1}));
6605}
6606
6609 SDValue Chain = N->getOperand(0);
6610 SDValue Intrin = N->getOperand(1);
6611 SDLoc DL(N);
6612
6613 // Get the intrinsic ID
6614 unsigned IntrinNo = Intrin.getNode()->getAsZExtVal();
6615 switch (IntrinNo) {
6616 default:
6617 return;
6618 case Intrinsic::nvvm_ldu_global_i:
6619 case Intrinsic::nvvm_ldu_global_f:
6620 case Intrinsic::nvvm_ldu_global_p: {
6621 EVT ResVT = N->getValueType(0);
6622
6623 if (ResVT.isVector()) {
6624 // Vector LDG/LDU
6625
6626 unsigned NumElts = ResVT.getVectorNumElements();
6627 EVT EltVT = ResVT.getVectorElementType();
6628
6629 // Since LDU/LDG are target nodes, we cannot rely on DAG type
6630 // legalization.
6631 // Therefore, we must ensure the type is legal. For i1 and i8, we set the
6632 // loaded type to i16 and propagate the "real" type as the memory type.
6633 bool NeedTrunc = false;
6634 if (EltVT.getSizeInBits() < 16) {
6635 EltVT = MVT::i16;
6636 NeedTrunc = true;
6637 }
6638
6639 unsigned Opcode = 0;
6640 SDVTList LdResVTs;
6641
6642 switch (NumElts) {
6643 default:
6644 return;
6645 case 2:
6646 Opcode = NVPTXISD::LDUV2;
6647 LdResVTs = DAG.getVTList(EltVT, EltVT, MVT::Other);
6648 break;
6649 case 4: {
6650 Opcode = NVPTXISD::LDUV4;
6651 EVT ListVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other };
6652 LdResVTs = DAG.getVTList(ListVTs);
6653 break;
6654 }
6655 }
6656
6657 SmallVector<SDValue, 8> OtherOps;
6658
6659 // Copy regular operands
6660
6661 OtherOps.push_back(Chain); // Chain
6662 // Skip operand 1 (intrinsic ID)
6663 // Others
6664 OtherOps.append(N->op_begin() + 2, N->op_end());
6665
6667
6668 SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, DL, LdResVTs, OtherOps,
6669 MemSD->getMemoryVT(),
6670 MemSD->getMemOperand());
6671
6672 SmallVector<SDValue, 4> ScalarRes;
6673
6674 for (unsigned i = 0; i < NumElts; ++i) {
6675 SDValue Res = NewLD.getValue(i);
6676 if (NeedTrunc)
6677 Res =
6678 DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res);
6679 ScalarRes.push_back(Res);
6680 }
6681
6682 SDValue LoadChain = NewLD.getValue(NumElts);
6683
6684 SDValue BuildVec =
6685 DAG.getBuildVector(ResVT, DL, ScalarRes);
6686
6687 Results.push_back(BuildVec);
6688 Results.push_back(LoadChain);
6689 } else {
6690 // i8 LDG/LDU
6691 assert(ResVT.isSimple() && ResVT.getSimpleVT().SimpleTy == MVT::i8 &&
6692 "Custom handling of non-i8 ldu/ldg?");
6693
6694 // Just copy all operands as-is
6696
6697 // Force output to i16
6698 SDVTList LdResVTs = DAG.getVTList(MVT::i16, MVT::Other);
6699
6701
6702 // We make sure the memory type is i8, which will be used during isel
6703 // to select the proper instruction.
6704 SDValue NewLD =
6706 MVT::i8, MemSD->getMemOperand());
6707
6708 Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i8,
6709 NewLD.getValue(0)));
6710 Results.push_back(NewLD.getValue(1));
6711 }
6712 return;
6713 }
6714
6715 case Intrinsic::nvvm_tcgen05_ld_16x64b_x4:
6716 case Intrinsic::nvvm_tcgen05_ld_16x64b_x8:
6717 case Intrinsic::nvvm_tcgen05_ld_16x64b_x16:
6718 case Intrinsic::nvvm_tcgen05_ld_16x64b_x32:
6719 case Intrinsic::nvvm_tcgen05_ld_16x64b_x64:
6720 case Intrinsic::nvvm_tcgen05_ld_16x64b_x128:
6721 case Intrinsic::nvvm_tcgen05_ld_32x32b_x4:
6722 case Intrinsic::nvvm_tcgen05_ld_32x32b_x8:
6723 case Intrinsic::nvvm_tcgen05_ld_32x32b_x16:
6724 case Intrinsic::nvvm_tcgen05_ld_32x32b_x32:
6725 case Intrinsic::nvvm_tcgen05_ld_32x32b_x64:
6726 case Intrinsic::nvvm_tcgen05_ld_32x32b_x128:
6727 case Intrinsic::nvvm_tcgen05_ld_16x128b_x2:
6728 case Intrinsic::nvvm_tcgen05_ld_16x128b_x4:
6729 case Intrinsic::nvvm_tcgen05_ld_16x128b_x8:
6730 case Intrinsic::nvvm_tcgen05_ld_16x128b_x16:
6731 case Intrinsic::nvvm_tcgen05_ld_16x128b_x32:
6732 case Intrinsic::nvvm_tcgen05_ld_16x128b_x64:
6733 case Intrinsic::nvvm_tcgen05_ld_16x256b_x1:
6734 case Intrinsic::nvvm_tcgen05_ld_16x256b_x2:
6735 case Intrinsic::nvvm_tcgen05_ld_16x256b_x4:
6736 case Intrinsic::nvvm_tcgen05_ld_16x256b_x8:
6737 case Intrinsic::nvvm_tcgen05_ld_16x256b_x16:
6738 case Intrinsic::nvvm_tcgen05_ld_16x256b_x32:
6739 if (auto Res = lowerTcgen05Ld(N, DAG)) {
6740 Results.push_back(Res->first);
6741 Results.push_back(Res->second);
6742 }
6743 return;
6744
6745 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x4:
6746 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x8:
6747 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x16:
6748 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x32:
6749 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x64:
6750 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x128:
6751 if (auto Res = lowerTcgen05Ld(N, DAG, /*HasOffset=*/true)) {
6752 Results.push_back(Res->first);
6753 Results.push_back(Res->second);
6754 }
6755 return;
6756 }
6757}
6758
6761 // Change the CopyFromReg to output 2 64-bit results instead of a 128-bit
6762 // result so that it can pass the legalization
6763 SDLoc DL(N);
6764 SDValue Chain = N->getOperand(0);
6765 SDValue Reg = N->getOperand(1);
6766 SDValue Glue = N->getOperand(2);
6767
6768 assert(Reg.getValueType() == MVT::i128 &&
6769 "Custom lowering for CopyFromReg with 128-bit reg only");
6770 SmallVector<EVT, 4> ResultsType = {MVT::i64, MVT::i64, N->getValueType(1),
6771 N->getValueType(2)};
6772 SmallVector<SDValue, 3> NewOps = {Chain, Reg, Glue};
6773
6774 SDValue NewValue = DAG.getNode(ISD::CopyFromReg, DL, ResultsType, NewOps);
6775 SDValue Pair = DAG.getNode(ISD::BUILD_PAIR, DL, MVT::i128,
6776 {NewValue.getValue(0), NewValue.getValue(1)});
6777
6778 Results.push_back(Pair);
6779 Results.push_back(NewValue.getValue(2));
6780 Results.push_back(NewValue.getValue(3));
6781}
6782
6784 const TargetLowering &TLI,
6786 SDValue Chain = N->getOperand(0);
6787 SDValue Reg = N->getOperand(1);
6788
6789 MVT VT = TLI.getRegisterType(*DAG.getContext(), Reg.getValueType());
6790
6791 SDValue NewReg = DAG.getAnyExtOrTrunc(Reg, SDLoc(N), VT);
6792 SDValue NewProxy =
6793 DAG.getNode(NVPTXISD::ProxyReg, SDLoc(N), VT, {Chain, NewReg});
6794 SDValue Res = DAG.getAnyExtOrTrunc(NewProxy, SDLoc(N), N->getValueType(0));
6795
6796 Results.push_back(Res);
6797}
6798
6800 const NVPTXSubtarget &STI,
6802 assert(N->getValueType(0) == MVT::i128 &&
6803 "Custom lowering for atomic128 only supports i128");
6804
6806 SDLoc dl(N);
6807
6808 if (!STI.hasAtomSwap128()) {
6811 "Support for b128 atomics introduced in PTX ISA version 8.3 and "
6812 "requires target sm_90.",
6813 dl.getDebugLoc()));
6814
6815 Results.push_back(DAG.getUNDEF(MVT::i128));
6816 Results.push_back(AN->getOperand(0)); // Chain
6817 return;
6818 }
6819
6821 Ops.push_back(AN->getOperand(0)); // Chain
6822 Ops.push_back(AN->getOperand(1)); // Ptr
6823 for (const auto &Op : AN->ops().drop_front(2)) {
6824 // Low part
6825 Ops.push_back(DAG.getNode(ISD::EXTRACT_ELEMENT, dl, MVT::i64, Op,
6826 DAG.getIntPtrConstant(0, dl)));
6827 // High part
6828 Ops.push_back(DAG.getNode(ISD::EXTRACT_ELEMENT, dl, MVT::i64, Op,
6829 DAG.getIntPtrConstant(1, dl)));
6830 }
6831 unsigned Opcode = N->getOpcode() == ISD::ATOMIC_SWAP
6834 SDVTList Tys = DAG.getVTList(MVT::i64, MVT::i64, MVT::Other);
6835 SDValue Result = DAG.getMemIntrinsicNode(Opcode, dl, Tys, Ops, MVT::i128,
6836 AN->getMemOperand());
6837 Results.push_back(DAG.getNode(ISD::BUILD_PAIR, dl, MVT::i128,
6838 {Result.getValue(0), Result.getValue(1)}));
6839 Results.push_back(Result.getValue(2));
6840}
6841
6842void NVPTXTargetLowering::ReplaceNodeResults(
6844 switch (N->getOpcode()) {
6845 default:
6846 report_fatal_error("Unhandled custom legalization");
6847 case ISD::BITCAST:
6848 ReplaceBITCAST(N, DAG, Results);
6849 return;
6850 case ISD::LOAD:
6851 case ISD::MLOAD:
6852 replaceLoadVector(N, DAG, Results, STI);
6853 return;
6856 return;
6857 case ISD::CopyFromReg:
6859 return;
6860 case NVPTXISD::ProxyReg:
6861 replaceProxyReg(N, DAG, *this, Results);
6862 return;
6863 case ISD::ATOMIC_CMP_SWAP:
6864 case ISD::ATOMIC_SWAP:
6865 replaceAtomicSwap128(N, DAG, STI, Results);
6866 return;
6867 }
6868}
6869
6872 Type *Ty = AI->getValOperand()->getType();
6873
6874 if (AI->isFloatingPointOperation()) {
6876 if (Ty->isHalfTy() && STI.getSmVersion() >= 70 &&
6877 STI.getPTXVersion() >= 63)
6879 if (Ty->isBFloatTy() && STI.getSmVersion() >= 90 &&
6880 STI.getPTXVersion() >= 78)
6882 if (Ty->isFloatTy())
6884 if (Ty->isDoubleTy() && STI.hasAtomAddF64())
6886 }
6888 }
6889
6890 assert(Ty->isIntegerTy() && "Ty should be integer at this point");
6891 const unsigned BitWidth = cast<IntegerType>(Ty)->getBitWidth();
6892
6893 switch (AI->getOperation()) {
6894 default:
6897 if (BitWidth == 128)
6899 [[fallthrough]];
6903 switch (BitWidth) {
6904 case 8:
6905 case 16:
6907 case 32:
6909 case 64:
6910 if (STI.hasAtomBitwise64())
6913 case 128:
6915 default:
6916 llvm_unreachable("unsupported width encountered");
6917 }
6924 switch (BitWidth) {
6925 case 8:
6926 case 16:
6928 case 32:
6930 case 64:
6931 if (STI.hasAtomMinMax64())
6934 case 128:
6936 default:
6937 llvm_unreachable("unsupported width encountered");
6938 }
6941 switch (BitWidth) {
6942 case 32:
6944 case 8:
6945 case 16:
6946 case 64:
6947 case 128:
6949 default:
6950 llvm_unreachable("unsupported width encountered");
6951 }
6952 }
6953
6955}
6956
6958 const Instruction *I) const {
6959 auto *CI = dyn_cast<AtomicCmpXchgInst>(I);
6960 // When CAS bitwidth is not supported on the hardware, the CAS is emulated
6961 // using a retry loop that uses a higher-bitwidth monotonic CAS. We enforce
6962 // the memory order using explicit fences around the retry loop.
6963 // The memory order of natively supported CAS operations can be enforced
6964 // by lowering to an atom.cas with the right memory synchronizing effect.
6965 // However, atom.cas only supports relaxed, acquire, release and acq_rel.
6966 // So we also use explicit fences for enforcing memory order for
6967 // seq_cast CAS with natively-supported bitwidths.
6968 return CI &&
6969 (cast<IntegerType>(CI->getCompareOperand()->getType())->getBitWidth() <
6970 STI.getMinCmpXchgSizeInBits() ||
6971 CI->getMergedOrdering() == AtomicOrdering::SequentiallyConsistent);
6972}
6973
6975 const Instruction *I) const {
6976 auto *CI = dyn_cast<AtomicCmpXchgInst>(I);
6977 bool BitwidthSupportedAndIsSeqCst =
6978 CI && CI->getMergedOrdering() == AtomicOrdering::SequentiallyConsistent &&
6979 cast<IntegerType>(CI->getCompareOperand()->getType())->getBitWidth() >=
6980 STI.getMinCmpXchgSizeInBits();
6981 return BitwidthSupportedAndIsSeqCst ? AtomicOrdering::Acquire
6983}
6984
6986 Instruction *Inst,
6987 AtomicOrdering Ord) const {
6988 if (!isa<AtomicCmpXchgInst>(Inst))
6989 return TargetLoweringBase::emitLeadingFence(Builder, Inst, Ord);
6990
6991 // Specialize for cmpxchg
6992 // Emit a fence.sc leading fence for cmpxchg seq_cst which are not emulated
6993 SyncScope::ID SSID = cast<AtomicCmpXchgInst>(Inst)->getSyncScopeID();
6994 if (isReleaseOrStronger(Ord))
6995 return Builder.CreateFence(Ord == AtomicOrdering::SequentiallyConsistent
6996 ? Ord
6998 SSID);
6999
7000 return nullptr;
7001}
7002
7004 Instruction *Inst,
7005 AtomicOrdering Ord) const {
7006 // Specialize for cmpxchg
7007 if (!isa<AtomicCmpXchgInst>(Inst))
7008 return TargetLoweringBase::emitTrailingFence(Builder, Inst, Ord);
7009
7010 auto *CI = cast<AtomicCmpXchgInst>(Inst);
7011 auto CASWidth =
7012 cast<IntegerType>(CI->getCompareOperand()->getType())->getBitWidth();
7013 SyncScope::ID SSID = CI->getSyncScopeID();
7014 // Do not emit a trailing fence for cmpxchg seq_cst which are not emulated
7015 if (isAcquireOrStronger(Ord) &&
7017 CASWidth < STI.getMinCmpXchgSizeInBits()))
7018 return Builder.CreateFence(AtomicOrdering::Acquire, SSID);
7019
7020 return nullptr;
7021}
7022
7023// Rather than default to SINT when both UINT and SINT are custom, we only
7024// change the opcode when UINT is not legal and SINT is. UINT is preferred when
7025// both are custom since unsigned CVT instructions can lead to slightly better
7026// SASS code with fewer instructions.
7028 EVT ToVT) const {
7029 if (isOperationLegal(Op, ToVT))
7030 return Op;
7031 switch (Op) {
7032 case ISD::FP_TO_UINT:
7034 return ISD::FP_TO_SINT;
7035 break;
7039 break;
7040 case ISD::VP_FP_TO_UINT:
7041 if (isOperationLegal(ISD::VP_FP_TO_SINT, ToVT))
7042 return ISD::VP_FP_TO_SINT;
7043 break;
7044 default:
7045 break;
7046 }
7047 return Op;
7048}
7049
7050// Pin NVPTXTargetObjectFile's vtables to this file.
7052
7057
7059 const SelectionDAG &DAG, unsigned Depth) {
7060 SDValue A = Op.getOperand(0);
7061 SDValue B = Op.getOperand(1);
7062 ConstantSDNode *Selector = dyn_cast<ConstantSDNode>(Op.getOperand(2));
7063 unsigned Mode = Op.getConstantOperandVal(3);
7064
7065 if (!Selector)
7066 return;
7067
7068 KnownBits AKnown = DAG.computeKnownBits(A, Depth);
7069 KnownBits BKnown = DAG.computeKnownBits(B, Depth);
7070
7071 // {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}
7072 assert(AKnown.getBitWidth() == 32 && BKnown.getBitWidth() == 32 &&
7073 "PRMT must have i32 operands");
7074 assert(Known.getBitWidth() == 32 && "PRMT must have i32 result");
7075 KnownBits BitField = BKnown.concat(AKnown);
7076
7077 APInt SelectorVal = getPRMTSelector(Selector->getAPIntValue(), Mode);
7078 for (unsigned I : llvm::seq(4)) {
7079 APInt Sel = SelectorVal.extractBits(4, I * 4);
7080 unsigned Idx = Sel.getLoBits(3).getZExtValue();
7081 unsigned Sign = Sel.getHiBits(1).getZExtValue();
7082 KnownBits Byte = BitField.extractBits(8, Idx * 8);
7083 if (Sign)
7084 Byte = KnownBits::ashr(Byte, 8);
7085 Known.insertBits(Byte, I * 8);
7086 }
7087}
7088
7089static void computeKnownBitsForLoadV(const SDValue Op, KnownBits &Known) {
7091
7092 // We can't do anything without knowing the sign bit.
7093 auto ExtType = LD->getConstantOperandVal(LD->getNumOperands() - 1);
7094 if (ExtType == ISD::SEXTLOAD)
7095 return;
7096
7097 // ExtLoading to vector types is weird and may not work well with known bits.
7098 auto DestVT = LD->getValueType(0);
7099 if (DestVT.isVector())
7100 return;
7101
7102 assert(Known.getBitWidth() == DestVT.getSizeInBits());
7103 auto ElementBitWidth = NVPTXDAGToDAGISel::getFromTypeWidthForLoad(LD);
7104 Known.Zero.setHighBits(Known.getBitWidth() - ElementBitWidth);
7105}
7106
7108 const SDValue Op, KnownBits &Known, const APInt &DemandedElts,
7109 const SelectionDAG &DAG, unsigned Depth) const {
7110 Known.resetAll();
7111
7112 switch (Op.getOpcode()) {
7113 case NVPTXISD::PRMT:
7114 computeKnownBitsForPRMT(Op, Known, DAG, Depth);
7115 break;
7116 case NVPTXISD::LoadV2:
7117 case NVPTXISD::LoadV4:
7118 case NVPTXISD::LoadV8:
7120 break;
7121 default:
7122 break;
7123 }
7124}
7125
7126static std::pair<APInt, APInt> getPRMTDemandedBits(const APInt &SelectorVal,
7127 const APInt &DemandedBits) {
7128 APInt DemandedLHS = APInt(32, 0);
7129 APInt DemandedRHS = APInt(32, 0);
7130
7131 for (unsigned I : llvm::seq(4)) {
7132 if (DemandedBits.extractBits(8, I * 8).isZero())
7133 continue;
7134
7135 APInt Sel = SelectorVal.extractBits(4, I * 4);
7136 unsigned Idx = Sel.getLoBits(3).getZExtValue();
7137 unsigned Sign = Sel.getHiBits(1).getZExtValue();
7138
7139 APInt &Src = Idx < 4 ? DemandedLHS : DemandedRHS;
7140 unsigned ByteStart = (Idx % 4) * 8;
7141 if (Sign)
7142 Src.setBit(ByteStart + 7);
7143 else
7144 Src.setBits(ByteStart, ByteStart + 8);
7145 }
7146
7147 return {DemandedLHS, DemandedRHS};
7148}
7149
7150// Replace undef with 0 as this is easier for other optimizations such as
7151// known bits.
7153 if (!Op)
7154 return SDValue();
7155 if (Op.isUndef())
7156 return DAG.getConstant(0, SDLoc(), MVT::i32);
7157 return Op;
7158}
7159
7161 const APInt &DemandedBits,
7162 SelectionDAG &DAG,
7163 const TargetLowering &TLI,
7164 unsigned Depth) {
7165 assert(PRMT.getOpcode() == NVPTXISD::PRMT);
7166 SDValue Op0 = PRMT.getOperand(0);
7167 SDValue Op1 = PRMT.getOperand(1);
7168 auto *SelectorConst = dyn_cast<ConstantSDNode>(PRMT.getOperand(2));
7169 if (!SelectorConst)
7170 return SDValue();
7171
7172 unsigned Mode = PRMT.getConstantOperandVal(3);
7173 const APInt Selector = getPRMTSelector(SelectorConst->getAPIntValue(), Mode);
7174
7175 // Try to simplify the PRMT to one of the inputs if the used bytes are all
7176 // from the same input in the correct order.
7177 const unsigned LeadingBytes = DemandedBits.countLeadingZeros() / 8;
7178 const unsigned SelBits = (4 - LeadingBytes) * 4;
7179 if (Selector.getLoBits(SelBits) == APInt(32, 0x3210).getLoBits(SelBits))
7180 return Op0;
7181 if (Selector.getLoBits(SelBits) == APInt(32, 0x7654).getLoBits(SelBits))
7182 return Op1;
7183
7184 auto [DemandedLHS, DemandedRHS] = getPRMTDemandedBits(Selector, DemandedBits);
7185
7186 // Attempt to avoid multi-use ops if we don't need anything from them.
7187 SDValue DemandedOp0 =
7188 TLI.SimplifyMultipleUseDemandedBits(Op0, DemandedLHS, DAG, Depth + 1);
7189 SDValue DemandedOp1 =
7190 TLI.SimplifyMultipleUseDemandedBits(Op1, DemandedRHS, DAG, Depth + 1);
7191
7192 DemandedOp0 = canonicalizePRMTInput(DemandedOp0, DAG);
7193 DemandedOp1 = canonicalizePRMTInput(DemandedOp1, DAG);
7194 if ((DemandedOp0 && DemandedOp0 != Op0) ||
7195 (DemandedOp1 && DemandedOp1 != Op1)) {
7196 Op0 = DemandedOp0 ? DemandedOp0 : Op0;
7197 Op1 = DemandedOp1 ? DemandedOp1 : Op1;
7198 return getPRMT(Op0, Op1, Selector.getZExtValue(), SDLoc(PRMT), DAG);
7199 }
7200
7201 return SDValue();
7202}
7203
7205 SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
7206 KnownBits &Known, TargetLoweringOpt &TLO, unsigned Depth) const {
7207 Known.resetAll();
7208
7209 switch (Op.getOpcode()) {
7210 case NVPTXISD::PRMT:
7212 *this, Depth)) {
7213 TLO.CombineTo(Op, Result);
7214 return true;
7215 }
7216 break;
7217 default:
7218 break;
7219 }
7220
7221 computeKnownBitsForTargetNode(Op, Known, DemandedElts, TLO.DAG, Depth);
7222 return false;
7223}
return SDValue()
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
constexpr LLT S1
constexpr LLT F32
AMDGPU Register Bank Select
This file declares a class to represent arbitrary precision floating point values and provide a varie...
This file implements a class to represent arbitrary precision integral constant values and operations...
static SDValue PerformADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1, TargetLowering::DAGCombinerInfo &DCI, const ARMSubtarget *Subtarget)
PerformADDCombineWithOperands - Try DAG combinations for an ADD with operands N0 and N1.
static SDValue PerformADDCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, const ARMSubtarget *Subtarget)
PerformADDCombine - Target-specific dag combine xforms for ISD::ADD.
static SDValue PerformVSELECTCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, const ARMSubtarget *Subtarget)
static SDValue PerformMULCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, const ARMSubtarget *Subtarget)
static SDValue PerformFADDCombine(SDNode *N, SelectionDAG &DAG, const ARMSubtarget *Subtarget)
static SDValue PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, const ARMSubtarget *Subtarget)
PerformBUILD_VECTORCombine - Target-specific dag combine xforms for ISD::BUILD_VECTOR.
MachineBasicBlock & MBB
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Function Alias Analysis Results
Atomic ordering constants.
This file contains the simple types necessary to represent the attributes associated with functions a...
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
#define clEnumValN(ENUMVAL, FLAGNAME, DESC)
This file contains the declarations for the subclasses of Constant, which represent the different fla...
This file contains the declarations of entities that describe floating point environment and related ...
shuff Hexagon Optimize Shuffle Vector
Module.h This file contains the declarations for the Module class.
const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]
#define F(x, y, z)
Definition MD5.cpp:54
#define I(x, y, z)
Definition MD5.cpp:57
static DebugLoc getDebugLoc(MachineBasicBlock::instr_iterator FirstMI, MachineBasicBlock::instr_iterator LastMI)
Return the first DebugLoc that has line number information, given a range of instructions.
Register Reg
Register const TargetRegisterInfo * TRI
#define T
NVPTX address space definition.
static bool shouldConvertToIndirectCall(const CallBase *CB, const GlobalAddressSDNode *Func)
static SDValue combineADDRSPACECAST(SDNode *N, TargetLowering::DAGCombinerInfo &DCI)
static cl::opt< bool > sched4reg("nvptx-sched4reg", cl::desc("NVPTX Specific: schedule for register pressue"), cl::init(false))
static SDValue lowerTcgen05St(SDValue Op, SelectionDAG &DAG)
static SDValue PerformEXTRACTCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI)
static cl::opt< NVPTX::DivPrecisionLevel > UsePrecDivF32("nvptx-prec-divf32", cl::Hidden, cl::desc("NVPTX Specific: Override the precision of the lowering for f32 fdiv"), cl::values(clEnumValN(NVPTX::DivPrecisionLevel::Approx, "0", "Use div.approx"), clEnumValN(NVPTX::DivPrecisionLevel::Full, "1", "Use div.full"), clEnumValN(NVPTX::DivPrecisionLevel::IEEE754, "2", "Use IEEE Compliant F32 div.rnd if available (default)"), clEnumValN(NVPTX::DivPrecisionLevel::IEEE754_NoFTZ, "3", "Use IEEE Compliant F32 div.rnd if available, no FTZ")), cl::init(NVPTX::DivPrecisionLevel::IEEE754))
static bool isConstOne(const SDValue &Operand)
static cl::opt< unsigned > FMAContractLevelOpt("nvptx-fma-level", cl::Hidden, cl::desc("NVPTX Specific: FMA contraction (0: don't do it" " 1: do it 2: do it aggressively"), cl::init(2))
static bool IsPTXVectorType(MVT VT)
static SDValue lowerLOADi1(LoadSDNode *LD, SelectionDAG &DAG)
static SDValue lowerIntrinsicVoid(SDValue Op, SelectionDAG &DAG)
static MachinePointerInfo refinePtrAS(SDValue &Ptr, SelectionDAG &DAG, const DataLayout &DL, const TargetLowering &TL)
static SDValue lowerROT(SDValue Op, SelectionDAG &DAG)
static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL, LLVMContext &Ctx, CallingConv::ID CallConv, Type *Ty, SmallVectorImpl< EVT > &ValueVTs, SmallVectorImpl< uint64_t > &Offsets, uint64_t StartingOffset=0)
ComputePTXValueVTs - For the given Type Ty, returns the set of primitive legal-ish MVTs that compose ...
static void ReplaceBITCAST(SDNode *Node, SelectionDAG &DAG, SmallVectorImpl< SDValue > &Results)
static void replaceAtomicSwap128(SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI, SmallVectorImpl< SDValue > &Results)
static unsigned getMinMax3Opcode(unsigned MinMax2Opcode)
Get 3-input version of a 2-input min/max opcode.
static SDValue lowerSTOREVector(SDValue Op, SelectionDAG &DAG, const NVPTXSubtarget &STI)
static SDValue lowerLoadVector(SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI)
static void replaceProxyReg(SDNode *N, SelectionDAG &DAG, const TargetLowering &TLI, SmallVectorImpl< SDValue > &Results)
static void ReplaceCopyFromReg_128(SDNode *N, SelectionDAG &DAG, SmallVectorImpl< SDValue > &Results)
static SDValue lowerCTLZCTPOP(SDValue Op, SelectionDAG &DAG)
static SDValue combineMADConstOne(SDValue X, SDValue Add, EVT VT, SDLoc DL, TargetLowering::DAGCombinerInfo &DCI)
static SDValue combinePRMT(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, CodeGenOptLevel OptLevel)
static SDValue combinePackingMovIntoStore(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, unsigned Front, unsigned Back)
Fold packing movs into a store.
static void ReplaceINTRINSIC_W_CHAIN(SDNode *N, SelectionDAG &DAG, SmallVectorImpl< SDValue > &Results)
static SDValue getBuildVectorizedValue(unsigned N, const SDLoc &dl, SelectionDAG &DAG, T GetElement)
static SDValue getExtractVectorizedValue(SDValue V, unsigned I, EVT VT, const SDLoc &dl, SelectionDAG &DAG)
static unsigned canMergeParamLoadStoresStartingAt(unsigned Idx, uint32_t AccessSize, const SmallVectorImpl< EVT > &ValueVTs, const SmallVectorImpl< T > &Offsets, Align ParamAlignment)
static EVT getVectorizedVT(EVT VT, unsigned N, LLVMContext &C)
static SDValue lowerIntrinsicWOChain(SDValue Op, SelectionDAG &DAG)
static SDValue PerformFMinMaxCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, unsigned PTXVersion, unsigned SmVersion)
PerformFMinMaxCombine - Combine (fmaxnum (fmaxnum a, b), c) into (fmaxnum3 a, b, c).
static SDValue combineMulWide(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, CodeGenOptLevel OptLevel)
static SDValue PerformFADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1, TargetLowering::DAGCombinerInfo &DCI, CodeGenOptLevel OptLevel)
static std::optional< unsigned > getScalar3OpcodeForReduction(unsigned ReductionOpcode)
Get 3-input scalar reduction opcode.
static SDValue lowerIntrinsicWChain(SDValue Op, SelectionDAG &DAG)
static bool isConstZero(const SDValue &Operand)
static SDValue LowerVectorArith(SDValue Op, SelectionDAG &DAG)
static SDValue LowerTcgen05MMADisableOutputLane(SDValue Op, SelectionDAG &DAG)
static bool IsMulWideOperandDemotable(SDValue Op, unsigned OptSize, OperandSignedness &S)
IsMulWideOperandDemotable - Checks if the provided DAG node is an operand that can be demoted to OptS...
static unsigned getTcgen05MMADisableOutputLane(unsigned IID)
static std::pair< APInt, APInt > getPRMTDemandedBits(const APInt &SelectorVal, const APInt &DemandedBits)
static APInt computePRMT(APInt A, APInt B, APInt Selector, unsigned Mode)
static ISD::NodeType getScalarOpcodeForReduction(unsigned ReductionOpcode)
static SDValue PerformREMCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, CodeGenOptLevel OptLevel)
static SDValue lowerBSWAP(SDValue Op, SelectionDAG &DAG)
static SDValue lowerMSTORE(SDValue Op, SelectionDAG &DAG)
static SDValue PerformMULCombineWithOperands(SDNode *N, SDValue N0, SDValue N1, TargetLowering::DAGCombinerInfo &DCI)
static void computeKnownBitsForPRMT(const SDValue Op, KnownBits &Known, const SelectionDAG &DAG, unsigned Depth)
static SDValue combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI)
Fold unpacking movs into a load by increasing the number of return values.
static SDValue LowerClusterLaunchControlQueryCancel(SDValue Op, SelectionDAG &DAG)
static std::optional< std::pair< SDValue, SDValue > > lowerTcgen05Ld(SDNode *N, SelectionDAG &DAG, bool HasOffset=false)
static SDValue lowerCvtRSIntrinsics(SDValue Op, SelectionDAG &DAG)
static std::optional< std::pair< SDValue, SDValue > > replaceLoadVector(SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI)
replaceLoadVector - Convert vector loads into multi-output scalar loads.
static SDValue expandFSH64(SDValue A, SDValue B, SDValue ShiftAmount, SDLoc DL, unsigned Opcode, SelectionDAG &DAG)
static bool AreMulWideOperandsDemotable(SDValue LHS, SDValue RHS, unsigned OptSize, bool &IsSigned)
AreMulWideOperandsDemotable - Checks if the given LHS and RHS operands can be demoted to OptSize bits...
static std::pair< MemSDNode *, uint32_t > convertMLOADToLoadWithUsedBytesMask(MemSDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI)
static SDValue TryMULWIDECombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI)
TryMULWIDECombine - Attempt to replace a multiply of M bits with a multiply of M/2 bits that produces...
static SDValue lowerPrmtIntrinsic(SDValue Op, SelectionDAG &DAG)
static SDValue combineMulSelectConstOne(SDValue X, SDValue Select, EVT VT, SDLoc DL, TargetLowering::DAGCombinerInfo &DCI)
static SDValue buildTreeReduction(const SmallVector< SDValue > &Elements, EVT EltTy, ArrayRef< std::pair< unsigned, unsigned > > Ops, const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG)
Reduces the elements using the scalar operations provided.
static SDValue combineProxyReg(SDNode *N, TargetLowering::DAGCombinerInfo &DCI)
static SmallVector< unsigned, 16 > VectorizePTXValueVTs(const SmallVectorImpl< EVT > &ValueVTs, const SmallVectorImpl< T > &Offsets, Align ParamAlignment, bool IsVAArg=false)
static SDValue getPRMT(SDValue A, SDValue B, SDValue Selector, SDLoc DL, SelectionDAG &DAG, unsigned Mode=NVPTX::PTXPrmtMode::NONE)
static SDValue matchMADConstOnePattern(SDValue Add)
static SDValue correctParamType(SDValue V, EVT ExpectedVT, ISD::ArgFlagsTy Flags, SelectionDAG &DAG, SDLoc dl)
static ISD::NodeType getExtOpcode(const ISD::ArgFlagsTy &Flags)
static cl::opt< bool > UsePrecSqrtF32("nvptx-prec-sqrtf32", cl::Hidden, cl::desc("NVPTX Specific: 0 use sqrt.approx, 1 use sqrt.rn."), cl::init(true))
static void computeKnownBitsForLoadV(const SDValue Op, KnownBits &Known)
static APInt getPRMTSelector(const APInt &Selector, unsigned Mode)
static EVT promoteScalarIntegerPTX(const EVT VT)
PromoteScalarIntegerPTX Used to make sure the arguments/returns are suitable for passing and promote ...
static SDValue simplifyDemandedBitsForPRMT(SDValue PRMT, const APInt &DemandedBits, SelectionDAG &DAG, const TargetLowering &TLI, unsigned Depth)
static SDValue lowerFREM(SDValue Op, SelectionDAG &DAG)
static SDValue canonicalizePRMTInput(SDValue Op, SelectionDAG &DAG)
static SDValue sinkProxyReg(SDValue R, SDValue Chain, TargetLowering::DAGCombinerInfo &DCI)
static SDValue lowerFSH(SDValue Op, SelectionDAG &DAG)
static SDValue PromoteBinOpToF32(SDNode *N, SelectionDAG &DAG)
static SDValue PerformSETCCCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, unsigned int SmVersion)
static std::optional< std::pair< unsigned int, MVT > > getVectorLoweringShape(EVT VectorEVT, const NVPTXSubtarget &STI, unsigned AddressSpace)
static cl::opt< bool > ForceMinByValParamAlign("nvptx-force-min-byval-param-align", cl::Hidden, cl::desc("NVPTX Specific: force 4-byte minimal alignment for byval" " params of device functions."), cl::init(false))
static cl::opt< bool > UseApproxLog2F32("nvptx-approx-log2f32", cl::desc("NVPTX Specific: whether to use lg2.approx for log2"), cl::init(false))
Whereas CUDA's implementation (see libdevice) uses ex2.approx for exp2(), it does NOT use lg2....
static SDValue lowerSELECT(SDValue Op, SelectionDAG &DAG)
static SDValue combineLOAD(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, const NVPTXSubtarget &STI)
static SDValue combineSTORE(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, const NVPTXSubtarget &STI)
static SDValue PerformSHLCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, CodeGenOptLevel OptLevel)
PerformSHLCombine - Runs PTX-specific DAG combine patterns on SHL nodes.
MachineInstr unsigned OpIdx
uint64_t High
#define P(N)
const SmallVectorImpl< MachineOperand > & Cond
static cl::opt< RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode > Mode("regalloc-enable-advisor", cl::Hidden, cl::init(RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode::Default), cl::desc("Enable regalloc advisor mode"), cl::values(clEnumValN(RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode::Default, "default", "Default"), clEnumValN(RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode::Release, "release", "precompiled"), clEnumValN(RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode::Development, "development", "for training")))
This file contains some templates that are useful if you are working with the STL at all.
This file defines the SmallVector class.
static TableGen::Emitter::Opt Y("gen-skeleton-entry", EmitSkeleton, "Generate example skeleton entry")
static TableGen::Emitter::OptClass< SkeletonEmitter > X("gen-skeleton-class", "Generate example skeleton class")
This file describes how to lower LLVM code to machine code.
Value * RHS
Value * LHS
BinaryOperator * Mul
static const fltSemantics & IEEEsingle()
Definition APFloat.h:296
static APFloat getInf(const fltSemantics &Sem, bool Negative=false)
Factory for Positive and Negative Infinity.
Definition APFloat.h:1080
Class for arbitrary precision integers.
Definition APInt.h:78
LLVM_ABI APInt getLoBits(unsigned numBits) const
Compute an APInt containing numBits lowbits from this APInt.
Definition APInt.cpp:644
uint64_t getZExtValue() const
Get zero extended value.
Definition APInt.h:1541
void setHighBits(unsigned hiBits)
Set the top hiBits bits.
Definition APInt.h:1392
LLVM_ABI APInt getHiBits(unsigned numBits) const
Compute an APInt containing numBits highbits from this APInt.
Definition APInt.cpp:639
LLVM_ABI APInt trunc(unsigned width) const
Truncate to new width.
Definition APInt.cpp:936
void setBit(unsigned BitPosition)
Set the given bit to 1 whose position is given as "bitPosition".
Definition APInt.h:1331
unsigned getBitWidth() const
Return the number of bits in the APInt.
Definition APInt.h:1489
bool isSignedIntN(unsigned N) const
Check if this APInt has an N-bits signed integer value.
Definition APInt.h:436
bool slt(const APInt &RHS) const
Signed less than comparison.
Definition APInt.h:1131
LLVM_ABI APInt extractBits(unsigned numBits, unsigned bitPosition) const
Return an APInt with the extracted bits [bitPosition,bitPosition+numBits).
Definition APInt.cpp:482
bool isIntN(unsigned N) const
Check if this APInt has an N-bits unsigned integer value.
Definition APInt.h:433
bool sge(const APInt &RHS) const
Signed greater or equal comparison.
Definition APInt.h:1238
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition ArrayRef.h:40
const T & back() const
back - Get the last element.
Definition ArrayRef.h:151
ArrayRef< T > drop_back(size_t N=1) const
Drop the last N elements of the array.
Definition ArrayRef.h:201
bool empty() const
empty - Check if the array is empty.
Definition ArrayRef.h:137
ArrayRef< T > slice(size_t N, size_t M) const
slice(n, m) - Chop off the first N elements of the array, and keep M elements in the array.
Definition ArrayRef.h:186
an instruction that atomically reads a memory location, combines it with another value,...
@ Add
*p = old + v
@ FAdd
*p = old + v
@ Min
*p = old <signed v ? old : v
@ Sub
*p = old - v
@ And
*p = old & v
@ Xor
*p = old ^ v
@ UIncWrap
Increment one up to a maximum value.
@ Max
*p = old >signed v ? old : v
@ UMin
*p = old <unsigned v ? old : v
@ UMax
*p = old >unsigned v ? old : v
@ UDecWrap
Decrement one until a minimum value or zero.
bool isFloatingPointOperation() const
BinOp getOperation() const
This is an SDNode representing atomic operations.
Base class for all callable instructions (InvokeInst and CallInst) Holds everything related to callin...
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
FunctionType * getFunctionType() const
const APInt & getAPIntValue() const
static LLVM_ABI Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
A parsed version of the target data layout string in and methods for querying it.
Definition DataLayout.h:63
LLVM_ABI TypeSize getTypeAllocSize(Type *Ty) const
Returns the offset in bytes between successive objects of the specified type, including alignment pad...
LLVM_ABI Align getPrefTypeAlign(Type *Ty) const
Returns the preferred stack/global alignment for the specified type.
Diagnostic information for unsupported feature in backend.
void addFnAttr(Attribute::AttrKind Kind)
Add function attributes to this function.
Definition Function.cpp:640
Common base class shared among various IRBuilders.
Definition IRBuilder.h:114
This is an important class for using LLVM in a threaded context.
Definition LLVMContext.h:68
LLVM_ABI void diagnose(const DiagnosticInfo &DI)
Report a message to the currently installed diagnostic handler.
This class is used to represent ISD::LOAD nodes.
MCSection * getDataSection() const
static constexpr unsigned NoRegister
Definition MCRegister.h:60
Instances of this class represent a uniqued identifier for a section in the current translation unit.
Definition MCSection.h:517
StringRef getName() const
getName - Get the symbol name.
Definition MCSymbol.h:188
Machine Value Type.
static auto integer_fixedlen_vector_valuetypes()
SimpleValueType SimpleTy
unsigned getVectorNumElements() const
bool isVector() const
Return true if this is a vector value type.
bool isScalableVector() const
Return true if this is a vector value type where the runtime length is machine dependent.
static auto integer_valuetypes()
TypeSize getSizeInBits() const
Returns the size of the specified MVT in bits.
static auto fixedlen_vector_valuetypes()
TypeSize getStoreSize() const
Return the number of bytes overwritten by a store of the specified value type.
static MVT getVectorVT(MVT VT, unsigned NumElements)
MVT getVectorElementType() const
static MVT getIntegerVT(unsigned BitWidth)
static auto fp_valuetypes()
MVT getScalarType() const
If this is a vector, return the element type, otherwise return this.
static auto fp_fixedlen_vector_valuetypes()
DenormalMode getDenormalMode(const fltSemantics &FPType) const
Returns the denormal handling type for the default rounding mode of the function.
Function & getFunction()
Return the LLVM function that this machine code represents.
const MachineJumpTableInfo * getJumpTableInfo() const
getJumpTableInfo - Return the jump table info object for the current function.
const TargetMachine & getTarget() const
getTarget - Return the target machine this machine code is compiled with
@ EK_Inline
EK_Inline - Jump table entries are emitted inline at their point of use.
const std::vector< MachineJumpTableEntry > & getJumpTables() const
@ MODereferenceable
The memory access is dereferenceable (i.e., doesn't trap).
@ MOLoad
The memory access reads data.
@ MOInvariant
The memory access always returns the same value (or traps).
@ MOStore
The memory access writes data.
This SDNode is used for target intrinsics that touch memory and need an associated MachineMemOperand.
This is an abstract virtual class for memory operations.
Align getAlign() const
MachineMemOperand * getMemOperand() const
Return a MachineMemOperand object describing the memory reference performed by operation.
EVT getMemoryVT() const
Return the type of the in-memory value.
static unsigned getFromTypeWidthForLoad(const MemSDNode *Mem)
bool hasUsedBytesMaskPragma() const
bool hasAtomSwap128() const
bool hasF32x2Instructions() const
bool has256BitVectorLoadStore(unsigned AS) const
AtomicOrdering atomicOperationOrderAfterFenceSplit(const Instruction *I) const override
ConstraintType getConstraintType(StringRef Constraint) const override
getConstraintType - Given a constraint letter, return the type of constraint it is for this target.
SDValue LowerOperation(SDValue Op, SelectionDAG &DAG) const override
This callback is invoked for operations that are unsupported by the target, which are registered to u...
const NVPTXTargetMachine * nvTM
bool getTgtMemIntrinsic(IntrinsicInfo &Info, const CallBase &I, MachineFunction &MF, unsigned Intrinsic) const override
Given an intrinsic, checks if on the target the intrinsic will need to map to a MemIntrinsicNode (tou...
bool SimplifyDemandedBitsForTargetNode(SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts, KnownBits &Known, TargetLoweringOpt &TLO, unsigned Depth=0) const override
Attempt to simplify any target nodes based on the demanded bits/elts, returning true on success.
NVPTXTargetLowering(const NVPTXTargetMachine &TM, const NVPTXSubtarget &STI)
std::string getPrototype(const DataLayout &DL, Type *, const ArgListTy &, const SmallVectorImpl< ISD::OutputArg > &, std::optional< unsigned > FirstVAArg, const CallBase &CB, unsigned UniqueCallSite) const
unsigned getPreferredFPToIntOpcode(unsigned Op, EVT FromVT, EVT ToVT) const override
bool useF32FTZ(const MachineFunction &MF) const
SDValue LowerSTACKSAVE(SDValue Op, SelectionDAG &DAG) const
Align getFunctionArgumentAlignment(const Function *F, Type *Ty, unsigned Idx, const DataLayout &DL) const
SDValue getSqrtEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled, int &ExtraSteps, bool &UseOneConst, bool Reciprocal) const override
Hooks for building estimates in place of slower divisions and square roots.
SDValue LowerReturn(SDValue Chain, CallingConv::ID CallConv, bool isVarArg, const SmallVectorImpl< ISD::OutputArg > &Outs, const SmallVectorImpl< SDValue > &OutVals, const SDLoc &dl, SelectionDAG &DAG) const override
This hook must be implemented to lower outgoing return values, described by the Outs array,...
SDValue LowerFormalArguments(SDValue Chain, CallingConv::ID CallConv, bool isVarArg, const SmallVectorImpl< ISD::InputArg > &Ins, const SDLoc &dl, SelectionDAG &DAG, SmallVectorImpl< SDValue > &InVals) const override
This hook must be implemented to lower the incoming (formal) arguments, described by the Ins array,...
void LowerAsmOperandForConstraint(SDValue Op, StringRef Constraint, std::vector< SDValue > &Ops, SelectionDAG &DAG) const override
Lower the specified operand into the Ops vector.
SDValue LowerSTACKRESTORE(SDValue Op, SelectionDAG &DAG) const
Instruction * emitTrailingFence(IRBuilderBase &Builder, Instruction *Inst, AtomicOrdering Ord) const override
std::string getParamName(const Function *F, int Idx) const
TargetLoweringBase::LegalizeTypeAction getPreferredVectorAction(MVT VT) const override
Return the preferred vector type legalization action.
NVPTX::DivPrecisionLevel getDivF32Level(const MachineFunction &MF, const SDNode &N) const
bool shouldInsertFencesForAtomic(const Instruction *) const override
Whether AtomicExpandPass should automatically insert fences and reduce ordering for this atomic.
Align getFunctionParamOptimizedAlign(const Function *F, Type *ArgTy, const DataLayout &DL) const
getFunctionParamOptimizedAlign - since function arguments are passed via .param space,...
SDValue LowerDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const
EVT getSetCCResultType(const DataLayout &DL, LLVMContext &Ctx, EVT VT) const override
Return the ValueType of the result of SETCC operations.
std::pair< unsigned, const TargetRegisterClass * > getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, StringRef Constraint, MVT VT) const override
Given a physical register constraint (e.g.
bool isLegalAddressingMode(const DataLayout &DL, const AddrMode &AM, Type *Ty, unsigned AS, Instruction *I=nullptr) const override
isLegalAddressingMode - Return true if the addressing mode represented by AM is legal for this target...
Instruction * emitLeadingFence(IRBuilderBase &Builder, Instruction *Inst, AtomicOrdering Ord) const override
Inserts in the IR a target-specific intrinsic specifying a fence.
AtomicExpansionKind shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const override
Returns how the IR-level AtomicExpand pass should expand the given AtomicRMW, if at all.
Align getFunctionByValParamAlign(const Function *F, Type *ArgTy, Align InitialAlign, const DataLayout &DL) const
Helper for computing alignment of a device function byval parameter.
bool allowFMA(MachineFunction &MF, CodeGenOptLevel OptLevel) const
bool usePrecSqrtF32(const SDNode *N=nullptr) const
unsigned getJumpTableEncoding() const override
Return the entry encoding for a jump table in the current function.
SDValue LowerCall(CallLoweringInfo &CLI, SmallVectorImpl< SDValue > &InVals) const override
This hook must be implemented to lower calls into the specified DAG.
void computeKnownBitsForTargetNode(const SDValue Op, KnownBits &Known, const APInt &DemandedElts, const SelectionDAG &DAG, unsigned Depth=0) const override
Determine which of the bits specified in Mask are known to be either zero or one and return them in t...
MCSection * SelectSectionForGlobal(const GlobalObject *GO, SectionKind Kind, const TargetMachine &TM) const override
static LLVM_ABI PointerType * get(Type *ElementType, unsigned AddressSpace)
This constructs a pointer to an object of the specified type in a numbered address space.
Wrapper class for IR location info (IR ordering and DebugLoc) to be passed into SDNode creation funct...
const DebugLoc & getDebugLoc() const
Represents one node in the SelectionDAG.
ArrayRef< SDUse > ops() const
const APInt & getAsAPIntVal() const
Helper method returns the APInt value of a ConstantSDNode.
unsigned getOpcode() const
Return the SelectionDAG opcode value for this node.
bool hasOneUse() const
Return true if there is exactly one use of this node.
unsigned getIROrder() const
Return the node ordering.
SDNodeFlags getFlags() const
uint64_t getAsZExtVal() const
Helper method returns the zero-extended integer value of a ConstantSDNode.
unsigned getNumValues() const
Return the number of values defined/returned by this operator.
SDVTList getVTList() const
const SDValue & getOperand(unsigned Num) const
bool isUndef() const
Returns true if the node type is UNDEF or POISON.
iterator_range< user_iterator > users()
void setFlags(SDNodeFlags NewFlags)
Represents a use of a SDNode.
Unlike LLVM values, Selection DAG nodes may return multiple values as the result of a computation.
SDNode * getNode() const
get the SDNode which holds the desired result
bool hasOneUse() const
Return true if there is exactly one node using value ResNo of Node.
SDValue getValue(unsigned R) const
EVT getValueType() const
Return the ValueType of the referenced return value.
const SDValue & getOperand(unsigned i) const
uint64_t getScalarValueSizeInBits() const
uint64_t getConstantOperandVal(unsigned i) const
unsigned getOpcode() const
SectionKind - This is a simple POD value that classifies the properties of a section.
Definition SectionKind.h:22
This is used to represent a portion of an LLVM function in a low-level Data Dependence DAG representa...
LLVM_ABI SDValue getExtLoad(ISD::LoadExtType ExtType, const SDLoc &dl, EVT VT, SDValue Chain, SDValue Ptr, MachinePointerInfo PtrInfo, EVT MemVT, MaybeAlign Alignment=MaybeAlign(), MachineMemOperand::Flags MMOFlags=MachineMemOperand::MONone, const AAMDNodes &AAInfo=AAMDNodes())
const SDValue & getRoot() const
Return the root tag of the SelectionDAG.
LLVM_ABI SDValue getAddrSpaceCast(const SDLoc &dl, EVT VT, SDValue Ptr, unsigned SrcAS, unsigned DestAS)
Return an AddrSpaceCastSDNode.
const TargetSubtargetInfo & getSubtarget() const
LLVM_ABI SDValue getMergeValues(ArrayRef< SDValue > Ops, const SDLoc &dl)
Create a MERGE_VALUES node from the given operands.
LLVM_ABI SDVTList getVTList(EVT VT)
Return an SDVTList that represents the list of values specified.
LLVM_ABI void ExtractVectorElements(SDValue Op, SmallVectorImpl< SDValue > &Args, unsigned Start=0, unsigned Count=0, EVT EltVT=EVT())
Append the extracted elements from Start to Count out of the vector Op in Args.
LLVM_ABI SDValue getFreeze(SDValue V)
Return a freeze using the SDLoc of the value operand.
SDValue getSetCC(const SDLoc &DL, EVT VT, SDValue LHS, SDValue RHS, ISD::CondCode Cond, SDValue Chain=SDValue(), bool IsSignaling=false)
Helper function to make it easier to build SetCC's if you just have an ISD::CondCode instead of an SD...
LLVM_ABI SDValue getSymbolFunctionGlobalAddress(SDValue Op, Function **TargetFunction=nullptr)
Return a GlobalAddress of the function from the current module with name matching the given ExternalS...
LLVM_ABI SDValue getConstantFP(double Val, const SDLoc &DL, EVT VT, bool isTarget=false)
Create a ConstantFPSDNode wrapping a constant value.
LLVM_ABI SDValue getRegister(Register Reg, EVT VT)
LLVM_ABI SDValue getLoad(EVT VT, const SDLoc &dl, SDValue Chain, SDValue Ptr, MachinePointerInfo PtrInfo, MaybeAlign Alignment=MaybeAlign(), MachineMemOperand::Flags MMOFlags=MachineMemOperand::MONone, const AAMDNodes &AAInfo=AAMDNodes(), const MDNode *Ranges=nullptr)
Loads are not normal binary operators: their result type is not determined by their operands,...
LLVM_ABI SDValue getMemIntrinsicNode(unsigned Opcode, const SDLoc &dl, SDVTList VTList, ArrayRef< SDValue > Ops, EVT MemVT, MachinePointerInfo PtrInfo, Align Alignment, MachineMemOperand::Flags Flags=MachineMemOperand::MOLoad|MachineMemOperand::MOStore, LocationSize Size=LocationSize::precise(0), const AAMDNodes &AAInfo=AAMDNodes())
Creates a MemIntrinsicNode that may produce a result and takes a list of operands.
LLVM_ABI Align getEVTAlign(EVT MemoryVT) const
Compute the default alignment value for the given type.
LLVM_ABI SDValue getNOT(const SDLoc &DL, SDValue Val, EVT VT)
Create a bitwise NOT operation as (XOR Val, -1).
const TargetLowering & getTargetLoweringInfo() const
LLVM_ABI SDNode * MorphNodeTo(SDNode *N, unsigned Opc, SDVTList VTs, ArrayRef< SDValue > Ops)
This mutates the specified node to have the specified return type, opcode, and operands.
SDValue getUNDEF(EVT VT)
Return an UNDEF node. UNDEF does not have a useful SDLoc.
SDValue getCALLSEQ_END(SDValue Chain, SDValue Op1, SDValue Op2, SDValue InGlue, const SDLoc &DL)
Return a new CALLSEQ_END node, which always must have a glue result (to ensure it's not CSE'd).
SDValue getBuildVector(EVT VT, const SDLoc &DL, ArrayRef< SDValue > Ops)
Return an ISD::BUILD_VECTOR node.
LLVM_ABI SDValue getBitcast(EVT VT, SDValue V)
Return a bitcast using the SDLoc of the value operand, and casting to the provided type.
SDValue getSelect(const SDLoc &DL, EVT VT, SDValue Cond, SDValue LHS, SDValue RHS, SDNodeFlags Flags=SDNodeFlags())
Helper function to make it easier to build Select's if you just have operands and don't want to check...
const DataLayout & getDataLayout() const
LLVM_ABI SDValue getTokenFactor(const SDLoc &DL, SmallVectorImpl< SDValue > &Vals)
Creates a new TokenFactor containing Vals.
LLVM_ABI SDValue getConstant(uint64_t Val, const SDLoc &DL, EVT VT, bool isTarget=false, bool isOpaque=false)
Create a ConstantSDNode wrapping a constant value.
LLVM_ABI SDValue getTruncStore(SDValue Chain, const SDLoc &dl, SDValue Val, SDValue Ptr, MachinePointerInfo PtrInfo, EVT SVT, Align Alignment, MachineMemOperand::Flags MMOFlags=MachineMemOperand::MONone, const AAMDNodes &AAInfo=AAMDNodes())
LLVM_ABI SDValue getStore(SDValue Chain, const SDLoc &dl, SDValue Val, SDValue Ptr, MachinePointerInfo PtrInfo, Align Alignment, MachineMemOperand::Flags MMOFlags=MachineMemOperand::MONone, const AAMDNodes &AAInfo=AAMDNodes())
Helper function to build ISD::STORE nodes.
LLVM_ABI SDValue getSignedConstant(int64_t Val, const SDLoc &DL, EVT VT, bool isTarget=false, bool isOpaque=false)
SDValue getCALLSEQ_START(SDValue Chain, uint64_t InSize, uint64_t OutSize, const SDLoc &DL)
Return a new CALLSEQ_START node, that starts new call frame, in which InSize bytes are set up inside ...
LLVM_ABI SDValue getBasicBlock(MachineBasicBlock *MBB)
SDValue getSelectCC(const SDLoc &DL, SDValue LHS, SDValue RHS, SDValue True, SDValue False, ISD::CondCode Cond, SDNodeFlags Flags=SDNodeFlags())
Helper function to make it easier to build SelectCC's if you just have an ISD::CondCode instead of an...
LLVM_ABI SDValue getExternalSymbol(const char *Sym, EVT VT)
LLVM_ABI SDValue getAnyExtOrTrunc(SDValue Op, const SDLoc &DL, EVT VT)
Convert Op, which must be of integer type, to the integer type VT, by either any-extending or truncat...
LLVM_ABI SDValue getIntPtrConstant(uint64_t Val, const SDLoc &DL, bool isTarget=false)
LLVM_ABI SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, ArrayRef< SDUse > Ops)
Gets or creates the specified node.
LLVM_ABI SDValue getFPExtendOrRound(SDValue Op, const SDLoc &DL, EVT VT)
Convert Op, which must be of float type, to the float type VT, by either extending or rounding (by tr...
SDValue getTargetConstant(uint64_t Val, const SDLoc &DL, EVT VT, bool isOpaque=false)
LLVM_ABI SDValue getVectorIdxConstant(uint64_t Val, const SDLoc &DL, bool isTarget=false)
MachineFunction & getMachineFunction() const
LLVM_ABI KnownBits computeKnownBits(SDValue Op, unsigned Depth=0) const
Determine which bits of Op are known to be either zero or one and return them in Known.
LLVM_ABI SDValue getZExtOrTrunc(SDValue Op, const SDLoc &DL, EVT VT)
Convert Op, which must be of integer type, to the integer type VT, by either zero-extending or trunca...
SDValue getObjectPtrOffset(const SDLoc &SL, SDValue Ptr, TypeSize Offset)
Create an add instruction with appropriate flags when used for addressing some offset of an object.
LLVMContext * getContext() const
const SDValue & setRoot(SDValue N)
Set the current root tag of the SelectionDAG.
LLVM_ABI SDValue getTargetExternalSymbol(const char *Sym, EVT VT, unsigned TargetFlags=0)
ArrayRef< int > getMask() const
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
This class is used to represent ISD::STORE nodes.
StringRef - Represent a constant reference to a string, i.e.
Definition StringRef.h:55
constexpr size_t size() const
size - Get the string size.
Definition StringRef.h:146
constexpr const char * data() const
data - Get a pointer to the start of the string (which may not be null terminated).
Definition StringRef.h:140
Align getStackAlign() const
getStackAlignment - This method returns the number of bytes to which the stack pointer must be aligne...
void setBooleanVectorContents(BooleanContent Ty)
Specify how the target extends the result of a vector boolean value from a vector of i1 to a wider ty...
void setOperationAction(unsigned Op, MVT VT, LegalizeAction Action)
Indicate that the specified operation does not work with the specified type and indicate what to do a...
void setMaxDivRemBitWidthSupported(unsigned SizeInBits)
Set the size in bits of the maximum div/rem the backend supports.
EVT getValueType(const DataLayout &DL, Type *Ty, bool AllowUnknown=false) const
Return the EVT corresponding to this LLVM type.
unsigned MaxStoresPerMemcpyOptSize
Likewise for functions with the OptSize attribute.
const TargetMachine & getTargetMachine() const
virtual unsigned getNumRegistersForCallingConv(LLVMContext &Context, CallingConv::ID CC, EVT VT) const
Certain targets require unusual breakdowns of certain types.
virtual MVT getRegisterTypeForCallingConv(LLVMContext &Context, CallingConv::ID CC, EVT VT) const
Certain combinations of ABIs, Targets and features require that types are legal for some operations a...
void setOperationPromotedToType(unsigned Opc, MVT OrigVT, MVT DestVT)
Convenience method to set an operation to Promote and specify the type in a single call.
LegalizeTypeAction
This enum indicates whether a types are legal for a target, and if not, what action should be used to...
void addBypassSlowDiv(unsigned int SlowBitWidth, unsigned int FastBitWidth)
Tells the code generator which bitwidths to bypass.
virtual unsigned getNumRegisters(LLVMContext &Context, EVT VT, std::optional< MVT > RegisterVT=std::nullopt) const
Return the number of registers that this ValueType will eventually require.
void setMaxAtomicSizeInBitsSupported(unsigned SizeInBits)
Set the maximum atomic operation size supported by the backend.
virtual TargetLoweringBase::LegalizeTypeAction getPreferredVectorAction(MVT VT) const
Return the preferred vector type legalization action.
unsigned MaxStoresPerMemsetOptSize
Likewise for functions with the OptSize attribute.
void setBooleanContents(BooleanContent Ty)
Specify how the target extends the result of integer and floating point boolean values from i1 to a w...
unsigned MaxStoresPerMemmove
Specify maximum number of store instructions per memmove call.
void computeRegisterProperties(const TargetRegisterInfo *TRI)
Once all of the register classes are added, this allows us to compute derived properties we expose.
unsigned MaxStoresPerMemmoveOptSize
Likewise for functions with the OptSize attribute.
void addRegisterClass(MVT VT, const TargetRegisterClass *RC)
Add the specified register class as an available regclass for the specified value type.
bool isTypeLegal(EVT VT) const
Return true if the target has native support for the specified value type.
virtual MVT getPointerTy(const DataLayout &DL, uint32_t AS=0) const
Return the pointer type for the given address space, defaults to the pointer type from the data layou...
bool isOperationLegal(unsigned Op, EVT VT) const
Return true if the specified operation is legal on this target.
unsigned MaxStoresPerMemset
Specify maximum number of store instructions per memset call.
void setTruncStoreAction(MVT ValVT, MVT MemVT, LegalizeAction Action)
Indicate that the specified truncating store does not work with the specified type and indicate what ...
void setMinCmpXchgSizeInBits(unsigned SizeInBits)
Sets the minimum cmpxchg or ll/sc size supported by the backend.
void AddPromotedToType(unsigned Opc, MVT OrigVT, MVT DestVT)
If Opc/OrigVT is specified as being promoted, the promotion code defaults to trying a larger integer/...
AtomicExpansionKind
Enum that specifies what an atomic load/AtomicRMWInst is expanded to, if at all.
void setCondCodeAction(ArrayRef< ISD::CondCode > CCs, MVT VT, LegalizeAction Action)
Indicate that the specified condition code is or isn't supported on the target and indicate what to d...
void setTargetDAGCombine(ArrayRef< ISD::NodeType > NTs)
Targets should invoke this method for each target independent node that they want to provide a custom...
Align getMinStackArgumentAlignment() const
Return the minimum stack alignment of an argument.
void setLoadExtAction(unsigned ExtType, MVT ValVT, MVT MemVT, LegalizeAction Action)
Indicate that the specified load with extension does not work with the specified type and indicate wh...
std::vector< ArgListEntry > ArgListTy
virtual Instruction * emitTrailingFence(IRBuilderBase &Builder, Instruction *Inst, AtomicOrdering Ord) const
virtual Instruction * emitLeadingFence(IRBuilderBase &Builder, Instruction *Inst, AtomicOrdering Ord) const
Inserts in the IR a target-specific intrinsic specifying a fence.
unsigned MaxStoresPerMemcpy
Specify maximum number of store instructions per memcpy call.
void setSchedulingPreference(Sched::Preference Pref)
Specify the target scheduling preference.
MVT getRegisterType(MVT VT) const
Return the type of registers that this ValueType will eventually require.
void setJumpIsExpensive(bool isExpensive=true)
Tells the code generator not to expand logic operations on comparison predicates into separate sequen...
LegalizeAction getOperationAction(unsigned Op, EVT VT) const
Return how this operation should be treated: either it is legal, needs to be promoted to a larger siz...
This class defines information used to lower LLVM code to legal SelectionDAG operators that the targe...
SDValue SimplifyMultipleUseDemandedBits(SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts, SelectionDAG &DAG, unsigned Depth=0) const
More limited version of SimplifyDemandedBits that can be used to "lookthrough" ops that don't contrib...
virtual ConstraintType getConstraintType(StringRef Constraint) const
Given a constraint, return the type of constraint it is for this target.
virtual std::pair< unsigned, const TargetRegisterClass * > getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, StringRef Constraint, MVT VT) const
Given a physical register constraint (e.g.
TargetLowering(const TargetLowering &)=delete
SDValue expandRoundInexactToOdd(EVT ResultVT, SDValue Op, const SDLoc &DL, SelectionDAG &DAG) const
Truncate Op to ResultVT.
SDValue expandFP_ROUND(SDNode *Node, SelectionDAG &DAG) const
Expand round(fp) to fp conversion.
virtual void LowerAsmOperandForConstraint(SDValue Op, StringRef Constraint, std::vector< SDValue > &Ops, SelectionDAG &DAG) const
Lower the specified operand into the Ops vector.
Primary interface to the complete machine description for the target machine.
CodeGenOptLevel getOptLevel() const
Returns the optimization level: None, Less, Default, or Aggressive.
TargetOptions Options
MCSymbol * getSymbol(const GlobalValue *GV) const
FPOpFusion::FPOpFusionMode AllowFPOpFusion
AllowFPOpFusion - This flag is set by the -fp-contract=xxx option.
TargetRegisterInfo base class - We assume that the target defines a static array of TargetRegisterDes...
virtual const TargetFrameLowering * getFrameLowering() const
static constexpr TypeSize getFixed(ScalarTy ExactSize)
Definition TypeSize.h:343
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:45
LLVM_ABI TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
Definition Type.cpp:197
bool isFloatingPointTy() const
Return true if this is one of the floating-point types.
Definition Type.h:184
bool isIntegerTy() const
True if this is an instance of IntegerType.
Definition Type.h:240
bool isVoidTy() const
Return true if this is 'void'.
Definition Type.h:139
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:256
A raw_ostream that writes to an std::string.
CallInst * Call
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
LLVM_ABI APInt pow(const APInt &X, int64_t N)
Compute X^N for N>=0.
Definition APInt.cpp:3155
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
NodeType
ISD::NodeType enum - This enum defines the target-independent operators for a SelectionDAG.
Definition ISDOpcodes.h:41
@ SETCC
SetCC operator - This evaluates to a true value iff the condition is true.
Definition ISDOpcodes.h:807
@ CTLZ_ZERO_UNDEF
Definition ISDOpcodes.h:780
@ POISON
POISON - A poison node.
Definition ISDOpcodes.h:231
@ SMUL_LOHI
SMUL_LOHI/UMUL_LOHI - Multiply two integers of type iN, producing a signed/unsigned value of type i[2...
Definition ISDOpcodes.h:270
@ BSWAP
Byte Swap and Counting operators.
Definition ISDOpcodes.h:771
@ ADDC
Carry-setting nodes for multiple precision addition and subtraction.
Definition ISDOpcodes.h:289
@ ADD
Simple integer binary arithmetic operators.
Definition ISDOpcodes.h:259
@ ANY_EXTEND
ANY_EXTEND - Used for integer types. The high bits are undefined.
Definition ISDOpcodes.h:841
@ FMA
FMA - Perform a * b + c with no intermediate rounding step.
Definition ISDOpcodes.h:511
@ INTRINSIC_VOID
OUTCHAIN = INTRINSIC_VOID(INCHAIN, INTRINSICID, arg1, arg2, ...) This node represents a target intrin...
Definition ISDOpcodes.h:215
@ SINT_TO_FP
[SU]INT_TO_FP - These operators convert integers (whose interpreted sign depends on the first letter)...
Definition ISDOpcodes.h:868
@ CONCAT_VECTORS
CONCAT_VECTORS(VECTOR0, VECTOR1, ...) - Given a number of values of vector type with the same length ...
Definition ISDOpcodes.h:577
@ FADD
Simple binary floating point operators.
Definition ISDOpcodes.h:410
@ ABS
ABS - Determine the unsigned absolute value of a signed integer value of the same bitwidth.
Definition ISDOpcodes.h:744
@ SDIVREM
SDIVREM/UDIVREM - Divide two integers and produce both a quotient and remainder result.
Definition ISDOpcodes.h:275
@ BUILD_PAIR
BUILD_PAIR - This is the opposite of EXTRACT_ELEMENT in some ways.
Definition ISDOpcodes.h:249
@ SIGN_EXTEND
Conversion operators.
Definition ISDOpcodes.h:832
@ SSUBO
Same for subtraction.
Definition ISDOpcodes.h:347
@ SSUBSAT
RESULT = [US]SUBSAT(LHS, RHS) - Perform saturation subtraction on 2 integers with the same bit width ...
Definition ISDOpcodes.h:369
@ SELECT
Select(COND, TRUEVAL, FALSEVAL).
Definition ISDOpcodes.h:784
@ UNDEF
UNDEF - An undefined node.
Definition ISDOpcodes.h:228
@ EXTRACT_ELEMENT
EXTRACT_ELEMENT - This is used to get the lower or upper (determined by a Constant,...
Definition ISDOpcodes.h:242
@ CopyFromReg
CopyFromReg - This node indicates that the input value is a virtual or physical register that is defi...
Definition ISDOpcodes.h:225
@ SADDO
RESULT, BOOL = [SU]ADDO(LHS, RHS) - Overflow-aware nodes for addition.
Definition ISDOpcodes.h:343
@ MULHU
MULHU/MULHS - Multiply high - Multiply two integers of type iN, producing an unsigned/signed value of...
Definition ISDOpcodes.h:701
@ SHL
Shift and rotation operations.
Definition ISDOpcodes.h:762
@ VECTOR_SHUFFLE
VECTOR_SHUFFLE(VEC1, VEC2) - Returns a vector, of the same type as VEC1/VEC2.
Definition ISDOpcodes.h:642
@ EXTRACT_SUBVECTOR
EXTRACT_SUBVECTOR(VECTOR, IDX) - Returns a subvector from VECTOR.
Definition ISDOpcodes.h:607
@ EXTRACT_VECTOR_ELT
EXTRACT_VECTOR_ELT(VECTOR, IDX) - Returns a single element from VECTOR identified by the (potentially...
Definition ISDOpcodes.h:569
@ CopyToReg
CopyToReg - This node has three operands: a chain, a register number to set to this value,...
Definition ISDOpcodes.h:219
@ ZERO_EXTEND
ZERO_EXTEND - Used for integer types, zeroing the new bits.
Definition ISDOpcodes.h:838
@ SELECT_CC
Select with condition operator - This selects between a true value and a false value (ops #2 and #3) ...
Definition ISDOpcodes.h:799
@ SSHLSAT
RESULT = [US]SHLSAT(LHS, RHS) - Perform saturation left shift.
Definition ISDOpcodes.h:379
@ SMULO
Same for multiplication.
Definition ISDOpcodes.h:351
@ SIGN_EXTEND_INREG
SIGN_EXTEND_INREG - This operator atomically performs a SHL/SRA pair to sign extend a small value in ...
Definition ISDOpcodes.h:876
@ SMIN
[US]{MIN/MAX} - Binary minimum or maximum of signed or unsigned integers.
Definition ISDOpcodes.h:724
@ VSELECT
Select with a vector condition (op #0) and two vector operands (ops #1 and #2), returning a vector re...
Definition ISDOpcodes.h:793
@ UADDO_CARRY
Carry-using nodes for multiple precision addition and subtraction.
Definition ISDOpcodes.h:323
@ FRAMEADDR
FRAMEADDR, RETURNADDR - These nodes represent llvm.frameaddress and llvm.returnaddress on the DAG.
Definition ISDOpcodes.h:110
@ STRICT_FP_TO_UINT
Definition ISDOpcodes.h:471
@ STRICT_FP_TO_SINT
STRICT_FP_TO_[US]INT - Convert a floating point value to a signed or unsigned integer.
Definition ISDOpcodes.h:470
@ FP_TO_SINT
FP_TO_[US]INT - Convert a floating point value to a signed or unsigned integer.
Definition ISDOpcodes.h:914
@ AND
Bitwise operators - logical and, logical or, logical xor.
Definition ISDOpcodes.h:736
@ INTRINSIC_WO_CHAIN
RESULT = INTRINSIC_WO_CHAIN(INTRINSICID, arg1, arg2, ...) This node represents a target intrinsic fun...
Definition ISDOpcodes.h:200
@ ADDE
Carry-using nodes for multiple precision addition and subtraction.
Definition ISDOpcodes.h:299
@ FREEZE
FREEZE - FREEZE(VAL) returns an arbitrary value if VAL is UNDEF (or is evaluated to UNDEF),...
Definition ISDOpcodes.h:236
@ INSERT_VECTOR_ELT
INSERT_VECTOR_ELT(VECTOR, VAL, IDX) - Returns VECTOR with the element at IDX replaced with VAL.
Definition ISDOpcodes.h:558
@ FP_ROUND
X = FP_ROUND(Y, TRUNC) - Rounding 'Y' from a larger floating point type down to the precision of the ...
Definition ISDOpcodes.h:947
@ TRUNCATE
TRUNCATE - Completely drop the high bits.
Definition ISDOpcodes.h:844
@ SHL_PARTS
SHL_PARTS/SRA_PARTS/SRL_PARTS - These operators are used for expanded integer shift operations.
Definition ISDOpcodes.h:821
@ FCOPYSIGN
FCOPYSIGN(X, Y) - Return the value of X with the sign of Y.
Definition ISDOpcodes.h:527
@ SADDSAT
RESULT = [US]ADDSAT(LHS, RHS) - Perform saturation addition on 2 integers with the same bit width (W)...
Definition ISDOpcodes.h:360
@ SADDO_CARRY
Carry-using overflow-aware nodes for multiple precision addition and subtraction.
Definition ISDOpcodes.h:333
@ INTRINSIC_W_CHAIN
RESULT,OUTCHAIN = INTRINSIC_W_CHAIN(INCHAIN, INTRINSICID, arg1, ...) This node represents a target in...
Definition ISDOpcodes.h:208
@ BUILD_VECTOR
BUILD_VECTOR(ELT0, ELT1, ELT2, ELT3,...) - Return a fixed-width vector with the specified,...
Definition ISDOpcodes.h:549
LLVM_ABI bool allOperandsUndef(const SDNode *N)
Return true if the node has at least one operand and all operands of the specified node are ISD::UNDE...
This namespace contains an enum with a value for every intrinsic/builtin function known by LLVM.
@ Bitcast
Perform the operation on a different, but equivalently sized type.
@ ATOMIC_CMP_SWAP_B128
These nodes are used to lower atomic instructions with i128 type.
bool isPackedVectorTy(EVT VT)
DivPrecisionLevel
Definition NVPTX.h:257
ValuesClass values(OptsTy... Options)
Helper to build a ValuesClass by forwarding a variable number of arguments as an initializer list to ...
initializer< Ty > init(const Ty &Val)
NodeAddr< NodeBase * > Node
Definition RDFGraph.h:381
This is an optimization pass for GlobalISel generic memory operations.
@ Low
Lower the current thread's priority such that it does not affect foreground tasks significantly.
Definition Threading.h:280
@ Offset
Definition DWP.cpp:532
detail::zippy< detail::zip_shortest, T, U, Args... > zip(T &&t, U &&u, Args &&...args)
zip iterator for two or more iteratable types.
Definition STLExtras.h:829
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
bool shouldEmitPTXNoReturn(const Value *V, const TargetMachine &TM)
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1737
MaybeAlign getAlign(const CallInst &I, unsigned Index)
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
Definition STLExtras.h:1667
void ComputeValueVTs(const TargetLowering &TLI, const DataLayout &DL, Type *Ty, SmallVectorImpl< EVT > &ValueVTs, SmallVectorImpl< EVT > *MemVTs=nullptr, SmallVectorImpl< TypeSize > *Offsets=nullptr, TypeSize StartingOffset=TypeSize::getZero())
ComputeValueVTs - Given an LLVM IR type, compute a sequence of EVTs that represent all the individual...
Definition Analysis.cpp:119
auto enumerate(FirstRange &&First, RestRanges &&...Rest)
Given two or more input ranges, returns a new range whose values are tuples (A, B,...
Definition STLExtras.h:2484
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
uint64_t PowerOf2Ceil(uint64_t A)
Returns the power of two which is greater than or equal to the given value.
Definition MathExtras.h:385
bool isReleaseOrStronger(AtomicOrdering AO)
OutputIt transform(R &&Range, OutputIt d_first, UnaryFunction F)
Wrapper function around std::transform to apply a function to a range and store the result elsewhere.
Definition STLExtras.h:1980
auto reverse(ContainerTy &&C)
Definition STLExtras.h:406
unsigned promoteScalarArgumentSize(unsigned size)
LLVM_ABI void report_fatal_error(Error Err, bool gen_crash_diag=true)
Definition Error.cpp:167
bool shouldPassAsArray(Type *Ty)
CodeGenOptLevel
Code generation optimization level.
Definition CodeGen.h:82
@ Default
-O2, -Os, -Oz
Definition CodeGen.h:85
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:547
AtomicOrdering
Atomic ordering for LLVM's memory model.
@ Sub
Subtraction of integers.
@ Add
Sum of integers.
uint64_t alignTo(uint64_t Size, Align A)
Returns a multiple of A needed to store Size bytes.
Definition Alignment.h:144
DWARFExpression::Operation Op
ArrayRef(const T &OneElt) -> ArrayRef< T >
bool isAcquireOrStronger(AtomicOrdering AO)
constexpr unsigned BitWidth
bool isKernelFunction(const Function &F)
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
Function * getMaybeBitcastedCallee(const CallBase *CB)
Align commonAlignment(Align A, uint64_t Offset)
Returns the alignment that satisfies both alignments.
Definition Alignment.h:201
auto seq(T Begin, T End)
Iterate over an integral type from Begin up to - but not including - End.
Definition Sequence.h:305
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition BitVector.h:872
#define N
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition Alignment.h:39
constexpr uint64_t value() const
This is a hole in the type system and should not be abused.
Definition Alignment.h:77
@ PreserveSign
The sign of a flushed-to-zero number is preserved in the sign of 0.
DenormalModeKind Output
Denormal flushing mode for floating point instruction results in the default floating point environme...
Extended Value Type.
Definition ValueTypes.h:35
TypeSize getStoreSize() const
Return the number of bytes overwritten by a store of the specified value type.
Definition ValueTypes.h:395
bool isSimple() const
Test if the given EVT is simple (as opposed to being extended).
Definition ValueTypes.h:137
static EVT getVectorVT(LLVMContext &Context, EVT VT, unsigned NumElements, bool IsScalable=false)
Returns the EVT that represents a vector NumElements in length, where each element is of type VT.
Definition ValueTypes.h:74
EVT changeTypeToInteger() const
Return the type converted to an equivalently sized integer or vector with integer element type.
Definition ValueTypes.h:121
bool bitsGT(EVT VT) const
Return true if this has more bits than VT.
Definition ValueTypes.h:284
bool bitsLT(EVT VT) const
Return true if this has less bits than VT.
Definition ValueTypes.h:300
ElementCount getVectorElementCount() const
Definition ValueTypes.h:350
bool is32BitVector() const
Return true if this is a 32-bit vector type.
Definition ValueTypes.h:197
TypeSize getSizeInBits() const
Return the size of the specified value type in bits.
Definition ValueTypes.h:373
uint64_t getScalarSizeInBits() const
Definition ValueTypes.h:385
MVT getSimpleVT() const
Return the SimpleValueType held in the specified simple EVT.
Definition ValueTypes.h:316
uint64_t getFixedSizeInBits() const
Return the size of the specified fixed width value type in bits.
Definition ValueTypes.h:381
bool isVector() const
Return true if this is a vector value type.
Definition ValueTypes.h:168
EVT getScalarType() const
If this is a vector type, return the element type, otherwise return this.
Definition ValueTypes.h:323
bool bitsEq(EVT VT) const
Return true if this has the same number of bits as VT.
Definition ValueTypes.h:256
LLVM_ABI Type * getTypeForEVT(LLVMContext &Context) const
This method returns an LLVM type corresponding to the specified EVT.
EVT getVectorElementType() const
Given a vector type, return the type of each element.
Definition ValueTypes.h:328
bool isScalarInteger() const
Return true if this is an integer, but not a vector.
Definition ValueTypes.h:157
EVT changeVectorElementType(EVT EltVT) const
Return a VT for a vector type whose attributes match ourselves with the exception of the element type...
Definition ValueTypes.h:102
unsigned getVectorNumElements() const
Given a vector type, return the number of elements it contains.
Definition ValueTypes.h:336
bool isInteger() const
Return true if this is an integer or a vector integer type.
Definition ValueTypes.h:152
static LLVM_ABI KnownBits ashr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for ashr(LHS, RHS).
KnownBits concat(const KnownBits &Lo) const
Concatenate the bits from Lo onto the bottom of *this.
Definition KnownBits.h:233
unsigned getBitWidth() const
Get the bit width of this value.
Definition KnownBits.h:44
void resetAll()
Resets the known state of all bits.
Definition KnownBits.h:74
void insertBits(const KnownBits &SubBits, unsigned BitPosition)
Insert the bits from a smaller known bits starting at bitPosition.
Definition KnownBits.h:219
This class contains a discriminated union of information about pointers in memory operands,...
This struct is a compact representation of a valid (power of two) or undefined (0) alignment.
Definition Alignment.h:106
These are IR-level optimization flags that may be propagated to SDNodes.
bool hasAllowContract() const
This represents a list of ValueType's that has been intern'd by a SelectionDAG.
This represents an addressing mode of: BaseGV + BaseOffs + BaseReg + Scale*ScaleReg + ScalableOffset*...
This structure contains all information that is necessary for lowering calls.
SmallVector< ISD::InputArg, 32 > Ins
SmallVector< ISD::OutputArg, 32 > Outs
Type * RetTy
Same as OrigRetTy, or partially legalized for soft float libcalls.
A convenience struct that encapsulates a DAG, and two SDValues for returning information from TargetL...