LLVM 20.0.0git
AArch64TargetTransformInfo.cpp
Go to the documentation of this file.
1//===-- AArch64TargetTransformInfo.cpp - AArch64 specific TTI -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "AArch64ExpandImm.h"
14#include "llvm/ADT/DenseMap.h"
22#include "llvm/IR/Intrinsics.h"
23#include "llvm/IR/IntrinsicsAArch64.h"
25#include "llvm/Support/Debug.h"
28#include <algorithm>
29#include <optional>
30using namespace llvm;
31using namespace llvm::PatternMatch;
32
33#define DEBUG_TYPE "aarch64tti"
34
35static cl::opt<bool> EnableFalkorHWPFUnrollFix("enable-falkor-hwpf-unroll-fix",
36 cl::init(true), cl::Hidden);
37
38static cl::opt<unsigned> SVEGatherOverhead("sve-gather-overhead", cl::init(10),
40
41static cl::opt<unsigned> SVEScatterOverhead("sve-scatter-overhead",
42 cl::init(10), cl::Hidden);
43
44static cl::opt<unsigned> SVETailFoldInsnThreshold("sve-tail-folding-insn-threshold",
45 cl::init(15), cl::Hidden);
46
48 NeonNonConstStrideOverhead("neon-nonconst-stride-overhead", cl::init(10),
50
52 "call-penalty-sm-change", cl::init(5), cl::Hidden,
54 "Penalty of calling a function that requires a change to PSTATE.SM"));
55
57 "inline-call-penalty-sm-change", cl::init(10), cl::Hidden,
58 cl::desc("Penalty of inlining a call that requires a change to PSTATE.SM"));
59
60static cl::opt<bool> EnableOrLikeSelectOpt("enable-aarch64-or-like-select",
61 cl::init(true), cl::Hidden);
62
63static cl::opt<bool> EnableLSRCostOpt("enable-aarch64-lsr-cost-opt",
64 cl::init(true), cl::Hidden);
65
66// A complete guess as to a reasonable cost.
68 BaseHistCntCost("aarch64-base-histcnt-cost", cl::init(8), cl::Hidden,
69 cl::desc("The cost of a histcnt instruction"));
70
72 "dmb-lookahead-threshold", cl::init(10), cl::Hidden,
73 cl::desc("The number of instructions to search for a redundant dmb"));
74
75namespace {
76class TailFoldingOption {
77 // These bitfields will only ever be set to something non-zero in operator=,
78 // when setting the -sve-tail-folding option. This option should always be of
79 // the form (default|simple|all|disable)[+(Flag1|Flag2|etc)], where here
80 // InitialBits is one of (disabled|all|simple). EnableBits represents
81 // additional flags we're enabling, and DisableBits for those flags we're
82 // disabling. The default flag is tracked in the variable NeedsDefault, since
83 // at the time of setting the option we may not know what the default value
84 // for the CPU is.
85 TailFoldingOpts InitialBits = TailFoldingOpts::Disabled;
86 TailFoldingOpts EnableBits = TailFoldingOpts::Disabled;
87 TailFoldingOpts DisableBits = TailFoldingOpts::Disabled;
88
89 // This value needs to be initialised to true in case the user does not
90 // explicitly set the -sve-tail-folding option.
91 bool NeedsDefault = true;
92
93 void setInitialBits(TailFoldingOpts Bits) { InitialBits = Bits; }
94
95 void setNeedsDefault(bool V) { NeedsDefault = V; }
96
97 void setEnableBit(TailFoldingOpts Bit) {
98 EnableBits |= Bit;
99 DisableBits &= ~Bit;
100 }
101
102 void setDisableBit(TailFoldingOpts Bit) {
103 EnableBits &= ~Bit;
104 DisableBits |= Bit;
105 }
106
107 TailFoldingOpts getBits(TailFoldingOpts DefaultBits) const {
108 TailFoldingOpts Bits = TailFoldingOpts::Disabled;
109
110 assert((InitialBits == TailFoldingOpts::Disabled || !NeedsDefault) &&
111 "Initial bits should only include one of "
112 "(disabled|all|simple|default)");
113 Bits = NeedsDefault ? DefaultBits : InitialBits;
114 Bits |= EnableBits;
115 Bits &= ~DisableBits;
116
117 return Bits;
118 }
119
120 void reportError(std::string Opt) {
121 errs() << "invalid argument '" << Opt
122 << "' to -sve-tail-folding=; the option should be of the form\n"
123 " (disabled|all|default|simple)[+(reductions|recurrences"
124 "|reverse|noreductions|norecurrences|noreverse)]\n";
125 report_fatal_error("Unrecognised tail-folding option");
126 }
127
128public:
129
130 void operator=(const std::string &Val) {
131 // If the user explicitly sets -sve-tail-folding= then treat as an error.
132 if (Val.empty()) {
133 reportError("");
134 return;
135 }
136
137 // Since the user is explicitly setting the option we don't automatically
138 // need the default unless they require it.
139 setNeedsDefault(false);
140
141 SmallVector<StringRef, 4> TailFoldTypes;
142 StringRef(Val).split(TailFoldTypes, '+', -1, false);
143
144 unsigned StartIdx = 1;
145 if (TailFoldTypes[0] == "disabled")
146 setInitialBits(TailFoldingOpts::Disabled);
147 else if (TailFoldTypes[0] == "all")
148 setInitialBits(TailFoldingOpts::All);
149 else if (TailFoldTypes[0] == "default")
150 setNeedsDefault(true);
151 else if (TailFoldTypes[0] == "simple")
152 setInitialBits(TailFoldingOpts::Simple);
153 else {
154 StartIdx = 0;
155 setInitialBits(TailFoldingOpts::Disabled);
156 }
157
158 for (unsigned I = StartIdx; I < TailFoldTypes.size(); I++) {
159 if (TailFoldTypes[I] == "reductions")
160 setEnableBit(TailFoldingOpts::Reductions);
161 else if (TailFoldTypes[I] == "recurrences")
162 setEnableBit(TailFoldingOpts::Recurrences);
163 else if (TailFoldTypes[I] == "reverse")
164 setEnableBit(TailFoldingOpts::Reverse);
165 else if (TailFoldTypes[I] == "noreductions")
166 setDisableBit(TailFoldingOpts::Reductions);
167 else if (TailFoldTypes[I] == "norecurrences")
168 setDisableBit(TailFoldingOpts::Recurrences);
169 else if (TailFoldTypes[I] == "noreverse")
170 setDisableBit(TailFoldingOpts::Reverse);
171 else
172 reportError(Val);
173 }
174 }
175
176 bool satisfies(TailFoldingOpts DefaultBits, TailFoldingOpts Required) const {
177 return (getBits(DefaultBits) & Required) == Required;
178 }
179};
180} // namespace
181
182TailFoldingOption TailFoldingOptionLoc;
183
185 "sve-tail-folding",
186 cl::desc(
187 "Control the use of vectorisation using tail-folding for SVE where the"
188 " option is specified in the form (Initial)[+(Flag1|Flag2|...)]:"
189 "\ndisabled (Initial) No loop types will vectorize using "
190 "tail-folding"
191 "\ndefault (Initial) Uses the default tail-folding settings for "
192 "the target CPU"
193 "\nall (Initial) All legal loop types will vectorize using "
194 "tail-folding"
195 "\nsimple (Initial) Use tail-folding for simple loops (not "
196 "reductions or recurrences)"
197 "\nreductions Use tail-folding for loops containing reductions"
198 "\nnoreductions Inverse of above"
199 "\nrecurrences Use tail-folding for loops containing fixed order "
200 "recurrences"
201 "\nnorecurrences Inverse of above"
202 "\nreverse Use tail-folding for loops requiring reversed "
203 "predicates"
204 "\nnoreverse Inverse of above"),
206
207// Experimental option that will only be fully functional when the
208// code-generator is changed to use SVE instead of NEON for all fixed-width
209// operations.
211 "enable-fixedwidth-autovec-in-streaming-mode", cl::init(false), cl::Hidden);
212
213// Experimental option that will only be fully functional when the cost-model
214// and code-generator have been changed to avoid using scalable vector
215// instructions that are not legal in streaming SVE mode.
217 "enable-scalable-autovec-in-streaming-mode", cl::init(false), cl::Hidden);
218
219static bool isSMEABIRoutineCall(const CallInst &CI) {
220 const auto *F = CI.getCalledFunction();
221 return F && StringSwitch<bool>(F->getName())
222 .Case("__arm_sme_state", true)
223 .Case("__arm_tpidr2_save", true)
224 .Case("__arm_tpidr2_restore", true)
225 .Case("__arm_za_disable", true)
226 .Default(false);
227}
228
229/// Returns true if the function has explicit operations that can only be
230/// lowered using incompatible instructions for the selected mode. This also
231/// returns true if the function F may use or modify ZA state.
233 for (const BasicBlock &BB : *F) {
234 for (const Instruction &I : BB) {
235 // Be conservative for now and assume that any call to inline asm or to
236 // intrinsics could could result in non-streaming ops (e.g. calls to
237 // @llvm.aarch64.* or @llvm.gather/scatter intrinsics). We can assume that
238 // all native LLVM instructions can be lowered to compatible instructions.
239 if (isa<CallInst>(I) && !I.isDebugOrPseudoInst() &&
240 (cast<CallInst>(I).isInlineAsm() || isa<IntrinsicInst>(I) ||
241 isSMEABIRoutineCall(cast<CallInst>(I))))
242 return true;
243 }
244 }
245 return false;
246}
247
249 const Function *Callee) const {
250 SMEAttrs CallerAttrs(*Caller), CalleeAttrs(*Callee);
251
252 // When inlining, we should consider the body of the function, not the
253 // interface.
254 if (CalleeAttrs.hasStreamingBody()) {
255 CalleeAttrs.set(SMEAttrs::SM_Compatible, false);
256 CalleeAttrs.set(SMEAttrs::SM_Enabled, true);
257 }
258
259 if (CalleeAttrs.isNewZA())
260 return false;
261
262 if (CallerAttrs.requiresLazySave(CalleeAttrs) ||
263 CallerAttrs.requiresSMChange(CalleeAttrs) ||
264 CallerAttrs.requiresPreservingZT0(CalleeAttrs) ||
265 CallerAttrs.requiresPreservingAllZAState(CalleeAttrs)) {
266 if (hasPossibleIncompatibleOps(Callee))
267 return false;
268 }
269
270 return BaseT::areInlineCompatible(Caller, Callee);
271}
272
274 const Function *Caller, const Function *Callee,
275 const ArrayRef<Type *> &Types) const {
276 if (!BaseT::areTypesABICompatible(Caller, Callee, Types))
277 return false;
278
279 // We need to ensure that argument promotion does not attempt to promote
280 // pointers to fixed-length vector types larger than 128 bits like
281 // <8 x float> (and pointers to aggregate types which have such fixed-length
282 // vector type members) into the values of the pointees. Such vector types
283 // are used for SVE VLS but there is no ABI for SVE VLS arguments and the
284 // backend cannot lower such value arguments. The 128-bit fixed-length SVE
285 // types can be safely treated as 128-bit NEON types and they cannot be
286 // distinguished in IR.
287 if (ST->useSVEForFixedLengthVectors() && llvm::any_of(Types, [](Type *Ty) {
288 auto FVTy = dyn_cast<FixedVectorType>(Ty);
289 return FVTy &&
290 FVTy->getScalarSizeInBits() * FVTy->getNumElements() > 128;
291 }))
292 return false;
293
294 return true;
295}
296
297unsigned
299 unsigned DefaultCallPenalty) const {
300 // This function calculates a penalty for executing Call in F.
301 //
302 // There are two ways this function can be called:
303 // (1) F:
304 // call from F -> G (the call here is Call)
305 //
306 // For (1), Call.getCaller() == F, so it will always return a high cost if
307 // a streaming-mode change is required (thus promoting the need to inline the
308 // function)
309 //
310 // (2) F:
311 // call from F -> G (the call here is not Call)
312 // G:
313 // call from G -> H (the call here is Call)
314 //
315 // For (2), if after inlining the body of G into F the call to H requires a
316 // streaming-mode change, and the call to G from F would also require a
317 // streaming-mode change, then there is benefit to do the streaming-mode
318 // change only once and avoid inlining of G into F.
319 SMEAttrs FAttrs(*F);
320 SMEAttrs CalleeAttrs(Call);
321 if (FAttrs.requiresSMChange(CalleeAttrs)) {
322 if (F == Call.getCaller()) // (1)
323 return CallPenaltyChangeSM * DefaultCallPenalty;
324 if (FAttrs.requiresSMChange(SMEAttrs(*Call.getCaller()))) // (2)
325 return InlineCallPenaltyChangeSM * DefaultCallPenalty;
326 }
327
328 return DefaultCallPenalty;
329}
330
335 ST->isNeonAvailable());
336}
337
338/// Calculate the cost of materializing a 64-bit value. This helper
339/// method might only calculate a fraction of a larger immediate. Therefore it
340/// is valid to return a cost of ZERO.
342 // Check if the immediate can be encoded within an instruction.
343 if (Val == 0 || AArch64_AM::isLogicalImmediate(Val, 64))
344 return 0;
345
346 if (Val < 0)
347 Val = ~Val;
348
349 // Calculate how many moves we will need to materialize this constant.
352 return Insn.size();
353}
354
355/// Calculate the cost of materializing the given constant.
358 assert(Ty->isIntegerTy());
359
360 unsigned BitSize = Ty->getPrimitiveSizeInBits();
361 if (BitSize == 0)
362 return ~0U;
363
364 // Sign-extend all constants to a multiple of 64-bit.
365 APInt ImmVal = Imm;
366 if (BitSize & 0x3f)
367 ImmVal = Imm.sext((BitSize + 63) & ~0x3fU);
368
369 // Split the constant into 64-bit chunks and calculate the cost for each
370 // chunk.
372 for (unsigned ShiftVal = 0; ShiftVal < BitSize; ShiftVal += 64) {
373 APInt Tmp = ImmVal.ashr(ShiftVal).sextOrTrunc(64);
374 int64_t Val = Tmp.getSExtValue();
375 Cost += getIntImmCost(Val);
376 }
377 // We need at least one instruction to materialze the constant.
378 return std::max<InstructionCost>(1, Cost);
379}
380
382 const APInt &Imm, Type *Ty,
384 Instruction *Inst) {
385 assert(Ty->isIntegerTy());
386
387 unsigned BitSize = Ty->getPrimitiveSizeInBits();
388 // There is no cost model for constants with a bit size of 0. Return TCC_Free
389 // here, so that constant hoisting will ignore this constant.
390 if (BitSize == 0)
391 return TTI::TCC_Free;
392
393 unsigned ImmIdx = ~0U;
394 switch (Opcode) {
395 default:
396 return TTI::TCC_Free;
397 case Instruction::GetElementPtr:
398 // Always hoist the base address of a GetElementPtr.
399 if (Idx == 0)
400 return 2 * TTI::TCC_Basic;
401 return TTI::TCC_Free;
402 case Instruction::Store:
403 ImmIdx = 0;
404 break;
405 case Instruction::Add:
406 case Instruction::Sub:
407 case Instruction::Mul:
408 case Instruction::UDiv:
409 case Instruction::SDiv:
410 case Instruction::URem:
411 case Instruction::SRem:
412 case Instruction::And:
413 case Instruction::Or:
414 case Instruction::Xor:
415 case Instruction::ICmp:
416 ImmIdx = 1;
417 break;
418 // Always return TCC_Free for the shift value of a shift instruction.
419 case Instruction::Shl:
420 case Instruction::LShr:
421 case Instruction::AShr:
422 if (Idx == 1)
423 return TTI::TCC_Free;
424 break;
425 case Instruction::Trunc:
426 case Instruction::ZExt:
427 case Instruction::SExt:
428 case Instruction::IntToPtr:
429 case Instruction::PtrToInt:
430 case Instruction::BitCast:
431 case Instruction::PHI:
432 case Instruction::Call:
433 case Instruction::Select:
434 case Instruction::Ret:
435 case Instruction::Load:
436 break;
437 }
438
439 if (Idx == ImmIdx) {
440 int NumConstants = (BitSize + 63) / 64;
442 return (Cost <= NumConstants * TTI::TCC_Basic)
443 ? static_cast<int>(TTI::TCC_Free)
444 : Cost;
445 }
447}
448
451 const APInt &Imm, Type *Ty,
453 assert(Ty->isIntegerTy());
454
455 unsigned BitSize = Ty->getPrimitiveSizeInBits();
456 // There is no cost model for constants with a bit size of 0. Return TCC_Free
457 // here, so that constant hoisting will ignore this constant.
458 if (BitSize == 0)
459 return TTI::TCC_Free;
460
461 // Most (all?) AArch64 intrinsics do not support folding immediates into the
462 // selected instruction, so we compute the materialization cost for the
463 // immediate directly.
464 if (IID >= Intrinsic::aarch64_addg && IID <= Intrinsic::aarch64_udiv)
466
467 switch (IID) {
468 default:
469 return TTI::TCC_Free;
470 case Intrinsic::sadd_with_overflow:
471 case Intrinsic::uadd_with_overflow:
472 case Intrinsic::ssub_with_overflow:
473 case Intrinsic::usub_with_overflow:
474 case Intrinsic::smul_with_overflow:
475 case Intrinsic::umul_with_overflow:
476 if (Idx == 1) {
477 int NumConstants = (BitSize + 63) / 64;
479 return (Cost <= NumConstants * TTI::TCC_Basic)
480 ? static_cast<int>(TTI::TCC_Free)
481 : Cost;
482 }
483 break;
484 case Intrinsic::experimental_stackmap:
485 if ((Idx < 2) || (Imm.getBitWidth() <= 64 && isInt<64>(Imm.getSExtValue())))
486 return TTI::TCC_Free;
487 break;
488 case Intrinsic::experimental_patchpoint_void:
489 case Intrinsic::experimental_patchpoint:
490 if ((Idx < 4) || (Imm.getBitWidth() <= 64 && isInt<64>(Imm.getSExtValue())))
491 return TTI::TCC_Free;
492 break;
493 case Intrinsic::experimental_gc_statepoint:
494 if ((Idx < 5) || (Imm.getBitWidth() <= 64 && isInt<64>(Imm.getSExtValue())))
495 return TTI::TCC_Free;
496 break;
497 }
499}
500
503 assert(isPowerOf2_32(TyWidth) && "Ty width must be power of 2");
504 if (TyWidth == 32 || TyWidth == 64)
506 // TODO: AArch64TargetLowering::LowerCTPOP() supports 128bit popcount.
507 return TTI::PSK_Software;
508}
509
510static bool isUnpackedVectorVT(EVT VecVT) {
511 return VecVT.isScalableVector() &&
513}
514
516 Type *BucketPtrsTy = ICA.getArgTypes()[0]; // Type of vector of pointers
517 Type *EltTy = ICA.getArgTypes()[1]; // Type of bucket elements
518 unsigned TotalHistCnts = 1;
519
520 unsigned EltSize = EltTy->getScalarSizeInBits();
521 // Only allow (up to 64b) integers or pointers
522 if ((!EltTy->isIntegerTy() && !EltTy->isPointerTy()) || EltSize > 64)
524
525 // FIXME: We should be able to generate histcnt for fixed-length vectors
526 // using ptrue with a specific VL.
527 if (VectorType *VTy = dyn_cast<VectorType>(BucketPtrsTy)) {
528 unsigned EC = VTy->getElementCount().getKnownMinValue();
529 if (!isPowerOf2_64(EC) || !VTy->isScalableTy())
531
532 // HistCnt only supports 32b and 64b element types
533 unsigned LegalEltSize = EltSize <= 32 ? 32 : 64;
534
535 if (EC == 2 || (LegalEltSize == 32 && EC == 4))
537
538 unsigned NaturalVectorWidth = AArch64::SVEBitsPerBlock / LegalEltSize;
539 TotalHistCnts = EC / NaturalVectorWidth;
540 }
541
542 return InstructionCost(BaseHistCntCost * TotalHistCnts);
543}
544
548 // The code-generator is currently not able to handle scalable vectors
549 // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
550 // it. This change will be removed when code-generation for these types is
551 // sufficiently reliable.
552 auto *RetTy = ICA.getReturnType();
553 if (auto *VTy = dyn_cast<ScalableVectorType>(RetTy))
554 if (VTy->getElementCount() == ElementCount::getScalable(1))
556
557 switch (ICA.getID()) {
558 case Intrinsic::experimental_vector_histogram_add:
559 if (!ST->hasSVE2())
561 return getHistogramCost(ICA);
562 case Intrinsic::umin:
563 case Intrinsic::umax:
564 case Intrinsic::smin:
565 case Intrinsic::smax: {
566 static const auto ValidMinMaxTys = {MVT::v8i8, MVT::v16i8, MVT::v4i16,
567 MVT::v8i16, MVT::v2i32, MVT::v4i32,
568 MVT::nxv16i8, MVT::nxv8i16, MVT::nxv4i32,
569 MVT::nxv2i64};
571 // v2i64 types get converted to cmp+bif hence the cost of 2
572 if (LT.second == MVT::v2i64)
573 return LT.first * 2;
574 if (any_of(ValidMinMaxTys, [&LT](MVT M) { return M == LT.second; }))
575 return LT.first;
576 break;
577 }
578 case Intrinsic::sadd_sat:
579 case Intrinsic::ssub_sat:
580 case Intrinsic::uadd_sat:
581 case Intrinsic::usub_sat: {
582 static const auto ValidSatTys = {MVT::v8i8, MVT::v16i8, MVT::v4i16,
583 MVT::v8i16, MVT::v2i32, MVT::v4i32,
584 MVT::v2i64};
586 // This is a base cost of 1 for the vadd, plus 3 extract shifts if we
587 // need to extend the type, as it uses shr(qadd(shl, shl)).
588 unsigned Instrs =
589 LT.second.getScalarSizeInBits() == RetTy->getScalarSizeInBits() ? 1 : 4;
590 if (any_of(ValidSatTys, [&LT](MVT M) { return M == LT.second; }))
591 return LT.first * Instrs;
592 break;
593 }
594 case Intrinsic::abs: {
595 static const auto ValidAbsTys = {MVT::v8i8, MVT::v16i8, MVT::v4i16,
596 MVT::v8i16, MVT::v2i32, MVT::v4i32,
597 MVT::v2i64};
599 if (any_of(ValidAbsTys, [&LT](MVT M) { return M == LT.second; }))
600 return LT.first;
601 break;
602 }
603 case Intrinsic::bswap: {
604 static const auto ValidAbsTys = {MVT::v4i16, MVT::v8i16, MVT::v2i32,
605 MVT::v4i32, MVT::v2i64};
607 if (any_of(ValidAbsTys, [&LT](MVT M) { return M == LT.second; }) &&
608 LT.second.getScalarSizeInBits() == RetTy->getScalarSizeInBits())
609 return LT.first;
610 break;
611 }
612 case Intrinsic::stepvector: {
613 InstructionCost Cost = 1; // Cost of the `index' instruction
615 // Legalisation of illegal vectors involves an `index' instruction plus
616 // (LT.first - 1) vector adds.
617 if (LT.first > 1) {
618 Type *LegalVTy = EVT(LT.second).getTypeForEVT(RetTy->getContext());
619 InstructionCost AddCost =
620 getArithmeticInstrCost(Instruction::Add, LegalVTy, CostKind);
621 Cost += AddCost * (LT.first - 1);
622 }
623 return Cost;
624 }
625 case Intrinsic::vector_extract:
626 case Intrinsic::vector_insert: {
627 // If both the vector and subvector types are legal types and the index
628 // is 0, then this should be a no-op or simple operation; return a
629 // relatively low cost.
630
631 // If arguments aren't actually supplied, then we cannot determine the
632 // value of the index. We also want to skip predicate types.
633 if (ICA.getArgs().size() != ICA.getArgTypes().size() ||
635 break;
636
637 LLVMContext &C = RetTy->getContext();
638 EVT VecVT = getTLI()->getValueType(DL, ICA.getArgTypes()[0]);
639 bool IsExtract = ICA.getID() == Intrinsic::vector_extract;
640 EVT SubVecVT = IsExtract ? getTLI()->getValueType(DL, RetTy)
641 : getTLI()->getValueType(DL, ICA.getArgTypes()[1]);
642 // Skip this if either the vector or subvector types are unpacked
643 // SVE types; they may get lowered to stack stores and loads.
644 if (isUnpackedVectorVT(VecVT) || isUnpackedVectorVT(SubVecVT))
645 break;
646
648 getTLI()->getTypeConversion(C, SubVecVT);
650 getTLI()->getTypeConversion(C, VecVT);
651 const Value *Idx = IsExtract ? ICA.getArgs()[1] : ICA.getArgs()[2];
652 const ConstantInt *CIdx = cast<ConstantInt>(Idx);
653 if (SubVecLK.first == TargetLoweringBase::TypeLegal &&
654 VecLK.first == TargetLoweringBase::TypeLegal && CIdx->isZero())
655 return TTI::TCC_Free;
656 break;
657 }
658 case Intrinsic::bitreverse: {
659 static const CostTblEntry BitreverseTbl[] = {
660 {Intrinsic::bitreverse, MVT::i32, 1},
661 {Intrinsic::bitreverse, MVT::i64, 1},
662 {Intrinsic::bitreverse, MVT::v8i8, 1},
663 {Intrinsic::bitreverse, MVT::v16i8, 1},
664 {Intrinsic::bitreverse, MVT::v4i16, 2},
665 {Intrinsic::bitreverse, MVT::v8i16, 2},
666 {Intrinsic::bitreverse, MVT::v2i32, 2},
667 {Intrinsic::bitreverse, MVT::v4i32, 2},
668 {Intrinsic::bitreverse, MVT::v1i64, 2},
669 {Intrinsic::bitreverse, MVT::v2i64, 2},
670 };
671 const auto LegalisationCost = getTypeLegalizationCost(RetTy);
672 const auto *Entry =
673 CostTableLookup(BitreverseTbl, ICA.getID(), LegalisationCost.second);
674 if (Entry) {
675 // Cost Model is using the legal type(i32) that i8 and i16 will be
676 // converted to +1 so that we match the actual lowering cost
677 if (TLI->getValueType(DL, RetTy, true) == MVT::i8 ||
678 TLI->getValueType(DL, RetTy, true) == MVT::i16)
679 return LegalisationCost.first * Entry->Cost + 1;
680
681 return LegalisationCost.first * Entry->Cost;
682 }
683 break;
684 }
685 case Intrinsic::ctpop: {
686 if (!ST->hasNEON()) {
687 // 32-bit or 64-bit ctpop without NEON is 12 instructions.
688 return getTypeLegalizationCost(RetTy).first * 12;
689 }
690 static const CostTblEntry CtpopCostTbl[] = {
691 {ISD::CTPOP, MVT::v2i64, 4},
692 {ISD::CTPOP, MVT::v4i32, 3},
693 {ISD::CTPOP, MVT::v8i16, 2},
694 {ISD::CTPOP, MVT::v16i8, 1},
695 {ISD::CTPOP, MVT::i64, 4},
696 {ISD::CTPOP, MVT::v2i32, 3},
697 {ISD::CTPOP, MVT::v4i16, 2},
698 {ISD::CTPOP, MVT::v8i8, 1},
699 {ISD::CTPOP, MVT::i32, 5},
700 };
702 MVT MTy = LT.second;
703 if (const auto *Entry = CostTableLookup(CtpopCostTbl, ISD::CTPOP, MTy)) {
704 // Extra cost of +1 when illegal vector types are legalized by promoting
705 // the integer type.
706 int ExtraCost = MTy.isVector() && MTy.getScalarSizeInBits() !=
707 RetTy->getScalarSizeInBits()
708 ? 1
709 : 0;
710 return LT.first * Entry->Cost + ExtraCost;
711 }
712 break;
713 }
714 case Intrinsic::sadd_with_overflow:
715 case Intrinsic::uadd_with_overflow:
716 case Intrinsic::ssub_with_overflow:
717 case Intrinsic::usub_with_overflow:
718 case Intrinsic::smul_with_overflow:
719 case Intrinsic::umul_with_overflow: {
720 static const CostTblEntry WithOverflowCostTbl[] = {
721 {Intrinsic::sadd_with_overflow, MVT::i8, 3},
722 {Intrinsic::uadd_with_overflow, MVT::i8, 3},
723 {Intrinsic::sadd_with_overflow, MVT::i16, 3},
724 {Intrinsic::uadd_with_overflow, MVT::i16, 3},
725 {Intrinsic::sadd_with_overflow, MVT::i32, 1},
726 {Intrinsic::uadd_with_overflow, MVT::i32, 1},
727 {Intrinsic::sadd_with_overflow, MVT::i64, 1},
728 {Intrinsic::uadd_with_overflow, MVT::i64, 1},
729 {Intrinsic::ssub_with_overflow, MVT::i8, 3},
730 {Intrinsic::usub_with_overflow, MVT::i8, 3},
731 {Intrinsic::ssub_with_overflow, MVT::i16, 3},
732 {Intrinsic::usub_with_overflow, MVT::i16, 3},
733 {Intrinsic::ssub_with_overflow, MVT::i32, 1},
734 {Intrinsic::usub_with_overflow, MVT::i32, 1},
735 {Intrinsic::ssub_with_overflow, MVT::i64, 1},
736 {Intrinsic::usub_with_overflow, MVT::i64, 1},
737 {Intrinsic::smul_with_overflow, MVT::i8, 5},
738 {Intrinsic::umul_with_overflow, MVT::i8, 4},
739 {Intrinsic::smul_with_overflow, MVT::i16, 5},
740 {Intrinsic::umul_with_overflow, MVT::i16, 4},
741 {Intrinsic::smul_with_overflow, MVT::i32, 2}, // eg umull;tst
742 {Intrinsic::umul_with_overflow, MVT::i32, 2}, // eg umull;cmp sxtw
743 {Intrinsic::smul_with_overflow, MVT::i64, 3}, // eg mul;smulh;cmp
744 {Intrinsic::umul_with_overflow, MVT::i64, 3}, // eg mul;umulh;cmp asr
745 };
746 EVT MTy = TLI->getValueType(DL, RetTy->getContainedType(0), true);
747 if (MTy.isSimple())
748 if (const auto *Entry = CostTableLookup(WithOverflowCostTbl, ICA.getID(),
749 MTy.getSimpleVT()))
750 return Entry->Cost;
751 break;
752 }
753 case Intrinsic::fptosi_sat:
754 case Intrinsic::fptoui_sat: {
755 if (ICA.getArgTypes().empty())
756 break;
757 bool IsSigned = ICA.getID() == Intrinsic::fptosi_sat;
758 auto LT = getTypeLegalizationCost(ICA.getArgTypes()[0]);
759 EVT MTy = TLI->getValueType(DL, RetTy);
760 // Check for the legal types, which are where the size of the input and the
761 // output are the same, or we are using cvt f64->i32 or f32->i64.
762 if ((LT.second == MVT::f32 || LT.second == MVT::f64 ||
763 LT.second == MVT::v2f32 || LT.second == MVT::v4f32 ||
764 LT.second == MVT::v2f64)) {
765 if ((LT.second.getScalarSizeInBits() == MTy.getScalarSizeInBits() ||
766 (LT.second == MVT::f64 && MTy == MVT::i32) ||
767 (LT.second == MVT::f32 && MTy == MVT::i64)))
768 return LT.first;
769 // Extending vector types v2f32->v2i64, fcvtl*2 + fcvt*2
770 if (LT.second.getScalarType() == MVT::f32 && MTy.isFixedLengthVector() &&
771 MTy.getScalarSizeInBits() == 64)
772 return LT.first * (MTy.getVectorNumElements() > 2 ? 4 : 2);
773 }
774 // Similarly for fp16 sizes. Without FullFP16 we generally need to fcvt to
775 // f32.
776 if (LT.second.getScalarType() == MVT::f16 && !ST->hasFullFP16())
777 return LT.first + getIntrinsicInstrCost(
778 {ICA.getID(),
779 RetTy,
780 {ICA.getArgTypes()[0]->getWithNewType(
781 Type::getFloatTy(RetTy->getContext()))}},
782 CostKind);
783 if ((LT.second == MVT::f16 && MTy == MVT::i32) ||
784 (LT.second == MVT::f16 && MTy == MVT::i64) ||
785 ((LT.second == MVT::v4f16 || LT.second == MVT::v8f16) &&
786 (LT.second.getScalarSizeInBits() == MTy.getScalarSizeInBits())))
787 return LT.first;
788 // Extending vector types v8f16->v8i32, fcvtl*2 + fcvt*2
789 if (LT.second.getScalarType() == MVT::f16 && MTy.isFixedLengthVector() &&
790 MTy.getScalarSizeInBits() == 32)
791 return LT.first * (MTy.getVectorNumElements() > 4 ? 4 : 2);
792 // Extending vector types v8f16->v8i32. These current scalarize but the
793 // codegen could be better.
794 if (LT.second.getScalarType() == MVT::f16 && MTy.isFixedLengthVector() &&
795 MTy.getScalarSizeInBits() == 64)
796 return MTy.getVectorNumElements() * 3;
797
798 // If we can we use a legal convert followed by a min+max
799 if ((LT.second.getScalarType() == MVT::f32 ||
800 LT.second.getScalarType() == MVT::f64 ||
801 LT.second.getScalarType() == MVT::f16) &&
802 LT.second.getScalarSizeInBits() >= MTy.getScalarSizeInBits()) {
803 Type *LegalTy =
804 Type::getIntNTy(RetTy->getContext(), LT.second.getScalarSizeInBits());
805 if (LT.second.isVector())
806 LegalTy = VectorType::get(LegalTy, LT.second.getVectorElementCount());
808 IntrinsicCostAttributes Attrs1(IsSigned ? Intrinsic::smin : Intrinsic::umin,
809 LegalTy, {LegalTy, LegalTy});
811 IntrinsicCostAttributes Attrs2(IsSigned ? Intrinsic::smax : Intrinsic::umax,
812 LegalTy, {LegalTy, LegalTy});
814 return LT.first * Cost +
815 ((LT.second.getScalarType() != MVT::f16 || ST->hasFullFP16()) ? 0
816 : 1);
817 }
818 // Otherwise we need to follow the default expansion that clamps the value
819 // using a float min/max with a fcmp+sel for nan handling when signed.
820 Type *FPTy = ICA.getArgTypes()[0]->getScalarType();
821 RetTy = RetTy->getScalarType();
822 if (LT.second.isVector()) {
823 FPTy = VectorType::get(FPTy, LT.second.getVectorElementCount());
824 RetTy = VectorType::get(RetTy, LT.second.getVectorElementCount());
825 }
826 IntrinsicCostAttributes Attrs1(Intrinsic::minnum, FPTy, {FPTy, FPTy});
828 IntrinsicCostAttributes Attrs2(Intrinsic::maxnum, FPTy, {FPTy, FPTy});
830 Cost +=
831 getCastInstrCost(IsSigned ? Instruction::FPToSI : Instruction::FPToUI,
833 if (IsSigned) {
834 Type *CondTy = RetTy->getWithNewBitWidth(1);
835 Cost += getCmpSelInstrCost(BinaryOperator::FCmp, FPTy, CondTy,
837 Cost += getCmpSelInstrCost(BinaryOperator::Select, RetTy, CondTy,
839 }
840 return LT.first * Cost;
841 }
842 case Intrinsic::fshl:
843 case Intrinsic::fshr: {
844 if (ICA.getArgs().empty())
845 break;
846
847 // TODO: Add handling for fshl where third argument is not a constant.
848 const TTI::OperandValueInfo OpInfoZ = TTI::getOperandInfo(ICA.getArgs()[2]);
849 if (!OpInfoZ.isConstant())
850 break;
851
852 const auto LegalisationCost = getTypeLegalizationCost(RetTy);
853 if (OpInfoZ.isUniform()) {
854 // FIXME: The costs could be lower if the codegen is better.
855 static const CostTblEntry FshlTbl[] = {
856 {Intrinsic::fshl, MVT::v4i32, 3}, // ushr + shl + orr
857 {Intrinsic::fshl, MVT::v2i64, 3}, {Intrinsic::fshl, MVT::v16i8, 4},
858 {Intrinsic::fshl, MVT::v8i16, 4}, {Intrinsic::fshl, MVT::v2i32, 3},
859 {Intrinsic::fshl, MVT::v8i8, 4}, {Intrinsic::fshl, MVT::v4i16, 4}};
860 // Costs for both fshl & fshr are the same, so just pass Intrinsic::fshl
861 // to avoid having to duplicate the costs.
862 const auto *Entry =
863 CostTableLookup(FshlTbl, Intrinsic::fshl, LegalisationCost.second);
864 if (Entry)
865 return LegalisationCost.first * Entry->Cost;
866 }
867
868 auto TyL = getTypeLegalizationCost(RetTy);
869 if (!RetTy->isIntegerTy())
870 break;
871
872 // Estimate cost manually, as types like i8 and i16 will get promoted to
873 // i32 and CostTableLookup will ignore the extra conversion cost.
874 bool HigherCost = (RetTy->getScalarSizeInBits() != 32 &&
875 RetTy->getScalarSizeInBits() < 64) ||
876 (RetTy->getScalarSizeInBits() % 64 != 0);
877 unsigned ExtraCost = HigherCost ? 1 : 0;
878 if (RetTy->getScalarSizeInBits() == 32 ||
879 RetTy->getScalarSizeInBits() == 64)
880 ExtraCost = 0; // fhsl/fshr for i32 and i64 can be lowered to a single
881 // extr instruction.
882 else if (HigherCost)
883 ExtraCost = 1;
884 else
885 break;
886 return TyL.first + ExtraCost;
887 }
888 case Intrinsic::get_active_lane_mask: {
889 auto *RetTy = dyn_cast<FixedVectorType>(ICA.getReturnType());
890 if (RetTy) {
891 EVT RetVT = getTLI()->getValueType(DL, RetTy);
892 EVT OpVT = getTLI()->getValueType(DL, ICA.getArgTypes()[0]);
893 if (!getTLI()->shouldExpandGetActiveLaneMask(RetVT, OpVT) &&
894 !getTLI()->isTypeLegal(RetVT)) {
895 // We don't have enough context at this point to determine if the mask
896 // is going to be kept live after the block, which will force the vXi1
897 // type to be expanded to legal vectors of integers, e.g. v4i1->v4i32.
898 // For now, we just assume the vectorizer created this intrinsic and
899 // the result will be the input for a PHI. In this case the cost will
900 // be extremely high for fixed-width vectors.
901 // NOTE: getScalarizationOverhead returns a cost that's far too
902 // pessimistic for the actual generated codegen. In reality there are
903 // two instructions generated per lane.
904 return RetTy->getNumElements() * 2;
905 }
906 }
907 break;
908 }
909 case Intrinsic::experimental_vector_match: {
910 auto *NeedleTy = cast<FixedVectorType>(ICA.getArgTypes()[1]);
911 EVT SearchVT = getTLI()->getValueType(DL, ICA.getArgTypes()[0]);
912 unsigned SearchSize = NeedleTy->getNumElements();
913 if (!getTLI()->shouldExpandVectorMatch(SearchVT, SearchSize)) {
914 // Base cost for MATCH instructions. At least on the Neoverse V2 and
915 // Neoverse V3, these are cheap operations with the same latency as a
916 // vector ADD. In most cases, however, we also need to do an extra DUP.
917 // For fixed-length vectors we currently need an extra five--six
918 // instructions besides the MATCH.
920 if (isa<FixedVectorType>(RetTy))
921 Cost += 10;
922 return Cost;
923 }
924 break;
925 }
926 default:
927 break;
928 }
930}
931
932/// The function will remove redundant reinterprets casting in the presence
933/// of the control flow
934static std::optional<Instruction *> processPhiNode(InstCombiner &IC,
935 IntrinsicInst &II) {
937 auto RequiredType = II.getType();
938
939 auto *PN = dyn_cast<PHINode>(II.getArgOperand(0));
940 assert(PN && "Expected Phi Node!");
941
942 // Don't create a new Phi unless we can remove the old one.
943 if (!PN->hasOneUse())
944 return std::nullopt;
945
946 for (Value *IncValPhi : PN->incoming_values()) {
947 auto *Reinterpret = dyn_cast<IntrinsicInst>(IncValPhi);
948 if (!Reinterpret ||
949 Reinterpret->getIntrinsicID() !=
950 Intrinsic::aarch64_sve_convert_to_svbool ||
951 RequiredType != Reinterpret->getArgOperand(0)->getType())
952 return std::nullopt;
953 }
954
955 // Create the new Phi
956 IC.Builder.SetInsertPoint(PN);
957 PHINode *NPN = IC.Builder.CreatePHI(RequiredType, PN->getNumIncomingValues());
958 Worklist.push_back(PN);
959
960 for (unsigned I = 0; I < PN->getNumIncomingValues(); I++) {
961 auto *Reinterpret = cast<Instruction>(PN->getIncomingValue(I));
962 NPN->addIncoming(Reinterpret->getOperand(0), PN->getIncomingBlock(I));
963 Worklist.push_back(Reinterpret);
964 }
965
966 // Cleanup Phi Node and reinterprets
967 return IC.replaceInstUsesWith(II, NPN);
968}
969
970// (from_svbool (binop (to_svbool pred) (svbool_t _) (svbool_t _))))
971// => (binop (pred) (from_svbool _) (from_svbool _))
972//
973// The above transformation eliminates a `to_svbool` in the predicate
974// operand of bitwise operation `binop` by narrowing the vector width of
975// the operation. For example, it would convert a `<vscale x 16 x i1>
976// and` into a `<vscale x 4 x i1> and`. This is profitable because
977// to_svbool must zero the new lanes during widening, whereas
978// from_svbool is free.
979static std::optional<Instruction *>
981 auto BinOp = dyn_cast<IntrinsicInst>(II.getOperand(0));
982 if (!BinOp)
983 return std::nullopt;
984
985 auto IntrinsicID = BinOp->getIntrinsicID();
986 switch (IntrinsicID) {
987 case Intrinsic::aarch64_sve_and_z:
988 case Intrinsic::aarch64_sve_bic_z:
989 case Intrinsic::aarch64_sve_eor_z:
990 case Intrinsic::aarch64_sve_nand_z:
991 case Intrinsic::aarch64_sve_nor_z:
992 case Intrinsic::aarch64_sve_orn_z:
993 case Intrinsic::aarch64_sve_orr_z:
994 break;
995 default:
996 return std::nullopt;
997 }
998
999 auto BinOpPred = BinOp->getOperand(0);
1000 auto BinOpOp1 = BinOp->getOperand(1);
1001 auto BinOpOp2 = BinOp->getOperand(2);
1002
1003 auto PredIntr = dyn_cast<IntrinsicInst>(BinOpPred);
1004 if (!PredIntr ||
1005 PredIntr->getIntrinsicID() != Intrinsic::aarch64_sve_convert_to_svbool)
1006 return std::nullopt;
1007
1008 auto PredOp = PredIntr->getOperand(0);
1009 auto PredOpTy = cast<VectorType>(PredOp->getType());
1010 if (PredOpTy != II.getType())
1011 return std::nullopt;
1012
1013 SmallVector<Value *> NarrowedBinOpArgs = {PredOp};
1014 auto NarrowBinOpOp1 = IC.Builder.CreateIntrinsic(
1015 Intrinsic::aarch64_sve_convert_from_svbool, {PredOpTy}, {BinOpOp1});
1016 NarrowedBinOpArgs.push_back(NarrowBinOpOp1);
1017 if (BinOpOp1 == BinOpOp2)
1018 NarrowedBinOpArgs.push_back(NarrowBinOpOp1);
1019 else
1020 NarrowedBinOpArgs.push_back(IC.Builder.CreateIntrinsic(
1021 Intrinsic::aarch64_sve_convert_from_svbool, {PredOpTy}, {BinOpOp2}));
1022
1023 auto NarrowedBinOp =
1024 IC.Builder.CreateIntrinsic(IntrinsicID, {PredOpTy}, NarrowedBinOpArgs);
1025 return IC.replaceInstUsesWith(II, NarrowedBinOp);
1026}
1027
1028static std::optional<Instruction *>
1030 // If the reinterpret instruction operand is a PHI Node
1031 if (isa<PHINode>(II.getArgOperand(0)))
1032 return processPhiNode(IC, II);
1033
1034 if (auto BinOpCombine = tryCombineFromSVBoolBinOp(IC, II))
1035 return BinOpCombine;
1036
1037 // Ignore converts to/from svcount_t.
1038 if (isa<TargetExtType>(II.getArgOperand(0)->getType()) ||
1039 isa<TargetExtType>(II.getType()))
1040 return std::nullopt;
1041
1042 SmallVector<Instruction *, 32> CandidatesForRemoval;
1043 Value *Cursor = II.getOperand(0), *EarliestReplacement = nullptr;
1044
1045 const auto *IVTy = cast<VectorType>(II.getType());
1046
1047 // Walk the chain of conversions.
1048 while (Cursor) {
1049 // If the type of the cursor has fewer lanes than the final result, zeroing
1050 // must take place, which breaks the equivalence chain.
1051 const auto *CursorVTy = cast<VectorType>(Cursor->getType());
1052 if (CursorVTy->getElementCount().getKnownMinValue() <
1053 IVTy->getElementCount().getKnownMinValue())
1054 break;
1055
1056 // If the cursor has the same type as I, it is a viable replacement.
1057 if (Cursor->getType() == IVTy)
1058 EarliestReplacement = Cursor;
1059
1060 auto *IntrinsicCursor = dyn_cast<IntrinsicInst>(Cursor);
1061
1062 // If this is not an SVE conversion intrinsic, this is the end of the chain.
1063 if (!IntrinsicCursor || !(IntrinsicCursor->getIntrinsicID() ==
1064 Intrinsic::aarch64_sve_convert_to_svbool ||
1065 IntrinsicCursor->getIntrinsicID() ==
1066 Intrinsic::aarch64_sve_convert_from_svbool))
1067 break;
1068
1069 CandidatesForRemoval.insert(CandidatesForRemoval.begin(), IntrinsicCursor);
1070 Cursor = IntrinsicCursor->getOperand(0);
1071 }
1072
1073 // If no viable replacement in the conversion chain was found, there is
1074 // nothing to do.
1075 if (!EarliestReplacement)
1076 return std::nullopt;
1077
1078 return IC.replaceInstUsesWith(II, EarliestReplacement);
1079}
1080
1081static bool isAllActivePredicate(Value *Pred) {
1082 // Look through convert.from.svbool(convert.to.svbool(...) chain.
1083 Value *UncastedPred;
1084 if (match(Pred, m_Intrinsic<Intrinsic::aarch64_sve_convert_from_svbool>(
1085 m_Intrinsic<Intrinsic::aarch64_sve_convert_to_svbool>(
1086 m_Value(UncastedPred)))))
1087 // If the predicate has the same or less lanes than the uncasted
1088 // predicate then we know the casting has no effect.
1089 if (cast<ScalableVectorType>(Pred->getType())->getMinNumElements() <=
1090 cast<ScalableVectorType>(UncastedPred->getType())->getMinNumElements())
1091 Pred = UncastedPred;
1092
1093 return match(Pred, m_Intrinsic<Intrinsic::aarch64_sve_ptrue>(
1094 m_ConstantInt<AArch64SVEPredPattern::all>()));
1095}
1096
1097// Simplify unary operation where predicate has all inactive lanes by replacing
1098// instruction with its operand
1099static std::optional<Instruction *>
1101 bool hasInactiveVector) {
1102 int PredOperand = hasInactiveVector ? 1 : 0;
1103 int ReplaceOperand = hasInactiveVector ? 0 : 1;
1104 if (match(II.getOperand(PredOperand), m_ZeroInt())) {
1105 IC.replaceInstUsesWith(II, II.getOperand(ReplaceOperand));
1106 return IC.eraseInstFromFunction(II);
1107 }
1108 return std::nullopt;
1109}
1110
1111// Simplify unary operation where predicate has all inactive lanes or
1112// replace unused first operand with undef when all lanes are active
1113static std::optional<Instruction *>
1115 if (isAllActivePredicate(II.getOperand(1)) &&
1116 !isa<llvm::UndefValue>(II.getOperand(0)) &&
1117 !isa<llvm::PoisonValue>(II.getOperand(0))) {
1118 Value *Undef = llvm::UndefValue::get(II.getType());
1119 return IC.replaceOperand(II, 0, Undef);
1120 }
1121 return instCombineSVENoActiveReplace(IC, II, true);
1122}
1123
1124// Erase unary operation where predicate has all inactive lanes
1125static std::optional<Instruction *>
1127 int PredPos) {
1128 if (match(II.getOperand(PredPos), m_ZeroInt())) {
1129 return IC.eraseInstFromFunction(II);
1130 }
1131 return std::nullopt;
1132}
1133
1134// Simplify operation where predicate has all inactive lanes by replacing
1135// instruction with zeroed object
1136static std::optional<Instruction *>
1138 if (match(II.getOperand(0), m_ZeroInt())) {
1139 Constant *Node;
1140 Type *RetTy = II.getType();
1141 if (RetTy->isStructTy()) {
1142 auto StructT = cast<StructType>(RetTy);
1143 auto VecT = StructT->getElementType(0);
1145 for (unsigned i = 0; i < StructT->getNumElements(); i++) {
1146 ZerVec.push_back(VecT->isFPOrFPVectorTy() ? ConstantFP::get(VecT, 0.0)
1147 : ConstantInt::get(VecT, 0));
1148 }
1149 Node = ConstantStruct::get(StructT, ZerVec);
1150 } else
1151 Node = RetTy->isFPOrFPVectorTy() ? ConstantFP::get(RetTy, 0.0)
1152 : ConstantInt::get(II.getType(), 0);
1153
1155 return IC.eraseInstFromFunction(II);
1156 }
1157 return std::nullopt;
1158}
1159
1160static std::optional<Instruction *> instCombineSVESel(InstCombiner &IC,
1161 IntrinsicInst &II) {
1162 // svsel(ptrue, x, y) => x
1163 auto *OpPredicate = II.getOperand(0);
1164 if (isAllActivePredicate(OpPredicate))
1165 return IC.replaceInstUsesWith(II, II.getOperand(1));
1166
1167 auto Select =
1168 IC.Builder.CreateSelect(OpPredicate, II.getOperand(1), II.getOperand(2));
1169 return IC.replaceInstUsesWith(II, Select);
1170}
1171
1172static std::optional<Instruction *> instCombineSVEDup(InstCombiner &IC,
1173 IntrinsicInst &II) {
1174 IntrinsicInst *Pg = dyn_cast<IntrinsicInst>(II.getArgOperand(1));
1175 if (!Pg)
1176 return std::nullopt;
1177
1178 if (Pg->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
1179 return std::nullopt;
1180
1181 const auto PTruePattern =
1182 cast<ConstantInt>(Pg->getOperand(0))->getZExtValue();
1183 if (PTruePattern != AArch64SVEPredPattern::vl1)
1184 return std::nullopt;
1185
1186 // The intrinsic is inserting into lane zero so use an insert instead.
1187 auto *IdxTy = Type::getInt64Ty(II.getContext());
1188 auto *Insert = InsertElementInst::Create(
1189 II.getArgOperand(0), II.getArgOperand(2), ConstantInt::get(IdxTy, 0));
1190 Insert->insertBefore(&II);
1191 Insert->takeName(&II);
1192
1193 return IC.replaceInstUsesWith(II, Insert);
1194}
1195
1196static std::optional<Instruction *> instCombineSVEDupX(InstCombiner &IC,
1197 IntrinsicInst &II) {
1198 // Replace DupX with a regular IR splat.
1199 auto *RetTy = cast<ScalableVectorType>(II.getType());
1200 Value *Splat = IC.Builder.CreateVectorSplat(RetTy->getElementCount(),
1201 II.getArgOperand(0));
1202 Splat->takeName(&II);
1203 return IC.replaceInstUsesWith(II, Splat);
1204}
1205
1206static std::optional<Instruction *> instCombineSVECmpNE(InstCombiner &IC,
1207 IntrinsicInst &II) {
1208 LLVMContext &Ctx = II.getContext();
1209
1210 // Replace by zero constant when all lanes are inactive
1211 if (auto II_NA = instCombineSVENoActiveZero(IC, II))
1212 return II_NA;
1213
1214 // Check that the predicate is all active
1215 auto *Pg = dyn_cast<IntrinsicInst>(II.getArgOperand(0));
1216 if (!Pg || Pg->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
1217 return std::nullopt;
1218
1219 const auto PTruePattern =
1220 cast<ConstantInt>(Pg->getOperand(0))->getZExtValue();
1221 if (PTruePattern != AArch64SVEPredPattern::all)
1222 return std::nullopt;
1223
1224 // Check that we have a compare of zero..
1225 auto *SplatValue =
1226 dyn_cast_or_null<ConstantInt>(getSplatValue(II.getArgOperand(2)));
1227 if (!SplatValue || !SplatValue->isZero())
1228 return std::nullopt;
1229
1230 // ..against a dupq
1231 auto *DupQLane = dyn_cast<IntrinsicInst>(II.getArgOperand(1));
1232 if (!DupQLane ||
1233 DupQLane->getIntrinsicID() != Intrinsic::aarch64_sve_dupq_lane)
1234 return std::nullopt;
1235
1236 // Where the dupq is a lane 0 replicate of a vector insert
1237 auto *DupQLaneIdx = dyn_cast<ConstantInt>(DupQLane->getArgOperand(1));
1238 if (!DupQLaneIdx || !DupQLaneIdx->isZero())
1239 return std::nullopt;
1240
1241 auto *VecIns = dyn_cast<IntrinsicInst>(DupQLane->getArgOperand(0));
1242 if (!VecIns || VecIns->getIntrinsicID() != Intrinsic::vector_insert)
1243 return std::nullopt;
1244
1245 // Where the vector insert is a fixed constant vector insert into undef at
1246 // index zero
1247 if (!isa<UndefValue>(VecIns->getArgOperand(0)))
1248 return std::nullopt;
1249
1250 if (!cast<ConstantInt>(VecIns->getArgOperand(2))->isZero())
1251 return std::nullopt;
1252
1253 auto *ConstVec = dyn_cast<Constant>(VecIns->getArgOperand(1));
1254 if (!ConstVec)
1255 return std::nullopt;
1256
1257 auto *VecTy = dyn_cast<FixedVectorType>(ConstVec->getType());
1258 auto *OutTy = dyn_cast<ScalableVectorType>(II.getType());
1259 if (!VecTy || !OutTy || VecTy->getNumElements() != OutTy->getMinNumElements())
1260 return std::nullopt;
1261
1262 unsigned NumElts = VecTy->getNumElements();
1263 unsigned PredicateBits = 0;
1264
1265 // Expand intrinsic operands to a 16-bit byte level predicate
1266 for (unsigned I = 0; I < NumElts; ++I) {
1267 auto *Arg = dyn_cast<ConstantInt>(ConstVec->getAggregateElement(I));
1268 if (!Arg)
1269 return std::nullopt;
1270 if (!Arg->isZero())
1271 PredicateBits |= 1 << (I * (16 / NumElts));
1272 }
1273
1274 // If all bits are zero bail early with an empty predicate
1275 if (PredicateBits == 0) {
1276 auto *PFalse = Constant::getNullValue(II.getType());
1277 PFalse->takeName(&II);
1278 return IC.replaceInstUsesWith(II, PFalse);
1279 }
1280
1281 // Calculate largest predicate type used (where byte predicate is largest)
1282 unsigned Mask = 8;
1283 for (unsigned I = 0; I < 16; ++I)
1284 if ((PredicateBits & (1 << I)) != 0)
1285 Mask |= (I % 8);
1286
1287 unsigned PredSize = Mask & -Mask;
1288 auto *PredType = ScalableVectorType::get(
1289 Type::getInt1Ty(Ctx), AArch64::SVEBitsPerBlock / (PredSize * 8));
1290
1291 // Ensure all relevant bits are set
1292 for (unsigned I = 0; I < 16; I += PredSize)
1293 if ((PredicateBits & (1 << I)) == 0)
1294 return std::nullopt;
1295
1296 auto *PTruePat =
1297 ConstantInt::get(Type::getInt32Ty(Ctx), AArch64SVEPredPattern::all);
1298 auto *PTrue = IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptrue,
1299 {PredType}, {PTruePat});
1300 auto *ConvertToSVBool = IC.Builder.CreateIntrinsic(
1301 Intrinsic::aarch64_sve_convert_to_svbool, {PredType}, {PTrue});
1302 auto *ConvertFromSVBool =
1303 IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_convert_from_svbool,
1304 {II.getType()}, {ConvertToSVBool});
1305
1306 ConvertFromSVBool->takeName(&II);
1307 return IC.replaceInstUsesWith(II, ConvertFromSVBool);
1308}
1309
1310static std::optional<Instruction *> instCombineSVELast(InstCombiner &IC,
1311 IntrinsicInst &II) {
1312 Value *Pg = II.getArgOperand(0);
1313 Value *Vec = II.getArgOperand(1);
1314 auto IntrinsicID = II.getIntrinsicID();
1315 bool IsAfter = IntrinsicID == Intrinsic::aarch64_sve_lasta;
1316
1317 // lastX(splat(X)) --> X
1318 if (auto *SplatVal = getSplatValue(Vec))
1319 return IC.replaceInstUsesWith(II, SplatVal);
1320
1321 // If x and/or y is a splat value then:
1322 // lastX (binop (x, y)) --> binop(lastX(x), lastX(y))
1323 Value *LHS, *RHS;
1324 if (match(Vec, m_OneUse(m_BinOp(m_Value(LHS), m_Value(RHS))))) {
1325 if (isSplatValue(LHS) || isSplatValue(RHS)) {
1326 auto *OldBinOp = cast<BinaryOperator>(Vec);
1327 auto OpC = OldBinOp->getOpcode();
1328 auto *NewLHS =
1329 IC.Builder.CreateIntrinsic(IntrinsicID, {Vec->getType()}, {Pg, LHS});
1330 auto *NewRHS =
1331 IC.Builder.CreateIntrinsic(IntrinsicID, {Vec->getType()}, {Pg, RHS});
1333 OpC, NewLHS, NewRHS, OldBinOp, OldBinOp->getName(), II.getIterator());
1334 return IC.replaceInstUsesWith(II, NewBinOp);
1335 }
1336 }
1337
1338 auto *C = dyn_cast<Constant>(Pg);
1339 if (IsAfter && C && C->isNullValue()) {
1340 // The intrinsic is extracting lane 0 so use an extract instead.
1341 auto *IdxTy = Type::getInt64Ty(II.getContext());
1342 auto *Extract = ExtractElementInst::Create(Vec, ConstantInt::get(IdxTy, 0));
1343 Extract->insertBefore(&II);
1344 Extract->takeName(&II);
1345 return IC.replaceInstUsesWith(II, Extract);
1346 }
1347
1348 auto *IntrPG = dyn_cast<IntrinsicInst>(Pg);
1349 if (!IntrPG)
1350 return std::nullopt;
1351
1352 if (IntrPG->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
1353 return std::nullopt;
1354
1355 const auto PTruePattern =
1356 cast<ConstantInt>(IntrPG->getOperand(0))->getZExtValue();
1357
1358 // Can the intrinsic's predicate be converted to a known constant index?
1359 unsigned MinNumElts = getNumElementsFromSVEPredPattern(PTruePattern);
1360 if (!MinNumElts)
1361 return std::nullopt;
1362
1363 unsigned Idx = MinNumElts - 1;
1364 // Increment the index if extracting the element after the last active
1365 // predicate element.
1366 if (IsAfter)
1367 ++Idx;
1368
1369 // Ignore extracts whose index is larger than the known minimum vector
1370 // length. NOTE: This is an artificial constraint where we prefer to
1371 // maintain what the user asked for until an alternative is proven faster.
1372 auto *PgVTy = cast<ScalableVectorType>(Pg->getType());
1373 if (Idx >= PgVTy->getMinNumElements())
1374 return std::nullopt;
1375
1376 // The intrinsic is extracting a fixed lane so use an extract instead.
1377 auto *IdxTy = Type::getInt64Ty(II.getContext());
1378 auto *Extract = ExtractElementInst::Create(Vec, ConstantInt::get(IdxTy, Idx));
1379 Extract->insertBefore(&II);
1380 Extract->takeName(&II);
1381 return IC.replaceInstUsesWith(II, Extract);
1382}
1383
1384static std::optional<Instruction *> instCombineSVECondLast(InstCombiner &IC,
1385 IntrinsicInst &II) {
1386 // The SIMD&FP variant of CLAST[AB] is significantly faster than the scalar
1387 // integer variant across a variety of micro-architectures. Replace scalar
1388 // integer CLAST[AB] intrinsic with optimal SIMD&FP variant. A simple
1389 // bitcast-to-fp + clast[ab] + bitcast-to-int will cost a cycle or two more
1390 // depending on the micro-architecture, but has been observed as generally
1391 // being faster, particularly when the CLAST[AB] op is a loop-carried
1392 // dependency.
1393 Value *Pg = II.getArgOperand(0);
1394 Value *Fallback = II.getArgOperand(1);
1395 Value *Vec = II.getArgOperand(2);
1396 Type *Ty = II.getType();
1397
1398 if (!Ty->isIntegerTy())
1399 return std::nullopt;
1400
1401 Type *FPTy;
1402 switch (cast<IntegerType>(Ty)->getBitWidth()) {
1403 default:
1404 return std::nullopt;
1405 case 16:
1406 FPTy = IC.Builder.getHalfTy();
1407 break;
1408 case 32:
1409 FPTy = IC.Builder.getFloatTy();
1410 break;
1411 case 64:
1412 FPTy = IC.Builder.getDoubleTy();
1413 break;
1414 }
1415
1416 Value *FPFallBack = IC.Builder.CreateBitCast(Fallback, FPTy);
1417 auto *FPVTy = VectorType::get(
1418 FPTy, cast<VectorType>(Vec->getType())->getElementCount());
1419 Value *FPVec = IC.Builder.CreateBitCast(Vec, FPVTy);
1420 auto *FPII = IC.Builder.CreateIntrinsic(
1421 II.getIntrinsicID(), {FPVec->getType()}, {Pg, FPFallBack, FPVec});
1422 Value *FPIItoInt = IC.Builder.CreateBitCast(FPII, II.getType());
1423 return IC.replaceInstUsesWith(II, FPIItoInt);
1424}
1425
1426static std::optional<Instruction *> instCombineRDFFR(InstCombiner &IC,
1427 IntrinsicInst &II) {
1428 LLVMContext &Ctx = II.getContext();
1429 // Replace rdffr with predicated rdffr.z intrinsic, so that optimizePTestInstr
1430 // can work with RDFFR_PP for ptest elimination.
1431 auto *AllPat =
1432 ConstantInt::get(Type::getInt32Ty(Ctx), AArch64SVEPredPattern::all);
1433 auto *PTrue = IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptrue,
1434 {II.getType()}, {AllPat});
1435 auto *RDFFR =
1436 IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_rdffr_z, {}, {PTrue});
1437 RDFFR->takeName(&II);
1438 return IC.replaceInstUsesWith(II, RDFFR);
1439}
1440
1441static std::optional<Instruction *>
1443 const auto Pattern = cast<ConstantInt>(II.getArgOperand(0))->getZExtValue();
1444
1445 if (Pattern == AArch64SVEPredPattern::all) {
1446 Constant *StepVal = ConstantInt::get(II.getType(), NumElts);
1447 auto *VScale = IC.Builder.CreateVScale(StepVal);
1448 VScale->takeName(&II);
1449 return IC.replaceInstUsesWith(II, VScale);
1450 }
1451
1452 unsigned MinNumElts = getNumElementsFromSVEPredPattern(Pattern);
1453
1454 return MinNumElts && NumElts >= MinNumElts
1455 ? std::optional<Instruction *>(IC.replaceInstUsesWith(
1456 II, ConstantInt::get(II.getType(), MinNumElts)))
1457 : std::nullopt;
1458}
1459
1460static std::optional<Instruction *> instCombineSVEPTest(InstCombiner &IC,
1461 IntrinsicInst &II) {
1462 Value *PgVal = II.getArgOperand(0);
1463 Value *OpVal = II.getArgOperand(1);
1464
1465 // PTEST_<FIRST|LAST>(X, X) is equivalent to PTEST_ANY(X, X).
1466 // Later optimizations prefer this form.
1467 if (PgVal == OpVal &&
1468 (II.getIntrinsicID() == Intrinsic::aarch64_sve_ptest_first ||
1469 II.getIntrinsicID() == Intrinsic::aarch64_sve_ptest_last)) {
1470 Value *Ops[] = {PgVal, OpVal};
1471 Type *Tys[] = {PgVal->getType()};
1472
1473 auto *PTest =
1474 IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptest_any, Tys, Ops);
1475 PTest->takeName(&II);
1476
1477 return IC.replaceInstUsesWith(II, PTest);
1478 }
1479
1480 IntrinsicInst *Pg = dyn_cast<IntrinsicInst>(PgVal);
1481 IntrinsicInst *Op = dyn_cast<IntrinsicInst>(OpVal);
1482
1483 if (!Pg || !Op)
1484 return std::nullopt;
1485
1486 Intrinsic::ID OpIID = Op->getIntrinsicID();
1487
1488 if (Pg->getIntrinsicID() == Intrinsic::aarch64_sve_convert_to_svbool &&
1489 OpIID == Intrinsic::aarch64_sve_convert_to_svbool &&
1490 Pg->getArgOperand(0)->getType() == Op->getArgOperand(0)->getType()) {
1491 Value *Ops[] = {Pg->getArgOperand(0), Op->getArgOperand(0)};
1492 Type *Tys[] = {Pg->getArgOperand(0)->getType()};
1493
1494 auto *PTest = IC.Builder.CreateIntrinsic(II.getIntrinsicID(), Tys, Ops);
1495
1496 PTest->takeName(&II);
1497 return IC.replaceInstUsesWith(II, PTest);
1498 }
1499
1500 // Transform PTEST_ANY(X=OP(PG,...), X) -> PTEST_ANY(PG, X)).
1501 // Later optimizations may rewrite sequence to use the flag-setting variant
1502 // of instruction X to remove PTEST.
1503 if ((Pg == Op) && (II.getIntrinsicID() == Intrinsic::aarch64_sve_ptest_any) &&
1504 ((OpIID == Intrinsic::aarch64_sve_brka_z) ||
1505 (OpIID == Intrinsic::aarch64_sve_brkb_z) ||
1506 (OpIID == Intrinsic::aarch64_sve_brkpa_z) ||
1507 (OpIID == Intrinsic::aarch64_sve_brkpb_z) ||
1508 (OpIID == Intrinsic::aarch64_sve_rdffr_z) ||
1509 (OpIID == Intrinsic::aarch64_sve_and_z) ||
1510 (OpIID == Intrinsic::aarch64_sve_bic_z) ||
1511 (OpIID == Intrinsic::aarch64_sve_eor_z) ||
1512 (OpIID == Intrinsic::aarch64_sve_nand_z) ||
1513 (OpIID == Intrinsic::aarch64_sve_nor_z) ||
1514 (OpIID == Intrinsic::aarch64_sve_orn_z) ||
1515 (OpIID == Intrinsic::aarch64_sve_orr_z))) {
1516 Value *Ops[] = {Pg->getArgOperand(0), Pg};
1517 Type *Tys[] = {Pg->getType()};
1518
1519 auto *PTest = IC.Builder.CreateIntrinsic(II.getIntrinsicID(), Tys, Ops);
1520 PTest->takeName(&II);
1521
1522 return IC.replaceInstUsesWith(II, PTest);
1523 }
1524
1525 return std::nullopt;
1526}
1527
1528template <Intrinsic::ID MulOpc, typename Intrinsic::ID FuseOpc>
1529static std::optional<Instruction *>
1531 bool MergeIntoAddendOp) {
1532 Value *P = II.getOperand(0);
1533 Value *MulOp0, *MulOp1, *AddendOp, *Mul;
1534 if (MergeIntoAddendOp) {
1535 AddendOp = II.getOperand(1);
1536 Mul = II.getOperand(2);
1537 } else {
1538 AddendOp = II.getOperand(2);
1539 Mul = II.getOperand(1);
1540 }
1541
1542 if (!match(Mul, m_Intrinsic<MulOpc>(m_Specific(P), m_Value(MulOp0),
1543 m_Value(MulOp1))))
1544 return std::nullopt;
1545
1546 if (!Mul->hasOneUse())
1547 return std::nullopt;
1548
1549 Instruction *FMFSource = nullptr;
1550 if (II.getType()->isFPOrFPVectorTy()) {
1551 llvm::FastMathFlags FAddFlags = II.getFastMathFlags();
1552 // Stop the combine when the flags on the inputs differ in case dropping
1553 // flags would lead to us missing out on more beneficial optimizations.
1554 if (FAddFlags != cast<CallInst>(Mul)->getFastMathFlags())
1555 return std::nullopt;
1556 if (!FAddFlags.allowContract())
1557 return std::nullopt;
1558 FMFSource = &II;
1559 }
1560
1561 CallInst *Res;
1562 if (MergeIntoAddendOp)
1563 Res = IC.Builder.CreateIntrinsic(FuseOpc, {II.getType()},
1564 {P, AddendOp, MulOp0, MulOp1}, FMFSource);
1565 else
1566 Res = IC.Builder.CreateIntrinsic(FuseOpc, {II.getType()},
1567 {P, MulOp0, MulOp1, AddendOp}, FMFSource);
1568
1569 return IC.replaceInstUsesWith(II, Res);
1570}
1571
1572static std::optional<Instruction *>
1574 Value *Pred = II.getOperand(0);
1575 Value *PtrOp = II.getOperand(1);
1576 Type *VecTy = II.getType();
1577
1578 // Replace by zero constant when all lanes are inactive
1579 if (auto II_NA = instCombineSVENoActiveZero(IC, II))
1580 return II_NA;
1581
1582 if (isAllActivePredicate(Pred)) {
1583 LoadInst *Load = IC.Builder.CreateLoad(VecTy, PtrOp);
1584 Load->copyMetadata(II);
1585 return IC.replaceInstUsesWith(II, Load);
1586 }
1587
1588 CallInst *MaskedLoad =
1589 IC.Builder.CreateMaskedLoad(VecTy, PtrOp, PtrOp->getPointerAlignment(DL),
1590 Pred, ConstantAggregateZero::get(VecTy));
1591 MaskedLoad->copyMetadata(II);
1592 return IC.replaceInstUsesWith(II, MaskedLoad);
1593}
1594
1595static std::optional<Instruction *>
1597 Value *VecOp = II.getOperand(0);
1598 Value *Pred = II.getOperand(1);
1599 Value *PtrOp = II.getOperand(2);
1600
1601 if (isAllActivePredicate(Pred)) {
1602 StoreInst *Store = IC.Builder.CreateStore(VecOp, PtrOp);
1603 Store->copyMetadata(II);
1604 return IC.eraseInstFromFunction(II);
1605 }
1606
1607 CallInst *MaskedStore = IC.Builder.CreateMaskedStore(
1608 VecOp, PtrOp, PtrOp->getPointerAlignment(DL), Pred);
1609 MaskedStore->copyMetadata(II);
1610 return IC.eraseInstFromFunction(II);
1611}
1612
1614 switch (Intrinsic) {
1615 case Intrinsic::aarch64_sve_fmul_u:
1616 return Instruction::BinaryOps::FMul;
1617 case Intrinsic::aarch64_sve_fadd_u:
1618 return Instruction::BinaryOps::FAdd;
1619 case Intrinsic::aarch64_sve_fsub_u:
1620 return Instruction::BinaryOps::FSub;
1621 default:
1622 return Instruction::BinaryOpsEnd;
1623 }
1624}
1625
1626static std::optional<Instruction *>
1628 // Bail due to missing support for ISD::STRICT_ scalable vector operations.
1629 if (II.isStrictFP())
1630 return std::nullopt;
1631
1632 auto *OpPredicate = II.getOperand(0);
1633 auto BinOpCode = intrinsicIDToBinOpCode(II.getIntrinsicID());
1634 if (BinOpCode == Instruction::BinaryOpsEnd ||
1635 !match(OpPredicate, m_Intrinsic<Intrinsic::aarch64_sve_ptrue>(
1636 m_ConstantInt<AArch64SVEPredPattern::all>())))
1637 return std::nullopt;
1639 IC.Builder.setFastMathFlags(II.getFastMathFlags());
1640 auto BinOp =
1641 IC.Builder.CreateBinOp(BinOpCode, II.getOperand(1), II.getOperand(2));
1642 return IC.replaceInstUsesWith(II, BinOp);
1643}
1644
1645// Canonicalise operations that take an all active predicate (e.g. sve.add ->
1646// sve.add_u).
1647static std::optional<Instruction *> instCombineSVEAllActive(IntrinsicInst &II,
1648 Intrinsic::ID IID) {
1649 auto *OpPredicate = II.getOperand(0);
1650 if (!match(OpPredicate, m_Intrinsic<Intrinsic::aarch64_sve_ptrue>(
1651 m_ConstantInt<AArch64SVEPredPattern::all>())))
1652 return std::nullopt;
1653
1654 auto *Mod = II.getModule();
1655 auto *NewDecl = Intrinsic::getOrInsertDeclaration(Mod, IID, {II.getType()});
1656 II.setCalledFunction(NewDecl);
1657
1658 return &II;
1659}
1660
1661// Simplify operations where predicate has all inactive lanes or try to replace
1662// with _u form when all lanes are active
1663static std::optional<Instruction *>
1665 Intrinsic::ID IID) {
1666 if (match(II.getOperand(0), m_ZeroInt())) {
1667 // llvm_ir, pred(0), op1, op2 - Spec says to return op1 when all lanes are
1668 // inactive for sv[func]_m
1669 return IC.replaceInstUsesWith(II, II.getOperand(1));
1670 }
1671 return instCombineSVEAllActive(II, IID);
1672}
1673
1674static std::optional<Instruction *> instCombineSVEVectorAdd(InstCombiner &IC,
1675 IntrinsicInst &II) {
1676 if (auto II_U =
1677 instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_add_u))
1678 return II_U;
1679 if (auto MLA = instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_mul,
1680 Intrinsic::aarch64_sve_mla>(
1681 IC, II, true))
1682 return MLA;
1683 if (auto MAD = instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_mul,
1684 Intrinsic::aarch64_sve_mad>(
1685 IC, II, false))
1686 return MAD;
1687 return std::nullopt;
1688}
1689
1690static std::optional<Instruction *>
1692 if (auto II_U =
1693 instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fadd_u))
1694 return II_U;
1695 if (auto FMLA =
1696 instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
1697 Intrinsic::aarch64_sve_fmla>(IC, II,
1698 true))
1699 return FMLA;
1700 if (auto FMAD =
1701 instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
1702 Intrinsic::aarch64_sve_fmad>(IC, II,
1703 false))
1704 return FMAD;
1705 if (auto FMLA =
1706 instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul_u,
1707 Intrinsic::aarch64_sve_fmla>(IC, II,
1708 true))
1709 return FMLA;
1710 return std::nullopt;
1711}
1712
1713static std::optional<Instruction *>
1715 if (auto FMLA =
1716 instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
1717 Intrinsic::aarch64_sve_fmla>(IC, II,
1718 true))
1719 return FMLA;
1720 if (auto FMAD =
1721 instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
1722 Intrinsic::aarch64_sve_fmad>(IC, II,
1723 false))
1724 return FMAD;
1725 if (auto FMLA_U =
1726 instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul_u,
1727 Intrinsic::aarch64_sve_fmla_u>(
1728 IC, II, true))
1729 return FMLA_U;
1730 return instCombineSVEVectorBinOp(IC, II);
1731}
1732
1733static std::optional<Instruction *>
1735 if (auto II_U =
1736 instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fsub_u))
1737 return II_U;
1738 if (auto FMLS =
1739 instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
1740 Intrinsic::aarch64_sve_fmls>(IC, II,
1741 true))
1742 return FMLS;
1743 if (auto FMSB =
1744 instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
1745 Intrinsic::aarch64_sve_fnmsb>(
1746 IC, II, false))
1747 return FMSB;
1748 if (auto FMLS =
1749 instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul_u,
1750 Intrinsic::aarch64_sve_fmls>(IC, II,
1751 true))
1752 return FMLS;
1753 return std::nullopt;
1754}
1755
1756static std::optional<Instruction *>
1758 if (auto FMLS =
1759 instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
1760 Intrinsic::aarch64_sve_fmls>(IC, II,
1761 true))
1762 return FMLS;
1763 if (auto FMSB =
1764 instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
1765 Intrinsic::aarch64_sve_fnmsb>(
1766 IC, II, false))
1767 return FMSB;
1768 if (auto FMLS_U =
1769 instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul_u,
1770 Intrinsic::aarch64_sve_fmls_u>(
1771 IC, II, true))
1772 return FMLS_U;
1773 return instCombineSVEVectorBinOp(IC, II);
1774}
1775
1776static std::optional<Instruction *> instCombineSVEVectorSub(InstCombiner &IC,
1777 IntrinsicInst &II) {
1778 if (auto II_U =
1779 instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_sub_u))
1780 return II_U;
1781 if (auto MLS = instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_mul,
1782 Intrinsic::aarch64_sve_mls>(
1783 IC, II, true))
1784 return MLS;
1785 return std::nullopt;
1786}
1787
1788static std::optional<Instruction *> instCombineSVEVectorMul(InstCombiner &IC,
1790 Intrinsic::ID IID) {
1791 auto *OpPredicate = II.getOperand(0);
1792 auto *OpMultiplicand = II.getOperand(1);
1793 auto *OpMultiplier = II.getOperand(2);
1794
1795 // Return true if a given instruction is a unit splat value, false otherwise.
1796 auto IsUnitSplat = [](auto *I) {
1797 auto *SplatValue = getSplatValue(I);
1798 if (!SplatValue)
1799 return false;
1800 return match(SplatValue, m_FPOne()) || match(SplatValue, m_One());
1801 };
1802
1803 // Return true if a given instruction is an aarch64_sve_dup intrinsic call
1804 // with a unit splat value, false otherwise.
1805 auto IsUnitDup = [](auto *I) {
1806 auto *IntrI = dyn_cast<IntrinsicInst>(I);
1807 if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::aarch64_sve_dup)
1808 return false;
1809
1810 auto *SplatValue = IntrI->getOperand(2);
1811 return match(SplatValue, m_FPOne()) || match(SplatValue, m_One());
1812 };
1813
1814 if (IsUnitSplat(OpMultiplier)) {
1815 // [f]mul pg %n, (dupx 1) => %n
1816 OpMultiplicand->takeName(&II);
1817 return IC.replaceInstUsesWith(II, OpMultiplicand);
1818 } else if (IsUnitDup(OpMultiplier)) {
1819 // [f]mul pg %n, (dup pg 1) => %n
1820 auto *DupInst = cast<IntrinsicInst>(OpMultiplier);
1821 auto *DupPg = DupInst->getOperand(1);
1822 // TODO: this is naive. The optimization is still valid if DupPg
1823 // 'encompasses' OpPredicate, not only if they're the same predicate.
1824 if (OpPredicate == DupPg) {
1825 OpMultiplicand->takeName(&II);
1826 return IC.replaceInstUsesWith(II, OpMultiplicand);
1827 }
1828 }
1829
1830 return instCombineSVEVectorBinOp(IC, II);
1831}
1832
1833static std::optional<Instruction *> instCombineSVEUnpack(InstCombiner &IC,
1834 IntrinsicInst &II) {
1835 Value *UnpackArg = II.getArgOperand(0);
1836 auto *RetTy = cast<ScalableVectorType>(II.getType());
1837 bool IsSigned = II.getIntrinsicID() == Intrinsic::aarch64_sve_sunpkhi ||
1838 II.getIntrinsicID() == Intrinsic::aarch64_sve_sunpklo;
1839
1840 // Hi = uunpkhi(splat(X)) --> Hi = splat(extend(X))
1841 // Lo = uunpklo(splat(X)) --> Lo = splat(extend(X))
1842 if (auto *ScalarArg = getSplatValue(UnpackArg)) {
1843 ScalarArg =
1844 IC.Builder.CreateIntCast(ScalarArg, RetTy->getScalarType(), IsSigned);
1845 Value *NewVal =
1846 IC.Builder.CreateVectorSplat(RetTy->getElementCount(), ScalarArg);
1847 NewVal->takeName(&II);
1848 return IC.replaceInstUsesWith(II, NewVal);
1849 }
1850
1851 return std::nullopt;
1852}
1853static std::optional<Instruction *> instCombineSVETBL(InstCombiner &IC,
1854 IntrinsicInst &II) {
1855 auto *OpVal = II.getOperand(0);
1856 auto *OpIndices = II.getOperand(1);
1857 VectorType *VTy = cast<VectorType>(II.getType());
1858
1859 // Check whether OpIndices is a constant splat value < minimal element count
1860 // of result.
1861 auto *SplatValue = dyn_cast_or_null<ConstantInt>(getSplatValue(OpIndices));
1862 if (!SplatValue ||
1863 SplatValue->getValue().uge(VTy->getElementCount().getKnownMinValue()))
1864 return std::nullopt;
1865
1866 // Convert sve_tbl(OpVal sve_dup_x(SplatValue)) to
1867 // splat_vector(extractelement(OpVal, SplatValue)) for further optimization.
1868 auto *Extract = IC.Builder.CreateExtractElement(OpVal, SplatValue);
1869 auto *VectorSplat =
1870 IC.Builder.CreateVectorSplat(VTy->getElementCount(), Extract);
1871
1872 VectorSplat->takeName(&II);
1873 return IC.replaceInstUsesWith(II, VectorSplat);
1874}
1875
1876static std::optional<Instruction *> instCombineSVEUzp1(InstCombiner &IC,
1877 IntrinsicInst &II) {
1878 Value *A, *B;
1879 Type *RetTy = II.getType();
1880 constexpr Intrinsic::ID FromSVB = Intrinsic::aarch64_sve_convert_from_svbool;
1881 constexpr Intrinsic::ID ToSVB = Intrinsic::aarch64_sve_convert_to_svbool;
1882
1883 // uzp1(to_svbool(A), to_svbool(B)) --> <A, B>
1884 // uzp1(from_svbool(to_svbool(A)), from_svbool(to_svbool(B))) --> <A, B>
1885 if ((match(II.getArgOperand(0),
1886 m_Intrinsic<FromSVB>(m_Intrinsic<ToSVB>(m_Value(A)))) &&
1887 match(II.getArgOperand(1),
1888 m_Intrinsic<FromSVB>(m_Intrinsic<ToSVB>(m_Value(B))))) ||
1889 (match(II.getArgOperand(0), m_Intrinsic<ToSVB>(m_Value(A))) &&
1890 match(II.getArgOperand(1), m_Intrinsic<ToSVB>(m_Value(B))))) {
1891 auto *TyA = cast<ScalableVectorType>(A->getType());
1892 if (TyA == B->getType() &&
1894 auto *SubVec = IC.Builder.CreateInsertVector(
1896 auto *ConcatVec = IC.Builder.CreateInsertVector(
1897 RetTy, SubVec, B, IC.Builder.getInt64(TyA->getMinNumElements()));
1898 ConcatVec->takeName(&II);
1899 return IC.replaceInstUsesWith(II, ConcatVec);
1900 }
1901 }
1902
1903 return std::nullopt;
1904}
1905
1906static std::optional<Instruction *> instCombineSVEZip(InstCombiner &IC,
1907 IntrinsicInst &II) {
1908 // zip1(uzp1(A, B), uzp2(A, B)) --> A
1909 // zip2(uzp1(A, B), uzp2(A, B)) --> B
1910 Value *A, *B;
1911 if (match(II.getArgOperand(0),
1912 m_Intrinsic<Intrinsic::aarch64_sve_uzp1>(m_Value(A), m_Value(B))) &&
1913 match(II.getArgOperand(1), m_Intrinsic<Intrinsic::aarch64_sve_uzp2>(
1914 m_Specific(A), m_Specific(B))))
1915 return IC.replaceInstUsesWith(
1916 II, (II.getIntrinsicID() == Intrinsic::aarch64_sve_zip1 ? A : B));
1917
1918 return std::nullopt;
1919}
1920
1921static std::optional<Instruction *>
1923 Value *Mask = II.getOperand(0);
1924 Value *BasePtr = II.getOperand(1);
1925 Value *Index = II.getOperand(2);
1926 Type *Ty = II.getType();
1927 Value *PassThru = ConstantAggregateZero::get(Ty);
1928
1929 // Replace by zero constant when all lanes are inactive
1930 if (auto II_NA = instCombineSVENoActiveZero(IC, II))
1931 return II_NA;
1932
1933 // Contiguous gather => masked load.
1934 // (sve.ld1.gather.index Mask BasePtr (sve.index IndexBase 1))
1935 // => (masked.load (gep BasePtr IndexBase) Align Mask zeroinitializer)
1936 Value *IndexBase;
1937 if (match(Index, m_Intrinsic<Intrinsic::aarch64_sve_index>(
1938 m_Value(IndexBase), m_SpecificInt(1)))) {
1939 Align Alignment =
1940 BasePtr->getPointerAlignment(II.getDataLayout());
1941
1942 Type *VecPtrTy = PointerType::getUnqual(Ty);
1943 Value *Ptr = IC.Builder.CreateGEP(cast<VectorType>(Ty)->getElementType(),
1944 BasePtr, IndexBase);
1945 Ptr = IC.Builder.CreateBitCast(Ptr, VecPtrTy);
1946 CallInst *MaskedLoad =
1947 IC.Builder.CreateMaskedLoad(Ty, Ptr, Alignment, Mask, PassThru);
1948 MaskedLoad->takeName(&II);
1949 return IC.replaceInstUsesWith(II, MaskedLoad);
1950 }
1951
1952 return std::nullopt;
1953}
1954
1955static std::optional<Instruction *>
1957 Value *Val = II.getOperand(0);
1958 Value *Mask = II.getOperand(1);
1959 Value *BasePtr = II.getOperand(2);
1960 Value *Index = II.getOperand(3);
1961 Type *Ty = Val->getType();
1962
1963 // Contiguous scatter => masked store.
1964 // (sve.st1.scatter.index Value Mask BasePtr (sve.index IndexBase 1))
1965 // => (masked.store Value (gep BasePtr IndexBase) Align Mask)
1966 Value *IndexBase;
1967 if (match(Index, m_Intrinsic<Intrinsic::aarch64_sve_index>(
1968 m_Value(IndexBase), m_SpecificInt(1)))) {
1969 Align Alignment =
1970 BasePtr->getPointerAlignment(II.getDataLayout());
1971
1972 Value *Ptr = IC.Builder.CreateGEP(cast<VectorType>(Ty)->getElementType(),
1973 BasePtr, IndexBase);
1974 Type *VecPtrTy = PointerType::getUnqual(Ty);
1975 Ptr = IC.Builder.CreateBitCast(Ptr, VecPtrTy);
1976
1977 (void)IC.Builder.CreateMaskedStore(Val, Ptr, Alignment, Mask);
1978
1979 return IC.eraseInstFromFunction(II);
1980 }
1981
1982 return std::nullopt;
1983}
1984
1985static std::optional<Instruction *> instCombineSVESDIV(InstCombiner &IC,
1986 IntrinsicInst &II) {
1987 Type *Int32Ty = IC.Builder.getInt32Ty();
1988 Value *Pred = II.getOperand(0);
1989 Value *Vec = II.getOperand(1);
1990 Value *DivVec = II.getOperand(2);
1991
1992 Value *SplatValue = getSplatValue(DivVec);
1993 ConstantInt *SplatConstantInt = dyn_cast_or_null<ConstantInt>(SplatValue);
1994 if (!SplatConstantInt)
1995 return std::nullopt;
1996
1997 APInt Divisor = SplatConstantInt->getValue();
1998 const int64_t DivisorValue = Divisor.getSExtValue();
1999 if (DivisorValue == -1)
2000 return std::nullopt;
2001 if (DivisorValue == 1)
2002 IC.replaceInstUsesWith(II, Vec);
2003
2004 if (Divisor.isPowerOf2()) {
2005 Constant *DivisorLog2 = ConstantInt::get(Int32Ty, Divisor.logBase2());
2006 auto ASRD = IC.Builder.CreateIntrinsic(
2007 Intrinsic::aarch64_sve_asrd, {II.getType()}, {Pred, Vec, DivisorLog2});
2008 return IC.replaceInstUsesWith(II, ASRD);
2009 }
2010 if (Divisor.isNegatedPowerOf2()) {
2011 Divisor.negate();
2012 Constant *DivisorLog2 = ConstantInt::get(Int32Ty, Divisor.logBase2());
2013 auto ASRD = IC.Builder.CreateIntrinsic(
2014 Intrinsic::aarch64_sve_asrd, {II.getType()}, {Pred, Vec, DivisorLog2});
2015 auto NEG = IC.Builder.CreateIntrinsic(
2016 Intrinsic::aarch64_sve_neg, {ASRD->getType()}, {ASRD, Pred, ASRD});
2017 return IC.replaceInstUsesWith(II, NEG);
2018 }
2019
2020 return std::nullopt;
2021}
2022
2023bool SimplifyValuePattern(SmallVector<Value *> &Vec, bool AllowPoison) {
2024 size_t VecSize = Vec.size();
2025 if (VecSize == 1)
2026 return true;
2027 if (!isPowerOf2_64(VecSize))
2028 return false;
2029 size_t HalfVecSize = VecSize / 2;
2030
2031 for (auto LHS = Vec.begin(), RHS = Vec.begin() + HalfVecSize;
2032 RHS != Vec.end(); LHS++, RHS++) {
2033 if (*LHS != nullptr && *RHS != nullptr) {
2034 if (*LHS == *RHS)
2035 continue;
2036 else
2037 return false;
2038 }
2039 if (!AllowPoison)
2040 return false;
2041 if (*LHS == nullptr && *RHS != nullptr)
2042 *LHS = *RHS;
2043 }
2044
2045 Vec.resize(HalfVecSize);
2046 SimplifyValuePattern(Vec, AllowPoison);
2047 return true;
2048}
2049
2050// Try to simplify dupqlane patterns like dupqlane(f32 A, f32 B, f32 A, f32 B)
2051// to dupqlane(f64(C)) where C is A concatenated with B
2052static std::optional<Instruction *> instCombineSVEDupqLane(InstCombiner &IC,
2053 IntrinsicInst &II) {
2054 Value *CurrentInsertElt = nullptr, *Default = nullptr;
2055 if (!match(II.getOperand(0),
2056 m_Intrinsic<Intrinsic::vector_insert>(
2057 m_Value(Default), m_Value(CurrentInsertElt), m_Value())) ||
2058 !isa<FixedVectorType>(CurrentInsertElt->getType()))
2059 return std::nullopt;
2060 auto IIScalableTy = cast<ScalableVectorType>(II.getType());
2061
2062 // Insert the scalars into a container ordered by InsertElement index
2063 SmallVector<Value *> Elts(IIScalableTy->getMinNumElements(), nullptr);
2064 while (auto InsertElt = dyn_cast<InsertElementInst>(CurrentInsertElt)) {
2065 auto Idx = cast<ConstantInt>(InsertElt->getOperand(2));
2066 Elts[Idx->getValue().getZExtValue()] = InsertElt->getOperand(1);
2067 CurrentInsertElt = InsertElt->getOperand(0);
2068 }
2069
2070 bool AllowPoison =
2071 isa<PoisonValue>(CurrentInsertElt) && isa<PoisonValue>(Default);
2072 if (!SimplifyValuePattern(Elts, AllowPoison))
2073 return std::nullopt;
2074
2075 // Rebuild the simplified chain of InsertElements. e.g. (a, b, a, b) as (a, b)
2076 Value *InsertEltChain = PoisonValue::get(CurrentInsertElt->getType());
2077 for (size_t I = 0; I < Elts.size(); I++) {
2078 if (Elts[I] == nullptr)
2079 continue;
2080 InsertEltChain = IC.Builder.CreateInsertElement(InsertEltChain, Elts[I],
2081 IC.Builder.getInt64(I));
2082 }
2083 if (InsertEltChain == nullptr)
2084 return std::nullopt;
2085
2086 // Splat the simplified sequence, e.g. (f16 a, f16 b, f16 c, f16 d) as one i64
2087 // value or (f16 a, f16 b) as one i32 value. This requires an InsertSubvector
2088 // be bitcast to a type wide enough to fit the sequence, be splatted, and then
2089 // be narrowed back to the original type.
2090 unsigned PatternWidth = IIScalableTy->getScalarSizeInBits() * Elts.size();
2091 unsigned PatternElementCount = IIScalableTy->getScalarSizeInBits() *
2092 IIScalableTy->getMinNumElements() /
2093 PatternWidth;
2094
2095 IntegerType *WideTy = IC.Builder.getIntNTy(PatternWidth);
2096 auto *WideScalableTy = ScalableVectorType::get(WideTy, PatternElementCount);
2097 auto *WideShuffleMaskTy =
2098 ScalableVectorType::get(IC.Builder.getInt32Ty(), PatternElementCount);
2099
2100 auto ZeroIdx = ConstantInt::get(IC.Builder.getInt64Ty(), APInt(64, 0));
2101 auto InsertSubvector = IC.Builder.CreateInsertVector(
2102 II.getType(), PoisonValue::get(II.getType()), InsertEltChain, ZeroIdx);
2103 auto WideBitcast =
2104 IC.Builder.CreateBitOrPointerCast(InsertSubvector, WideScalableTy);
2105 auto WideShuffleMask = ConstantAggregateZero::get(WideShuffleMaskTy);
2106 auto WideShuffle = IC.Builder.CreateShuffleVector(
2107 WideBitcast, PoisonValue::get(WideScalableTy), WideShuffleMask);
2108 auto NarrowBitcast =
2109 IC.Builder.CreateBitOrPointerCast(WideShuffle, II.getType());
2110
2111 return IC.replaceInstUsesWith(II, NarrowBitcast);
2112}
2113
2114static std::optional<Instruction *> instCombineMaxMinNM(InstCombiner &IC,
2115 IntrinsicInst &II) {
2116 Value *A = II.getArgOperand(0);
2117 Value *B = II.getArgOperand(1);
2118 if (A == B)
2119 return IC.replaceInstUsesWith(II, A);
2120
2121 return std::nullopt;
2122}
2123
2124static std::optional<Instruction *> instCombineSVESrshl(InstCombiner &IC,
2125 IntrinsicInst &II) {
2126 Value *Pred = II.getOperand(0);
2127 Value *Vec = II.getOperand(1);
2128 Value *Shift = II.getOperand(2);
2129
2130 // Convert SRSHL into the simpler LSL intrinsic when fed by an ABS intrinsic.
2131 Value *AbsPred, *MergedValue;
2132 if (!match(Vec, m_Intrinsic<Intrinsic::aarch64_sve_sqabs>(
2133 m_Value(MergedValue), m_Value(AbsPred), m_Value())) &&
2134 !match(Vec, m_Intrinsic<Intrinsic::aarch64_sve_abs>(
2135 m_Value(MergedValue), m_Value(AbsPred), m_Value())))
2136
2137 return std::nullopt;
2138
2139 // Transform is valid if any of the following are true:
2140 // * The ABS merge value is an undef or non-negative
2141 // * The ABS predicate is all active
2142 // * The ABS predicate and the SRSHL predicates are the same
2143 if (!isa<UndefValue>(MergedValue) && !match(MergedValue, m_NonNegative()) &&
2144 AbsPred != Pred && !isAllActivePredicate(AbsPred))
2145 return std::nullopt;
2146
2147 // Only valid when the shift amount is non-negative, otherwise the rounding
2148 // behaviour of SRSHL cannot be ignored.
2149 if (!match(Shift, m_NonNegative()))
2150 return std::nullopt;
2151
2152 auto LSL = IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_lsl,
2153 {II.getType()}, {Pred, Vec, Shift});
2154
2155 return IC.replaceInstUsesWith(II, LSL);
2156}
2157
2158static std::optional<Instruction *> instCombineSVEInsr(InstCombiner &IC,
2159 IntrinsicInst &II) {
2160 Value *Vec = II.getOperand(0);
2161
2162 if (getSplatValue(Vec) == II.getOperand(1))
2163 return IC.replaceInstUsesWith(II, Vec);
2164
2165 return std::nullopt;
2166}
2167
2168static std::optional<Instruction *> instCombineDMB(InstCombiner &IC,
2169 IntrinsicInst &II) {
2170 // If this barrier is post-dominated by identical one we can remove it
2171 auto *NI = II.getNextNonDebugInstruction();
2172 unsigned LookaheadThreshold = DMBLookaheadThreshold;
2173 auto CanSkipOver = [](Instruction *I) {
2174 return !I->mayReadOrWriteMemory() && !I->mayHaveSideEffects();
2175 };
2176 while (LookaheadThreshold-- && CanSkipOver(NI)) {
2177 auto *NIBB = NI->getParent();
2178 NI = NI->getNextNonDebugInstruction();
2179 if (!NI) {
2180 if (auto *SuccBB = NIBB->getUniqueSuccessor())
2181 NI = SuccBB->getFirstNonPHIOrDbgOrLifetime();
2182 else
2183 break;
2184 }
2185 }
2186 auto *NextII = dyn_cast_or_null<IntrinsicInst>(NI);
2187 if (NextII && II.isIdenticalTo(NextII))
2188 return IC.eraseInstFromFunction(II);
2189
2190 return std::nullopt;
2191}
2192
2193std::optional<Instruction *>
2195 IntrinsicInst &II) const {
2196 Intrinsic::ID IID = II.getIntrinsicID();
2197 switch (IID) {
2198 default:
2199 break;
2200 case Intrinsic::aarch64_dmb:
2201 return instCombineDMB(IC, II);
2202 case Intrinsic::aarch64_sve_fcvt_bf16f32_v2:
2203 case Intrinsic::aarch64_sve_fcvt_f16f32:
2204 case Intrinsic::aarch64_sve_fcvt_f16f64:
2205 case Intrinsic::aarch64_sve_fcvt_f32f16:
2206 case Intrinsic::aarch64_sve_fcvt_f32f64:
2207 case Intrinsic::aarch64_sve_fcvt_f64f16:
2208 case Intrinsic::aarch64_sve_fcvt_f64f32:
2209 case Intrinsic::aarch64_sve_fcvtlt_f32f16:
2210 case Intrinsic::aarch64_sve_fcvtlt_f64f32:
2211 case Intrinsic::aarch64_sve_fcvtx_f32f64:
2212 case Intrinsic::aarch64_sve_fcvtzs:
2213 case Intrinsic::aarch64_sve_fcvtzs_i32f16:
2214 case Intrinsic::aarch64_sve_fcvtzs_i32f64:
2215 case Intrinsic::aarch64_sve_fcvtzs_i64f16:
2216 case Intrinsic::aarch64_sve_fcvtzs_i64f32:
2217 case Intrinsic::aarch64_sve_fcvtzu:
2218 case Intrinsic::aarch64_sve_fcvtzu_i32f16:
2219 case Intrinsic::aarch64_sve_fcvtzu_i32f64:
2220 case Intrinsic::aarch64_sve_fcvtzu_i64f16:
2221 case Intrinsic::aarch64_sve_fcvtzu_i64f32:
2222 case Intrinsic::aarch64_sve_scvtf:
2223 case Intrinsic::aarch64_sve_scvtf_f16i32:
2224 case Intrinsic::aarch64_sve_scvtf_f16i64:
2225 case Intrinsic::aarch64_sve_scvtf_f32i64:
2226 case Intrinsic::aarch64_sve_scvtf_f64i32:
2227 case Intrinsic::aarch64_sve_ucvtf:
2228 case Intrinsic::aarch64_sve_ucvtf_f16i32:
2229 case Intrinsic::aarch64_sve_ucvtf_f16i64:
2230 case Intrinsic::aarch64_sve_ucvtf_f32i64:
2231 case Intrinsic::aarch64_sve_ucvtf_f64i32:
2233 case Intrinsic::aarch64_sve_fcvtnt_bf16f32_v2:
2234 case Intrinsic::aarch64_sve_fcvtnt_f16f32:
2235 case Intrinsic::aarch64_sve_fcvtnt_f32f64:
2236 case Intrinsic::aarch64_sve_fcvtxnt_f32f64:
2237 return instCombineSVENoActiveReplace(IC, II, true);
2238 case Intrinsic::aarch64_sve_st1_scatter:
2239 case Intrinsic::aarch64_sve_st1_scatter_scalar_offset:
2240 case Intrinsic::aarch64_sve_st1_scatter_sxtw:
2241 case Intrinsic::aarch64_sve_st1_scatter_sxtw_index:
2242 case Intrinsic::aarch64_sve_st1_scatter_uxtw:
2243 case Intrinsic::aarch64_sve_st1_scatter_uxtw_index:
2244 case Intrinsic::aarch64_sve_st1dq:
2245 case Intrinsic::aarch64_sve_st1q_scatter_index:
2246 case Intrinsic::aarch64_sve_st1q_scatter_scalar_offset:
2247 case Intrinsic::aarch64_sve_st1q_scatter_vector_offset:
2248 case Intrinsic::aarch64_sve_st1wq:
2249 case Intrinsic::aarch64_sve_stnt1:
2250 case Intrinsic::aarch64_sve_stnt1_scatter:
2251 case Intrinsic::aarch64_sve_stnt1_scatter_index:
2252 case Intrinsic::aarch64_sve_stnt1_scatter_scalar_offset:
2253 case Intrinsic::aarch64_sve_stnt1_scatter_uxtw:
2254 return instCombineSVENoActiveUnaryErase(IC, II, 1);
2255 case Intrinsic::aarch64_sve_st2:
2256 case Intrinsic::aarch64_sve_st2q:
2257 return instCombineSVENoActiveUnaryErase(IC, II, 2);
2258 case Intrinsic::aarch64_sve_st3:
2259 case Intrinsic::aarch64_sve_st3q:
2260 return instCombineSVENoActiveUnaryErase(IC, II, 3);
2261 case Intrinsic::aarch64_sve_st4:
2262 case Intrinsic::aarch64_sve_st4q:
2263 return instCombineSVENoActiveUnaryErase(IC, II, 4);
2264 case Intrinsic::aarch64_sve_addqv:
2265 case Intrinsic::aarch64_sve_and_z:
2266 case Intrinsic::aarch64_sve_bic_z:
2267 case Intrinsic::aarch64_sve_brka_z:
2268 case Intrinsic::aarch64_sve_brkb_z:
2269 case Intrinsic::aarch64_sve_brkn_z:
2270 case Intrinsic::aarch64_sve_brkpa_z:
2271 case Intrinsic::aarch64_sve_brkpb_z:
2272 case Intrinsic::aarch64_sve_cntp:
2273 case Intrinsic::aarch64_sve_compact:
2274 case Intrinsic::aarch64_sve_eor_z:
2275 case Intrinsic::aarch64_sve_eorv:
2276 case Intrinsic::aarch64_sve_eorqv:
2277 case Intrinsic::aarch64_sve_nand_z:
2278 case Intrinsic::aarch64_sve_nor_z:
2279 case Intrinsic::aarch64_sve_orn_z:
2280 case Intrinsic::aarch64_sve_orr_z:
2281 case Intrinsic::aarch64_sve_orv:
2282 case Intrinsic::aarch64_sve_orqv:
2283 case Intrinsic::aarch64_sve_pnext:
2284 case Intrinsic::aarch64_sve_rdffr_z:
2285 case Intrinsic::aarch64_sve_saddv:
2286 case Intrinsic::aarch64_sve_uaddv:
2287 case Intrinsic::aarch64_sve_umaxv:
2288 case Intrinsic::aarch64_sve_umaxqv:
2289 case Intrinsic::aarch64_sve_cmpeq:
2290 case Intrinsic::aarch64_sve_cmpeq_wide:
2291 case Intrinsic::aarch64_sve_cmpge:
2292 case Intrinsic::aarch64_sve_cmpge_wide:
2293 case Intrinsic::aarch64_sve_cmpgt:
2294 case Intrinsic::aarch64_sve_cmpgt_wide:
2295 case Intrinsic::aarch64_sve_cmphi:
2296 case Intrinsic::aarch64_sve_cmphi_wide:
2297 case Intrinsic::aarch64_sve_cmphs:
2298 case Intrinsic::aarch64_sve_cmphs_wide:
2299 case Intrinsic::aarch64_sve_cmple_wide:
2300 case Intrinsic::aarch64_sve_cmplo_wide:
2301 case Intrinsic::aarch64_sve_cmpls_wide:
2302 case Intrinsic::aarch64_sve_cmplt_wide:
2303 case Intrinsic::aarch64_sve_facge:
2304 case Intrinsic::aarch64_sve_facgt:
2305 case Intrinsic::aarch64_sve_fcmpeq:
2306 case Intrinsic::aarch64_sve_fcmpge:
2307 case Intrinsic::aarch64_sve_fcmpgt:
2308 case Intrinsic::aarch64_sve_fcmpne:
2309 case Intrinsic::aarch64_sve_fcmpuo:
2310 case Intrinsic::aarch64_sve_ld1_gather:
2311 case Intrinsic::aarch64_sve_ld1_gather_scalar_offset:
2312 case Intrinsic::aarch64_sve_ld1_gather_sxtw:
2313 case Intrinsic::aarch64_sve_ld1_gather_sxtw_index:
2314 case Intrinsic::aarch64_sve_ld1_gather_uxtw:
2315 case Intrinsic::aarch64_sve_ld1_gather_uxtw_index:
2316 case Intrinsic::aarch64_sve_ld1q_gather_index:
2317 case Intrinsic::aarch64_sve_ld1q_gather_scalar_offset:
2318 case Intrinsic::aarch64_sve_ld1q_gather_vector_offset:
2319 case Intrinsic::aarch64_sve_ld1ro:
2320 case Intrinsic::aarch64_sve_ld1rq:
2321 case Intrinsic::aarch64_sve_ld1udq:
2322 case Intrinsic::aarch64_sve_ld1uwq:
2323 case Intrinsic::aarch64_sve_ld2_sret:
2324 case Intrinsic::aarch64_sve_ld2q_sret:
2325 case Intrinsic::aarch64_sve_ld3_sret:
2326 case Intrinsic::aarch64_sve_ld3q_sret:
2327 case Intrinsic::aarch64_sve_ld4_sret:
2328 case Intrinsic::aarch64_sve_ld4q_sret:
2329 case Intrinsic::aarch64_sve_ldff1:
2330 case Intrinsic::aarch64_sve_ldff1_gather:
2331 case Intrinsic::aarch64_sve_ldff1_gather_index:
2332 case Intrinsic::aarch64_sve_ldff1_gather_scalar_offset:
2333 case Intrinsic::aarch64_sve_ldff1_gather_sxtw:
2334 case Intrinsic::aarch64_sve_ldff1_gather_sxtw_index:
2335 case Intrinsic::aarch64_sve_ldff1_gather_uxtw:
2336 case Intrinsic::aarch64_sve_ldff1_gather_uxtw_index:
2337 case Intrinsic::aarch64_sve_ldnf1:
2338 case Intrinsic::aarch64_sve_ldnt1:
2339 case Intrinsic::aarch64_sve_ldnt1_gather:
2340 case Intrinsic::aarch64_sve_ldnt1_gather_index:
2341 case Intrinsic::aarch64_sve_ldnt1_gather_scalar_offset:
2342 case Intrinsic::aarch64_sve_ldnt1_gather_uxtw:
2343 return instCombineSVENoActiveZero(IC, II);
2344 case Intrinsic::aarch64_sve_prf:
2345 case Intrinsic::aarch64_sve_prfb_gather_index:
2346 case Intrinsic::aarch64_sve_prfb_gather_scalar_offset:
2347 case Intrinsic::aarch64_sve_prfb_gather_sxtw_index:
2348 case Intrinsic::aarch64_sve_prfb_gather_uxtw_index:
2349 case Intrinsic::aarch64_sve_prfd_gather_index:
2350 case Intrinsic::aarch64_sve_prfd_gather_scalar_offset:
2351 case Intrinsic::aarch64_sve_prfd_gather_sxtw_index:
2352 case Intrinsic::aarch64_sve_prfd_gather_uxtw_index:
2353 case Intrinsic::aarch64_sve_prfh_gather_index:
2354 case Intrinsic::aarch64_sve_prfh_gather_scalar_offset:
2355 case Intrinsic::aarch64_sve_prfh_gather_sxtw_index:
2356 case Intrinsic::aarch64_sve_prfh_gather_uxtw_index:
2357 case Intrinsic::aarch64_sve_prfw_gather_index:
2358 case Intrinsic::aarch64_sve_prfw_gather_scalar_offset:
2359 case Intrinsic::aarch64_sve_prfw_gather_sxtw_index:
2360 case Intrinsic::aarch64_sve_prfw_gather_uxtw_index:
2361 return instCombineSVENoActiveUnaryErase(IC, II, 0);
2362 case Intrinsic::aarch64_neon_fmaxnm:
2363 case Intrinsic::aarch64_neon_fminnm:
2364 return instCombineMaxMinNM(IC, II);
2365 case Intrinsic::aarch64_sve_convert_from_svbool:
2366 return instCombineConvertFromSVBool(IC, II);
2367 case Intrinsic::aarch64_sve_dup:
2368 return instCombineSVEDup(IC, II);
2369 case Intrinsic::aarch64_sve_dup_x:
2370 return instCombineSVEDupX(IC, II);
2371 case Intrinsic::aarch64_sve_cmpne:
2372 case Intrinsic::aarch64_sve_cmpne_wide:
2373 return instCombineSVECmpNE(IC, II);
2374 case Intrinsic::aarch64_sve_rdffr:
2375 return instCombineRDFFR(IC, II);
2376 case Intrinsic::aarch64_sve_lasta:
2377 case Intrinsic::aarch64_sve_lastb:
2378 return instCombineSVELast(IC, II);
2379 case Intrinsic::aarch64_sve_clasta_n:
2380 case Intrinsic::aarch64_sve_clastb_n:
2381 return instCombineSVECondLast(IC, II);
2382 case Intrinsic::aarch64_sve_cntd:
2383 return instCombineSVECntElts(IC, II, 2);
2384 case Intrinsic::aarch64_sve_cntw:
2385 return instCombineSVECntElts(IC, II, 4);
2386 case Intrinsic::aarch64_sve_cnth:
2387 return instCombineSVECntElts(IC, II, 8);
2388 case Intrinsic::aarch64_sve_cntb:
2389 return instCombineSVECntElts(IC, II, 16);
2390 case Intrinsic::aarch64_sve_ptest_any:
2391 case Intrinsic::aarch64_sve_ptest_first:
2392 case Intrinsic::aarch64_sve_ptest_last:
2393 return instCombineSVEPTest(IC, II);
2394 case Intrinsic::aarch64_sve_fabd:
2395 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fabd_u);
2396 case Intrinsic::aarch64_sve_fadd:
2397 return instCombineSVEVectorFAdd(IC, II);
2398 case Intrinsic::aarch64_sve_fadd_u:
2399 return instCombineSVEVectorFAddU(IC, II);
2400 case Intrinsic::aarch64_sve_fdiv:
2401 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fdiv_u);
2402 case Intrinsic::aarch64_sve_fmax:
2403 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fmax_u);
2404 case Intrinsic::aarch64_sve_fmaxnm:
2405 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fmaxnm_u);
2406 case Intrinsic::aarch64_sve_fmin:
2407 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fmin_u);
2408 case Intrinsic::aarch64_sve_fminnm:
2409 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fminnm_u);
2410 case Intrinsic::aarch64_sve_fmla:
2411 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fmla_u);
2412 case Intrinsic::aarch64_sve_fmls:
2413 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fmls_u);
2414 case Intrinsic::aarch64_sve_fmul:
2415 if (auto II_U =
2416 instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fmul_u))
2417 return II_U;
2418 return instCombineSVEVectorMul(IC, II, Intrinsic::aarch64_sve_fmul_u);
2419 case Intrinsic::aarch64_sve_fmul_u:
2420 return instCombineSVEVectorMul(IC, II, Intrinsic::aarch64_sve_fmul_u);
2421 case Intrinsic::aarch64_sve_fmulx:
2422 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fmulx_u);
2423 case Intrinsic::aarch64_sve_fnmla:
2424 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fnmla_u);
2425 case Intrinsic::aarch64_sve_fnmls:
2426 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fnmls_u);
2427 case Intrinsic::aarch64_sve_fsub:
2428 return instCombineSVEVectorFSub(IC, II);
2429 case Intrinsic::aarch64_sve_fsub_u:
2430 return instCombineSVEVectorFSubU(IC, II);
2431 case Intrinsic::aarch64_sve_add:
2432 return instCombineSVEVectorAdd(IC, II);
2433 case Intrinsic::aarch64_sve_add_u:
2434 return instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_mul_u,
2435 Intrinsic::aarch64_sve_mla_u>(
2436 IC, II, true);
2437 case Intrinsic::aarch64_sve_mla:
2438 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_mla_u);
2439 case Intrinsic::aarch64_sve_mls:
2440 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_mls_u);
2441 case Intrinsic::aarch64_sve_mul:
2442 if (auto II_U =
2443 instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_mul_u))
2444 return II_U;
2445 return instCombineSVEVectorMul(IC, II, Intrinsic::aarch64_sve_mul_u);
2446 case Intrinsic::aarch64_sve_mul_u:
2447 return instCombineSVEVectorMul(IC, II, Intrinsic::aarch64_sve_mul_u);
2448 case Intrinsic::aarch64_sve_sabd:
2449 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_sabd_u);
2450 case Intrinsic::aarch64_sve_smax:
2451 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_smax_u);
2452 case Intrinsic::aarch64_sve_smin:
2453 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_smin_u);
2454 case Intrinsic::aarch64_sve_smulh:
2455 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_smulh_u);
2456 case Intrinsic::aarch64_sve_sub:
2457 return instCombineSVEVectorSub(IC, II);
2458 case Intrinsic::aarch64_sve_sub_u:
2459 return instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_mul_u,
2460 Intrinsic::aarch64_sve_mls_u>(
2461 IC, II, true);
2462 case Intrinsic::aarch64_sve_uabd:
2463 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_uabd_u);
2464 case Intrinsic::aarch64_sve_umax:
2465 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_umax_u);
2466 case Intrinsic::aarch64_sve_umin:
2467 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_umin_u);
2468 case Intrinsic::aarch64_sve_umulh:
2469 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_umulh_u);
2470 case Intrinsic::aarch64_sve_asr:
2471 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_asr_u);
2472 case Intrinsic::aarch64_sve_lsl:
2473 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_lsl_u);
2474 case Intrinsic::aarch64_sve_lsr:
2475 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_lsr_u);
2476 case Intrinsic::aarch64_sve_and:
2477 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_and_u);
2478 case Intrinsic::aarch64_sve_bic:
2479 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_bic_u);
2480 case Intrinsic::aarch64_sve_eor:
2481 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_eor_u);
2482 case Intrinsic::aarch64_sve_orr:
2483 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_orr_u);
2484 case Intrinsic::aarch64_sve_sqsub:
2485 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_sqsub_u);
2486 case Intrinsic::aarch64_sve_uqsub:
2487 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_uqsub_u);
2488 case Intrinsic::aarch64_sve_tbl:
2489 return instCombineSVETBL(IC, II);
2490 case Intrinsic::aarch64_sve_uunpkhi:
2491 case Intrinsic::aarch64_sve_uunpklo:
2492 case Intrinsic::aarch64_sve_sunpkhi:
2493 case Intrinsic::aarch64_sve_sunpklo:
2494 return instCombineSVEUnpack(IC, II);
2495 case Intrinsic::aarch64_sve_uzp1:
2496 return instCombineSVEUzp1(IC, II);
2497 case Intrinsic::aarch64_sve_zip1:
2498 case Intrinsic::aarch64_sve_zip2:
2499 return instCombineSVEZip(IC, II);
2500 case Intrinsic::aarch64_sve_ld1_gather_index:
2501 return instCombineLD1GatherIndex(IC, II);
2502 case Intrinsic::aarch64_sve_st1_scatter_index:
2503 return instCombineST1ScatterIndex(IC, II);
2504 case Intrinsic::aarch64_sve_ld1:
2505 return instCombineSVELD1(IC, II, DL);
2506 case Intrinsic::aarch64_sve_st1:
2507 return instCombineSVEST1(IC, II, DL);
2508 case Intrinsic::aarch64_sve_sdiv:
2509 return instCombineSVESDIV(IC, II);
2510 case Intrinsic::aarch64_sve_sel:
2511 return instCombineSVESel(IC, II);
2512 case Intrinsic::aarch64_sve_srshl:
2513 return instCombineSVESrshl(IC, II);
2514 case Intrinsic::aarch64_sve_dupq_lane:
2515 return instCombineSVEDupqLane(IC, II);
2516 case Intrinsic::aarch64_sve_insr:
2517 return instCombineSVEInsr(IC, II);
2518 }
2519
2520 return std::nullopt;
2521}
2522
2524 InstCombiner &IC, IntrinsicInst &II, APInt OrigDemandedElts,
2525 APInt &UndefElts, APInt &UndefElts2, APInt &UndefElts3,
2526 std::function<void(Instruction *, unsigned, APInt, APInt &)>
2527 SimplifyAndSetOp) const {
2528 switch (II.getIntrinsicID()) {
2529 default:
2530 break;
2531 case Intrinsic::aarch64_neon_fcvtxn:
2532 case Intrinsic::aarch64_neon_rshrn:
2533 case Intrinsic::aarch64_neon_sqrshrn:
2534 case Intrinsic::aarch64_neon_sqrshrun:
2535 case Intrinsic::aarch64_neon_sqshrn:
2536 case Intrinsic::aarch64_neon_sqshrun:
2537 case Intrinsic::aarch64_neon_sqxtn:
2538 case Intrinsic::aarch64_neon_sqxtun:
2539 case Intrinsic::aarch64_neon_uqrshrn:
2540 case Intrinsic::aarch64_neon_uqshrn:
2541 case Intrinsic::aarch64_neon_uqxtn:
2542 SimplifyAndSetOp(&II, 0, OrigDemandedElts, UndefElts);
2543 break;
2544 }
2545
2546 return std::nullopt;
2547}
2548
2550 return ST->isSVEAvailable() || (ST->isSVEorStreamingSVEAvailable() &&
2552}
2553
2556 switch (K) {
2558 return TypeSize::getFixed(64);
2560 if (ST->useSVEForFixedLengthVectors() &&
2562 return TypeSize::getFixed(
2563 std::max(ST->getMinSVEVectorSizeInBits(), 128u));
2564 else if (ST->isNeonAvailable())
2565 return TypeSize::getFixed(128);
2566 else
2567 return TypeSize::getFixed(0);
2569 if (ST->isSVEAvailable() || (ST->isSVEorStreamingSVEAvailable() &&
2571 return TypeSize::getScalable(128);
2572 else
2573 return TypeSize::getScalable(0);
2574 }
2575 llvm_unreachable("Unsupported register kind");
2576}
2577
2578bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,
2580 Type *SrcOverrideTy) {
2581 // A helper that returns a vector type from the given type. The number of
2582 // elements in type Ty determines the vector width.
2583 auto toVectorTy = [&](Type *ArgTy) {
2584 return VectorType::get(ArgTy->getScalarType(),
2585 cast<VectorType>(DstTy)->getElementCount());
2586 };
2587
2588 // Exit early if DstTy is not a vector type whose elements are one of [i16,
2589 // i32, i64]. SVE doesn't generally have the same set of instructions to
2590 // perform an extend with the add/sub/mul. There are SMULLB style
2591 // instructions, but they operate on top/bottom, requiring some sort of lane
2592 // interleaving to be used with zext/sext.
2593 unsigned DstEltSize = DstTy->getScalarSizeInBits();
2594 if (!useNeonVector(DstTy) || Args.size() != 2 ||
2595 (DstEltSize != 16 && DstEltSize != 32 && DstEltSize != 64))
2596 return false;
2597
2598 // Determine if the operation has a widening variant. We consider both the
2599 // "long" (e.g., usubl) and "wide" (e.g., usubw) versions of the
2600 // instructions.
2601 //
2602 // TODO: Add additional widening operations (e.g., shl, etc.) once we
2603 // verify that their extending operands are eliminated during code
2604 // generation.
2605 Type *SrcTy = SrcOverrideTy;
2606 switch (Opcode) {
2607 case Instruction::Add: // UADDL(2), SADDL(2), UADDW(2), SADDW(2).
2608 case Instruction::Sub: // USUBL(2), SSUBL(2), USUBW(2), SSUBW(2).
2609 // The second operand needs to be an extend
2610 if (isa<SExtInst>(Args[1]) || isa<ZExtInst>(Args[1])) {
2611 if (!SrcTy)
2612 SrcTy =
2613 toVectorTy(cast<Instruction>(Args[1])->getOperand(0)->getType());
2614 } else
2615 return false;
2616 break;
2617 case Instruction::Mul: { // SMULL(2), UMULL(2)
2618 // Both operands need to be extends of the same type.
2619 if ((isa<SExtInst>(Args[0]) && isa<SExtInst>(Args[1])) ||
2620 (isa<ZExtInst>(Args[0]) && isa<ZExtInst>(Args[1]))) {
2621 if (!SrcTy)
2622 SrcTy =
2623 toVectorTy(cast<Instruction>(Args[0])->getOperand(0)->getType());
2624 } else if (isa<ZExtInst>(Args[0]) || isa<ZExtInst>(Args[1])) {
2625 // If one of the operands is a Zext and the other has enough zero bits to
2626 // be treated as unsigned, we can still general a umull, meaning the zext
2627 // is free.
2628 KnownBits Known =
2629 computeKnownBits(isa<ZExtInst>(Args[0]) ? Args[1] : Args[0], DL);
2630 if (Args[0]->getType()->getScalarSizeInBits() -
2631 Known.Zero.countLeadingOnes() >
2632 DstTy->getScalarSizeInBits() / 2)
2633 return false;
2634 if (!SrcTy)
2635 SrcTy = toVectorTy(Type::getIntNTy(DstTy->getContext(),
2636 DstTy->getScalarSizeInBits() / 2));
2637 } else
2638 return false;
2639 break;
2640 }
2641 default:
2642 return false;
2643 }
2644
2645 // Legalize the destination type and ensure it can be used in a widening
2646 // operation.
2647 auto DstTyL = getTypeLegalizationCost(DstTy);
2648 if (!DstTyL.second.isVector() || DstEltSize != DstTy->getScalarSizeInBits())
2649 return false;
2650
2651 // Legalize the source type and ensure it can be used in a widening
2652 // operation.
2653 assert(SrcTy && "Expected some SrcTy");
2654 auto SrcTyL = getTypeLegalizationCost(SrcTy);
2655 unsigned SrcElTySize = SrcTyL.second.getScalarSizeInBits();
2656 if (!SrcTyL.second.isVector() || SrcElTySize != SrcTy->getScalarSizeInBits())
2657 return false;
2658
2659 // Get the total number of vector elements in the legalized types.
2660 InstructionCost NumDstEls =
2661 DstTyL.first * DstTyL.second.getVectorMinNumElements();
2662 InstructionCost NumSrcEls =
2663 SrcTyL.first * SrcTyL.second.getVectorMinNumElements();
2664
2665 // Return true if the legalized types have the same number of vector elements
2666 // and the destination element type size is twice that of the source type.
2667 return NumDstEls == NumSrcEls && 2 * SrcElTySize == DstEltSize;
2668}
2669
2670// s/urhadd instructions implement the following pattern, making the
2671// extends free:
2672// %x = add ((zext i8 -> i16), 1)
2673// %y = (zext i8 -> i16)
2674// trunc i16 (lshr (add %x, %y), 1) -> i8
2675//
2677 Type *Src) {
2678 // The source should be a legal vector type.
2679 if (!Src->isVectorTy() || !TLI->isTypeLegal(TLI->getValueType(DL, Src)) ||
2680 (Src->isScalableTy() && !ST->hasSVE2()))
2681 return false;
2682
2683 if (ExtUser->getOpcode() != Instruction::Add || !ExtUser->hasOneUse())
2684 return false;
2685
2686 // Look for trunc/shl/add before trying to match the pattern.
2687 const Instruction *Add = ExtUser;
2688 auto *AddUser =
2689 dyn_cast_or_null<Instruction>(Add->getUniqueUndroppableUser());
2690 if (AddUser && AddUser->getOpcode() == Instruction::Add)
2691 Add = AddUser;
2692
2693 auto *Shr = dyn_cast_or_null<Instruction>(Add->getUniqueUndroppableUser());
2694 if (!Shr || Shr->getOpcode() != Instruction::LShr)
2695 return false;
2696
2697 auto *Trunc = dyn_cast_or_null<Instruction>(Shr->getUniqueUndroppableUser());
2698 if (!Trunc || Trunc->getOpcode() != Instruction::Trunc ||
2699 Src->getScalarSizeInBits() !=
2700 cast<CastInst>(Trunc)->getDestTy()->getScalarSizeInBits())
2701 return false;
2702
2703 // Try to match the whole pattern. Ext could be either the first or second
2704 // m_ZExtOrSExt matched.
2705 Instruction *Ex1, *Ex2;
2706 if (!(match(Add, m_c_Add(m_Instruction(Ex1),
2707 m_c_Add(m_Instruction(Ex2), m_SpecificInt(1))))))
2708 return false;
2709
2710 // Ensure both extends are of the same type
2711 if (match(Ex1, m_ZExtOrSExt(m_Value())) &&
2712 Ex1->getOpcode() == Ex2->getOpcode())
2713 return true;
2714
2715 return false;
2716}
2717
2719 Type *Src,
2722 const Instruction *I) {
2723 int ISD = TLI->InstructionOpcodeToISD(Opcode);
2724 assert(ISD && "Invalid opcode");
2725 // If the cast is observable, and it is used by a widening instruction (e.g.,
2726 // uaddl, saddw, etc.), it may be free.
2727 if (I && I->hasOneUser()) {
2728 auto *SingleUser = cast<Instruction>(*I->user_begin());
2729 SmallVector<const Value *, 4> Operands(SingleUser->operand_values());
2730 if (isWideningInstruction(Dst, SingleUser->getOpcode(), Operands, Src)) {
2731 // For adds only count the second operand as free if both operands are
2732 // extends but not the same operation. (i.e both operands are not free in
2733 // add(sext, zext)).
2734 if (SingleUser->getOpcode() == Instruction::Add) {
2735 if (I == SingleUser->getOperand(1) ||
2736 (isa<CastInst>(SingleUser->getOperand(1)) &&
2737 cast<CastInst>(SingleUser->getOperand(1))->getOpcode() == Opcode))
2738 return 0;
2739 } else // Others are free so long as isWideningInstruction returned true.
2740 return 0;
2741 }
2742
2743 // The cast will be free for the s/urhadd instructions
2744 if ((isa<ZExtInst>(I) || isa<SExtInst>(I)) &&
2745 isExtPartOfAvgExpr(SingleUser, Dst, Src))
2746 return 0;
2747 }
2748
2749 // TODO: Allow non-throughput costs that aren't binary.
2750 auto AdjustCost = [&CostKind](InstructionCost Cost) -> InstructionCost {
2752 return Cost == 0 ? 0 : 1;
2753 return Cost;
2754 };
2755
2756 EVT SrcTy = TLI->getValueType(DL, Src);
2757 EVT DstTy = TLI->getValueType(DL, Dst);
2758
2759 if (!SrcTy.isSimple() || !DstTy.isSimple())
2760 return AdjustCost(
2761 BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I));
2762
2763 static const TypeConversionCostTblEntry ConversionTbl[] = {
2764 {ISD::TRUNCATE, MVT::v2i8, MVT::v2i64, 1}, // xtn
2765 {ISD::TRUNCATE, MVT::v2i16, MVT::v2i64, 1}, // xtn
2766 {ISD::TRUNCATE, MVT::v2i32, MVT::v2i64, 1}, // xtn
2767 {ISD::TRUNCATE, MVT::v4i8, MVT::v4i32, 1}, // xtn
2768 {ISD::TRUNCATE, MVT::v4i8, MVT::v4i64, 3}, // 2 xtn + 1 uzp1
2769 {ISD::TRUNCATE, MVT::v4i16, MVT::v4i32, 1}, // xtn
2770 {ISD::TRUNCATE, MVT::v4i16, MVT::v4i64, 2}, // 1 uzp1 + 1 xtn
2771 {ISD::TRUNCATE, MVT::v4i32, MVT::v4i64, 1}, // 1 uzp1
2772 {ISD::TRUNCATE, MVT::v8i8, MVT::v8i16, 1}, // 1 xtn
2773 {ISD::TRUNCATE, MVT::v8i8, MVT::v8i32, 2}, // 1 uzp1 + 1 xtn
2774 {ISD::TRUNCATE, MVT::v8i8, MVT::v8i64, 4}, // 3 x uzp1 + xtn
2775 {ISD::TRUNCATE, MVT::v8i16, MVT::v8i32, 1}, // 1 uzp1
2776 {ISD::TRUNCATE, MVT::v8i16, MVT::v8i64, 3}, // 3 x uzp1
2777 {ISD::TRUNCATE, MVT::v8i32, MVT::v8i64, 2}, // 2 x uzp1
2778 {ISD::TRUNCATE, MVT::v16i8, MVT::v16i16, 1}, // uzp1
2779 {ISD::TRUNCATE, MVT::v16i8, MVT::v16i32, 3}, // (2 + 1) x uzp1
2780 {ISD::TRUNCATE, MVT::v16i8, MVT::v16i64, 7}, // (4 + 2 + 1) x uzp1
2781 {ISD::TRUNCATE, MVT::v16i16, MVT::v16i32, 2}, // 2 x uzp1
2782 {ISD::TRUNCATE, MVT::v16i16, MVT::v16i64, 6}, // (4 + 2) x uzp1
2783 {ISD::TRUNCATE, MVT::v16i32, MVT::v16i64, 4}, // 4 x uzp1
2784
2785 // Truncations on nxvmiN
2786 {ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i8, 2},
2787 {ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i16, 2},
2788 {ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i32, 2},
2789 {ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i64, 2},
2790 {ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i8, 2},
2791 {ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i16, 2},
2792 {ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i32, 2},
2793 {ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i64, 5},
2794 {ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i8, 2},
2795 {ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i16, 2},
2796 {ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i32, 5},
2797 {ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i64, 11},
2798 {ISD::TRUNCATE, MVT::nxv16i1, MVT::nxv16i8, 2},
2799 {ISD::TRUNCATE, MVT::nxv2i8, MVT::nxv2i16, 0},
2800 {ISD::TRUNCATE, MVT::nxv2i8, MVT::nxv2i32, 0},
2801 {ISD::TRUNCATE, MVT::nxv2i8, MVT::nxv2i64, 0},
2802 {ISD::TRUNCATE, MVT::nxv2i16, MVT::nxv2i32, 0},
2803 {ISD::TRUNCATE, MVT::nxv2i16, MVT::nxv2i64, 0},
2804 {ISD::TRUNCATE, MVT::nxv2i32, MVT::nxv2i64, 0},
2805 {ISD::TRUNCATE, MVT::nxv4i8, MVT::nxv4i16, 0},
2806 {ISD::TRUNCATE, MVT::nxv4i8, MVT::nxv4i32, 0},
2807 {ISD::TRUNCATE, MVT::nxv4i8, MVT::nxv4i64, 1},
2808 {ISD::TRUNCATE, MVT::nxv4i16, MVT::nxv4i32, 0},
2809 {ISD::TRUNCATE, MVT::nxv4i16, MVT::nxv4i64, 1},
2810 {ISD::TRUNCATE, MVT::nxv4i32, MVT::nxv4i64, 1},
2811 {ISD::TRUNCATE, MVT::nxv8i8, MVT::nxv8i16, 0},
2812 {ISD::TRUNCATE, MVT::nxv8i8, MVT::nxv8i32, 1},
2813 {ISD::TRUNCATE, MVT::nxv8i8, MVT::nxv8i64, 3},
2814 {ISD::TRUNCATE, MVT::nxv8i16, MVT::nxv8i32, 1},
2815 {ISD::TRUNCATE, MVT::nxv8i16, MVT::nxv8i64, 3},
2816 {ISD::TRUNCATE, MVT::nxv16i8, MVT::nxv16i16, 1},
2817 {ISD::TRUNCATE, MVT::nxv16i8, MVT::nxv16i32, 3},
2818 {ISD::TRUNCATE, MVT::nxv16i8, MVT::nxv16i64, 7},
2819
2820 // The number of shll instructions for the extension.
2821 {ISD::SIGN_EXTEND, MVT::v4i64, MVT::v4i16, 3},
2822 {ISD::ZERO_EXTEND, MVT::v4i64, MVT::v4i16, 3},
2823 {ISD::SIGN_EXTEND, MVT::v4i64, MVT::v4i32, 2},
2824 {ISD::ZERO_EXTEND, MVT::v4i64, MVT::v4i32, 2},
2825 {ISD::SIGN_EXTEND, MVT::v8i32, MVT::v8i8, 3},
2826 {ISD::ZERO_EXTEND, MVT::v8i32, MVT::v8i8, 3},
2827 {ISD::SIGN_EXTEND, MVT::v8i32, MVT::v8i16, 2},
2828 {ISD::ZERO_EXTEND, MVT::v8i32, MVT::v8i16, 2},
2829 {ISD::SIGN_EXTEND, MVT::v8i64, MVT::v8i8, 7},
2830 {ISD::ZERO_EXTEND, MVT::v8i64, MVT::v8i8, 7},
2831 {ISD::SIGN_EXTEND, MVT::v8i64, MVT::v8i16, 6},
2832 {ISD::ZERO_EXTEND, MVT::v8i64, MVT::v8i16, 6},
2833 {ISD::SIGN_EXTEND, MVT::v16i16, MVT::v16i8, 2},
2834 {ISD::ZERO_EXTEND, MVT::v16i16, MVT::v16i8, 2},
2835 {ISD::SIGN_EXTEND, MVT::v16i32, MVT::v16i8, 6},
2836 {ISD::ZERO_EXTEND, MVT::v16i32, MVT::v16i8, 6},
2837
2838 // FP Ext and trunc
2839 {ISD::FP_EXTEND, MVT::f64, MVT::f32, 1}, // fcvt
2840 {ISD::FP_EXTEND, MVT::v2f64, MVT::v2f32, 1}, // fcvtl
2841 {ISD::FP_EXTEND, MVT::v4f64, MVT::v4f32, 2}, // fcvtl+fcvtl2
2842 // FP16
2843 {ISD::FP_EXTEND, MVT::f32, MVT::f16, 1}, // fcvt
2844 {ISD::FP_EXTEND, MVT::f64, MVT::f16, 1}, // fcvt
2845 {ISD::FP_EXTEND, MVT::v4f32, MVT::v4f16, 1}, // fcvtl
2846 {ISD::FP_EXTEND, MVT::v8f32, MVT::v8f16, 2}, // fcvtl+fcvtl2
2847 {ISD::FP_EXTEND, MVT::v2f64, MVT::v2f16, 2}, // fcvtl+fcvtl
2848 {ISD::FP_EXTEND, MVT::v4f64, MVT::v4f16, 3}, // fcvtl+fcvtl2+fcvtl
2849 {ISD::FP_EXTEND, MVT::v8f64, MVT::v8f16, 6}, // 2 * fcvtl+fcvtl2+fcvtl
2850 // FP Ext and trunc
2851 {ISD::FP_ROUND, MVT::f32, MVT::f64, 1}, // fcvt
2852 {ISD::FP_ROUND, MVT::v2f32, MVT::v2f64, 1}, // fcvtn
2853 {ISD::FP_ROUND, MVT::v4f32, MVT::v4f64, 2}, // fcvtn+fcvtn2
2854 // FP16
2855 {ISD::FP_ROUND, MVT::f16, MVT::f32, 1}, // fcvt
2856 {ISD::FP_ROUND, MVT::f16, MVT::f64, 1}, // fcvt
2857 {ISD::FP_ROUND, MVT::v4f16, MVT::v4f32, 1}, // fcvtn
2858 {ISD::FP_ROUND, MVT::v8f16, MVT::v8f32, 2}, // fcvtn+fcvtn2
2859 {ISD::FP_ROUND, MVT::v2f16, MVT::v2f64, 2}, // fcvtn+fcvtn
2860 {ISD::FP_ROUND, MVT::v4f16, MVT::v4f64, 3}, // fcvtn+fcvtn2+fcvtn
2861 {ISD::FP_ROUND, MVT::v8f16, MVT::v8f64, 6}, // 2 * fcvtn+fcvtn2+fcvtn
2862
2863 // LowerVectorINT_TO_FP:
2864 {ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i32, 1},
2865 {ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i32, 1},
2866 {ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i64, 1},
2867 {ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i32, 1},
2868 {ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i32, 1},
2869 {ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i64, 1},
2870
2871 // Complex: to v2f32
2872 {ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i8, 3},
2873 {ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i16, 3},
2874 {ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i64, 2},
2875 {ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i8, 3},
2876 {ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i16, 3},
2877 {ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i64, 2},
2878
2879 // Complex: to v4f32
2880 {ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i8, 4},
2881 {ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i16, 2},
2882 {ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i8, 3},
2883 {ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i16, 2},
2884
2885 // Complex: to v8f32
2886 {ISD::SINT_TO_FP, MVT::v8f32, MVT::v8i8, 10},
2887 {ISD::SINT_TO_FP, MVT::v8f32, MVT::v8i16, 4},
2888 {ISD::UINT_TO_FP, MVT::v8f32, MVT::v8i8, 10},
2889 {ISD::UINT_TO_FP, MVT::v8f32, MVT::v8i16, 4},
2890
2891 // Complex: to v16f32
2892 {ISD::SINT_TO_FP, MVT::v16f32, MVT::v16i8, 21},
2893 {ISD::UINT_TO_FP, MVT::v16f32, MVT::v16i8, 21},
2894
2895 // Complex: to v2f64
2896 {ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i8, 4},
2897 {ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i16, 4},
2898 {ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i32, 2},
2899 {ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i8, 4},
2900 {ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i16, 4},
2901 {ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i32, 2},
2902
2903 // Complex: to v4f64
2904 {ISD::SINT_TO_FP, MVT::v4f64, MVT::v4i32, 4},
2905 {ISD::UINT_TO_FP, MVT::v4f64, MVT::v4i32, 4},
2906
2907 // LowerVectorFP_TO_INT
2908 {ISD::FP_TO_SINT, MVT::v2i32, MVT::v2f32, 1},
2909 {ISD::FP_TO_SINT, MVT::v4i32, MVT::v4f32, 1},
2910 {ISD::FP_TO_SINT, MVT::v2i64, MVT::v2f64, 1},
2911 {ISD::FP_TO_UINT, MVT::v2i32, MVT::v2f32, 1},
2912 {ISD::FP_TO_UINT, MVT::v4i32, MVT::v4f32, 1},
2913 {ISD::FP_TO_UINT, MVT::v2i64, MVT::v2f64, 1},
2914
2915 // Complex, from v2f32: legal type is v2i32 (no cost) or v2i64 (1 ext).
2916 {ISD::FP_TO_SINT, MVT::v2i64, MVT::v2f32, 2},
2917 {ISD::FP_TO_SINT, MVT::v2i16, MVT::v2f32, 1},
2918 {ISD::FP_TO_SINT, MVT::v2i8, MVT::v2f32, 1},
2919 {ISD::FP_TO_UINT, MVT::v2i64, MVT::v2f32, 2},
2920 {ISD::FP_TO_UINT, MVT::v2i16, MVT::v2f32, 1},
2921 {ISD::FP_TO_UINT, MVT::v2i8, MVT::v2f32, 1},
2922
2923 // Complex, from v4f32: legal type is v4i16, 1 narrowing => ~2
2924 {ISD::FP_TO_SINT, MVT::v4i16, MVT::v4f32, 2},
2925 {ISD::FP_TO_SINT, MVT::v4i8, MVT::v4f32, 2},
2926 {ISD::FP_TO_UINT, MVT::v4i16, MVT::v4f32, 2},
2927 {ISD::FP_TO_UINT, MVT::v4i8, MVT::v4f32, 2},
2928
2929 // Complex, from nxv2f32.
2930 {ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f32, 1},
2931 {ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f32, 1},
2932 {ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f32, 1},
2933 {ISD::FP_TO_SINT, MVT::nxv2i8, MVT::nxv2f32, 1},
2934 {ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f32, 1},
2935 {ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f32, 1},
2936 {ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f32, 1},
2937 {ISD::FP_TO_UINT, MVT::nxv2i8, MVT::nxv2f32, 1},
2938
2939 // Complex, from v2f64: legal type is v2i32, 1 narrowing => ~2.
2940 {ISD::FP_TO_SINT, MVT::v2i32, MVT::v2f64, 2},
2941 {ISD::FP_TO_SINT, MVT::v2i16, MVT::v2f64, 2},
2942 {ISD::FP_TO_SINT, MVT::v2i8, MVT::v2f64, 2},
2943 {ISD::FP_TO_UINT, MVT::v2i32, MVT::v2f64, 2},
2944 {ISD::FP_TO_UINT, MVT::v2i16, MVT::v2f64, 2},
2945 {ISD::FP_TO_UINT, MVT::v2i8, MVT::v2f64, 2},
2946
2947 // Complex, from nxv2f64.
2948 {ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f64, 1},
2949 {ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f64, 1},
2950 {ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f64, 1},
2951 {ISD::FP_TO_SINT, MVT::nxv2i8, MVT::nxv2f64, 1},
2952 {ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f64, 1},
2953 {ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f64, 1},
2954 {ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f64, 1},
2955 {ISD::FP_TO_UINT, MVT::nxv2i8, MVT::nxv2f64, 1},
2956
2957 // Complex, from nxv4f32.
2958 {ISD::FP_TO_SINT, MVT::nxv4i64, MVT::nxv4f32, 4},
2959 {ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f32, 1},
2960 {ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f32, 1},
2961 {ISD::FP_TO_SINT, MVT::nxv4i8, MVT::nxv4f32, 1},
2962 {ISD::FP_TO_UINT, MVT::nxv4i64, MVT::nxv4f32, 4},
2963 {ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f32, 1},
2964 {ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f32, 1},
2965 {ISD::FP_TO_UINT, MVT::nxv4i8, MVT::nxv4f32, 1},
2966
2967 // Complex, from nxv8f64. Illegal -> illegal conversions not required.
2968 {ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f64, 7},
2969 {ISD::FP_TO_SINT, MVT::nxv8i8, MVT::nxv8f64, 7},
2970 {ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f64, 7},
2971 {ISD::FP_TO_UINT, MVT::nxv8i8, MVT::nxv8f64, 7},
2972
2973 // Complex, from nxv4f64. Illegal -> illegal conversions not required.
2974 {ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f64, 3},
2975 {ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f64, 3},
2976 {ISD::FP_TO_SINT, MVT::nxv4i8, MVT::nxv4f64, 3},
2977 {ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f64, 3},
2978 {ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f64, 3},
2979 {ISD::FP_TO_UINT, MVT::nxv4i8, MVT::nxv4f64, 3},
2980
2981 // Complex, from nxv8f32. Illegal -> illegal conversions not required.
2982 {ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f32, 3},
2983 {ISD::FP_TO_SINT, MVT::nxv8i8, MVT::nxv8f32, 3},
2984 {ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f32, 3},
2985 {ISD::FP_TO_UINT, MVT::nxv8i8, MVT::nxv8f32, 3},
2986
2987 // Complex, from nxv8f16.
2988 {ISD::FP_TO_SINT, MVT::nxv8i64, MVT::nxv8f16, 10},
2989 {ISD::FP_TO_SINT, MVT::nxv8i32, MVT::nxv8f16, 4},
2990 {ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f16, 1},
2991 {ISD::FP_TO_SINT, MVT::nxv8i8, MVT::nxv8f16, 1},
2992 {ISD::FP_TO_UINT, MVT::nxv8i64, MVT::nxv8f16, 10},
2993 {ISD::FP_TO_UINT, MVT::nxv8i32, MVT::nxv8f16, 4},
2994 {ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f16, 1},
2995 {ISD::FP_TO_UINT, MVT::nxv8i8, MVT::nxv8f16, 1},
2996
2997 // Complex, from nxv4f16.
2998 {ISD::FP_TO_SINT, MVT::nxv4i64, MVT::nxv4f16, 4},
2999 {ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f16, 1},
3000 {ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f16, 1},
3001 {ISD::FP_TO_SINT, MVT::nxv4i8, MVT::nxv4f16, 1},
3002 {ISD::FP_TO_UINT, MVT::nxv4i64, MVT::nxv4f16, 4},
3003 {ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f16, 1},
3004 {ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f16, 1},
3005 {ISD::FP_TO_UINT, MVT::nxv4i8, MVT::nxv4f16, 1},
3006
3007 // Complex, from nxv2f16.
3008 {ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f16, 1},
3009 {ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f16, 1},
3010 {ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f16, 1},
3011 {ISD::FP_TO_SINT, MVT::nxv2i8, MVT::nxv2f16, 1},
3012 {ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f16, 1},
3013 {ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f16, 1},
3014 {ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f16, 1},
3015 {ISD::FP_TO_UINT, MVT::nxv2i8, MVT::nxv2f16, 1},
3016
3017 // Truncate from nxvmf32 to nxvmf16.
3018 {ISD::FP_ROUND, MVT::nxv2f16, MVT::nxv2f32, 1},
3019 {ISD::FP_ROUND, MVT::nxv4f16, MVT::nxv4f32, 1},
3020 {ISD::FP_ROUND, MVT::nxv8f16, MVT::nxv8f32, 3},
3021
3022 // Truncate from nxvmf64 to nxvmf16.
3023 {ISD::FP_ROUND, MVT::nxv2f16, MVT::nxv2f64, 1},
3024 {ISD::FP_ROUND, MVT::nxv4f16, MVT::nxv4f64, 3},
3025 {ISD::FP_ROUND, MVT::nxv8f16, MVT::nxv8f64, 7},
3026
3027 // Truncate from nxvmf64 to nxvmf32.
3028 {ISD::FP_ROUND, MVT::nxv2f32, MVT::nxv2f64, 1},
3029 {ISD::FP_ROUND, MVT::nxv4f32, MVT::nxv4f64, 3},
3030 {ISD::FP_ROUND, MVT::nxv8f32, MVT::nxv8f64, 6},
3031
3032 // Extend from nxvmf16 to nxvmf32.
3033 {ISD::FP_EXTEND, MVT::nxv2f32, MVT::nxv2f16, 1},
3034 {ISD::FP_EXTEND, MVT::nxv4f32, MVT::nxv4f16, 1},
3035 {ISD::FP_EXTEND, MVT::nxv8f32, MVT::nxv8f16, 2},
3036
3037 // Extend from nxvmf16 to nxvmf64.
3038 {ISD::FP_EXTEND, MVT::nxv2f64, MVT::nxv2f16, 1},
3039 {ISD::FP_EXTEND, MVT::nxv4f64, MVT::nxv4f16, 2},
3040 {ISD::FP_EXTEND, MVT::nxv8f64, MVT::nxv8f16, 4},
3041
3042 // Extend from nxvmf32 to nxvmf64.
3043 {ISD::FP_EXTEND, MVT::nxv2f64, MVT::nxv2f32, 1},
3044 {ISD::FP_EXTEND, MVT::nxv4f64, MVT::nxv4f32, 2},
3045 {ISD::FP_EXTEND, MVT::nxv8f64, MVT::nxv8f32, 6},
3046
3047 // Bitcasts from float to integer
3048 {ISD::BITCAST, MVT::nxv2f16, MVT::nxv2i16, 0},
3049 {ISD::BITCAST, MVT::nxv4f16, MVT::nxv4i16, 0},
3050 {ISD::BITCAST, MVT::nxv2f32, MVT::nxv2i32, 0},
3051
3052 // Bitcasts from integer to float
3053 {ISD::BITCAST, MVT::nxv2i16, MVT::nxv2f16, 0},
3054 {ISD::BITCAST, MVT::nxv4i16, MVT::nxv4f16, 0},
3055 {ISD::BITCAST, MVT::nxv2i32, MVT::nxv2f32, 0},
3056
3057 // Add cost for extending to illegal -too wide- scalable vectors.
3058 // zero/sign extend are implemented by multiple unpack operations,
3059 // where each operation has a cost of 1.
3060 {ISD::ZERO_EXTEND, MVT::nxv16i16, MVT::nxv16i8, 2},
3061 {ISD::ZERO_EXTEND, MVT::nxv16i32, MVT::nxv16i8, 6},
3062 {ISD::ZERO_EXTEND, MVT::nxv16i64, MVT::nxv16i8, 14},
3063 {ISD::ZERO_EXTEND, MVT::nxv8i32, MVT::nxv8i16, 2},
3064 {ISD::ZERO_EXTEND, MVT::nxv8i64, MVT::nxv8i16, 6},
3065 {ISD::ZERO_EXTEND, MVT::nxv4i64, MVT::nxv4i32, 2},
3066
3067 {ISD::SIGN_EXTEND, MVT::nxv16i16, MVT::nxv16i8, 2},
3068 {ISD::SIGN_EXTEND, MVT::nxv16i32, MVT::nxv16i8, 6},
3069 {ISD::SIGN_EXTEND, MVT::nxv16i64, MVT::nxv16i8, 14},
3070 {ISD::SIGN_EXTEND, MVT::nxv8i32, MVT::nxv8i16, 2},
3071 {ISD::SIGN_EXTEND, MVT::nxv8i64, MVT::nxv8i16, 6},
3072 {ISD::SIGN_EXTEND, MVT::nxv4i64, MVT::nxv4i32, 2},
3073 };
3074
3075 // We have to estimate a cost of fixed length operation upon
3076 // SVE registers(operations) with the number of registers required
3077 // for a fixed type to be represented upon SVE registers.
3078 EVT WiderTy = SrcTy.bitsGT(DstTy) ? SrcTy : DstTy;
3079 if (SrcTy.isFixedLengthVector() && DstTy.isFixedLengthVector() &&
3080 SrcTy.getVectorNumElements() == DstTy.getVectorNumElements() &&
3081 ST->useSVEForFixedLengthVectors(WiderTy)) {
3082 std::pair<InstructionCost, MVT> LT =
3083 getTypeLegalizationCost(WiderTy.getTypeForEVT(Dst->getContext()));
3084 unsigned NumElements =
3085 AArch64::SVEBitsPerBlock / LT.second.getScalarSizeInBits();
3086 return AdjustCost(
3087 LT.first *
3089 Opcode, ScalableVectorType::get(Dst->getScalarType(), NumElements),
3090 ScalableVectorType::get(Src->getScalarType(), NumElements), CCH,
3091 CostKind, I));
3092 }
3093
3094 if (const auto *Entry = ConvertCostTableLookup(
3095 ConversionTbl, ISD, DstTy.getSimpleVT(), SrcTy.getSimpleVT()))
3096 return AdjustCost(Entry->Cost);
3097
3098 static const TypeConversionCostTblEntry FP16Tbl[] = {
3099 {ISD::FP_TO_SINT, MVT::v4i8, MVT::v4f16, 1}, // fcvtzs
3100 {ISD::FP_TO_UINT, MVT::v4i8, MVT::v4f16, 1},
3101 {ISD::FP_TO_SINT, MVT::v4i16, MVT::v4f16, 1}, // fcvtzs
3102 {ISD::FP_TO_UINT, MVT::v4i16, MVT::v4f16, 1},
3103 {ISD::FP_TO_SINT, MVT::v4i32, MVT::v4f16, 2}, // fcvtl+fcvtzs
3104 {ISD::FP_TO_UINT, MVT::v4i32, MVT::v4f16, 2},
3105 {ISD::FP_TO_SINT, MVT::v8i8, MVT::v8f16, 2}, // fcvtzs+xtn
3106 {ISD::FP_TO_UINT, MVT::v8i8, MVT::v8f16, 2},
3107 {ISD::FP_TO_SINT, MVT::v8i16, MVT::v8f16, 1}, // fcvtzs
3108 {ISD::FP_TO_UINT, MVT::v8i16, MVT::v8f16, 1},
3109 {ISD::FP_TO_SINT, MVT::v8i32, MVT::v8f16, 4}, // 2*fcvtl+2*fcvtzs
3110 {ISD::FP_TO_UINT, MVT::v8i32, MVT::v8f16, 4},
3111 {ISD::FP_TO_SINT, MVT::v16i8, MVT::v16f16, 3}, // 2*fcvtzs+xtn
3112 {ISD::FP_TO_UINT, MVT::v16i8, MVT::v16f16, 3},
3113 {ISD::FP_TO_SINT, MVT::v16i16, MVT::v16f16, 2}, // 2*fcvtzs
3114 {ISD::FP_TO_UINT, MVT::v16i16, MVT::v16f16, 2},
3115 {ISD::FP_TO_SINT, MVT::v16i32, MVT::v16f16, 8}, // 4*fcvtl+4*fcvtzs
3116 {ISD::FP_TO_UINT, MVT::v16i32, MVT::v16f16, 8},
3117 {ISD::UINT_TO_FP, MVT::v8f16, MVT::v8i8, 2}, // ushll + ucvtf
3118 {ISD::SINT_TO_FP, MVT::v8f16, MVT::v8i8, 2}, // sshll + scvtf
3119 {ISD::UINT_TO_FP, MVT::v16f16, MVT::v16i8, 4}, // 2 * ushl(2) + 2 * ucvtf
3120 {ISD::SINT_TO_FP, MVT::v16f16, MVT::v16i8, 4}, // 2 * sshl(2) + 2 * scvtf
3121 };
3122
3123 if (ST->hasFullFP16())
3124 if (const auto *Entry = ConvertCostTableLookup(
3125 FP16Tbl, ISD, DstTy.getSimpleVT(), SrcTy.getSimpleVT()))
3126 return AdjustCost(Entry->Cost);
3127
3128 if ((ISD == ISD::ZERO_EXTEND || ISD == ISD::SIGN_EXTEND) &&
3131 TLI->getTypeAction(Src->getContext(), SrcTy) ==
3133 TLI->getTypeAction(Dst->getContext(), DstTy) ==
3135 // The standard behaviour in the backend for these cases is to split the
3136 // extend up into two parts:
3137 // 1. Perform an extending load or masked load up to the legal type.
3138 // 2. Extend the loaded data to the final type.
3139 std::pair<InstructionCost, MVT> SrcLT = getTypeLegalizationCost(Src);
3140 Type *LegalTy = EVT(SrcLT.second).getTypeForEVT(Src->getContext());
3142 Opcode, LegalTy, Src, CCH, CostKind, I);
3144 Opcode, Dst, LegalTy, TTI::CastContextHint::None, CostKind, I);
3145 return Part1 + Part2;
3146 }
3147
3148 // The BasicTTIImpl version only deals with CCH==TTI::CastContextHint::Normal,
3149 // but we also want to include the TTI::CastContextHint::Masked case too.
3150 if ((ISD == ISD::ZERO_EXTEND || ISD == ISD::SIGN_EXTEND) &&
3152 ST->isSVEorStreamingSVEAvailable() && TLI->isTypeLegal(DstTy))
3154
3155 return AdjustCost(
3156 BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I));
3157}
3158
3160 Type *Dst,
3161 VectorType *VecTy,
3162 unsigned Index) {
3163
3164 // Make sure we were given a valid extend opcode.
3165 assert((Opcode == Instruction::SExt || Opcode == Instruction::ZExt) &&
3166 "Invalid opcode");
3167
3168 // We are extending an element we extract from a vector, so the source type
3169 // of the extend is the element type of the vector.
3170 auto *Src = VecTy->getElementType();
3171
3172 // Sign- and zero-extends are for integer types only.
3173 assert(isa<IntegerType>(Dst) && isa<IntegerType>(Src) && "Invalid type");
3174
3175 // Get the cost for the extract. We compute the cost (if any) for the extend
3176 // below.
3178 InstructionCost Cost = getVectorInstrCost(Instruction::ExtractElement, VecTy,
3179 CostKind, Index, nullptr, nullptr);
3180
3181 // Legalize the types.
3182 auto VecLT = getTypeLegalizationCost(VecTy);
3183 auto DstVT = TLI->getValueType(DL, Dst);
3184 auto SrcVT = TLI->getValueType(DL, Src);
3185
3186 // If the resulting type is still a vector and the destination type is legal,
3187 // we may get the extension for free. If not, get the default cost for the
3188 // extend.
3189 if (!VecLT.second.isVector() || !TLI->isTypeLegal(DstVT))
3190 return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None,
3191 CostKind);
3192
3193 // The destination type should be larger than the element type. If not, get
3194 // the default cost for the extend.
3195 if (DstVT.getFixedSizeInBits() < SrcVT.getFixedSizeInBits())
3196 return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None,
3197 CostKind);
3198
3199 switch (Opcode) {
3200 default:
3201 llvm_unreachable("Opcode should be either SExt or ZExt");
3202
3203 // For sign-extends, we only need a smov, which performs the extension
3204 // automatically.
3205 case Instruction::SExt:
3206 return Cost;
3207
3208 // For zero-extends, the extend is performed automatically by a umov unless
3209 // the destination type is i64 and the element type is i8 or i16.
3210 case Instruction::ZExt:
3211 if (DstVT.getSizeInBits() != 64u || SrcVT.getSizeInBits() == 32u)
3212 return Cost;
3213 }
3214
3215 // If we are unable to perform the extend for free, get the default cost.
3216 return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None,
3217 CostKind);
3218}
3219
3222 const Instruction *I) {
3224 return Opcode == Instruction::PHI ? 0 : 1;
3225 assert(CostKind == TTI::TCK_RecipThroughput && "unexpected CostKind");
3226 // Branches are assumed to be predicted.
3227 return 0;
3228}
3229
3230InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(
3231 unsigned Opcode, Type *Val, unsigned Index, bool HasRealUse,
3232 const Instruction *I, Value *Scalar,
3233 ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx) {
3234 assert(Val->isVectorTy() && "This must be a vector type");
3235
3236 if (Index != -1U) {
3237 // Legalize the type.
3238 std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Val);
3239
3240 // This type is legalized to a scalar type.
3241 if (!LT.second.isVector())
3242 return 0;
3243
3244 // The type may be split. For fixed-width vectors we can normalize the
3245 // index to the new type.
3246 if (LT.second.isFixedLengthVector()) {
3247 unsigned Width = LT.second.getVectorNumElements();
3248 Index = Index % Width;
3249 }
3250
3251 // The element at index zero is already inside the vector.
3252 // - For a physical (HasRealUse==true) insert-element or extract-element
3253 // instruction that extracts integers, an explicit FPR -> GPR move is
3254 // needed. So it has non-zero cost.
3255 // - For the rest of cases (virtual instruction or element type is float),
3256 // consider the instruction free.
3257 if (Index == 0 && (!HasRealUse || !Val->getScalarType()->isIntegerTy()))
3258 return 0;
3259
3260 // This is recognising a LD1 single-element structure to one lane of one
3261 // register instruction. I.e., if this is an `insertelement` instruction,
3262 // and its second operand is a load, then we will generate a LD1, which
3263 // are expensive instructions.
3264 if (I && dyn_cast<LoadInst>(I->getOperand(1)))
3265 return ST->getVectorInsertExtractBaseCost() + 1;
3266
3267 // i1 inserts and extract will include an extra cset or cmp of the vector
3268 // value. Increase the cost by 1 to account.
3269 if (Val->getScalarSizeInBits() == 1)
3270 return ST->getVectorInsertExtractBaseCost() + 1;
3271
3272 // FIXME:
3273 // If the extract-element and insert-element instructions could be
3274 // simplified away (e.g., could be combined into users by looking at use-def
3275 // context), they have no cost. This is not done in the first place for
3276 // compile-time considerations.
3277 }
3278
3279 // In case of Neon, if there exists extractelement from lane != 0 such that
3280 // 1. extractelement does not necessitate a move from vector_reg -> GPR.
3281 // 2. extractelement result feeds into fmul.
3282 // 3. Other operand of fmul is an extractelement from lane 0 or lane
3283 // equivalent to 0.
3284 // then the extractelement can be merged with fmul in the backend and it
3285 // incurs no cost.
3286 // e.g.
3287 // define double @foo(<2 x double> %a) {
3288 // %1 = extractelement <2 x double> %a, i32 0
3289 // %2 = extractelement <2 x double> %a, i32 1
3290 // %res = fmul double %1, %2
3291 // ret double %res
3292 // }
3293 // %2 and %res can be merged in the backend to generate fmul d0, d0, v1.d[1]
3294 auto ExtractCanFuseWithFmul = [&]() {
3295 // We bail out if the extract is from lane 0.
3296 if (Index == 0)
3297 return false;
3298
3299 // Check if the scalar element type of the vector operand of ExtractElement
3300 // instruction is one of the allowed types.
3301 auto IsAllowedScalarTy = [&](const Type *T) {
3302 return T->isFloatTy() || T->isDoubleTy() ||
3303 (T->isHalfTy() && ST->hasFullFP16());
3304 };
3305
3306 // Check if the extractelement user is scalar fmul.
3307 auto IsUserFMulScalarTy = [](const Value *EEUser) {
3308 // Check if the user is scalar fmul.
3309 const auto *BO = dyn_cast<BinaryOperator>(EEUser);
3310 return BO && BO->getOpcode() == BinaryOperator::FMul &&
3311 !BO->getType()->isVectorTy();
3312 };
3313
3314 // Check if the extract index is from lane 0 or lane equivalent to 0 for a
3315 // certain scalar type and a certain vector register width.
3316 auto IsExtractLaneEquivalentToZero = [&](unsigned Idx, unsigned EltSz) {
3317 auto RegWidth =
3319 .getFixedValue();
3320 return Idx == 0 || (RegWidth != 0 && (Idx * EltSz) % RegWidth == 0);
3321 };
3322
3323 // Check if the type constraints on input vector type and result scalar type
3324 // of extractelement instruction are satisfied.
3325 if (!isa<FixedVectorType>(Val) || !IsAllowedScalarTy(Val->getScalarType()))
3326 return false;
3327
3328 if (Scalar) {
3329 DenseMap<User *, unsigned> UserToExtractIdx;
3330 for (auto *U : Scalar->users()) {
3331 if (!IsUserFMulScalarTy(U))
3332 return false;
3333 // Recording entry for the user is important. Index value is not
3334 // important.
3335 UserToExtractIdx[U];
3336 }
3337 if (UserToExtractIdx.empty())
3338 return false;
3339 for (auto &[S, U, L] : ScalarUserAndIdx) {
3340 for (auto *U : S->users()) {
3341 if (UserToExtractIdx.find(U) != UserToExtractIdx.end()) {
3342 auto *FMul = cast<BinaryOperator>(U);
3343 auto *Op0 = FMul->getOperand(0);
3344 auto *Op1 = FMul->getOperand(1);
3345 if ((Op0 == S && Op1 == S) || Op0 != S || Op1 != S) {
3346 UserToExtractIdx[U] = L;
3347 break;
3348 }
3349 }
3350 }
3351 }
3352 for (auto &[U, L] : UserToExtractIdx) {
3353 if (!IsExtractLaneEquivalentToZero(Index, Val->getScalarSizeInBits()) &&
3354 !IsExtractLaneEquivalentToZero(L, Val->getScalarSizeInBits()))
3355 return false;
3356 }
3357 } else {
3358 const auto *EE = cast<ExtractElementInst>(I);
3359
3360 const auto *IdxOp = dyn_cast<ConstantInt>(EE->getIndexOperand());
3361 if (!IdxOp)
3362 return false;
3363
3364 return !EE->users().empty() && all_of(EE->users(), [&](const User *U) {
3365 if (!IsUserFMulScalarTy(U))
3366 return false;
3367
3368 // Check if the other operand of extractelement is also extractelement
3369 // from lane equivalent to 0.
3370 const auto *BO = cast<BinaryOperator>(U);
3371 const auto *OtherEE = dyn_cast<ExtractElementInst>(
3372 BO->getOperand(0) == EE ? BO->getOperand(1) : BO->getOperand(0));
3373 if (OtherEE) {
3374 const auto *IdxOp = dyn_cast<ConstantInt>(OtherEE->getIndexOperand());
3375 if (!IdxOp)
3376 return false;
3377 return IsExtractLaneEquivalentToZero(
3378 cast<ConstantInt>(OtherEE->getIndexOperand())
3379 ->getValue()
3380 .getZExtValue(),
3381 OtherEE->getType()->getScalarSizeInBits());
3382 }
3383 return true;
3384 });
3385 }
3386 return true;
3387 };
3388
3389 if (Opcode == Instruction::ExtractElement && (I || Scalar) &&
3390 ExtractCanFuseWithFmul())
3391 return 0;
3392
3393 // All other insert/extracts cost this much.
3394 return ST->getVectorInsertExtractBaseCost();
3395}
3396
3399 unsigned Index, Value *Op0,
3400 Value *Op1) {
3401 bool HasRealUse =
3402 Opcode == Instruction::InsertElement && Op0 && !isa<UndefValue>(Op0);
3403 return getVectorInstrCostHelper(Opcode, Val, Index, HasRealUse);
3404}
3405
3407 unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
3408 Value *Scalar,
3409 ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx) {
3410 return getVectorInstrCostHelper(Opcode, Val, Index, false, nullptr, Scalar,
3411 ScalarUserAndIdx);
3412}
3413
3415 Type *Val,
3417 unsigned Index) {
3418 return getVectorInstrCostHelper(I.getOpcode(), Val, Index,
3419 true /* HasRealUse */, &I);
3420}
3421
3423 VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
3425 if (isa<ScalableVectorType>(Ty))
3427 if (Ty->getElementType()->isFloatingPointTy())
3428 return BaseT::getScalarizationOverhead(Ty, DemandedElts, Insert, Extract,
3429 CostKind);
3430 return DemandedElts.popcount() * (Insert + Extract) *
3432}
3433
3435 unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
3438 const Instruction *CxtI) {
3439
3440 // The code-generator is currently not able to handle scalable vectors
3441 // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
3442 // it. This change will be removed when code-generation for these types is
3443 // sufficiently reliable.
3444 if (auto *VTy = dyn_cast<ScalableVectorType>(Ty))
3445 if (VTy->getElementCount() == ElementCount::getScalable(1))
3447
3448 // TODO: Handle more cost kinds.
3450 return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info,
3451 Op2Info, Args, CxtI);
3452
3453 // Legalize the type.
3454 std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Ty);
3455 int ISD = TLI->InstructionOpcodeToISD(Opcode);
3456
3457 switch (ISD) {
3458 default:
3459 return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info,
3460 Op2Info);
3461 case ISD::SDIV:
3462 if (Op2Info.isConstant() && Op2Info.isUniform() && Op2Info.isPowerOf2()) {
3463 // On AArch64, scalar signed division by constants power-of-two are
3464 // normally expanded to the sequence ADD + CMP + SELECT + SRA.
3465 // The OperandValue properties many not be same as that of previous
3466 // operation; conservatively assume OP_None.
3468 Instruction::Add, Ty, CostKind,
3469 Op1Info.getNoProps(), Op2Info.getNoProps());
3470 Cost += getArithmeticInstrCost(Instruction::Sub, Ty, CostKind,
3471 Op1Info.getNoProps(), Op2Info.getNoProps());
3473 Instruction::Select, Ty, CostKind,
3474 Op1Info.getNoProps(), Op2Info.getNoProps());
3475 Cost += getArithmeticInstrCost(Instruction::AShr, Ty, CostKind,
3476 Op1Info.getNoProps(), Op2Info.getNoProps());
3477 return Cost;
3478 }
3479 [[fallthrough]];
3480 case ISD::UDIV: {
3481 auto VT = TLI->getValueType(DL, Ty);
3482 if (Op2Info.isConstant() && Op2Info.isUniform()) {
3483 if (TLI->isOperationLegalOrCustom(ISD::MULHU, VT)) {
3484 // Vector signed division by constant are expanded to the
3485 // sequence MULHS + ADD/SUB + SRA + SRL + ADD, and unsigned division
3486 // to MULHS + SUB + SRL + ADD + SRL.
3488 Instruction::Mul, Ty, CostKind, Op1Info.getNoProps(), Op2Info.getNoProps());
3490 Instruction::Add, Ty, CostKind, Op1Info.getNoProps(), Op2Info.getNoProps());
3492 Instruction::AShr, Ty, CostKind, Op1Info.getNoProps(), Op2Info.getNoProps());
3493 return MulCost * 2 + AddCost * 2 + ShrCost * 2 + 1;
3494 }
3495 }
3496
3497 // div i128's are lowered as libcalls. Pass nullptr as (u)divti3 calls are
3498 // emitted by the backend even when those functions are not declared in the
3499 // module.
3500 if (!VT.isVector() && VT.getSizeInBits() > 64)
3501 return getCallInstrCost(/*Function*/ nullptr, Ty, {Ty, Ty}, CostKind);
3502
3504 Opcode, Ty, CostKind, Op1Info, Op2Info);
3505 if (Ty->isVectorTy()) {
3506 if (TLI->isOperationLegalOrCustom(ISD, LT.second) && ST->hasSVE()) {
3507 // SDIV/UDIV operations are lowered using SVE, then we can have less
3508 // costs.
3509 if (isa<FixedVectorType>(Ty) && cast<FixedVectorType>(Ty)
3510 ->getPrimitiveSizeInBits()
3511 .getFixedValue() < 128) {
3512 EVT VT = TLI->getValueType(DL, Ty);
3513 static const CostTblEntry DivTbl[]{
3514 {ISD::SDIV, MVT::v2i8, 5}, {ISD::SDIV, MVT::v4i8, 8},
3515 {ISD::SDIV, MVT::v8i8, 8}, {ISD::SDIV, MVT::v2i16, 5},
3516 {ISD::SDIV, MVT::v4i16, 5}, {ISD::SDIV, MVT::v2i32, 1},
3517 {ISD::UDIV, MVT::v2i8, 5}, {ISD::UDIV, MVT::v4i8, 8},
3518 {ISD::UDIV, MVT::v8i8, 8}, {ISD::UDIV, MVT::v2i16, 5},
3519 {ISD::UDIV, MVT::v4i16, 5}, {ISD::UDIV, MVT::v2i32, 1}};
3520
3521 const auto *Entry = CostTableLookup(DivTbl, ISD, VT.getSimpleVT());
3522 if (nullptr != Entry)
3523 return Entry->Cost;
3524 }
3525 // For 8/16-bit elements, the cost is higher because the type
3526 // requires promotion and possibly splitting:
3527 if (LT.second.getScalarType() == MVT::i8)
3528 Cost *= 8;
3529 else if (LT.second.getScalarType() == MVT::i16)
3530 Cost *= 4;
3531 return Cost;
3532 } else {
3533 // If one of the operands is a uniform constant then the cost for each
3534 // element is Cost for insertion, extraction and division.
3535 // Insertion cost = 2, Extraction Cost = 2, Division = cost for the
3536 // operation with scalar type
3537 if ((Op1Info.isConstant() && Op1Info.isUniform()) ||
3538 (Op2Info.isConstant() && Op2Info.isUniform())) {
3539 if (auto *VTy = dyn_cast<FixedVectorType>(Ty)) {
3541 Opcode, Ty->getScalarType(), CostKind, Op1Info, Op2Info);
3542 return (4 + DivCost) * VTy->getNumElements();
3543 }
3544 }
3545 // On AArch64, without SVE, vector divisions are expanded
3546 // into scalar divisions of each pair of elements.
3547 Cost += getArithmeticInstrCost(Instruction::ExtractElement, Ty,
3548 CostKind, Op1Info, Op2Info);
3549 Cost += getArithmeticInstrCost(Instruction::InsertElement, Ty, CostKind,
3550 Op1Info, Op2Info);
3551 }
3552
3553 // TODO: if one of the arguments is scalar, then it's not necessary to
3554 // double the cost of handling the vector elements.
3555 Cost += Cost;
3556 }
3557 return Cost;
3558 }
3559 case ISD::MUL:
3560 // When SVE is available, then we can lower the v2i64 operation using
3561 // the SVE mul instruction, which has a lower cost.
3562 if (LT.second == MVT::v2i64 && ST->hasSVE())
3563 return LT.first;
3564
3565 // When SVE is not available, there is no MUL.2d instruction,
3566 // which means mul <2 x i64> is expensive as elements are extracted
3567 // from the vectors and the muls scalarized.
3568 // As getScalarizationOverhead is a bit too pessimistic, we
3569 // estimate the cost for a i64 vector directly here, which is:
3570 // - four 2-cost i64 extracts,
3571 // - two 2-cost i64 inserts, and
3572 // - two 1-cost muls.
3573 // So, for a v2i64 with LT.First = 1 the cost is 14, and for a v4i64 with
3574 // LT.first = 2 the cost is 28. If both operands are extensions it will not
3575 // need to scalarize so the cost can be cheaper (smull or umull).
3576 // so the cost can be cheaper (smull or umull).
3577 if (LT.second != MVT::v2i64 || isWideningInstruction(Ty, Opcode, Args))
3578 return LT.first;
3579 return LT.first * 14;
3580 case ISD::ADD:
3581 case ISD::XOR:
3582 case ISD::OR:
3583 case ISD::AND:
3584 case ISD::SRL:
3585 case ISD::SRA:
3586 case ISD::SHL:
3587 // These nodes are marked as 'custom' for combining purposes only.
3588 // We know that they are legal. See LowerAdd in ISelLowering.
3589 return LT.first;
3590
3591 case ISD::FNEG:
3592 // Scalar fmul(fneg) or fneg(fmul) can be converted to fnmul
3593 if ((Ty->isFloatTy() || Ty->isDoubleTy() ||
3594 (Ty->isHalfTy() && ST->hasFullFP16())) &&
3595 CxtI &&
3596 ((CxtI->hasOneUse() &&
3597 match(*CxtI->user_begin(), m_FMul(m_Value(), m_Value()))) ||
3598 match(CxtI->getOperand(0), m_FMul(m_Value(), m_Value()))))
3599 return 0;
3600 [[fallthrough]];
3601 case ISD::FADD:
3602 case ISD::FSUB:
3603 // Increase the cost for half and bfloat types if not architecturally
3604 // supported.
3605 if ((Ty->getScalarType()->isHalfTy() && !ST->hasFullFP16()) ||
3606 (Ty->getScalarType()->isBFloatTy() && !ST->hasBF16()))
3607 return 2 * LT.first;
3608 if (!Ty->getScalarType()->isFP128Ty())
3609 return LT.first;
3610 [[fallthrough]];
3611 case ISD::FMUL:
3612 case ISD::FDIV:
3613 // These nodes are marked as 'custom' just to lower them to SVE.
3614 // We know said lowering will incur no additional cost.
3615 if (!Ty->getScalarType()->isFP128Ty())
3616 return 2 * LT.first;
3617
3618 return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info,
3619 Op2Info);
3620 case ISD::FREM:
3621 // Pass nullptr as fmod/fmodf calls are emitted by the backend even when
3622 // those functions are not declared in the module.
3623 if (!Ty->isVectorTy())
3624 return getCallInstrCost(/*Function*/ nullptr, Ty, {Ty, Ty}, CostKind);
3625 return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info,
3626 Op2Info);
3627 }
3628}
3629
3631 ScalarEvolution *SE,
3632 const SCEV *Ptr) {
3633 // Address computations in vectorized code with non-consecutive addresses will
3634 // likely result in more instructions compared to scalar code where the
3635 // computation can more often be merged into the index mode. The resulting
3636 // extra micro-ops can significantly decrease throughput.
3637 unsigned NumVectorInstToHideOverhead = NeonNonConstStrideOverhead;
3638 int MaxMergeDistance = 64;
3639
3640 if (Ty->isVectorTy() && SE &&
3641 !BaseT::isConstantStridedAccessLessThan(SE, Ptr, MaxMergeDistance + 1))
3642 return NumVectorInstToHideOverhead;
3643
3644 // In many cases the address computation is not merged into the instruction
3645 // addressing mode.
3646 return 1;
3647}
3648
3650 unsigned Opcode, Type *ValTy, Type *CondTy, CmpInst::Predicate VecPred,
3652 TTI::OperandValueInfo Op2Info, const Instruction *I) {
3653 // TODO: Handle other cost kinds.
3655 return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, VecPred, CostKind,
3656 Op1Info, Op2Info, I);
3657
3658 int ISD = TLI->InstructionOpcodeToISD(Opcode);
3659 // We don't lower some vector selects well that are wider than the register
3660 // width.
3661 if (isa<FixedVectorType>(ValTy) && ISD == ISD::SELECT) {
3662 // We would need this many instructions to hide the scalarization happening.
3663 const int AmortizationCost = 20;
3664
3665 // If VecPred is not set, check if we can get a predicate from the context
3666 // instruction, if its type matches the requested ValTy.
3667 if (VecPred == CmpInst::BAD_ICMP_PREDICATE && I && I->getType() == ValTy) {
3668 CmpPredicate CurrentPred;
3669 if (match(I, m_Select(m_Cmp(CurrentPred, m_Value(), m_Value()), m_Value(),
3670 m_Value())))
3671 VecPred = CurrentPred;
3672 }
3673 // Check if we have a compare/select chain that can be lowered using
3674 // a (F)CMxx & BFI pair.
3675 if (CmpInst::isIntPredicate(VecPred) || VecPred == CmpInst::FCMP_OLE ||
3676 VecPred == CmpInst::FCMP_OLT || VecPred == CmpInst::FCMP_OGT ||
3677 VecPred == CmpInst::FCMP_OGE || VecPred == CmpInst::FCMP_OEQ ||
3678 VecPred == CmpInst::FCMP_UNE) {
3679 static const auto ValidMinMaxTys = {
3680 MVT::v8i8, MVT::v16i8, MVT::v4i16, MVT::v8i16, MVT::v2i32,
3681 MVT::v4i32, MVT::v2i64, MVT::v2f32, MVT::v4f32, MVT::v2f64};
3682 static const auto ValidFP16MinMaxTys = {MVT::v4f16, MVT::v8f16};
3683
3684 auto LT = getTypeLegalizationCost(ValTy);
3685 if (any_of(ValidMinMaxTys, [&LT](MVT M) { return M == LT.second; }) ||
3686 (ST->hasFullFP16() &&
3687 any_of(ValidFP16MinMaxTys, [&LT](MVT M) { return M == LT.second; })))
3688 return LT.first;
3689 }
3690
3691 static const TypeConversionCostTblEntry
3692 VectorSelectTbl[] = {
3693 { ISD::SELECT, MVT::v2i1, MVT::v2f32, 2 },
3694 { ISD::SELECT, MVT::v2i1, MVT::v2f64, 2 },
3695 { ISD::SELECT, MVT::v4i1, MVT::v4f32, 2 },
3696 { ISD::SELECT, MVT::v4i1, MVT::v4f16, 2 },
3697 { ISD::SELECT, MVT::v8i1, MVT::v8f16, 2 },
3698 { ISD::SELECT, MVT::v16i1, MVT::v16i16, 16 },
3699 { ISD::SELECT, MVT::v8i1, MVT::v8i32, 8 },
3700 { ISD::SELECT, MVT::v16i1, MVT::v16i32, 16 },
3701 { ISD::SELECT, MVT::v4i1, MVT::v4i64, 4 * AmortizationCost },
3702 { ISD::SELECT, MVT::v8i1, MVT::v8i64, 8 * AmortizationCost },
3703 { ISD::SELECT, MVT::v16i1, MVT::v16i64, 16 * AmortizationCost }
3704 };
3705
3706 EVT SelCondTy = TLI->getValueType(DL, CondTy);
3707 EVT SelValTy = TLI->getValueType(DL, ValTy);
3708 if (SelCondTy.isSimple() && SelValTy.isSimple()) {
3709 if (const auto *Entry = ConvertCostTableLookup(VectorSelectTbl, ISD,
3710 SelCondTy.getSimpleVT(),
3711 SelValTy.getSimpleVT()))
3712 return Entry->Cost;
3713 }
3714 }
3715
3716 if (isa<FixedVectorType>(ValTy) && ISD == ISD::SETCC) {
3717 auto LT = getTypeLegalizationCost(ValTy);
3718 // Cost v4f16 FCmp without FP16 support via converting to v4f32 and back.
3719 if (LT.second == MVT::v4f16 && !ST->hasFullFP16())
3720 return LT.first * 4; // fcvtl + fcvtl + fcmp + xtn
3721 }
3722
3723 // Treat the icmp in icmp(and, 0) as free, as we can make use of ands.
3724 // FIXME: This can apply to more conditions and add/sub if it can be shown to
3725 // be profitable.
3726 if (ValTy->isIntegerTy() && ISD == ISD::SETCC && I &&
3727 ICmpInst::isEquality(VecPred) &&
3728 TLI->isTypeLegal(TLI->getValueType(DL, ValTy)) &&
3729 match(I->getOperand(1), m_Zero()) &&
3730 match(I->getOperand(0), m_And(m_Value(), m_Value())))
3731 return 0;
3732
3733 // The base case handles scalable vectors fine for now, since it treats the
3734 // cost as 1 * legalization cost.
3735 return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, VecPred, CostKind,
3736 Op1Info, Op2Info, I);
3737}
3738
3740AArch64TTIImpl::enableMemCmpExpansion(bool OptSize, bool IsZeroCmp) const {
3742 if (ST->requiresStrictAlign()) {
3743 // TODO: Add cost modeling for strict align. Misaligned loads expand to
3744 // a bunch of instructions when strict align is enabled.
3745 return Options;
3746 }
3747 Options.AllowOverlappingLoads = true;
3748 Options.MaxNumLoads = TLI->getMaxExpandSizeMemcmp(OptSize);
3749 Options.NumLoadsPerBlock = Options.MaxNumLoads;
3750 // TODO: Though vector loads usually perform well on AArch64, in some targets
3751 // they may wake up the FP unit, which raises the power consumption. Perhaps
3752 // they could be used with no holds barred (-O3).
3753 Options.LoadSizes = {8, 4, 2, 1};
3754 Options.AllowedTailExpansions = {3, 5, 6};
3755 return Options;
3756}
3757
3759 return ST->hasSVE();
3760}
3761
3764 Align Alignment, unsigned AddressSpace,
3766 if (useNeonVector(Src))
3767 return BaseT::getMaskedMemoryOpCost(Opcode, Src, Alignment, AddressSpace,
3768 CostKind);
3769 auto LT = getTypeLegalizationCost(Src);
3770 if (!LT.first.isValid())
3772
3773 // Return an invalid cost for element types that we are unable to lower.
3774 auto *VT = cast<VectorType>(Src);
3775 if (VT->getElementType()->isIntegerTy(1))
3777
3778 // The code-generator is currently not able to handle scalable vectors
3779 // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
3780 // it. This change will be removed when code-generation for these types is
3781 // sufficiently reliable.
3782 if (VT->getElementCount() == ElementCount::getScalable(1))
3784
3785 return LT.first;
3786}
3787
3788// This function returns gather/scatter overhead either from
3789// user-provided value or specialized values per-target from \p ST.
3790static unsigned getSVEGatherScatterOverhead(unsigned Opcode,
3791 const AArch64Subtarget *ST) {
3792 assert((Opcode == Instruction::Load || Opcode == Instruction::Store) &&
3793 "Should be called on only load or stores.");
3794 switch (Opcode) {
3795 case Instruction::Load:
3796 if (SVEGatherOverhead.getNumOccurrences() > 0)
3797 return SVEGatherOverhead;
3798 return ST->getGatherOverhead();
3799 break;
3800 case Instruction::Store:
3801 if (SVEScatterOverhead.getNumOccurrences() > 0)
3802 return SVEScatterOverhead;
3803 return ST->getScatterOverhead();
3804 break;
3805 default:
3806 llvm_unreachable("Shouldn't have reached here");
3807 }
3808}
3809
3811 unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
3812 Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) {
3813 if (useNeonVector(DataTy) || !isLegalMaskedGatherScatter(DataTy))
3814 return BaseT::getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
3815 Alignment, CostKind, I);
3816 auto *VT = cast<VectorType>(DataTy);
3817 auto LT = getTypeLegalizationCost(DataTy);
3818 if (!LT.first.isValid())
3820
3821 // Return an invalid cost for element types that we are unable to lower.
3822 if (!LT.second.isVector() ||
3823 !isElementTypeLegalForScalableVector(VT->getElementType()) ||
3824 VT->getElementType()->isIntegerTy(1))
3826
3827 // The code-generator is currently not able to handle scalable vectors
3828 // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
3829 // it. This change will be removed when code-generation for these types is
3830 // sufficiently reliable.
3831 if (VT->getElementCount() == ElementCount::getScalable(1))
3833
3834 ElementCount LegalVF = LT.second.getVectorElementCount();
3835 InstructionCost MemOpCost =
3836 getMemoryOpCost(Opcode, VT->getElementType(), Alignment, 0, CostKind,
3837 {TTI::OK_AnyValue, TTI::OP_None}, I);
3838 // Add on an overhead cost for using gathers/scatters.
3839 MemOpCost *= getSVEGatherScatterOverhead(Opcode, ST);
3840 return LT.first * MemOpCost * getMaxNumElements(LegalVF);
3841}
3842
3844 return isa<FixedVectorType>(Ty) && !ST->useSVEForFixedLengthVectors();
3845}
3846
3848 MaybeAlign Alignment,
3849 unsigned AddressSpace,
3851 TTI::OperandValueInfo OpInfo,
3852 const Instruction *I) {
3853 EVT VT = TLI->getValueType(DL, Ty, true);
3854 // Type legalization can't handle structs
3855 if (VT == MVT::Other)
3856 return BaseT::getMemoryOpCost(Opcode, Ty, Alignment, AddressSpace,
3857 CostKind);
3858
3859 auto LT = getTypeLegalizationCost(Ty);
3860 if (!LT.first.isValid())
3862
3863 // The code-generator is currently not able to handle scalable vectors
3864 // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
3865 // it. This change will be removed when code-generation for these types is
3866 // sufficiently reliable.
3867 // We also only support full register predicate loads and stores.
3868 if (auto *VTy = dyn_cast<ScalableVectorType>(Ty))
3869 if (VTy->getElementCount() == ElementCount::getScalable(1) ||
3870 (VTy->getElementType()->isIntegerTy(1) &&
3871 !VTy->getElementCount().isKnownMultipleOf(
3874
3875 // TODO: consider latency as well for TCK_SizeAndLatency.
3877 return LT.first;
3878
3880 return 1;
3881
3882 if (ST->isMisaligned128StoreSlow() && Opcode == Instruction::Store &&
3883 LT.second.is128BitVector() && (!Alignment || *Alignment < Align(16))) {
3884 // Unaligned stores are extremely inefficient. We don't split all
3885 // unaligned 128-bit stores because the negative impact that has shown in
3886 // practice on inlined block copy code.
3887 // We make such stores expensive so that we will only vectorize if there
3888 // are 6 other instructions getting vectorized.
3889 const int AmortizationCost = 6;
3890
3891 return LT.first * 2 * AmortizationCost;
3892 }
3893
3894 // Opaque ptr or ptr vector types are i64s and can be lowered to STP/LDPs.
3895 if (Ty->isPtrOrPtrVectorTy())
3896 return LT.first;
3897
3898 if (useNeonVector(Ty)) {
3899 // Check truncating stores and extending loads.
3900 if (Ty->getScalarSizeInBits() != LT.second.getScalarSizeInBits()) {
3901 // v4i8 types are lowered to scalar a load/store and sshll/xtn.
3902 if (VT == MVT::v4i8)
3903 return 2;
3904 // Otherwise we need to scalarize.
3905 return cast<FixedVectorType>(Ty)->getNumElements() * 2;
3906 }
3907 EVT EltVT = VT.getVectorElementType();
3908 unsigned EltSize = EltVT.getScalarSizeInBits();
3909 if (!isPowerOf2_32(EltSize) || EltSize < 8 || EltSize > 64 ||
3910 VT.getVectorNumElements() >= (128 / EltSize) || !Alignment ||
3911 *Alignment != Align(1))
3912 return LT.first;
3913 // FIXME: v3i8 lowering currently is very inefficient, due to automatic
3914 // widening to v4i8, which produces suboptimal results.
3915 if (VT.getVectorNumElements() == 3 && EltVT == MVT::i8)
3916 return LT.first;
3917
3918 // Check non-power-of-2 loads/stores for legal vector element types with
3919 // NEON. Non-power-of-2 memory ops will get broken down to a set of
3920 // operations on smaller power-of-2 ops, including ld1/st1.
3921 LLVMContext &C = Ty->getContext();
3923 SmallVector<EVT> TypeWorklist;
3924 TypeWorklist.push_back(VT);
3925 while (!TypeWorklist.empty()) {
3926 EVT CurrVT = TypeWorklist.pop_back_val();
3927 unsigned CurrNumElements = CurrVT.getVectorNumElements();
3928 if (isPowerOf2_32(CurrNumElements)) {
3929 Cost += 1;
3930 continue;
3931 }
3932
3933 unsigned PrevPow2 = NextPowerOf2(CurrNumElements) / 2;
3934 TypeWorklist.push_back(EVT::getVectorVT(C, EltVT, PrevPow2));
3935 TypeWorklist.push_back(
3936 EVT::getVectorVT(C, EltVT, CurrNumElements - PrevPow2));
3937 }
3938 return Cost;
3939 }
3940
3941 return LT.first;
3942}
3943
3945 unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices,
3946 Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind,
3947 bool UseMaskForCond, bool UseMaskForGaps) {
3948 assert(Factor >= 2 && "Invalid interleave factor");
3949 auto *VecVTy = cast<VectorType>(VecTy);
3950
3951 if (VecTy->isScalableTy() && !ST->hasSVE())
3953
3954 // Vectorization for masked interleaved accesses is only enabled for scalable
3955 // VF.
3956 if (!VecTy->isScalableTy() && (UseMaskForCond || UseMaskForGaps))
3958
3959 if (!UseMaskForGaps && Factor <= TLI->getMaxSupportedInterleaveFactor()) {
3960 unsigned MinElts = VecVTy->getElementCount().getKnownMinValue();
3961 auto *SubVecTy =
3962 VectorType::get(VecVTy->getElementType(),
3963 VecVTy->getElementCount().divideCoefficientBy(Factor));
3964
3965 // ldN/stN only support legal vector types of size 64 or 128 in bits.
3966 // Accesses having vector types that are a multiple of 128 bits can be
3967 // matched to more than one ldN/stN instruction.
3968 bool UseScalable;
3969 if (MinElts % Factor == 0 &&
3970 TLI->isLegalInterleavedAccessType(SubVecTy, DL, UseScalable))
3971 return Factor * TLI->getNumInterleavedAccesses(SubVecTy, DL, UseScalable);
3972 }
3973
3974 return BaseT::getInterleavedMemoryOpCost(Opcode, VecTy, Factor, Indices,
3975 Alignment, AddressSpace, CostKind,
3976 UseMaskForCond, UseMaskForGaps);
3977}
3978
3983 for (auto *I : Tys) {
3984 if (!I->isVectorTy())
3985 continue;
3986 if (I->getScalarSizeInBits() * cast<FixedVectorType>(I)->getNumElements() ==
3987 128)
3988 Cost += getMemoryOpCost(Instruction::Store, I, Align(128), 0, CostKind) +
3989 getMemoryOpCost(Instruction::Load, I, Align(128), 0, CostKind);
3990 }
3991 return Cost;
3992}
3993
3995 return ST->getMaxInterleaveFactor();
3996}
3997
3998// For Falkor, we want to avoid having too many strided loads in a loop since
3999// that can exhaust the HW prefetcher resources. We adjust the unroller
4000// MaxCount preference below to attempt to ensure unrolling doesn't create too
4001// many strided loads.
4002static void
4005 enum { MaxStridedLoads = 7 };
4006 auto countStridedLoads = [](Loop *L, ScalarEvolution &SE) {
4007 int StridedLoads = 0;
4008 // FIXME? We could make this more precise by looking at the CFG and
4009 // e.g. not counting loads in each side of an if-then-else diamond.
4010 for (const auto BB : L->blocks()) {
4011 for (auto &I : *BB) {
4012 LoadInst *LMemI = dyn_cast<LoadInst>(&I);
4013 if (!LMemI)
4014 continue;
4015
4016 Value *PtrValue = LMemI->getPointerOperand();
4017 if (L->isLoopInvariant(PtrValue))
4018 continue;
4019
4020 const SCEV *LSCEV = SE.getSCEV(PtrValue);
4021 const SCEVAddRecExpr *LSCEVAddRec = dyn_cast<SCEVAddRecExpr>(LSCEV);
4022 if (!LSCEVAddRec || !LSCEVAddRec->isAffine())
4023 continue;
4024
4025 // FIXME? We could take pairing of unrolled load copies into account
4026 // by looking at the AddRec, but we would probably have to limit this
4027 // to loops with no stores or other memory optimization barriers.
4028 ++StridedLoads;
4029 // We've seen enough strided loads that seeing more won't make a
4030 // difference.
4031 if (StridedLoads > MaxStridedLoads / 2)
4032 return StridedLoads;
4033 }
4034 }
4035 return StridedLoads;
4036 };
4037
4038 int StridedLoads = countStridedLoads(L, SE);
4039 LLVM_DEBUG(dbgs() << "falkor-hwpf: detected " << StridedLoads
4040 << " strided loads\n");
4041 // Pick the largest power of 2 unroll count that won't result in too many
4042 // strided loads.
4043 if (StridedLoads) {
4044 UP.MaxCount = 1 << Log2_32(MaxStridedLoads / StridedLoads);
4045 LLVM_DEBUG(dbgs() << "falkor-hwpf: setting unroll MaxCount to "
4046 << UP.MaxCount << '\n');
4047 }
4048}
4049
4050/// For Apple CPUs, we want to runtime-unroll loops to make better use if the
4051/// OOO engine's wide instruction window and various predictors.
4052static void
4056 // Limit loops with structure that is highly likely to benefit from runtime
4057 // unrolling; that is we exclude outer loops, loops with multiple exits and
4058 // many blocks (i.e. likely with complex control flow). Note that the
4059 // heuristics here may be overly conservative and we err on the side of
4060 // avoiding runtime unrolling rather than unroll excessively. They are all
4061 // subject to further refinement.
4062 if (!L->isInnermost() || !L->getExitBlock() || L->getNumBlocks() > 8)
4063 return;
4064
4065 const SCEV *BTC = SE.getBackedgeTakenCount(L);
4066 if (isa<SCEVConstant>(BTC) || isa<SCEVCouldNotCompute>(BTC) ||
4067 (SE.getSmallConstantMaxTripCount(L) > 0 &&
4068 SE.getSmallConstantMaxTripCount(L) <= 32))
4069 return;
4070 if (findStringMetadataForLoop(L, "llvm.loop.isvectorized"))
4071 return;
4072
4073 int64_t Size = 0;
4074 for (auto *BB : L->getBlocks()) {
4075 for (auto &I : *BB) {
4076 if (!isa<IntrinsicInst>(&I) && isa<CallBase>(&I))
4077 return;
4078 SmallVector<const Value *, 4> Operands(I.operand_values());
4079 Size +=
4081 }
4082 }
4083
4084 // Limit to loops with trip counts that are cheap to expand.
4085 UP.SCEVExpansionBudget = 1;
4086
4087 // Try to unroll small, single block loops, if they have load/store
4088 // dependencies, to expose more parallel memory access streams.
4089 BasicBlock *Header = L->getHeader();
4090 if (Header == L->getLoopLatch()) {
4091 if (Size > 8)
4092 return;
4093
4094 SmallPtrSet<Value *, 8> LoadedValues;
4096 for (auto *BB : L->blocks()) {
4097 for (auto &I : *BB) {
4099 if (!Ptr)
4100 continue;
4101 const SCEV *PtrSCEV = SE.getSCEV(Ptr);
4102 if (SE.isLoopInvariant(PtrSCEV, L))
4103 continue;
4104 if (isa<LoadInst>(&I))
4105 LoadedValues.insert(&I);
4106 else
4107 Stores.push_back(cast<StoreInst>(&I));
4108 }
4109 }
4110
4111 // Try to find an unroll count that maximizes the use of the instruction
4112 // window, i.e. trying to fetch as many instructions per cycle as possible.
4113 unsigned MaxInstsPerLine = 16;
4114 unsigned UC = 1;
4115 unsigned BestUC = 1;
4116 unsigned SizeWithBestUC = BestUC * Size;
4117 while (UC <= 8) {
4118 unsigned SizeWithUC = UC * Size;
4119 if (SizeWithUC > 48)
4120 break;
4121 if ((SizeWithUC % MaxInstsPerLine) == 0 ||
4122 (SizeWithBestUC % MaxInstsPerLine) < (SizeWithUC % MaxInstsPerLine)) {
4123 BestUC = UC;
4124 SizeWithBestUC = BestUC * Size;
4125 }
4126 UC++;
4127 }
4128
4129 if (BestUC == 1 || none_of(Stores, [&LoadedValues](StoreInst *SI) {
4130 return LoadedValues.contains(SI->getOperand(0));
4131 }))
4132 return;
4133
4134 UP.Runtime = true;
4135 UP.DefaultUnrollRuntimeCount = BestUC;
4136 return;
4137 }
4138
4139 // Try to runtime-unroll loops with early-continues depending on loop-varying
4140 // loads; this helps with branch-prediction for the early-continues.
4141 auto *Term = dyn_cast<BranchInst>(Header->getTerminator());
4142 auto *Latch = L->getLoopLatch();
4144 if (!Term || !Term->isConditional() || Preds.size() == 1 ||
4145 none_of(Preds, [Header](BasicBlock *Pred) { return Header == Pred; }) ||
4146 none_of(Preds, [L](BasicBlock *Pred) { return L->contains(Pred); }))
4147 return;
4148
4149 std::function<bool(Instruction *, unsigned)> DependsOnLoopLoad =
4150 [&](Instruction *I, unsigned Depth) -> bool {
4151 if (isa<PHINode>(I) || L->isLoopInvariant(I) || Depth > 8)
4152 return false;
4153
4154 if (isa<LoadInst>(I))
4155 return true;
4156
4157 return any_of(I->operands(), [&](Value *V) {
4158 auto *I = dyn_cast<Instruction>(V);
4159 return I && DependsOnLoopLoad(I, Depth + 1);
4160 });
4161 };
4162 CmpPredicate Pred;
4163 Instruction *I;
4164 if (match(Term, m_Br(m_ICmp(Pred, m_Instruction(I), m_Value()), m_Value(),
4165 m_Value())) &&
4166 DependsOnLoopLoad(I, 0)) {
4167 UP.Runtime = true;
4168 }
4169}
4170
4174 // Enable partial unrolling and runtime unrolling.
4175 BaseT::getUnrollingPreferences(L, SE, UP, ORE);
4176
4177 UP.UpperBound = true;
4178
4179 // For inner loop, it is more likely to be a hot one, and the runtime check
4180 // can be promoted out from LICM pass, so the overhead is less, let's try
4181 // a larger threshold to unroll more loops.
4182 if (L->getLoopDepth() > 1)
4183 UP.PartialThreshold *= 2;
4184
4185 // Disable partial & runtime unrolling on -Os.
4187
4188 // Apply subtarget-specific unrolling preferences.
4189 switch (ST->getProcFamily()) {
4190 case AArch64Subtarget::AppleA14:
4191 case AArch64Subtarget::AppleA15:
4192 case AArch64Subtarget::AppleA16:
4193 case AArch64Subtarget::AppleM4:
4194 getAppleRuntimeUnrollPreferences(L, SE, UP, *this);
4195 break;
4196 case AArch64Subtarget::Falkor:
4199 break;
4200 default:
4201 break;
4202 }
4203
4204 // Scan the loop: don't unroll loops with calls as this could prevent
4205 // inlining. Don't unroll vector loops either, as they don't benefit much from
4206 // unrolling.
4207 for (auto *BB : L->getBlocks()) {
4208 for (auto &I : *BB) {
4209 // Don't unroll vectorised loop.
4210 if (I.getType()->isVectorTy())
4211 return;
4212
4213 if (isa<CallInst>(I) || isa<InvokeInst>(I)) {
4214 if (const Function *F = cast<CallBase>(I).getCalledFunction()) {
4215 if (!isLoweredToCall(F))
4216 continue;
4217 }
4218 return;
4219 }
4220 }
4221 }
4222
4223 // Enable runtime unrolling for in-order models
4224 // If mcpu is omitted, getProcFamily() returns AArch64Subtarget::Others, so by
4225 // checking for that case, we can ensure that the default behaviour is
4226 // unchanged
4228 !ST->getSchedModel().isOutOfOrder()) {
4229 UP.Runtime = true;
4230 UP.Partial = true;
4231 UP.UnrollRemainder = true;
4233
4234 UP.UnrollAndJam = true;
4236 }
4237}
4238
4242}
4243
4245 Type *ExpectedType) {
4246 switch (Inst->getIntrinsicID()) {
4247 default:
4248 return nullptr;
4249 case Intrinsic::aarch64_neon_st2:
4250 case Intrinsic::aarch64_neon_st3:
4251 case Intrinsic::aarch64_neon_st4: {
4252 // Create a struct type
4253 StructType *ST = dyn_cast<StructType>(ExpectedType);
4254 if (!ST)
4255 return nullptr;
4256 unsigned NumElts = Inst->arg_size() - 1;
4257 if (ST->getNumElements() != NumElts)
4258 return nullptr;
4259 for (unsigned i = 0, e = NumElts; i != e; ++i) {
4260 if (Inst->getArgOperand(i)->getType() != ST->getElementType(i))
4261 return nullptr;
4262 }
4263 Value *Res = PoisonValue::get(ExpectedType);
4264 IRBuilder<> Builder(Inst);
4265 for (unsigned i = 0, e = NumElts; i != e; ++i) {
4266 Value *L = Inst->getArgOperand(i);
4267 Res = Builder.CreateInsertValue(Res, L, i);
4268 }
4269 return Res;
4270 }
4271 case Intrinsic::aarch64_neon_ld2:
4272 case Intrinsic::aarch64_neon_ld3:
4273 case Intrinsic::aarch64_neon_ld4:
4274 if (Inst->getType() == ExpectedType)
4275 return Inst;
4276 return nullptr;
4277 }
4278}
4279
4281 MemIntrinsicInfo &Info) {
4282 switch (Inst->getIntrinsicID()) {
4283 default:
4284 break;
4285 case Intrinsic::aarch64_neon_ld2:
4286 case Intrinsic::aarch64_neon_ld3:
4287 case Intrinsic::aarch64_neon_ld4:
4288 Info.ReadMem = true;
4289 Info.WriteMem = false;
4290 Info.PtrVal = Inst->getArgOperand(0);
4291 break;
4292 case Intrinsic::aarch64_neon_st2:
4293 case Intrinsic::aarch64_neon_st3:
4294 case Intrinsic::aarch64_neon_st4:
4295 Info.ReadMem = false;
4296 Info.WriteMem = true;
4297 Info.PtrVal = Inst->getArgOperand(Inst->arg_size() - 1);
4298 break;
4299 }
4300
4301 switch (Inst->getIntrinsicID()) {
4302 default:
4303 return false;
4304 case Intrinsic::aarch64_neon_ld2:
4305 case Intrinsic::aarch64_neon_st2:
4306 Info.MatchingId = VECTOR_LDST_TWO_ELEMENTS;
4307 break;
4308 case Intrinsic::aarch64_neon_ld3:
4309 case Intrinsic::aarch64_neon_st3:
4310 Info.MatchingId = VECTOR_LDST_THREE_ELEMENTS;
4311 break;
4312 case Intrinsic::aarch64_neon_ld4:
4313 case Intrinsic::aarch64_neon_st4:
4314 Info.MatchingId = VECTOR_LDST_FOUR_ELEMENTS;
4315 break;
4316 }
4317 return true;
4318}
4319
4320/// See if \p I should be considered for address type promotion. We check if \p
4321/// I is a sext with right type and used in memory accesses. If it used in a
4322/// "complex" getelementptr, we allow it to be promoted without finding other
4323/// sext instructions that sign extended the same initial value. A getelementptr
4324/// is considered as "complex" if it has more than 2 operands.
4326 const Instruction &I, bool &AllowPromotionWithoutCommonHeader) {
4327 bool Considerable = false;
4328 AllowPromotionWithoutCommonHeader = false;
4329 if (!isa<SExtInst>(&I))
4330 return false;
4331 Type *ConsideredSExtType =
4332 Type::getInt64Ty(I.getParent()->getParent()->getContext());
4333 if (I.getType() != ConsideredSExtType)
4334 return false;
4335 // See if the sext is the one with the right type and used in at least one
4336 // GetElementPtrInst.
4337 for (const User *U : I.users()) {
4338 if (const GetElementPtrInst *GEPInst = dyn_cast<GetElementPtrInst>(U)) {
4339 Considerable = true;
4340 // A getelementptr is considered as "complex" if it has more than 2
4341 // operands. We will promote a SExt used in such complex GEP as we
4342 // expect some computation to be merged if they are done on 64 bits.
4343 if (GEPInst->getNumOperands() > 2) {
4344 AllowPromotionWithoutCommonHeader = true;
4345 break;
4346 }
4347 }
4348 }
4349 return Considerable;
4350}
4351
4353 const RecurrenceDescriptor &RdxDesc, ElementCount VF) const {
4354 if (!VF.isScalable())
4355 return true;
4356
4357 Type *Ty = RdxDesc.getRecurrenceType();
4359 return false;
4360
4361 switch (RdxDesc.getRecurrenceKind()) {
4362 case RecurKind::Add:
4363 case RecurKind::FAdd:
4364 case RecurKind::And:
4365 case RecurKind::Or:
4366 case RecurKind::Xor:
4367 case RecurKind::SMin:
4368 case RecurKind::SMax:
4369 case RecurKind::UMin:
4370 case RecurKind::UMax:
4371 case RecurKind::FMin:
4372 case RecurKind::FMax:
4373 case RecurKind::FMulAdd:
4374 case RecurKind::IAnyOf:
4375 case RecurKind::FAnyOf:
4376 return true;
4377 default:
4378 return false;
4379 }
4380}
4381
4384 FastMathFlags FMF,
4386 // The code-generator is currently not able to handle scalable vectors
4387 // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
4388 // it. This change will be removed when code-generation for these types is
4389 // sufficiently reliable.
4390 if (auto *VTy = dyn_cast<ScalableVectorType>(Ty))
4391 if (VTy->getElementCount() == ElementCount::getScalable(1))
4393
4394 std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Ty);
4395
4396 if (LT.second.getScalarType() == MVT::f16 && !ST->hasFullFP16())
4397 return BaseT::getMinMaxReductionCost(IID, Ty, FMF, CostKind);
4398
4399 InstructionCost LegalizationCost = 0;
4400 if (LT.first > 1) {
4401 Type *LegalVTy = EVT(LT.second).getTypeForEVT(Ty->getContext());
4402 IntrinsicCostAttributes Attrs(IID, LegalVTy, {LegalVTy, LegalVTy}, FMF);
4403 LegalizationCost = getIntrinsicInstrCost(Attrs, CostKind) * (LT.first - 1);
4404 }
4405
4406 return LegalizationCost + /*Cost of horizontal reduction*/ 2;
4407}
4408
4410 unsigned Opcode, VectorType *ValTy, TTI::TargetCostKind CostKind) {
4411 std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(ValTy);
4412 InstructionCost LegalizationCost = 0;
4413 if (LT.first > 1) {
4414 Type *LegalVTy = EVT(LT.second).getTypeForEVT(ValTy->getContext());
4415 LegalizationCost = getArithmeticInstrCost(Opcode, LegalVTy, CostKind);
4416 LegalizationCost *= LT.first - 1;
4417 }
4418
4419 int ISD = TLI->InstructionOpcodeToISD(Opcode);
4420 assert(ISD && "Invalid opcode");
4421 // Add the final reduction cost for the legal horizontal reduction
4422 switch (ISD) {
4423 case ISD::ADD:
4424 case ISD::AND:
4425 case ISD::OR:
4426 case ISD::XOR:
4427 case ISD::FADD:
4428 return LegalizationCost + 2;
4429 default:
4431 }
4432}
4433
4436 std::optional<FastMathFlags> FMF,
4438 // The code-generator is currently not able to handle scalable vectors
4439 // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
4440 // it. This change will be removed when code-generation for these types is
4441 // sufficiently reliable.
4442 if (auto *VTy = dyn_cast<ScalableVectorType>(ValTy))
4443 if (VTy->getElementCount() == ElementCount::getScalable(1))
4445
4447 if (auto *FixedVTy = dyn_cast<FixedVectorType>(ValTy)) {
4448 InstructionCost BaseCost =
4449 BaseT::getArithmeticReductionCost(Opcode, ValTy, FMF, CostKind);
4450 // Add on extra cost to reflect the extra overhead on some CPUs. We still
4451 // end up vectorizing for more computationally intensive loops.
4452 return BaseCost + FixedVTy->getNumElements();
4453 }
4454
4455 if (Opcode != Instruction::FAdd)
4457
4458 auto *VTy = cast<ScalableVectorType>(ValTy);
4460 getArithmeticInstrCost(Opcode, VTy->getScalarType(), CostKind);
4461 Cost *= getMaxNumElements(VTy->getElementCount());
4462 return Cost;
4463 }
4464
4465 if (isa<ScalableVectorType>(ValTy))
4466 return getArithmeticReductionCostSVE(Opcode, ValTy, CostKind);
4467
4468 std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(ValTy);
4469 MVT MTy = LT.second;
4470 int ISD = TLI->InstructionOpcodeToISD(Opcode);
4471 assert(ISD && "Invalid opcode");
4472
4473 // Horizontal adds can use the 'addv' instruction. We model the cost of these
4474 // instructions as twice a normal vector add, plus 1 for each legalization
4475 // step (LT.first). This is the only arithmetic vector reduction operation for
4476 // which we have an instruction.
4477 // OR, XOR and AND costs should match the codegen from:
4478 // OR: llvm/test/CodeGen/AArch64/reduce-or.ll
4479 // XOR: llvm/test/CodeGen/AArch64/reduce-xor.ll
4480 // AND: llvm/test/CodeGen/AArch64/reduce-and.ll
4481 static const CostTblEntry CostTblNoPairwise[]{
4482 {ISD::ADD, MVT::v8i8, 2},
4483 {ISD::ADD, MVT::v16i8, 2},
4484 {ISD::ADD, MVT::v4i16, 2},
4485 {ISD::ADD, MVT::v8i16, 2},
4486 {ISD::ADD, MVT::v4i32, 2},
4487 {ISD::ADD, MVT::v2i64, 2},
4488 {ISD::OR, MVT::v8i8, 15},
4489 {ISD::OR, MVT::v16i8, 17},
4490 {ISD::OR, MVT::v4i16, 7},
4491 {ISD::OR, MVT::v8i16, 9},
4492 {ISD::OR, MVT::v2i32, 3},
4493 {ISD::OR, MVT::v4i32, 5},
4494 {ISD::OR, MVT::v2i64, 3},
4495 {ISD::XOR, MVT::v8i8, 15},
4496 {ISD::XOR, MVT::v16i8, 17},
4497 {ISD::XOR, MVT::v4i16, 7},
4498 {ISD::XOR, MVT::v8i16, 9},
4499 {ISD::XOR, MVT::v2i32, 3},
4500 {ISD::XOR, MVT::v4i32, 5},
4501 {ISD::XOR, MVT::v2i64, 3},
4502 {ISD::AND, MVT::v8i8, 15},
4503 {ISD::AND, MVT::v16i8, 17},
4504 {ISD::AND, MVT::v4i16, 7},
4505 {ISD::AND, MVT::v8i16, 9},
4506 {ISD::AND, MVT::v2i32, 3},
4507 {ISD::AND, MVT::v4i32, 5},
4508 {ISD::AND, MVT::v2i64, 3},
4509 };
4510 switch (ISD) {
4511 default:
4512 break;
4513 case ISD::FADD:
4514 if (Type *EltTy = ValTy->getScalarType();
4515 // FIXME: For half types without fullfp16 support, this could extend and
4516 // use a fp32 faddp reduction but current codegen unrolls.
4517 MTy.isVector() && (EltTy->isFloatTy() || EltTy->isDoubleTy() ||
4518 (EltTy->isHalfTy() && ST->hasFullFP16()))) {
4519 const unsigned NElts = MTy.getVectorNumElements();
4520 if (ValTy->getElementCount().getFixedValue() >= 2 && NElts >= 2 &&
4521 isPowerOf2_32(NElts))
4522 // Reduction corresponding to series of fadd instructions is lowered to
4523 // series of faddp instructions. faddp has latency/throughput that
4524 // matches fadd instruction and hence, every faddp instruction can be
4525 // considered to have a relative cost = 1 with
4526 // CostKind = TCK_RecipThroughput.
4527 // An faddp will pairwise add vector elements, so the size of input
4528 // vector reduces by half every time, requiring
4529 // #(faddp instructions) = log2_32(NElts).
4530 return (LT.first - 1) + /*No of faddp instructions*/ Log2_32(NElts);
4531 }
4532 break;
4533 case ISD::ADD:
4534 if (const auto *Entry = CostTableLookup(CostTblNoPairwise, ISD, MTy))
4535 return (LT.first - 1) + Entry->Cost;
4536 break;
4537 case ISD::XOR:
4538 case ISD::AND:
4539 case ISD::OR:
4540 const auto *Entry = CostTableLookup(CostTblNoPairwise, ISD, MTy);
4541 if (!Entry)
4542 break;
4543 auto *ValVTy = cast<FixedVectorType>(ValTy);
4544 if (MTy.getVectorNumElements() <= ValVTy->getNumElements() &&
4545 isPowerOf2_32(ValVTy->getNumElements())) {
4546 InstructionCost ExtraCost = 0;
4547 if (LT.first != 1) {
4548 // Type needs to be split, so there is an extra cost of LT.first - 1
4549 // arithmetic ops.
4550 auto *Ty = FixedVectorType::get(ValTy->getElementType(),
4551 MTy.getVectorNumElements());
4552 ExtraCost = getArithmeticInstrCost(Opcode, Ty, CostKind);
4553 ExtraCost *= LT.first - 1;
4554 }
4555 // All and/or/xor of i1 will be lowered with maxv/minv/addv + fmov
4556 auto Cost = ValVTy->getElementType()->isIntegerTy(1) ? 2 : Entry->Cost;
4557 return Cost + ExtraCost;
4558 }
4559 break;
4560 }
4561 return BaseT::getArithmeticReductionCost(Opcode, ValTy, FMF, CostKind);
4562}
4563
4565 static const CostTblEntry ShuffleTbl[] = {
4566 { TTI::SK_Splice, MVT::nxv16i8, 1 },
4567 { TTI::SK_Splice, MVT::nxv8i16, 1 },
4568 { TTI::SK_Splice, MVT::nxv4i32, 1 },
4569 { TTI::SK_Splice, MVT::nxv2i64, 1 },
4570 { TTI::SK_Splice, MVT::nxv2f16, 1 },
4571 { TTI::SK_Splice, MVT::nxv4f16, 1 },
4572 { TTI::SK_Splice, MVT::nxv8f16, 1 },
4573 { TTI::SK_Splice, MVT::nxv2bf16, 1 },
4574 { TTI::SK_Splice, MVT::nxv4bf16, 1 },
4575 { TTI::SK_Splice, MVT::nxv8bf16, 1 },
4576 { TTI::SK_Splice, MVT::nxv2f32, 1 },
4577 { TTI::SK_Splice, MVT::nxv4f32, 1 },
4578 { TTI::SK_Splice, MVT::nxv2f64, 1 },
4579 };
4580
4581 // The code-generator is currently not able to handle scalable vectors
4582 // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
4583 // it. This change will be removed when code-generation for these types is
4584 // sufficiently reliable.
4587
4588 std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Tp);
4589 Type *LegalVTy = EVT(LT.second).getTypeForEVT(Tp->getContext());
4591 EVT PromotedVT = LT.second.getScalarType() == MVT::i1
4592 ? TLI->getPromotedVTForPredicate(EVT(LT.second))
4593 : LT.second;
4594 Type *PromotedVTy = EVT(PromotedVT).getTypeForEVT(Tp->getContext());
4595 InstructionCost LegalizationCost = 0;
4596 if (Index < 0) {
4597 LegalizationCost =
4598 getCmpSelInstrCost(Instruction::ICmp, PromotedVTy, PromotedVTy,
4600 getCmpSelInstrCost(Instruction::Select, PromotedVTy, LegalVTy,
4602 }
4603
4604 // Predicated splice are promoted when lowering. See AArch64ISelLowering.cpp
4605 // Cost performed on a promoted type.
4606 if (LT.second.getScalarType() == MVT::i1) {
4607 LegalizationCost +=
4608 getCastInstrCost(Instruction::ZExt, PromotedVTy, LegalVTy,
4610 getCastInstrCost(Instruction::Trunc, LegalVTy, PromotedVTy,
4612 }
4613 const auto *Entry =
4614 CostTableLookup(ShuffleTbl, TTI::SK_Splice, PromotedVT.getSimpleVT());
4615 assert(Entry && "Illegal Type for Splice");
4616 LegalizationCost += Entry->Cost;
4617 return LegalizationCost * LT.first;
4618}
4619
4622 TTI::TargetCostKind CostKind, int Index, VectorType *SubTp,
4623 ArrayRef<const Value *> Args, const Instruction *CxtI) {
4624 std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Tp);
4625
4626 // If we have a Mask, and the LT is being legalized somehow, split the Mask
4627 // into smaller vectors and sum the cost of each shuffle.
4628 if (!Mask.empty() && isa<FixedVectorType>(Tp) && LT.second.isVector() &&
4629 Tp->getScalarSizeInBits() == LT.second.getScalarSizeInBits() &&
4630 Mask.size() > LT.second.getVectorNumElements() && !Index && !SubTp) {
4631
4632 // Check for LD3/LD4 instructions, which are represented in llvm IR as
4633 // deinterleaving-shuffle(load). The shuffle cost could potentially be free,
4634 // but we model it with a cost of LT.first so that LD3/LD4 have a higher
4635 // cost than just the load.
4636 if (Args.size() >= 1 && isa<LoadInst>(Args[0]) &&
4639 return std::max<InstructionCost>(1, LT.first / 4);
4640
4641 // Check for ST3/ST4 instructions, which are represented in llvm IR as
4642 // store(interleaving-shuffle). The shuffle cost could potentially be free,
4643 // but we model it with a cost of LT.first so that ST3/ST4 have a higher
4644 // cost than just the store.
4645 if (CxtI && CxtI->hasOneUse() && isa<StoreInst>(*CxtI->user_begin()) &&
4647 Mask, 4, Tp->getElementCount().getKnownMinValue() * 2) ||
4649 Mask, 3, Tp->getElementCount().getKnownMinValue() * 2)))
4650 return LT.first;
4651
4652 unsigned TpNumElts = Mask.size();
4653 unsigned LTNumElts = LT.second.getVectorNumElements();
4654 unsigned NumVecs = (TpNumElts + LTNumElts - 1) / LTNumElts;
4655 VectorType *NTp =
4656 VectorType::get(Tp->getScalarType(), LT.second.getVectorElementCount());
4658 for (unsigned N = 0; N < NumVecs; N++) {
4659 SmallVector<int> NMask;
4660 // Split the existing mask into chunks of size LTNumElts. Track the source
4661 // sub-vectors to ensure the result has at most 2 inputs.
4662 unsigned Source1, Source2;
4663 unsigned NumSources = 0;
4664 for (unsigned E = 0; E < LTNumElts; E++) {
4665 int MaskElt = (N * LTNumElts + E < TpNumElts) ? Mask[N * LTNumElts + E]
4667 if (MaskElt < 0) {
4669 continue;
4670 }
4671
4672 // Calculate which source from the input this comes from and whether it
4673 // is new to us.
4674 unsigned Source = MaskElt / LTNumElts;
4675 if (NumSources == 0) {
4676 Source1 = Source;
4677 NumSources = 1;
4678 } else if (NumSources == 1 &