LLVM 20.0.0git
TargetLowering.cpp
Go to the documentation of this file.
1//===-- TargetLowering.cpp - Implement the TargetLowering class -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This implements the TargetLowering class.
10//
11//===----------------------------------------------------------------------===//
12
14#include "llvm/ADT/STLExtras.h"
24#include "llvm/IR/DataLayout.h"
27#include "llvm/IR/LLVMContext.h"
28#include "llvm/MC/MCAsmInfo.h"
29#include "llvm/MC/MCExpr.h"
35#include <cctype>
36using namespace llvm;
37
38/// NOTE: The TargetMachine owns TLOF.
40 : TargetLoweringBase(tm) {}
41
42const char *TargetLowering::getTargetNodeName(unsigned Opcode) const {
43 return nullptr;
44}
45
48}
49
50/// Check whether a given call node is in tail position within its function. If
51/// so, it sets Chain to the input chain of the tail call.
53 SDValue &Chain) const {
55
56 // First, check if tail calls have been disabled in this function.
57 if (F.getFnAttribute("disable-tail-calls").getValueAsBool())
58 return false;
59
60 // Conservatively require the attributes of the call to match those of
61 // the return. Ignore following attributes because they don't affect the
62 // call sequence.
63 AttrBuilder CallerAttrs(F.getContext(), F.getAttributes().getRetAttrs());
64 for (const auto &Attr : {Attribute::Alignment, Attribute::Dereferenceable,
65 Attribute::DereferenceableOrNull, Attribute::NoAlias,
66 Attribute::NonNull, Attribute::NoUndef,
67 Attribute::Range, Attribute::NoFPClass})
68 CallerAttrs.removeAttribute(Attr);
69
70 if (CallerAttrs.hasAttributes())
71 return false;
72
73 // It's not safe to eliminate the sign / zero extension of the return value.
74 if (CallerAttrs.contains(Attribute::ZExt) ||
75 CallerAttrs.contains(Attribute::SExt))
76 return false;
77
78 // Check if the only use is a function return node.
79 return isUsedByReturnOnly(Node, Chain);
80}
81
83 const uint32_t *CallerPreservedMask,
84 const SmallVectorImpl<CCValAssign> &ArgLocs,
85 const SmallVectorImpl<SDValue> &OutVals) const {
86 for (unsigned I = 0, E = ArgLocs.size(); I != E; ++I) {
87 const CCValAssign &ArgLoc = ArgLocs[I];
88 if (!ArgLoc.isRegLoc())
89 continue;
90 MCRegister Reg = ArgLoc.getLocReg();
91 // Only look at callee saved registers.
92 if (MachineOperand::clobbersPhysReg(CallerPreservedMask, Reg))
93 continue;
94 // Check that we pass the value used for the caller.
95 // (We look for a CopyFromReg reading a virtual register that is used
96 // for the function live-in value of register Reg)
97 SDValue Value = OutVals[I];
98 if (Value->getOpcode() == ISD::AssertZext)
99 Value = Value.getOperand(0);
100 if (Value->getOpcode() != ISD::CopyFromReg)
101 return false;
102 Register ArgReg = cast<RegisterSDNode>(Value->getOperand(1))->getReg();
103 if (MRI.getLiveInPhysReg(ArgReg) != Reg)
104 return false;
105 }
106 return true;
107}
108
109/// Set CallLoweringInfo attribute flags based on a call instruction
110/// and called function attributes.
112 unsigned ArgIdx) {
113 IsSExt = Call->paramHasAttr(ArgIdx, Attribute::SExt);
114 IsZExt = Call->paramHasAttr(ArgIdx, Attribute::ZExt);
115 IsNoExt = Call->paramHasAttr(ArgIdx, Attribute::NoExt);
116 IsInReg = Call->paramHasAttr(ArgIdx, Attribute::InReg);
117 IsSRet = Call->paramHasAttr(ArgIdx, Attribute::StructRet);
118 IsNest = Call->paramHasAttr(ArgIdx, Attribute::Nest);
119 IsByVal = Call->paramHasAttr(ArgIdx, Attribute::ByVal);
120 IsPreallocated = Call->paramHasAttr(ArgIdx, Attribute::Preallocated);
121 IsInAlloca = Call->paramHasAttr(ArgIdx, Attribute::InAlloca);
122 IsReturned = Call->paramHasAttr(ArgIdx, Attribute::Returned);
123 IsSwiftSelf = Call->paramHasAttr(ArgIdx, Attribute::SwiftSelf);
124 IsSwiftAsync = Call->paramHasAttr(ArgIdx, Attribute::SwiftAsync);
125 IsSwiftError = Call->paramHasAttr(ArgIdx, Attribute::SwiftError);
126 Alignment = Call->getParamStackAlign(ArgIdx);
127 IndirectType = nullptr;
129 "multiple ABI attributes?");
130 if (IsByVal) {
131 IndirectType = Call->getParamByValType(ArgIdx);
132 if (!Alignment)
133 Alignment = Call->getParamAlign(ArgIdx);
134 }
135 if (IsPreallocated)
136 IndirectType = Call->getParamPreallocatedType(ArgIdx);
137 if (IsInAlloca)
138 IndirectType = Call->getParamInAllocaType(ArgIdx);
139 if (IsSRet)
140 IndirectType = Call->getParamStructRetType(ArgIdx);
141}
142
143/// Generate a libcall taking the given operands as arguments and returning a
144/// result of type RetVT.
145std::pair<SDValue, SDValue>
148 MakeLibCallOptions CallOptions,
149 const SDLoc &dl,
150 SDValue InChain) const {
151 if (!InChain)
152 InChain = DAG.getEntryNode();
153
155 Args.reserve(Ops.size());
156
158 for (unsigned i = 0; i < Ops.size(); ++i) {
159 SDValue NewOp = Ops[i];
160 Entry.Node = NewOp;
161 Entry.Ty = Entry.Node.getValueType().getTypeForEVT(*DAG.getContext());
162 Entry.IsSExt =
163 shouldSignExtendTypeInLibCall(Entry.Ty, CallOptions.IsSigned);
164 Entry.IsZExt = !Entry.IsSExt;
165
166 if (CallOptions.IsSoften &&
168 Entry.IsSExt = Entry.IsZExt = false;
169 }
170 Args.push_back(Entry);
171 }
172
173 if (LC == RTLIB::UNKNOWN_LIBCALL)
174 report_fatal_error("Unsupported library call operation!");
177
178 Type *RetTy = RetVT.getTypeForEVT(*DAG.getContext());
180 bool signExtend = shouldSignExtendTypeInLibCall(RetTy, CallOptions.IsSigned);
181 bool zeroExtend = !signExtend;
182
183 if (CallOptions.IsSoften &&
185 signExtend = zeroExtend = false;
186 }
187
188 CLI.setDebugLoc(dl)
189 .setChain(InChain)
190 .setLibCallee(getLibcallCallingConv(LC), RetTy, Callee, std::move(Args))
191 .setNoReturn(CallOptions.DoesNotReturn)
194 .setSExtResult(signExtend)
195 .setZExtResult(zeroExtend);
196 return LowerCallTo(CLI);
197}
198
200 std::vector<EVT> &MemOps, unsigned Limit, const MemOp &Op, unsigned DstAS,
201 unsigned SrcAS, const AttributeList &FuncAttributes) const {
202 if (Limit != ~unsigned(0) && Op.isMemcpyWithFixedDstAlign() &&
203 Op.getSrcAlign() < Op.getDstAlign())
204 return false;
205
206 EVT VT = getOptimalMemOpType(Op, FuncAttributes);
207
208 if (VT == MVT::Other) {
209 // Use the largest integer type whose alignment constraints are satisfied.
210 // We only need to check DstAlign here as SrcAlign is always greater or
211 // equal to DstAlign (or zero).
212 VT = MVT::LAST_INTEGER_VALUETYPE;
213 if (Op.isFixedDstAlign())
214 while (Op.getDstAlign() < (VT.getSizeInBits() / 8) &&
215 !allowsMisalignedMemoryAccesses(VT, DstAS, Op.getDstAlign()))
217 assert(VT.isInteger());
218
219 // Find the largest legal integer type.
220 MVT LVT = MVT::LAST_INTEGER_VALUETYPE;
221 while (!isTypeLegal(LVT))
222 LVT = (MVT::SimpleValueType)(LVT.SimpleTy - 1);
223 assert(LVT.isInteger());
224
225 // If the type we've chosen is larger than the largest legal integer type
226 // then use that instead.
227 if (VT.bitsGT(LVT))
228 VT = LVT;
229 }
230
231 unsigned NumMemOps = 0;
232 uint64_t Size = Op.size();
233 while (Size) {
234 unsigned VTSize = VT.getSizeInBits() / 8;
235 while (VTSize > Size) {
236 // For now, only use non-vector load / store's for the left-over pieces.
237 EVT NewVT = VT;
238 unsigned NewVTSize;
239
240 bool Found = false;
241 if (VT.isVector() || VT.isFloatingPoint()) {
242 NewVT = (VT.getSizeInBits() > 64) ? MVT::i64 : MVT::i32;
245 Found = true;
246 else if (NewVT == MVT::i64 &&
248 isSafeMemOpType(MVT::f64)) {
249 // i64 is usually not legal on 32-bit targets, but f64 may be.
250 NewVT = MVT::f64;
251 Found = true;
252 }
253 }
254
255 if (!Found) {
256 do {
257 NewVT = (MVT::SimpleValueType)(NewVT.getSimpleVT().SimpleTy - 1);
258 if (NewVT == MVT::i8)
259 break;
260 } while (!isSafeMemOpType(NewVT.getSimpleVT()));
261 }
262 NewVTSize = NewVT.getSizeInBits() / 8;
263
264 // If the new VT cannot cover all of the remaining bits, then consider
265 // issuing a (or a pair of) unaligned and overlapping load / store.
266 unsigned Fast;
267 if (NumMemOps && Op.allowOverlap() && NewVTSize < Size &&
269 VT, DstAS, Op.isFixedDstAlign() ? Op.getDstAlign() : Align(1),
271 Fast)
272 VTSize = Size;
273 else {
274 VT = NewVT;
275 VTSize = NewVTSize;
276 }
277 }
278
279 if (++NumMemOps > Limit)
280 return false;
281
282 MemOps.push_back(VT);
283 Size -= VTSize;
284 }
285
286 return true;
287}
288
289/// Soften the operands of a comparison. This code is shared among BR_CC,
290/// SELECT_CC, and SETCC handlers.
292 SDValue &NewLHS, SDValue &NewRHS,
293 ISD::CondCode &CCCode,
294 const SDLoc &dl, const SDValue OldLHS,
295 const SDValue OldRHS) const {
296 SDValue Chain;
297 return softenSetCCOperands(DAG, VT, NewLHS, NewRHS, CCCode, dl, OldLHS,
298 OldRHS, Chain);
299}
300
302 SDValue &NewLHS, SDValue &NewRHS,
303 ISD::CondCode &CCCode,
304 const SDLoc &dl, const SDValue OldLHS,
305 const SDValue OldRHS,
306 SDValue &Chain,
307 bool IsSignaling) const {
308 // FIXME: Currently we cannot really respect all IEEE predicates due to libgcc
309 // not supporting it. We can update this code when libgcc provides such
310 // functions.
311
312 assert((VT == MVT::f32 || VT == MVT::f64 || VT == MVT::f128 || VT == MVT::ppcf128)
313 && "Unsupported setcc type!");
314
315 // Expand into one or more soft-fp libcall(s).
316 RTLIB::Libcall LC1 = RTLIB::UNKNOWN_LIBCALL, LC2 = RTLIB::UNKNOWN_LIBCALL;
317 bool ShouldInvertCC = false;
318 switch (CCCode) {
319 case ISD::SETEQ:
320 case ISD::SETOEQ:
321 LC1 = (VT == MVT::f32) ? RTLIB::OEQ_F32 :
322 (VT == MVT::f64) ? RTLIB::OEQ_F64 :
323 (VT == MVT::f128) ? RTLIB::OEQ_F128 : RTLIB::OEQ_PPCF128;
324 break;
325 case ISD::SETNE:
326 case ISD::SETUNE:
327 LC1 = (VT == MVT::f32) ? RTLIB::UNE_F32 :
328 (VT == MVT::f64) ? RTLIB::UNE_F64 :
329 (VT == MVT::f128) ? RTLIB::UNE_F128 : RTLIB::UNE_PPCF128;
330 break;
331 case ISD::SETGE:
332 case ISD::SETOGE:
333 LC1 = (VT == MVT::f32) ? RTLIB::OGE_F32 :
334 (VT == MVT::f64) ? RTLIB::OGE_F64 :
335 (VT == MVT::f128) ? RTLIB::OGE_F128 : RTLIB::OGE_PPCF128;
336 break;
337 case ISD::SETLT:
338 case ISD::SETOLT:
339 LC1 = (VT == MVT::f32) ? RTLIB::OLT_F32 :
340 (VT == MVT::f64) ? RTLIB::OLT_F64 :
341 (VT == MVT::f128) ? RTLIB::OLT_F128 : RTLIB::OLT_PPCF128;
342 break;
343 case ISD::SETLE:
344 case ISD::SETOLE:
345 LC1 = (VT == MVT::f32) ? RTLIB::OLE_F32 :
346 (VT == MVT::f64) ? RTLIB::OLE_F64 :
347 (VT == MVT::f128) ? RTLIB::OLE_F128 : RTLIB::OLE_PPCF128;
348 break;
349 case ISD::SETGT:
350 case ISD::SETOGT:
351 LC1 = (VT == MVT::f32) ? RTLIB::OGT_F32 :
352 (VT == MVT::f64) ? RTLIB::OGT_F64 :
353 (VT == MVT::f128) ? RTLIB::OGT_F128 : RTLIB::OGT_PPCF128;
354 break;
355 case ISD::SETO:
356 ShouldInvertCC = true;
357 [[fallthrough]];
358 case ISD::SETUO:
359 LC1 = (VT == MVT::f32) ? RTLIB::UO_F32 :
360 (VT == MVT::f64) ? RTLIB::UO_F64 :
361 (VT == MVT::f128) ? RTLIB::UO_F128 : RTLIB::UO_PPCF128;
362 break;
363 case ISD::SETONE:
364 // SETONE = O && UNE
365 ShouldInvertCC = true;
366 [[fallthrough]];
367 case ISD::SETUEQ:
368 LC1 = (VT == MVT::f32) ? RTLIB::UO_F32 :
369 (VT == MVT::f64) ? RTLIB::UO_F64 :
370 (VT == MVT::f128) ? RTLIB::UO_F128 : RTLIB::UO_PPCF128;
371 LC2 = (VT == MVT::f32) ? RTLIB::OEQ_F32 :
372 (VT == MVT::f64) ? RTLIB::OEQ_F64 :
373 (VT == MVT::f128) ? RTLIB::OEQ_F128 : RTLIB::OEQ_PPCF128;
374 break;
375 default:
376 // Invert CC for unordered comparisons
377 ShouldInvertCC = true;
378 switch (CCCode) {
379 case ISD::SETULT:
380 LC1 = (VT == MVT::f32) ? RTLIB::OGE_F32 :
381 (VT == MVT::f64) ? RTLIB::OGE_F64 :
382 (VT == MVT::f128) ? RTLIB::OGE_F128 : RTLIB::OGE_PPCF128;
383 break;
384 case ISD::SETULE:
385 LC1 = (VT == MVT::f32) ? RTLIB::OGT_F32 :
386 (VT == MVT::f64) ? RTLIB::OGT_F64 :
387 (VT == MVT::f128) ? RTLIB::OGT_F128 : RTLIB::OGT_PPCF128;
388 break;
389 case ISD::SETUGT:
390 LC1 = (VT == MVT::f32) ? RTLIB::OLE_F32 :
391 (VT == MVT::f64) ? RTLIB::OLE_F64 :
392 (VT == MVT::f128) ? RTLIB::OLE_F128 : RTLIB::OLE_PPCF128;
393 break;
394 case ISD::SETUGE:
395 LC1 = (VT == MVT::f32) ? RTLIB::OLT_F32 :
396 (VT == MVT::f64) ? RTLIB::OLT_F64 :
397 (VT == MVT::f128) ? RTLIB::OLT_F128 : RTLIB::OLT_PPCF128;
398 break;
399 default: llvm_unreachable("Do not know how to soften this setcc!");
400 }
401 }
402
403 // Use the target specific return value for comparison lib calls.
405 SDValue Ops[2] = {NewLHS, NewRHS};
407 EVT OpsVT[2] = { OldLHS.getValueType(),
408 OldRHS.getValueType() };
409 CallOptions.setTypeListBeforeSoften(OpsVT, RetVT, true);
410 auto Call = makeLibCall(DAG, LC1, RetVT, Ops, CallOptions, dl, Chain);
411 NewLHS = Call.first;
412 NewRHS = DAG.getConstant(0, dl, RetVT);
413
414 CCCode = getCmpLibcallCC(LC1);
415 if (ShouldInvertCC) {
416 assert(RetVT.isInteger());
417 CCCode = getSetCCInverse(CCCode, RetVT);
418 }
419
420 if (LC2 == RTLIB::UNKNOWN_LIBCALL) {
421 // Update Chain.
422 Chain = Call.second;
423 } else {
424 EVT SetCCVT =
425 getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), RetVT);
426 SDValue Tmp = DAG.getSetCC(dl, SetCCVT, NewLHS, NewRHS, CCCode);
427 auto Call2 = makeLibCall(DAG, LC2, RetVT, Ops, CallOptions, dl, Chain);
428 CCCode = getCmpLibcallCC(LC2);
429 if (ShouldInvertCC)
430 CCCode = getSetCCInverse(CCCode, RetVT);
431 NewLHS = DAG.getSetCC(dl, SetCCVT, Call2.first, NewRHS, CCCode);
432 if (Chain)
433 Chain = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, Call.second,
434 Call2.second);
435 NewLHS = DAG.getNode(ShouldInvertCC ? ISD::AND : ISD::OR, dl,
436 Tmp.getValueType(), Tmp, NewLHS);
437 NewRHS = SDValue();
438 }
439}
440
441/// Return the entry encoding for a jump table in the current function. The
442/// returned value is a member of the MachineJumpTableInfo::JTEntryKind enum.
444 // In non-pic modes, just use the address of a block.
445 if (!isPositionIndependent())
447
448 // In PIC mode, if the target supports a GPRel32 directive, use it.
449 if (getTargetMachine().getMCAsmInfo()->getGPRel32Directive() != nullptr)
451
452 // Otherwise, use a label difference.
454}
455
457 SelectionDAG &DAG) const {
458 // If our PIC model is GP relative, use the global offset table as the base.
459 unsigned JTEncoding = getJumpTableEncoding();
460
464
465 return Table;
466}
467
468/// This returns the relocation base for the given PIC jumptable, the same as
469/// getPICJumpTableRelocBase, but as an MCExpr.
470const MCExpr *
472 unsigned JTI,MCContext &Ctx) const{
473 // The normal PIC reloc base is the label at the start of the jump table.
474 return MCSymbolRefExpr::create(MF->getJTISymbol(JTI, Ctx), Ctx);
475}
476
478 SDValue Addr, int JTI,
479 SelectionDAG &DAG) const {
480 SDValue Chain = Value;
481 // Jump table debug info is only needed if CodeView is enabled.
483 Chain = DAG.getJumpTableDebugInfo(JTI, Chain, dl);
484 }
485 return DAG.getNode(ISD::BRIND, dl, MVT::Other, Chain, Addr);
486}
487
488bool
490 const TargetMachine &TM = getTargetMachine();
491 const GlobalValue *GV = GA->getGlobal();
492
493 // If the address is not even local to this DSO we will have to load it from
494 // a got and then add the offset.
495 if (!TM.shouldAssumeDSOLocal(GV))
496 return false;
497
498 // If the code is position independent we will have to add a base register.
499 if (isPositionIndependent())
500 return false;
501
502 // Otherwise we can do it.
503 return true;
504}
505
506//===----------------------------------------------------------------------===//
507// Optimization Methods
508//===----------------------------------------------------------------------===//
509
510/// If the specified instruction has a constant integer operand and there are
511/// bits set in that constant that are not demanded, then clear those bits and
512/// return true.
514 const APInt &DemandedBits,
515 const APInt &DemandedElts,
516 TargetLoweringOpt &TLO) const {
517 SDLoc DL(Op);
518 unsigned Opcode = Op.getOpcode();
519
520 // Early-out if we've ended up calling an undemanded node, leave this to
521 // constant folding.
522 if (DemandedBits.isZero() || DemandedElts.isZero())
523 return false;
524
525 // Do target-specific constant optimization.
526 if (targetShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO))
527 return TLO.New.getNode();
528
529 // FIXME: ISD::SELECT, ISD::SELECT_CC
530 switch (Opcode) {
531 default:
532 break;
533 case ISD::XOR:
534 case ISD::AND:
535 case ISD::OR: {
536 auto *Op1C = dyn_cast<ConstantSDNode>(Op.getOperand(1));
537 if (!Op1C || Op1C->isOpaque())
538 return false;
539
540 // If this is a 'not' op, don't touch it because that's a canonical form.
541 const APInt &C = Op1C->getAPIntValue();
542 if (Opcode == ISD::XOR && DemandedBits.isSubsetOf(C))
543 return false;
544
545 if (!C.isSubsetOf(DemandedBits)) {
546 EVT VT = Op.getValueType();
547 SDValue NewC = TLO.DAG.getConstant(DemandedBits & C, DL, VT);
548 SDValue NewOp = TLO.DAG.getNode(Opcode, DL, VT, Op.getOperand(0), NewC,
549 Op->getFlags());
550 return TLO.CombineTo(Op, NewOp);
551 }
552
553 break;
554 }
555 }
556
557 return false;
558}
559
561 const APInt &DemandedBits,
562 TargetLoweringOpt &TLO) const {
563 EVT VT = Op.getValueType();
564 APInt DemandedElts = VT.isVector()
566 : APInt(1, 1);
567 return ShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO);
568}
569
570/// Convert x+y to (VT)((SmallVT)x+(SmallVT)y) if the casts are free.
571/// This uses isTruncateFree/isZExtFree and ANY_EXTEND for the widening cast,
572/// but it could be generalized for targets with other types of implicit
573/// widening casts.
575 const APInt &DemandedBits,
576 TargetLoweringOpt &TLO) const {
577 assert(Op.getNumOperands() == 2 &&
578 "ShrinkDemandedOp only supports binary operators!");
579 assert(Op.getNode()->getNumValues() == 1 &&
580 "ShrinkDemandedOp only supports nodes with one result!");
581
582 EVT VT = Op.getValueType();
583 SelectionDAG &DAG = TLO.DAG;
584 SDLoc dl(Op);
585
586 // Early return, as this function cannot handle vector types.
587 if (VT.isVector())
588 return false;
589
590 assert(Op.getOperand(0).getValueType().getScalarSizeInBits() == BitWidth &&
591 Op.getOperand(1).getValueType().getScalarSizeInBits() == BitWidth &&
592 "ShrinkDemandedOp only supports operands that have the same size!");
593
594 // Don't do this if the node has another user, which may require the
595 // full value.
596 if (!Op.getNode()->hasOneUse())
597 return false;
598
599 // Search for the smallest integer type with free casts to and from
600 // Op's type. For expedience, just check power-of-2 integer types.
601 unsigned DemandedSize = DemandedBits.getActiveBits();
602 for (unsigned SmallVTBits = llvm::bit_ceil(DemandedSize);
603 SmallVTBits < BitWidth; SmallVTBits = NextPowerOf2(SmallVTBits)) {
604 EVT SmallVT = EVT::getIntegerVT(*DAG.getContext(), SmallVTBits);
605 if (isTruncateFree(VT, SmallVT) && isZExtFree(SmallVT, VT)) {
606 // We found a type with free casts.
607
608 // If the operation has the 'disjoint' flag, then the
609 // operands on the new node are also disjoint.
610 SDNodeFlags Flags(Op->getFlags().hasDisjoint() ? SDNodeFlags::Disjoint
612 SDValue X = DAG.getNode(
613 Op.getOpcode(), dl, SmallVT,
614 DAG.getNode(ISD::TRUNCATE, dl, SmallVT, Op.getOperand(0)),
615 DAG.getNode(ISD::TRUNCATE, dl, SmallVT, Op.getOperand(1)), Flags);
616 assert(DemandedSize <= SmallVTBits && "Narrowed below demanded bits?");
617 SDValue Z = DAG.getNode(ISD::ANY_EXTEND, dl, VT, X);
618 return TLO.CombineTo(Op, Z);
619 }
620 }
621 return false;
622}
623
625 DAGCombinerInfo &DCI) const {
626 SelectionDAG &DAG = DCI.DAG;
627 TargetLoweringOpt TLO(DAG, !DCI.isBeforeLegalize(),
628 !DCI.isBeforeLegalizeOps());
629 KnownBits Known;
630
631 bool Simplified = SimplifyDemandedBits(Op, DemandedBits, Known, TLO);
632 if (Simplified) {
633 DCI.AddToWorklist(Op.getNode());
635 }
636 return Simplified;
637}
638
640 const APInt &DemandedElts,
641 DAGCombinerInfo &DCI) const {
642 SelectionDAG &DAG = DCI.DAG;
643 TargetLoweringOpt TLO(DAG, !DCI.isBeforeLegalize(),
644 !DCI.isBeforeLegalizeOps());
645 KnownBits Known;
646
647 bool Simplified =
648 SimplifyDemandedBits(Op, DemandedBits, DemandedElts, Known, TLO);
649 if (Simplified) {
650 DCI.AddToWorklist(Op.getNode());
652 }
653 return Simplified;
654}
655
657 KnownBits &Known,
659 unsigned Depth,
660 bool AssumeSingleUse) const {
661 EVT VT = Op.getValueType();
662
663 // Since the number of lanes in a scalable vector is unknown at compile time,
664 // we track one bit which is implicitly broadcast to all lanes. This means
665 // that all lanes in a scalable vector are considered demanded.
666 APInt DemandedElts = VT.isFixedLengthVector()
668 : APInt(1, 1);
669 return SimplifyDemandedBits(Op, DemandedBits, DemandedElts, Known, TLO, Depth,
670 AssumeSingleUse);
671}
672
673// TODO: Under what circumstances can we create nodes? Constant folding?
675 SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
676 SelectionDAG &DAG, unsigned Depth) const {
677 EVT VT = Op.getValueType();
678
679 // Limit search depth.
681 return SDValue();
682
683 // Ignore UNDEFs.
684 if (Op.isUndef())
685 return SDValue();
686
687 // Not demanding any bits/elts from Op.
688 if (DemandedBits == 0 || DemandedElts == 0)
689 return DAG.getUNDEF(VT);
690
691 bool IsLE = DAG.getDataLayout().isLittleEndian();
692 unsigned NumElts = DemandedElts.getBitWidth();
693 unsigned BitWidth = DemandedBits.getBitWidth();
694 KnownBits LHSKnown, RHSKnown;
695 switch (Op.getOpcode()) {
696 case ISD::BITCAST: {
697 if (VT.isScalableVector())
698 return SDValue();
699
700 SDValue Src = peekThroughBitcasts(Op.getOperand(0));
701 EVT SrcVT = Src.getValueType();
702 EVT DstVT = Op.getValueType();
703 if (SrcVT == DstVT)
704 return Src;
705
706 unsigned NumSrcEltBits = SrcVT.getScalarSizeInBits();
707 unsigned NumDstEltBits = DstVT.getScalarSizeInBits();
708 if (NumSrcEltBits == NumDstEltBits)
709 if (SDValue V = SimplifyMultipleUseDemandedBits(
710 Src, DemandedBits, DemandedElts, DAG, Depth + 1))
711 return DAG.getBitcast(DstVT, V);
712
713 if (SrcVT.isVector() && (NumDstEltBits % NumSrcEltBits) == 0) {
714 unsigned Scale = NumDstEltBits / NumSrcEltBits;
715 unsigned NumSrcElts = SrcVT.getVectorNumElements();
716 APInt DemandedSrcBits = APInt::getZero(NumSrcEltBits);
717 APInt DemandedSrcElts = APInt::getZero(NumSrcElts);
718 for (unsigned i = 0; i != Scale; ++i) {
719 unsigned EltOffset = IsLE ? i : (Scale - 1 - i);
720 unsigned BitOffset = EltOffset * NumSrcEltBits;
721 APInt Sub = DemandedBits.extractBits(NumSrcEltBits, BitOffset);
722 if (!Sub.isZero()) {
723 DemandedSrcBits |= Sub;
724 for (unsigned j = 0; j != NumElts; ++j)
725 if (DemandedElts[j])
726 DemandedSrcElts.setBit((j * Scale) + i);
727 }
728 }
729
730 if (SDValue V = SimplifyMultipleUseDemandedBits(
731 Src, DemandedSrcBits, DemandedSrcElts, DAG, Depth + 1))
732 return DAG.getBitcast(DstVT, V);
733 }
734
735 // TODO - bigendian once we have test coverage.
736 if (IsLE && (NumSrcEltBits % NumDstEltBits) == 0) {
737 unsigned Scale = NumSrcEltBits / NumDstEltBits;
738 unsigned NumSrcElts = SrcVT.isVector() ? SrcVT.getVectorNumElements() : 1;
739 APInt DemandedSrcBits = APInt::getZero(NumSrcEltBits);
740 APInt DemandedSrcElts = APInt::getZero(NumSrcElts);
741 for (unsigned i = 0; i != NumElts; ++i)
742 if (DemandedElts[i]) {
743 unsigned Offset = (i % Scale) * NumDstEltBits;
744 DemandedSrcBits.insertBits(DemandedBits, Offset);
745 DemandedSrcElts.setBit(i / Scale);
746 }
747
748 if (SDValue V = SimplifyMultipleUseDemandedBits(
749 Src, DemandedSrcBits, DemandedSrcElts, DAG, Depth + 1))
750 return DAG.getBitcast(DstVT, V);
751 }
752
753 break;
754 }
755 case ISD::FREEZE: {
756 SDValue N0 = Op.getOperand(0);
757 if (DAG.isGuaranteedNotToBeUndefOrPoison(N0, DemandedElts,
758 /*PoisonOnly=*/false))
759 return N0;
760 break;
761 }
762 case ISD::AND: {
763 LHSKnown = DAG.computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
764 RHSKnown = DAG.computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
765
766 // If all of the demanded bits are known 1 on one side, return the other.
767 // These bits cannot contribute to the result of the 'and' in this
768 // context.
769 if (DemandedBits.isSubsetOf(LHSKnown.Zero | RHSKnown.One))
770 return Op.getOperand(0);
771 if (DemandedBits.isSubsetOf(RHSKnown.Zero | LHSKnown.One))
772 return Op.getOperand(1);
773 break;
774 }
775 case ISD::OR: {
776 LHSKnown = DAG.computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
777 RHSKnown = DAG.computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
778
779 // If all of the demanded bits are known zero on one side, return the
780 // other. These bits cannot contribute to the result of the 'or' in this
781 // context.
782 if (DemandedBits.isSubsetOf(LHSKnown.One | RHSKnown.Zero))
783 return Op.getOperand(0);
784 if (DemandedBits.isSubsetOf(RHSKnown.One | LHSKnown.Zero))
785 return Op.getOperand(1);
786 break;
787 }
788 case ISD::XOR: {
789 LHSKnown = DAG.computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
790 RHSKnown = DAG.computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
791
792 // If all of the demanded bits are known zero on one side, return the
793 // other.
794 if (DemandedBits.isSubsetOf(RHSKnown.Zero))
795 return Op.getOperand(0);
796 if (DemandedBits.isSubsetOf(LHSKnown.Zero))
797 return Op.getOperand(1);
798 break;
799 }
800 case ISD::ADD: {
801 RHSKnown = DAG.computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
802 if (RHSKnown.isZero())
803 return Op.getOperand(0);
804
805 LHSKnown = DAG.computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
806 if (LHSKnown.isZero())
807 return Op.getOperand(1);
808 break;
809 }
810 case ISD::SHL: {
811 // If we are only demanding sign bits then we can use the shift source
812 // directly.
813 if (std::optional<uint64_t> MaxSA =
814 DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) {
815 SDValue Op0 = Op.getOperand(0);
816 unsigned ShAmt = *MaxSA;
817 unsigned NumSignBits =
818 DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
819 unsigned UpperDemandedBits = BitWidth - DemandedBits.countr_zero();
820 if (NumSignBits > ShAmt && (NumSignBits - ShAmt) >= (UpperDemandedBits))
821 return Op0;
822 }
823 break;
824 }
825 case ISD::SRL: {
826 // If we are only demanding sign bits then we can use the shift source
827 // directly.
828 if (std::optional<uint64_t> MaxSA =
829 DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) {
830 SDValue Op0 = Op.getOperand(0);
831 unsigned ShAmt = *MaxSA;
832 // Must already be signbits in DemandedBits bounds, and can't demand any
833 // shifted in zeroes.
834 if (DemandedBits.countl_zero() >= ShAmt) {
835 unsigned NumSignBits =
836 DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
837 if (DemandedBits.countr_zero() >= (BitWidth - NumSignBits))
838 return Op0;
839 }
840 }
841 break;
842 }
843 case ISD::SETCC: {
844 SDValue Op0 = Op.getOperand(0);
845 SDValue Op1 = Op.getOperand(1);
846 ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(2))->get();
847 // If (1) we only need the sign-bit, (2) the setcc operands are the same
848 // width as the setcc result, and (3) the result of a setcc conforms to 0 or
849 // -1, we may be able to bypass the setcc.
850 if (DemandedBits.isSignMask() &&
854 // If we're testing X < 0, then this compare isn't needed - just use X!
855 // FIXME: We're limiting to integer types here, but this should also work
856 // if we don't care about FP signed-zero. The use of SETLT with FP means
857 // that we don't care about NaNs.
858 if (CC == ISD::SETLT && Op1.getValueType().isInteger() &&
860 return Op0;
861 }
862 break;
863 }
865 // If none of the extended bits are demanded, eliminate the sextinreg.
866 SDValue Op0 = Op.getOperand(0);
867 EVT ExVT = cast<VTSDNode>(Op.getOperand(1))->getVT();
868 unsigned ExBits = ExVT.getScalarSizeInBits();
869 if (DemandedBits.getActiveBits() <= ExBits &&
871 return Op0;
872 // If the input is already sign extended, just drop the extension.
873 unsigned NumSignBits = DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
874 if (NumSignBits >= (BitWidth - ExBits + 1))
875 return Op0;
876 break;
877 }
881 if (VT.isScalableVector())
882 return SDValue();
883
884 // If we only want the lowest element and none of extended bits, then we can
885 // return the bitcasted source vector.
886 SDValue Src = Op.getOperand(0);
887 EVT SrcVT = Src.getValueType();
888 EVT DstVT = Op.getValueType();
889 if (IsLE && DemandedElts == 1 &&
890 DstVT.getSizeInBits() == SrcVT.getSizeInBits() &&
891 DemandedBits.getActiveBits() <= SrcVT.getScalarSizeInBits()) {
892 return DAG.getBitcast(DstVT, Src);
893 }
894 break;
895 }
897 if (VT.isScalableVector())
898 return SDValue();
899
900 // If we don't demand the inserted element, return the base vector.
901 SDValue Vec = Op.getOperand(0);
902 auto *CIdx = dyn_cast<ConstantSDNode>(Op.getOperand(2));
903 EVT VecVT = Vec.getValueType();
904 if (CIdx && CIdx->getAPIntValue().ult(VecVT.getVectorNumElements()) &&
905 !DemandedElts[CIdx->getZExtValue()])
906 return Vec;
907 break;
908 }
910 if (VT.isScalableVector())
911 return SDValue();
912
913 SDValue Vec = Op.getOperand(0);
914 SDValue Sub = Op.getOperand(1);
915 uint64_t Idx = Op.getConstantOperandVal(2);
916 unsigned NumSubElts = Sub.getValueType().getVectorNumElements();
917 APInt DemandedSubElts = DemandedElts.extractBits(NumSubElts, Idx);
918 // If we don't demand the inserted subvector, return the base vector.
919 if (DemandedSubElts == 0)
920 return Vec;
921 break;
922 }
923 case ISD::VECTOR_SHUFFLE: {
925 ArrayRef<int> ShuffleMask = cast<ShuffleVectorSDNode>(Op)->getMask();
926
927 // If all the demanded elts are from one operand and are inline,
928 // then we can use the operand directly.
929 bool AllUndef = true, IdentityLHS = true, IdentityRHS = true;
930 for (unsigned i = 0; i != NumElts; ++i) {
931 int M = ShuffleMask[i];
932 if (M < 0 || !DemandedElts[i])
933 continue;
934 AllUndef = false;
935 IdentityLHS &= (M == (int)i);
936 IdentityRHS &= ((M - NumElts) == i);
937 }
938
939 if (AllUndef)
940 return DAG.getUNDEF(Op.getValueType());
941 if (IdentityLHS)
942 return Op.getOperand(0);
943 if (IdentityRHS)
944 return Op.getOperand(1);
945 break;
946 }
947 default:
948 // TODO: Probably okay to remove after audit; here to reduce change size
949 // in initial enablement patch for scalable vectors
950 if (VT.isScalableVector())
951 return SDValue();
952
953 if (Op.getOpcode() >= ISD::BUILTIN_OP_END)
954 if (SDValue V = SimplifyMultipleUseDemandedBitsForTargetNode(
955 Op, DemandedBits, DemandedElts, DAG, Depth))
956 return V;
957 break;
958 }
959 return SDValue();
960}
961
964 unsigned Depth) const {
965 EVT VT = Op.getValueType();
966 // Since the number of lanes in a scalable vector is unknown at compile time,
967 // we track one bit which is implicitly broadcast to all lanes. This means
968 // that all lanes in a scalable vector are considered demanded.
969 APInt DemandedElts = VT.isFixedLengthVector()
971 : APInt(1, 1);
972 return SimplifyMultipleUseDemandedBits(Op, DemandedBits, DemandedElts, DAG,
973 Depth);
974}
975
977 SDValue Op, const APInt &DemandedElts, SelectionDAG &DAG,
978 unsigned Depth) const {
979 APInt DemandedBits = APInt::getAllOnes(Op.getScalarValueSizeInBits());
980 return SimplifyMultipleUseDemandedBits(Op, DemandedBits, DemandedElts, DAG,
981 Depth);
982}
983
984// Attempt to form ext(avgfloor(A, B)) from shr(add(ext(A), ext(B)), 1).
985// or to form ext(avgceil(A, B)) from shr(add(ext(A), ext(B), 1), 1).
988 const TargetLowering &TLI,
989 const APInt &DemandedBits,
990 const APInt &DemandedElts, unsigned Depth) {
991 assert((Op.getOpcode() == ISD::SRL || Op.getOpcode() == ISD::SRA) &&
992 "SRL or SRA node is required here!");
993 // Is the right shift using an immediate value of 1?
994 ConstantSDNode *N1C = isConstOrConstSplat(Op.getOperand(1), DemandedElts);
995 if (!N1C || !N1C->isOne())
996 return SDValue();
997
998 // We are looking for an avgfloor
999 // add(ext, ext)
1000 // or one of these as a avgceil
1001 // add(add(ext, ext), 1)
1002 // add(add(ext, 1), ext)
1003 // add(ext, add(ext, 1))
1004 SDValue Add = Op.getOperand(0);
1005 if (Add.getOpcode() != ISD::ADD)
1006 return SDValue();
1007
1008 SDValue ExtOpA = Add.getOperand(0);
1009 SDValue ExtOpB = Add.getOperand(1);
1010 SDValue Add2;
1011 auto MatchOperands = [&](SDValue Op1, SDValue Op2, SDValue Op3, SDValue A) {
1012 ConstantSDNode *ConstOp;
1013 if ((ConstOp = isConstOrConstSplat(Op2, DemandedElts)) &&
1014 ConstOp->isOne()) {
1015 ExtOpA = Op1;
1016 ExtOpB = Op3;
1017 Add2 = A;
1018 return true;
1019 }
1020 if ((ConstOp = isConstOrConstSplat(Op3, DemandedElts)) &&
1021 ConstOp->isOne()) {
1022 ExtOpA = Op1;
1023 ExtOpB = Op2;
1024 Add2 = A;
1025 return true;
1026 }
1027 return false;
1028 };
1029 bool IsCeil =
1030 (ExtOpA.getOpcode() == ISD::ADD &&
1031 MatchOperands(ExtOpA.getOperand(0), ExtOpA.getOperand(1), ExtOpB, ExtOpA)) ||
1032 (ExtOpB.getOpcode() == ISD::ADD &&
1033 MatchOperands(ExtOpB.getOperand(0), ExtOpB.getOperand(1), ExtOpA, ExtOpB));
1034
1035 // If the shift is signed (sra):
1036 // - Needs >= 2 sign bit for both operands.
1037 // - Needs >= 2 zero bits.
1038 // If the shift is unsigned (srl):
1039 // - Needs >= 1 zero bit for both operands.
1040 // - Needs 1 demanded bit zero and >= 2 sign bits.
1041 SelectionDAG &DAG = TLO.DAG;
1042 unsigned ShiftOpc = Op.getOpcode();
1043 bool IsSigned = false;
1044 unsigned KnownBits;
1045 unsigned NumSignedA = DAG.ComputeNumSignBits(ExtOpA, DemandedElts, Depth);
1046 unsigned NumSignedB = DAG.ComputeNumSignBits(ExtOpB, DemandedElts, Depth);
1047 unsigned NumSigned = std::min(NumSignedA, NumSignedB) - 1;
1048 unsigned NumZeroA =
1049 DAG.computeKnownBits(ExtOpA, DemandedElts, Depth).countMinLeadingZeros();
1050 unsigned NumZeroB =
1051 DAG.computeKnownBits(ExtOpB, DemandedElts, Depth).countMinLeadingZeros();
1052 unsigned NumZero = std::min(NumZeroA, NumZeroB);
1053
1054 switch (ShiftOpc) {
1055 default:
1056 llvm_unreachable("Unexpected ShiftOpc in combineShiftToAVG");
1057 case ISD::SRA: {
1058 if (NumZero >= 2 && NumSigned < NumZero) {
1059 IsSigned = false;
1060 KnownBits = NumZero;
1061 break;
1062 }
1063 if (NumSigned >= 1) {
1064 IsSigned = true;
1065 KnownBits = NumSigned;
1066 break;
1067 }
1068 return SDValue();
1069 }
1070 case ISD::SRL: {
1071 if (NumZero >= 1 && NumSigned < NumZero) {
1072 IsSigned = false;
1073 KnownBits = NumZero;
1074 break;
1075 }
1076 if (NumSigned >= 1 && DemandedBits.isSignBitClear()) {
1077 IsSigned = true;
1078 KnownBits = NumSigned;
1079 break;
1080 }
1081 return SDValue();
1082 }
1083 }
1084
1085 unsigned AVGOpc = IsCeil ? (IsSigned ? ISD::AVGCEILS : ISD::AVGCEILU)
1086 : (IsSigned ? ISD::AVGFLOORS : ISD::AVGFLOORU);
1087
1088 // Find the smallest power-2 type that is legal for this vector size and
1089 // operation, given the original type size and the number of known sign/zero
1090 // bits.
1091 EVT VT = Op.getValueType();
1092 unsigned MinWidth =
1093 std::max<unsigned>(VT.getScalarSizeInBits() - KnownBits, 8);
1094 EVT NVT = EVT::getIntegerVT(*DAG.getContext(), llvm::bit_ceil(MinWidth));
1096 return SDValue();
1097 if (VT.isVector())
1098 NVT = EVT::getVectorVT(*DAG.getContext(), NVT, VT.getVectorElementCount());
1099 if (TLO.LegalTypes() && !TLI.isOperationLegal(AVGOpc, NVT)) {
1100 // If we could not transform, and (both) adds are nuw/nsw, we can use the
1101 // larger type size to do the transform.
1102 if (TLO.LegalOperations() && !TLI.isOperationLegal(AVGOpc, VT))
1103 return SDValue();
1104 if (DAG.willNotOverflowAdd(IsSigned, Add.getOperand(0),
1105 Add.getOperand(1)) &&
1106 (!Add2 || DAG.willNotOverflowAdd(IsSigned, Add2.getOperand(0),
1107 Add2.getOperand(1))))
1108 NVT = VT;
1109 else
1110 return SDValue();
1111 }
1112
1113 // Don't create a AVGFLOOR node with a scalar constant unless its legal as
1114 // this is likely to stop other folds (reassociation, value tracking etc.)
1115 if (!IsCeil && !TLI.isOperationLegal(AVGOpc, NVT) &&
1116 (isa<ConstantSDNode>(ExtOpA) || isa<ConstantSDNode>(ExtOpB)))
1117 return SDValue();
1118
1119 SDLoc DL(Op);
1120 SDValue ResultAVG =
1121 DAG.getNode(AVGOpc, DL, NVT, DAG.getExtOrTrunc(IsSigned, ExtOpA, DL, NVT),
1122 DAG.getExtOrTrunc(IsSigned, ExtOpB, DL, NVT));
1123 return DAG.getExtOrTrunc(IsSigned, ResultAVG, DL, VT);
1124}
1125
1126/// Look at Op. At this point, we know that only the OriginalDemandedBits of the
1127/// result of Op are ever used downstream. If we can use this information to
1128/// simplify Op, create a new simplified DAG node and return true, returning the
1129/// original and new nodes in Old and New. Otherwise, analyze the expression and
1130/// return a mask of Known bits for the expression (used to simplify the
1131/// caller). The Known bits may only be accurate for those bits in the
1132/// OriginalDemandedBits and OriginalDemandedElts.
1134 SDValue Op, const APInt &OriginalDemandedBits,
1135 const APInt &OriginalDemandedElts, KnownBits &Known, TargetLoweringOpt &TLO,
1136 unsigned Depth, bool AssumeSingleUse) const {
1137 unsigned BitWidth = OriginalDemandedBits.getBitWidth();
1138 assert(Op.getScalarValueSizeInBits() == BitWidth &&
1139 "Mask size mismatches value type size!");
1140
1141 // Don't know anything.
1142 Known = KnownBits(BitWidth);
1143
1144 EVT VT = Op.getValueType();
1145 bool IsLE = TLO.DAG.getDataLayout().isLittleEndian();
1146 unsigned NumElts = OriginalDemandedElts.getBitWidth();
1147 assert((!VT.isFixedLengthVector() || NumElts == VT.getVectorNumElements()) &&
1148 "Unexpected vector size");
1149
1150 APInt DemandedBits = OriginalDemandedBits;
1151 APInt DemandedElts = OriginalDemandedElts;
1152 SDLoc dl(Op);
1153
1154 // Undef operand.
1155 if (Op.isUndef())
1156 return false;
1157
1158 // We can't simplify target constants.
1159 if (Op.getOpcode() == ISD::TargetConstant)
1160 return false;
1161
1162 if (Op.getOpcode() == ISD::Constant) {
1163 // We know all of the bits for a constant!
1164 Known = KnownBits::makeConstant(Op->getAsAPIntVal());
1165 return false;
1166 }
1167
1168 if (Op.getOpcode() == ISD::ConstantFP) {
1169 // We know all of the bits for a floating point constant!
1171 cast<ConstantFPSDNode>(Op)->getValueAPF().bitcastToAPInt());
1172 return false;
1173 }
1174
1175 // Other users may use these bits.
1176 bool HasMultiUse = false;
1177 if (!AssumeSingleUse && !Op.getNode()->hasOneUse()) {
1179 // Limit search depth.
1180 return false;
1181 }
1182 // Allow multiple uses, just set the DemandedBits/Elts to all bits.
1184 DemandedElts = APInt::getAllOnes(NumElts);
1185 HasMultiUse = true;
1186 } else if (OriginalDemandedBits == 0 || OriginalDemandedElts == 0) {
1187 // Not demanding any bits/elts from Op.
1188 return TLO.CombineTo(Op, TLO.DAG.getUNDEF(VT));
1189 } else if (Depth >= SelectionDAG::MaxRecursionDepth) {
1190 // Limit search depth.
1191 return false;
1192 }
1193
1194 KnownBits Known2;
1195 switch (Op.getOpcode()) {
1196 case ISD::SCALAR_TO_VECTOR: {
1197 if (VT.isScalableVector())
1198 return false;
1199 if (!DemandedElts[0])
1200 return TLO.CombineTo(Op, TLO.DAG.getUNDEF(VT));
1201
1202 KnownBits SrcKnown;
1203 SDValue Src = Op.getOperand(0);
1204 unsigned SrcBitWidth = Src.getScalarValueSizeInBits();
1205 APInt SrcDemandedBits = DemandedBits.zext(SrcBitWidth);
1206 if (SimplifyDemandedBits(Src, SrcDemandedBits, SrcKnown, TLO, Depth + 1))
1207 return true;
1208
1209 // Upper elements are undef, so only get the knownbits if we just demand
1210 // the bottom element.
1211 if (DemandedElts == 1)
1212 Known = SrcKnown.anyextOrTrunc(BitWidth);
1213 break;
1214 }
1215 case ISD::BUILD_VECTOR:
1216 // Collect the known bits that are shared by every demanded element.
1217 // TODO: Call SimplifyDemandedBits for non-constant demanded elements.
1218 Known = TLO.DAG.computeKnownBits(Op, DemandedElts, Depth);
1219 return false; // Don't fall through, will infinitely loop.
1220 case ISD::SPLAT_VECTOR: {
1221 SDValue Scl = Op.getOperand(0);
1222 APInt DemandedSclBits = DemandedBits.zextOrTrunc(Scl.getValueSizeInBits());
1223 KnownBits KnownScl;
1224 if (SimplifyDemandedBits(Scl, DemandedSclBits, KnownScl, TLO, Depth + 1))
1225 return true;
1226
1227 // Implicitly truncate the bits to match the official semantics of
1228 // SPLAT_VECTOR.
1229 Known = KnownScl.trunc(BitWidth);
1230 break;
1231 }
1232 case ISD::LOAD: {
1233 auto *LD = cast<LoadSDNode>(Op);
1234 if (getTargetConstantFromLoad(LD)) {
1235 Known = TLO.DAG.computeKnownBits(Op, DemandedElts, Depth);
1236 return false; // Don't fall through, will infinitely loop.
1237 }
1238 if (ISD::isZEXTLoad(Op.getNode()) && Op.getResNo() == 0) {
1239 // If this is a ZEXTLoad and we are looking at the loaded value.
1240 EVT MemVT = LD->getMemoryVT();
1241 unsigned MemBits = MemVT.getScalarSizeInBits();
1242 Known.Zero.setBitsFrom(MemBits);
1243 return false; // Don't fall through, will infinitely loop.
1244 }
1245 break;
1246 }
1248 if (VT.isScalableVector())
1249 return false;
1250 SDValue Vec = Op.getOperand(0);
1251 SDValue Scl = Op.getOperand(1);
1252 auto *CIdx = dyn_cast<ConstantSDNode>(Op.getOperand(2));
1253 EVT VecVT = Vec.getValueType();
1254
1255 // If index isn't constant, assume we need all vector elements AND the
1256 // inserted element.
1257 APInt DemandedVecElts(DemandedElts);
1258 if (CIdx && CIdx->getAPIntValue().ult(VecVT.getVectorNumElements())) {
1259 unsigned Idx = CIdx->getZExtValue();
1260 DemandedVecElts.clearBit(Idx);
1261
1262 // Inserted element is not required.
1263 if (!DemandedElts[Idx])
1264 return TLO.CombineTo(Op, Vec);
1265 }
1266
1267 KnownBits KnownScl;
1268 unsigned NumSclBits = Scl.getScalarValueSizeInBits();
1269 APInt DemandedSclBits = DemandedBits.zextOrTrunc(NumSclBits);
1270 if (SimplifyDemandedBits(Scl, DemandedSclBits, KnownScl, TLO, Depth + 1))
1271 return true;
1272
1273 Known = KnownScl.anyextOrTrunc(BitWidth);
1274
1275 KnownBits KnownVec;
1276 if (SimplifyDemandedBits(Vec, DemandedBits, DemandedVecElts, KnownVec, TLO,
1277 Depth + 1))
1278 return true;
1279
1280 if (!!DemandedVecElts)
1281 Known = Known.intersectWith(KnownVec);
1282
1283 return false;
1284 }
1285 case ISD::INSERT_SUBVECTOR: {
1286 if (VT.isScalableVector())
1287 return false;
1288 // Demand any elements from the subvector and the remainder from the src its
1289 // inserted into.
1290 SDValue Src = Op.getOperand(0);
1291 SDValue Sub = Op.getOperand(1);
1292 uint64_t Idx = Op.getConstantOperandVal(2);
1293 unsigned NumSubElts = Sub.getValueType().getVectorNumElements();
1294 APInt DemandedSubElts = DemandedElts.extractBits(NumSubElts, Idx);
1295 APInt DemandedSrcElts = DemandedElts;
1296 DemandedSrcElts.insertBits(APInt::getZero(NumSubElts), Idx);
1297
1298 KnownBits KnownSub, KnownSrc;
1299 if (SimplifyDemandedBits(Sub, DemandedBits, DemandedSubElts, KnownSub, TLO,
1300 Depth + 1))
1301 return true;
1302 if (SimplifyDemandedBits(Src, DemandedBits, DemandedSrcElts, KnownSrc, TLO,
1303 Depth + 1))
1304 return true;
1305
1306 Known.Zero.setAllBits();
1307 Known.One.setAllBits();
1308 if (!!DemandedSubElts)
1309 Known = Known.intersectWith(KnownSub);
1310 if (!!DemandedSrcElts)
1311 Known = Known.intersectWith(KnownSrc);
1312
1313 // Attempt to avoid multi-use src if we don't need anything from it.
1314 if (!DemandedBits.isAllOnes() || !DemandedSubElts.isAllOnes() ||
1315 !DemandedSrcElts.isAllOnes()) {
1316 SDValue NewSub = SimplifyMultipleUseDemandedBits(
1317 Sub, DemandedBits, DemandedSubElts, TLO.DAG, Depth + 1);
1318 SDValue NewSrc = SimplifyMultipleUseDemandedBits(
1319 Src, DemandedBits, DemandedSrcElts, TLO.DAG, Depth + 1);
1320 if (NewSub || NewSrc) {
1321 NewSub = NewSub ? NewSub : Sub;
1322 NewSrc = NewSrc ? NewSrc : Src;
1323 SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), dl, VT, NewSrc, NewSub,
1324 Op.getOperand(2));
1325 return TLO.CombineTo(Op, NewOp);
1326 }
1327 }
1328 break;
1329 }
1331 if (VT.isScalableVector())
1332 return false;
1333 // Offset the demanded elts by the subvector index.
1334 SDValue Src = Op.getOperand(0);
1335 if (Src.getValueType().isScalableVector())
1336 break;
1337 uint64_t Idx = Op.getConstantOperandVal(1);
1338 unsigned NumSrcElts = Src.getValueType().getVectorNumElements();
1339 APInt DemandedSrcElts = DemandedElts.zext(NumSrcElts).shl(Idx);
1340
1341 if (SimplifyDemandedBits(Src, DemandedBits, DemandedSrcElts, Known, TLO,
1342 Depth + 1))
1343 return true;
1344
1345 // Attempt to avoid multi-use src if we don't need anything from it.
1346 if (!DemandedBits.isAllOnes() || !DemandedSrcElts.isAllOnes()) {
1347 SDValue DemandedSrc = SimplifyMultipleUseDemandedBits(
1348 Src, DemandedBits, DemandedSrcElts, TLO.DAG, Depth + 1);
1349 if (DemandedSrc) {
1350 SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), dl, VT, DemandedSrc,
1351 Op.getOperand(1));
1352 return TLO.CombineTo(Op, NewOp);
1353 }
1354 }
1355 break;
1356 }
1357 case ISD::CONCAT_VECTORS: {
1358 if (VT.isScalableVector())
1359 return false;
1360 Known.Zero.setAllBits();
1361 Known.One.setAllBits();
1362 EVT SubVT = Op.getOperand(0).getValueType();
1363 unsigned NumSubVecs = Op.getNumOperands();
1364 unsigned NumSubElts = SubVT.getVectorNumElements();
1365 for (unsigned i = 0; i != NumSubVecs; ++i) {
1366 APInt DemandedSubElts =
1367 DemandedElts.extractBits(NumSubElts, i * NumSubElts);
1368 if (SimplifyDemandedBits(Op.getOperand(i), DemandedBits, DemandedSubElts,
1369 Known2, TLO, Depth + 1))
1370 return true;
1371 // Known bits are shared by every demanded subvector element.
1372 if (!!DemandedSubElts)
1373 Known = Known.intersectWith(Known2);
1374 }
1375 break;
1376 }
1377 case ISD::VECTOR_SHUFFLE: {
1378 assert(!VT.isScalableVector());
1379 ArrayRef<int> ShuffleMask = cast<ShuffleVectorSDNode>(Op)->getMask();
1380
1381 // Collect demanded elements from shuffle operands..
1382 APInt DemandedLHS, DemandedRHS;
1383 if (!getShuffleDemandedElts(NumElts, ShuffleMask, DemandedElts, DemandedLHS,
1384 DemandedRHS))
1385 break;
1386
1387 if (!!DemandedLHS || !!DemandedRHS) {
1388 SDValue Op0 = Op.getOperand(0);
1389 SDValue Op1 = Op.getOperand(1);
1390
1391 Known.Zero.setAllBits();
1392 Known.One.setAllBits();
1393 if (!!DemandedLHS) {
1394 if (SimplifyDemandedBits(Op0, DemandedBits, DemandedLHS, Known2, TLO,
1395 Depth + 1))
1396 return true;
1397 Known = Known.intersectWith(Known2);
1398 }
1399 if (!!DemandedRHS) {
1400 if (SimplifyDemandedBits(Op1, DemandedBits, DemandedRHS, Known2, TLO,
1401 Depth + 1))
1402 return true;
1403 Known = Known.intersectWith(Known2);
1404 }
1405
1406 // Attempt to avoid multi-use ops if we don't need anything from them.
1407 SDValue DemandedOp0 = SimplifyMultipleUseDemandedBits(
1408 Op0, DemandedBits, DemandedLHS, TLO.DAG, Depth + 1);
1409 SDValue DemandedOp1 = SimplifyMultipleUseDemandedBits(
1410 Op1, DemandedBits, DemandedRHS, TLO.DAG, Depth + 1);
1411 if (DemandedOp0 || DemandedOp1) {
1412 Op0 = DemandedOp0 ? DemandedOp0 : Op0;
1413 Op1 = DemandedOp1 ? DemandedOp1 : Op1;
1414 SDValue NewOp = TLO.DAG.getVectorShuffle(VT, dl, Op0, Op1, ShuffleMask);
1415 return TLO.CombineTo(Op, NewOp);
1416 }
1417 }
1418 break;
1419 }
1420 case ISD::AND: {
1421 SDValue Op0 = Op.getOperand(0);
1422 SDValue Op1 = Op.getOperand(1);
1423
1424 // If the RHS is a constant, check to see if the LHS would be zero without
1425 // using the bits from the RHS. Below, we use knowledge about the RHS to
1426 // simplify the LHS, here we're using information from the LHS to simplify
1427 // the RHS.
1428 if (ConstantSDNode *RHSC = isConstOrConstSplat(Op1, DemandedElts)) {
1429 // Do not increment Depth here; that can cause an infinite loop.
1430 KnownBits LHSKnown = TLO.DAG.computeKnownBits(Op0, DemandedElts, Depth);
1431 // If the LHS already has zeros where RHSC does, this 'and' is dead.
1432 if ((LHSKnown.Zero & DemandedBits) ==
1433 (~RHSC->getAPIntValue() & DemandedBits))
1434 return TLO.CombineTo(Op, Op0);
1435
1436 // If any of the set bits in the RHS are known zero on the LHS, shrink
1437 // the constant.
1438 if (ShrinkDemandedConstant(Op, ~LHSKnown.Zero & DemandedBits,
1439 DemandedElts, TLO))
1440 return true;
1441
1442 // Bitwise-not (xor X, -1) is a special case: we don't usually shrink its
1443 // constant, but if this 'and' is only clearing bits that were just set by
1444 // the xor, then this 'and' can be eliminated by shrinking the mask of
1445 // the xor. For example, for a 32-bit X:
1446 // and (xor (srl X, 31), -1), 1 --> xor (srl X, 31), 1
1447 if (isBitwiseNot(Op0) && Op0.hasOneUse() &&
1448 LHSKnown.One == ~RHSC->getAPIntValue()) {
1449 SDValue Xor = TLO.DAG.getNode(ISD::XOR, dl, VT, Op0.getOperand(0), Op1);
1450 return TLO.CombineTo(Op, Xor);
1451 }
1452 }
1453
1454 // AND(INSERT_SUBVECTOR(C,X,I),M) -> INSERT_SUBVECTOR(AND(C,M),X,I)
1455 // iff 'C' is Undef/Constant and AND(X,M) == X (for DemandedBits).
1456 if (Op0.getOpcode() == ISD::INSERT_SUBVECTOR && !VT.isScalableVector() &&
1457 (Op0.getOperand(0).isUndef() ||
1459 Op0->hasOneUse()) {
1460 unsigned NumSubElts =
1462 unsigned SubIdx = Op0.getConstantOperandVal(2);
1463 APInt DemandedSub =
1464 APInt::getBitsSet(NumElts, SubIdx, SubIdx + NumSubElts);
1465 KnownBits KnownSubMask =
1466 TLO.DAG.computeKnownBits(Op1, DemandedSub & DemandedElts, Depth + 1);
1467 if (DemandedBits.isSubsetOf(KnownSubMask.One)) {
1468 SDValue NewAnd =
1469 TLO.DAG.getNode(ISD::AND, dl, VT, Op0.getOperand(0), Op1);
1470 SDValue NewInsert =
1471 TLO.DAG.getNode(ISD::INSERT_SUBVECTOR, dl, VT, NewAnd,
1472 Op0.getOperand(1), Op0.getOperand(2));
1473 return TLO.CombineTo(Op, NewInsert);
1474 }
1475 }
1476
1477 if (SimplifyDemandedBits(Op1, DemandedBits, DemandedElts, Known, TLO,
1478 Depth + 1))
1479 return true;
1480 if (SimplifyDemandedBits(Op0, ~Known.Zero & DemandedBits, DemandedElts,
1481 Known2, TLO, Depth + 1))
1482 return true;
1483
1484 // If all of the demanded bits are known one on one side, return the other.
1485 // These bits cannot contribute to the result of the 'and'.
1486 if (DemandedBits.isSubsetOf(Known2.Zero | Known.One))
1487 return TLO.CombineTo(Op, Op0);
1488 if (DemandedBits.isSubsetOf(Known.Zero | Known2.One))
1489 return TLO.CombineTo(Op, Op1);
1490 // If all of the demanded bits in the inputs are known zeros, return zero.
1491 if (DemandedBits.isSubsetOf(Known.Zero | Known2.Zero))
1492 return TLO.CombineTo(Op, TLO.DAG.getConstant(0, dl, VT));
1493 // If the RHS is a constant, see if we can simplify it.
1494 if (ShrinkDemandedConstant(Op, ~Known2.Zero & DemandedBits, DemandedElts,
1495 TLO))
1496 return true;
1497 // If the operation can be done in a smaller type, do so.
1498 if (ShrinkDemandedOp(Op, BitWidth, DemandedBits, TLO))
1499 return true;
1500
1501 // Attempt to avoid multi-use ops if we don't need anything from them.
1502 if (!DemandedBits.isAllOnes() || !DemandedElts.isAllOnes()) {
1503 SDValue DemandedOp0 = SimplifyMultipleUseDemandedBits(
1504 Op0, DemandedBits, DemandedElts, TLO.DAG, Depth + 1);
1505 SDValue DemandedOp1 = SimplifyMultipleUseDemandedBits(
1506 Op1, DemandedBits, DemandedElts, TLO.DAG, Depth + 1);
1507 if (DemandedOp0 || DemandedOp1) {
1508 Op0 = DemandedOp0 ? DemandedOp0 : Op0;
1509 Op1 = DemandedOp1 ? DemandedOp1 : Op1;
1510 SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), dl, VT, Op0, Op1);
1511 return TLO.CombineTo(Op, NewOp);
1512 }
1513 }
1514
1515 Known &= Known2;
1516 break;
1517 }
1518 case ISD::OR: {
1519 SDValue Op0 = Op.getOperand(0);
1520 SDValue Op1 = Op.getOperand(1);
1521 if (SimplifyDemandedBits(Op1, DemandedBits, DemandedElts, Known, TLO,
1522 Depth + 1)) {
1523 Op->dropFlags(SDNodeFlags::Disjoint);
1524 return true;
1525 }
1526
1527 if (SimplifyDemandedBits(Op0, ~Known.One & DemandedBits, DemandedElts,
1528 Known2, TLO, Depth + 1)) {
1529 Op->dropFlags(SDNodeFlags::Disjoint);
1530 return true;
1531 }
1532
1533 // If all of the demanded bits are known zero on one side, return the other.
1534 // These bits cannot contribute to the result of the 'or'.
1535 if (DemandedBits.isSubsetOf(Known2.One | Known.Zero))
1536 return TLO.CombineTo(Op, Op0);
1537 if (DemandedBits.isSubsetOf(Known.One | Known2.Zero))
1538 return TLO.CombineTo(Op, Op1);
1539 // If the RHS is a constant, see if we can simplify it.
1540 if (ShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO))
1541 return true;
1542 // If the operation can be done in a smaller type, do so.
1543 if (ShrinkDemandedOp(Op, BitWidth, DemandedBits, TLO))
1544 return true;
1545
1546 // Attempt to avoid multi-use ops if we don't need anything from them.
1547 if (!DemandedBits.isAllOnes() || !DemandedElts.isAllOnes()) {
1548 SDValue DemandedOp0 = SimplifyMultipleUseDemandedBits(
1549 Op0, DemandedBits, DemandedElts, TLO.DAG, Depth + 1);
1550 SDValue DemandedOp1 = SimplifyMultipleUseDemandedBits(
1551 Op1, DemandedBits, DemandedElts, TLO.DAG, Depth + 1);
1552 if (DemandedOp0 || DemandedOp1) {
1553 Op0 = DemandedOp0 ? DemandedOp0 : Op0;
1554 Op1 = DemandedOp1 ? DemandedOp1 : Op1;
1555 SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), dl, VT, Op0, Op1);
1556 return TLO.CombineTo(Op, NewOp);
1557 }
1558 }
1559
1560 // (or (and X, C1), (and (or X, Y), C2)) -> (or (and X, C1|C2), (and Y, C2))
1561 // TODO: Use SimplifyMultipleUseDemandedBits to peek through masks.
1562 if (Op0.getOpcode() == ISD::AND && Op1.getOpcode() == ISD::AND &&
1563 Op0->hasOneUse() && Op1->hasOneUse()) {
1564 // Attempt to match all commutations - m_c_Or would've been useful!
1565 for (int I = 0; I != 2; ++I) {
1566 SDValue X = Op.getOperand(I).getOperand(0);
1567 SDValue C1 = Op.getOperand(I).getOperand(1);
1568 SDValue Alt = Op.getOperand(1 - I).getOperand(0);
1569 SDValue C2 = Op.getOperand(1 - I).getOperand(1);
1570 if (Alt.getOpcode() == ISD::OR) {
1571 for (int J = 0; J != 2; ++J) {
1572 if (X == Alt.getOperand(J)) {
1573 SDValue Y = Alt.getOperand(1 - J);
1574 if (SDValue C12 = TLO.DAG.FoldConstantArithmetic(ISD::OR, dl, VT,
1575 {C1, C2})) {
1576 SDValue MaskX = TLO.DAG.getNode(ISD::AND, dl, VT, X, C12);
1577 SDValue MaskY = TLO.DAG.getNode(ISD::AND, dl, VT, Y, C2);
1578 return TLO.CombineTo(
1579 Op, TLO.DAG.getNode(ISD::OR, dl, VT, MaskX, MaskY));
1580 }
1581 }
1582 }
1583 }
1584 }
1585 }
1586
1587 Known |= Known2;
1588 break;
1589 }
1590 case ISD::XOR: {
1591 SDValue Op0 = Op.getOperand(0);
1592 SDValue Op1 = Op.getOperand(1);
1593
1594 if (SimplifyDemandedBits(Op1, DemandedBits, DemandedElts, Known, TLO,
1595 Depth + 1))
1596 return true;
1597 if (SimplifyDemandedBits(Op0, DemandedBits, DemandedElts, Known2, TLO,
1598 Depth + 1))
1599 return true;
1600
1601 // If all of the demanded bits are known zero on one side, return the other.
1602 // These bits cannot contribute to the result of the 'xor'.
1603 if (DemandedBits.isSubsetOf(Known.Zero))
1604 return TLO.CombineTo(Op, Op0);
1605 if (DemandedBits.isSubsetOf(Known2.Zero))
1606 return TLO.CombineTo(Op, Op1);
1607 // If the operation can be done in a smaller type, do so.
1608 if (ShrinkDemandedOp(Op, BitWidth, DemandedBits, TLO))
1609 return true;
1610
1611 // If all of the unknown bits are known to be zero on one side or the other
1612 // turn this into an *inclusive* or.
1613 // e.g. (A & C1)^(B & C2) -> (A & C1)|(B & C2) iff C1&C2 == 0
1614 if (DemandedBits.isSubsetOf(Known.Zero | Known2.Zero))
1615 return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::OR, dl, VT, Op0, Op1));
1616
1617 ConstantSDNode *C = isConstOrConstSplat(Op1, DemandedElts);
1618 if (C) {
1619 // If one side is a constant, and all of the set bits in the constant are
1620 // also known set on the other side, turn this into an AND, as we know
1621 // the bits will be cleared.
1622 // e.g. (X | C1) ^ C2 --> (X | C1) & ~C2 iff (C1&C2) == C2
1623 // NB: it is okay if more bits are known than are requested
1624 if (C->getAPIntValue() == Known2.One) {
1625 SDValue ANDC =
1626 TLO.DAG.getConstant(~C->getAPIntValue() & DemandedBits, dl, VT);
1627 return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::AND, dl, VT, Op0, ANDC));
1628 }
1629
1630 // If the RHS is a constant, see if we can change it. Don't alter a -1
1631 // constant because that's a 'not' op, and that is better for combining
1632 // and codegen.
1633 if (!C->isAllOnes() && DemandedBits.isSubsetOf(C->getAPIntValue())) {
1634 // We're flipping all demanded bits. Flip the undemanded bits too.
1635 SDValue New = TLO.DAG.getNOT(dl, Op0, VT);
1636 return TLO.CombineTo(Op, New);
1637 }
1638
1639 unsigned Op0Opcode = Op0.getOpcode();
1640 if ((Op0Opcode == ISD::SRL || Op0Opcode == ISD::SHL) && Op0.hasOneUse()) {
1641 if (ConstantSDNode *ShiftC =
1642 isConstOrConstSplat(Op0.getOperand(1), DemandedElts)) {
1643 // Don't crash on an oversized shift. We can not guarantee that a
1644 // bogus shift has been simplified to undef.
1645 if (ShiftC->getAPIntValue().ult(BitWidth)) {
1646 uint64_t ShiftAmt = ShiftC->getZExtValue();
1648 Ones = Op0Opcode == ISD::SHL ? Ones.shl(ShiftAmt)
1649 : Ones.lshr(ShiftAmt);
1650 if ((DemandedBits & C->getAPIntValue()) == (DemandedBits & Ones) &&
1651 isDesirableToCommuteXorWithShift(Op.getNode())) {
1652 // If the xor constant is a demanded mask, do a 'not' before the
1653 // shift:
1654 // xor (X << ShiftC), XorC --> (not X) << ShiftC
1655 // xor (X >> ShiftC), XorC --> (not X) >> ShiftC
1656 SDValue Not = TLO.DAG.getNOT(dl, Op0.getOperand(0), VT);
1657 return TLO.CombineTo(Op, TLO.DAG.getNode(Op0Opcode, dl, VT, Not,
1658 Op0.getOperand(1)));
1659 }
1660 }
1661 }
1662 }
1663 }
1664
1665 // If we can't turn this into a 'not', try to shrink the constant.
1666 if (!C || !C->isAllOnes())
1667 if (ShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO))
1668 return true;
1669
1670 // Attempt to avoid multi-use ops if we don't need anything from them.
1671 if (!DemandedBits.isAllOnes() || !DemandedElts.isAllOnes()) {
1672 SDValue DemandedOp0 = SimplifyMultipleUseDemandedBits(
1673 Op0, DemandedBits, DemandedElts, TLO.DAG, Depth + 1);
1674 SDValue DemandedOp1 = SimplifyMultipleUseDemandedBits(
1675 Op1, DemandedBits, DemandedElts, TLO.DAG, Depth + 1);
1676 if (DemandedOp0 || DemandedOp1) {
1677 Op0 = DemandedOp0 ? DemandedOp0 : Op0;
1678 Op1 = DemandedOp1 ? DemandedOp1 : Op1;
1679 SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), dl, VT, Op0, Op1);
1680 return TLO.CombineTo(Op, NewOp);
1681 }
1682 }
1683
1684 Known ^= Known2;
1685 break;
1686 }
1687 case ISD::SELECT:
1688 if (SimplifyDemandedBits(Op.getOperand(2), DemandedBits, DemandedElts,
1689 Known, TLO, Depth + 1))
1690 return true;
1691 if (SimplifyDemandedBits(Op.getOperand(1), DemandedBits, DemandedElts,
1692 Known2, TLO, Depth + 1))
1693 return true;
1694
1695 // If the operands are constants, see if we can simplify them.
1696 if (ShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO))
1697 return true;
1698
1699 // Only known if known in both the LHS and RHS.
1700 Known = Known.intersectWith(Known2);
1701 break;
1702 case ISD::VSELECT:
1703 if (SimplifyDemandedBits(Op.getOperand(2), DemandedBits, DemandedElts,
1704 Known, TLO, Depth + 1))
1705 return true;
1706 if (SimplifyDemandedBits(Op.getOperand(1), DemandedBits, DemandedElts,
1707 Known2, TLO, Depth + 1))
1708 return true;
1709
1710 // Only known if known in both the LHS and RHS.
1711 Known = Known.intersectWith(Known2);
1712 break;
1713 case ISD::SELECT_CC:
1714 if (SimplifyDemandedBits(Op.getOperand(3), DemandedBits, DemandedElts,
1715 Known, TLO, Depth + 1))
1716 return true;
1717 if (SimplifyDemandedBits(Op.getOperand(2), DemandedBits, DemandedElts,
1718 Known2, TLO, Depth + 1))
1719 return true;
1720
1721 // If the operands are constants, see if we can simplify them.
1722 if (ShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO))
1723 return true;
1724
1725 // Only known if known in both the LHS and RHS.
1726 Known = Known.intersectWith(Known2);
1727 break;
1728 case ISD::SETCC: {
1729 SDValue Op0 = Op.getOperand(0);
1730 SDValue Op1 = Op.getOperand(1);
1731 ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(2))->get();
1732 // If (1) we only need the sign-bit, (2) the setcc operands are the same
1733 // width as the setcc result, and (3) the result of a setcc conforms to 0 or
1734 // -1, we may be able to bypass the setcc.
1735 if (DemandedBits.isSignMask() &&
1739 // If we're testing X < 0, then this compare isn't needed - just use X!
1740 // FIXME: We're limiting to integer types here, but this should also work
1741 // if we don't care about FP signed-zero. The use of SETLT with FP means
1742 // that we don't care about NaNs.
1743 if (CC == ISD::SETLT && Op1.getValueType().isInteger() &&
1745 return TLO.CombineTo(Op, Op0);
1746
1747 // TODO: Should we check for other forms of sign-bit comparisons?
1748 // Examples: X <= -1, X >= 0
1749 }
1750 if (getBooleanContents(Op0.getValueType()) ==
1752 BitWidth > 1)
1753 Known.Zero.setBitsFrom(1);
1754 break;
1755 }
1756 case ISD::SHL: {
1757 SDValue Op0 = Op.getOperand(0);
1758 SDValue Op1 = Op.getOperand(1);
1759 EVT ShiftVT = Op1.getValueType();
1760
1761 if (std::optional<uint64_t> KnownSA =
1762 TLO.DAG.getValidShiftAmount(Op, DemandedElts, Depth + 1)) {
1763 unsigned ShAmt = *KnownSA;
1764 if (ShAmt == 0)
1765 return TLO.CombineTo(Op, Op0);
1766
1767 // If this is ((X >>u C1) << ShAmt), see if we can simplify this into a
1768 // single shift. We can do this if the bottom bits (which are shifted
1769 // out) are never demanded.
1770 // TODO - support non-uniform vector amounts.
1771 if (Op0.getOpcode() == ISD::SRL) {
1772 if (!DemandedBits.intersects(APInt::getLowBitsSet(BitWidth, ShAmt))) {
1773 if (std::optional<uint64_t> InnerSA =
1774 TLO.DAG.getValidShiftAmount(Op0, DemandedElts, Depth + 2)) {
1775 unsigned C1 = *InnerSA;
1776 unsigned Opc = ISD::SHL;
1777 int Diff = ShAmt - C1;
1778 if (Diff < 0) {
1779 Diff = -Diff;
1780 Opc = ISD::SRL;
1781 }
1782 SDValue NewSA = TLO.DAG.getConstant(Diff, dl, ShiftVT);
1783 return TLO.CombineTo(
1784 Op, TLO.DAG.getNode(Opc, dl, VT, Op0.getOperand(0), NewSA));
1785 }
1786 }
1787 }
1788
1789 // Convert (shl (anyext x, c)) to (anyext (shl x, c)) if the high bits
1790 // are not demanded. This will likely allow the anyext to be folded away.
1791 // TODO - support non-uniform vector amounts.
1792 if (Op0.getOpcode() == ISD::ANY_EXTEND) {
1793 SDValue InnerOp = Op0.getOperand(0);
1794 EVT InnerVT = InnerOp.getValueType();
1795 unsigned InnerBits = InnerVT.getScalarSizeInBits();
1796 if (ShAmt < InnerBits && DemandedBits.getActiveBits() <= InnerBits &&
1797 isTypeDesirableForOp(ISD::SHL, InnerVT)) {
1798 SDValue NarrowShl = TLO.DAG.getNode(
1799 ISD::SHL, dl, InnerVT, InnerOp,
1800 TLO.DAG.getShiftAmountConstant(ShAmt, InnerVT, dl));
1801 return TLO.CombineTo(
1802 Op, TLO.DAG.getNode(ISD::ANY_EXTEND, dl, VT, NarrowShl));
1803 }
1804
1805 // Repeat the SHL optimization above in cases where an extension
1806 // intervenes: (shl (anyext (shr x, c1)), c2) to
1807 // (shl (anyext x), c2-c1). This requires that the bottom c1 bits
1808 // aren't demanded (as above) and that the shifted upper c1 bits of
1809 // x aren't demanded.
1810 // TODO - support non-uniform vector amounts.
1811 if (InnerOp.getOpcode() == ISD::SRL && Op0.hasOneUse() &&
1812 InnerOp.hasOneUse()) {
1813 if (std::optional<uint64_t> SA2 = TLO.DAG.getValidShiftAmount(
1814 InnerOp, DemandedElts, Depth + 2)) {
1815 unsigned InnerShAmt = *SA2;
1816 if (InnerShAmt < ShAmt && InnerShAmt < InnerBits &&
1817 DemandedBits.getActiveBits() <=
1818 (InnerBits - InnerShAmt + ShAmt) &&
1819 DemandedBits.countr_zero() >= ShAmt) {
1820 SDValue NewSA =
1821 TLO.DAG.getConstant(ShAmt - InnerShAmt, dl, ShiftVT);
1822 SDValue NewExt = TLO.DAG.getNode(ISD::ANY_EXTEND, dl, VT,
1823 InnerOp.getOperand(0));
1824 return TLO.CombineTo(
1825 Op, TLO.DAG.getNode(ISD::SHL, dl, VT, NewExt, NewSA));
1826 }
1827 }
1828 }
1829 }
1830
1831 APInt InDemandedMask = DemandedBits.lshr(ShAmt);
1832 if (SimplifyDemandedBits(Op0, InDemandedMask, DemandedElts, Known, TLO,
1833 Depth + 1)) {
1834 // Disable the nsw and nuw flags. We can no longer guarantee that we
1835 // won't wrap after simplification.
1836 Op->dropFlags(SDNodeFlags::NoWrap);
1837 return true;
1838 }
1839 Known.Zero <<= ShAmt;
1840 Known.One <<= ShAmt;
1841 // low bits known zero.
1842 Known.Zero.setLowBits(ShAmt);
1843
1844 // Attempt to avoid multi-use ops if we don't need anything from them.
1845 if (!InDemandedMask.isAllOnes() || !DemandedElts.isAllOnes()) {
1846 SDValue DemandedOp0 = SimplifyMultipleUseDemandedBits(
1847 Op0, InDemandedMask, DemandedElts, TLO.DAG, Depth + 1);
1848 if (DemandedOp0) {
1849 SDValue NewOp = TLO.DAG.getNode(ISD::SHL, dl, VT, DemandedOp0, Op1);
1850 return TLO.CombineTo(Op, NewOp);
1851 }
1852 }
1853
1854 // TODO: Can we merge this fold with the one below?
1855 // Try shrinking the operation as long as the shift amount will still be
1856 // in range.
1857 if (ShAmt < DemandedBits.getActiveBits() && !VT.isVector() &&
1858 Op.getNode()->hasOneUse()) {
1859 // Search for the smallest integer type with free casts to and from
1860 // Op's type. For expedience, just check power-of-2 integer types.
1861 unsigned DemandedSize = DemandedBits.getActiveBits();
1862 for (unsigned SmallVTBits = llvm::bit_ceil(DemandedSize);
1863 SmallVTBits < BitWidth; SmallVTBits = NextPowerOf2(SmallVTBits)) {
1864 EVT SmallVT = EVT::getIntegerVT(*TLO.DAG.getContext(), SmallVTBits);
1865 if (isNarrowingProfitable(Op.getNode(), VT, SmallVT) &&
1866 isTypeDesirableForOp(ISD::SHL, SmallVT) &&
1867 isTruncateFree(VT, SmallVT) && isZExtFree(SmallVT, VT) &&
1868 (!TLO.LegalOperations() || isOperationLegal(ISD::SHL, SmallVT))) {
1869 assert(DemandedSize <= SmallVTBits &&
1870 "Narrowed below demanded bits?");
1871 // We found a type with free casts.
1872 SDValue NarrowShl = TLO.DAG.getNode(
1873 ISD::SHL, dl, SmallVT,
1874 TLO.DAG.getNode(ISD::TRUNCATE, dl, SmallVT, Op.getOperand(0)),
1875 TLO.DAG.getShiftAmountConstant(ShAmt, SmallVT, dl));
1876 return TLO.CombineTo(
1877 Op, TLO.DAG.getNode(ISD::ANY_EXTEND, dl, VT, NarrowShl));
1878 }
1879 }
1880 }
1881
1882 // Narrow shift to lower half - similar to ShrinkDemandedOp.
1883 // (shl i64:x, K) -> (i64 zero_extend (shl (i32 (trunc i64:x)), K))
1884 // Only do this if we demand the upper half so the knownbits are correct.
1885 unsigned HalfWidth = BitWidth / 2;
1886 if ((BitWidth % 2) == 0 && !VT.isVector() && ShAmt < HalfWidth &&
1887 DemandedBits.countLeadingOnes() >= HalfWidth) {
1888 EVT HalfVT = EVT::getIntegerVT(*TLO.DAG.getContext(), HalfWidth);
1889 if (isNarrowingProfitable(Op.getNode(), VT, HalfVT) &&
1890 isTypeDesirableForOp(ISD::SHL, HalfVT) &&
1891 isTruncateFree(VT, HalfVT) && isZExtFree(HalfVT, VT) &&
1892 (!TLO.LegalOperations() || isOperationLegal(ISD::SHL, HalfVT))) {
1893 // If we're demanding the upper bits at all, we must ensure
1894 // that the upper bits of the shift result are known to be zero,
1895 // which is equivalent to the narrow shift being NUW.
1896 if (bool IsNUW = (Known.countMinLeadingZeros() >= HalfWidth)) {
1897 bool IsNSW = Known.countMinSignBits() > HalfWidth;
1898 SDNodeFlags Flags;
1899 Flags.setNoSignedWrap(IsNSW);
1900 Flags.setNoUnsignedWrap(IsNUW);
1901 SDValue NewOp = TLO.DAG.getNode(ISD::TRUNCATE, dl, HalfVT, Op0);
1902 SDValue NewShiftAmt =
1903 TLO.DAG.getShiftAmountConstant(ShAmt, HalfVT, dl);
1904 SDValue NewShift = TLO.DAG.getNode(ISD::SHL, dl, HalfVT, NewOp,
1905 NewShiftAmt, Flags);
1906 SDValue NewExt =
1907 TLO.DAG.getNode(ISD::ZERO_EXTEND, dl, VT, NewShift);
1908 return TLO.CombineTo(Op, NewExt);
1909 }
1910 }
1911 }
1912 } else {
1913 // This is a variable shift, so we can't shift the demand mask by a known
1914 // amount. But if we are not demanding high bits, then we are not
1915 // demanding those bits from the pre-shifted operand either.
1916 if (unsigned CTLZ = DemandedBits.countl_zero()) {
1917 APInt DemandedFromOp(APInt::getLowBitsSet(BitWidth, BitWidth - CTLZ));
1918 if (SimplifyDemandedBits(Op0, DemandedFromOp, DemandedElts, Known, TLO,
1919 Depth + 1)) {
1920 // Disable the nsw and nuw flags. We can no longer guarantee that we
1921 // won't wrap after simplification.
1922 Op->dropFlags(SDNodeFlags::NoWrap);
1923 return true;
1924 }
1925 Known.resetAll();
1926 }
1927 }
1928
1929 // If we are only demanding sign bits then we can use the shift source
1930 // directly.
1931 if (std::optional<uint64_t> MaxSA =
1932 TLO.DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) {
1933 unsigned ShAmt = *MaxSA;
1934 unsigned NumSignBits =
1935 TLO.DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
1936 unsigned UpperDemandedBits = BitWidth - DemandedBits.countr_zero();
1937 if (NumSignBits > ShAmt && (NumSignBits - ShAmt) >= (UpperDemandedBits))
1938 return TLO.CombineTo(Op, Op0);
1939 }
1940 break;
1941 }
1942 case ISD::SRL: {
1943 SDValue Op0 = Op.getOperand(0);
1944 SDValue Op1 = Op.getOperand(1);
1945 EVT ShiftVT = Op1.getValueType();
1946
1947 if (std::optional<uint64_t> KnownSA =
1948 TLO.DAG.getValidShiftAmount(Op, DemandedElts, Depth + 1)) {
1949 unsigned ShAmt = *KnownSA;
1950 if (ShAmt == 0)
1951 return TLO.CombineTo(Op, Op0);
1952
1953 // If this is ((X << C1) >>u ShAmt), see if we can simplify this into a
1954 // single shift. We can do this if the top bits (which are shifted out)
1955 // are never demanded.
1956 // TODO - support non-uniform vector amounts.
1957 if (Op0.getOpcode() == ISD::SHL) {
1958 if (!DemandedBits.intersects(APInt::getHighBitsSet(BitWidth, ShAmt))) {
1959 if (std::optional<uint64_t> InnerSA =
1960 TLO.DAG.getValidShiftAmount(Op0, DemandedElts, Depth + 2)) {
1961 unsigned C1 = *InnerSA;
1962 unsigned Opc = ISD::SRL;
1963 int Diff = ShAmt - C1;
1964 if (Diff < 0) {
1965 Diff = -Diff;
1966 Opc = ISD::SHL;
1967 }
1968 SDValue NewSA = TLO.DAG.getConstant(Diff, dl, ShiftVT);
1969 return TLO.CombineTo(
1970 Op, TLO.DAG.getNode(Opc, dl, VT, Op0.getOperand(0), NewSA));
1971 }
1972 }
1973 }
1974
1975 // If this is (srl (sra X, C1), ShAmt), see if we can combine this into a
1976 // single sra. We can do this if the top bits are never demanded.
1977 if (Op0.getOpcode() == ISD::SRA && Op0.hasOneUse()) {
1978 if (!DemandedBits.intersects(APInt::getHighBitsSet(BitWidth, ShAmt))) {
1979 if (std::optional<uint64_t> InnerSA =
1980 TLO.DAG.getValidShiftAmount(Op0, DemandedElts, Depth + 2)) {
1981 unsigned C1 = *InnerSA;
1982 // Clamp the combined shift amount if it exceeds the bit width.
1983 unsigned Combined = std::min(C1 + ShAmt, BitWidth - 1);
1984 SDValue NewSA = TLO.DAG.getConstant(Combined, dl, ShiftVT);
1985 return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SRA, dl, VT,
1986 Op0.getOperand(0), NewSA));
1987 }
1988 }
1989 }
1990
1991 APInt InDemandedMask = (DemandedBits << ShAmt);
1992
1993 // If the shift is exact, then it does demand the low bits (and knows that
1994 // they are zero).
1995 if (Op->getFlags().hasExact())
1996 InDemandedMask.setLowBits(ShAmt);
1997
1998 // Narrow shift to lower half - similar to ShrinkDemandedOp.
1999 // (srl i64:x, K) -> (i64 zero_extend (srl (i32 (trunc i64:x)), K))
2000 if ((BitWidth % 2) == 0 && !VT.isVector()) {
2002 EVT HalfVT = EVT::getIntegerVT(*TLO.DAG.getContext(), BitWidth / 2);
2003 if (isNarrowingProfitable(Op.getNode(), VT, HalfVT) &&
2004 isTypeDesirableForOp(ISD::SRL, HalfVT) &&
2005 isTruncateFree(VT, HalfVT) && isZExtFree(HalfVT, VT) &&
2006 (!TLO.LegalOperations() || isOperationLegal(ISD::SRL, HalfVT)) &&
2007 ((InDemandedMask.countLeadingZeros() >= (BitWidth / 2)) ||
2008 TLO.DAG.MaskedValueIsZero(Op0, HiBits))) {
2009 SDValue NewOp = TLO.DAG.getNode(ISD::TRUNCATE, dl, HalfVT, Op0);
2010 SDValue NewShiftAmt =
2011 TLO.DAG.getShiftAmountConstant(ShAmt, HalfVT, dl);
2012 SDValue NewShift =
2013 TLO.DAG.getNode(ISD::SRL, dl, HalfVT, NewOp, NewShiftAmt);
2014 return TLO.CombineTo(
2015 Op, TLO.DAG.getNode(ISD::ZERO_EXTEND, dl, VT, NewShift));
2016 }
2017 }
2018
2019 // Compute the new bits that are at the top now.
2020 if (SimplifyDemandedBits(Op0, InDemandedMask, DemandedElts, Known, TLO,
2021 Depth + 1))
2022 return true;
2023 Known.Zero.lshrInPlace(ShAmt);
2024 Known.One.lshrInPlace(ShAmt);
2025 // High bits known zero.
2026 Known.Zero.setHighBits(ShAmt);
2027
2028 // Attempt to avoid multi-use ops if we don't need anything from them.
2029 if (!InDemandedMask.isAllOnes() || !DemandedElts.isAllOnes()) {
2030 SDValue DemandedOp0 = SimplifyMultipleUseDemandedBits(
2031 Op0, InDemandedMask, DemandedElts, TLO.DAG, Depth + 1);
2032 if (DemandedOp0) {
2033 SDValue NewOp = TLO.DAG.getNode(ISD::SRL, dl, VT, DemandedOp0, Op1);
2034 return TLO.CombineTo(Op, NewOp);
2035 }
2036 }
2037 } else {
2038 // Use generic knownbits computation as it has support for non-uniform
2039 // shift amounts.
2040 Known = TLO.DAG.computeKnownBits(Op, DemandedElts, Depth);
2041 }
2042
2043 // If we are only demanding sign bits then we can use the shift source
2044 // directly.
2045 if (std::optional<uint64_t> MaxSA =
2046 TLO.DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) {
2047 unsigned ShAmt = *MaxSA;
2048 // Must already be signbits in DemandedBits bounds, and can't demand any
2049 // shifted in zeroes.
2050 if (DemandedBits.countl_zero() >= ShAmt) {
2051 unsigned NumSignBits =
2052 TLO.DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
2053 if (DemandedBits.countr_zero() >= (BitWidth - NumSignBits))
2054 return TLO.CombineTo(Op, Op0);
2055 }
2056 }
2057
2058 // Try to match AVG patterns (after shift simplification).
2059 if (SDValue AVG = combineShiftToAVG(Op, TLO, *this, DemandedBits,
2060 DemandedElts, Depth + 1))
2061 return TLO.CombineTo(Op, AVG);
2062
2063 break;
2064 }
2065 case ISD::SRA: {
2066 SDValue Op0 = Op.getOperand(0);
2067 SDValue Op1 = Op.getOperand(1);
2068 EVT ShiftVT = Op1.getValueType();
2069
2070 // If we only want bits that already match the signbit then we don't need
2071 // to shift.
2072 unsigned NumHiDemandedBits = BitWidth - DemandedBits.countr_zero();
2073 if (TLO.DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1) >=
2074 NumHiDemandedBits)
2075 return TLO.CombineTo(Op, Op0);
2076
2077 // If this is an arithmetic shift right and only the low-bit is set, we can
2078 // always convert this into a logical shr, even if the shift amount is
2079 // variable. The low bit of the shift cannot be an input sign bit unless
2080 // the shift amount is >= the size of the datatype, which is undefined.
2081 if (DemandedBits.isOne())
2082 return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SRL, dl, VT, Op0, Op1));
2083
2084 if (std::optional<uint64_t> KnownSA =
2085 TLO.DAG.getValidShiftAmount(Op, DemandedElts, Depth + 1)) {
2086 unsigned ShAmt = *KnownSA;
2087 if (ShAmt == 0)
2088 return TLO.CombineTo(Op, Op0);
2089
2090 // fold (sra (shl x, c1), c1) -> sext_inreg for some c1 and target
2091 // supports sext_inreg.
2092 if (Op0.getOpcode() == ISD::SHL) {
2093 if (std::optional<uint64_t> InnerSA =
2094 TLO.DAG.getValidShiftAmount(Op0, DemandedElts, Depth + 2)) {
2095 unsigned LowBits = BitWidth - ShAmt;
2096 EVT ExtVT = EVT::getIntegerVT(*TLO.DAG.getContext(), LowBits);
2097 if (VT.isVector())
2098 ExtVT = EVT::getVectorVT(*TLO.DAG.getContext(), ExtVT,
2100
2101 if (*InnerSA == ShAmt) {
2102 if (!TLO.LegalOperations() ||
2104 return TLO.CombineTo(
2105 Op, TLO.DAG.getNode(ISD::SIGN_EXTEND_INREG, dl, VT,
2106 Op0.getOperand(0),
2107 TLO.DAG.getValueType(ExtVT)));
2108
2109 // Even if we can't convert to sext_inreg, we might be able to
2110 // remove this shift pair if the input is already sign extended.
2111 unsigned NumSignBits =
2112 TLO.DAG.ComputeNumSignBits(Op0.getOperand(0), DemandedElts);
2113 if (NumSignBits > ShAmt)
2114 return TLO.CombineTo(Op, Op0.getOperand(0));
2115 }
2116 }
2117 }
2118
2119 APInt InDemandedMask = (DemandedBits << ShAmt);
2120
2121 // If the shift is exact, then it does demand the low bits (and knows that
2122 // they are zero).
2123 if (Op->getFlags().hasExact())
2124 InDemandedMask.setLowBits(ShAmt);
2125
2126 // If any of the demanded bits are produced by the sign extension, we also
2127 // demand the input sign bit.
2128 if (DemandedBits.countl_zero() < ShAmt)
2129 InDemandedMask.setSignBit();
2130
2131 if (SimplifyDemandedBits(Op0, InDemandedMask, DemandedElts, Known, TLO,
2132 Depth + 1))
2133 return true;
2134 Known.Zero.lshrInPlace(ShAmt);
2135 Known.One.lshrInPlace(ShAmt);
2136
2137 // If the input sign bit is known to be zero, or if none of the top bits
2138 // are demanded, turn this into an unsigned shift right.
2139 if (Known.Zero[BitWidth - ShAmt - 1] ||
2140 DemandedBits.countl_zero() >= ShAmt) {
2141 SDNodeFlags Flags;
2142 Flags.setExact(Op->getFlags().hasExact());
2143 return TLO.CombineTo(
2144 Op, TLO.DAG.getNode(ISD::SRL, dl, VT, Op0, Op1, Flags));
2145 }
2146
2147 int Log2 = DemandedBits.exactLogBase2();
2148 if (Log2 >= 0) {
2149 // The bit must come from the sign.
2150 SDValue NewSA = TLO.DAG.getConstant(BitWidth - 1 - Log2, dl, ShiftVT);
2151 return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SRL, dl, VT, Op0, NewSA));
2152 }
2153
2154 if (Known.One[BitWidth - ShAmt - 1])
2155 // New bits are known one.
2156 Known.One.setHighBits(ShAmt);
2157
2158 // Attempt to avoid multi-use ops if we don't need anything from them.
2159 if (!InDemandedMask.isAllOnes() || !DemandedElts.isAllOnes()) {
2160 SDValue DemandedOp0 = SimplifyMultipleUseDemandedBits(
2161 Op0, InDemandedMask, DemandedElts, TLO.DAG, Depth + 1);
2162 if (DemandedOp0) {
2163 SDValue NewOp = TLO.DAG.getNode(ISD::SRA, dl, VT, DemandedOp0, Op1);
2164 return TLO.CombineTo(Op, NewOp);
2165 }
2166 }
2167 }
2168
2169 // Try to match AVG patterns (after shift simplification).
2170 if (SDValue AVG = combineShiftToAVG(Op, TLO, *this, DemandedBits,
2171 DemandedElts, Depth + 1))
2172 return TLO.CombineTo(Op, AVG);
2173
2174 break;
2175 }
2176 case ISD::FSHL:
2177 case ISD::FSHR: {
2178 SDValue Op0 = Op.getOperand(0);
2179 SDValue Op1 = Op.getOperand(1);
2180 SDValue Op2 = Op.getOperand(2);
2181 bool IsFSHL = (Op.getOpcode() == ISD::FSHL);
2182
2183 if (ConstantSDNode *SA = isConstOrConstSplat(Op2, DemandedElts)) {
2184 unsigned Amt = SA->getAPIntValue().urem(BitWidth);
2185
2186 // For fshl, 0-shift returns the 1st arg.
2187 // For fshr, 0-shift returns the 2nd arg.
2188 if (Amt == 0) {
2189 if (SimplifyDemandedBits(IsFSHL ? Op0 : Op1, DemandedBits, DemandedElts,
2190 Known, TLO, Depth + 1))
2191 return true;
2192 break;
2193 }
2194
2195 // fshl: (Op0 << Amt) | (Op1 >> (BW - Amt))
2196 // fshr: (Op0 << (BW - Amt)) | (Op1 >> Amt)
2197 APInt Demanded0 = DemandedBits.lshr(IsFSHL ? Amt : (BitWidth - Amt));
2198 APInt Demanded1 = DemandedBits << (IsFSHL ? (BitWidth - Amt) : Amt);
2199 if (SimplifyDemandedBits(Op0, Demanded0, DemandedElts, Known2, TLO,
2200 Depth + 1))
2201 return true;
2202 if (SimplifyDemandedBits(Op1, Demanded1, DemandedElts, Known, TLO,
2203 Depth + 1))
2204 return true;
2205
2206 Known2.One <<= (IsFSHL ? Amt : (BitWidth - Amt));
2207 Known2.Zero <<= (IsFSHL ? Amt : (BitWidth - Amt));
2208 Known.One.lshrInPlace(IsFSHL ? (BitWidth - Amt) : Amt);
2209 Known.Zero.lshrInPlace(IsFSHL ? (BitWidth - Amt) : Amt);
2210 Known = Known.unionWith(Known2);
2211
2212 // Attempt to avoid multi-use ops if we don't need anything from them.
2213 if (!Demanded0.isAllOnes() || !Demanded1.isAllOnes() ||
2214 !DemandedElts.isAllOnes()) {
2215 SDValue DemandedOp0 = SimplifyMultipleUseDemandedBits(
2216 Op0, Demanded0, DemandedElts, TLO.DAG, Depth + 1);
2217 SDValue DemandedOp1 = SimplifyMultipleUseDemandedBits(
2218 Op1, Demanded1, DemandedElts, TLO.DAG, Depth + 1);
2219 if (DemandedOp0 || DemandedOp1) {
2220 DemandedOp0 = DemandedOp0 ? DemandedOp0 : Op0;
2221 DemandedOp1 = DemandedOp1 ? DemandedOp1 : Op1;
2222 SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), dl, VT, DemandedOp0,
2223 DemandedOp1, Op2);
2224 return TLO.CombineTo(Op, NewOp);
2225 }
2226 }
2227 }
2228
2229 // For pow-2 bitwidths we only demand the bottom modulo amt bits.
2230 if (isPowerOf2_32(BitWidth)) {
2231 APInt DemandedAmtBits(Op2.getScalarValueSizeInBits(), BitWidth - 1);
2232 if (SimplifyDemandedBits(Op2, DemandedAmtBits, DemandedElts,
2233 Known2, TLO, Depth + 1))
2234 return true;
2235 }
2236 break;
2237 }
2238 case ISD::ROTL:
2239 case ISD::ROTR: {
2240 SDValue Op0 = Op.getOperand(0);
2241 SDValue Op1 = Op.getOperand(1);
2242 bool IsROTL = (Op.getOpcode() == ISD::ROTL);
2243
2244 // If we're rotating an 0/-1 value, then it stays an 0/-1 value.
2245 if (BitWidth == TLO.DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1))
2246 return TLO.CombineTo(Op, Op0);
2247
2248 if (ConstantSDNode *SA = isConstOrConstSplat(Op1, DemandedElts)) {
2249 unsigned Amt = SA->getAPIntValue().urem(BitWidth);
2250 unsigned RevAmt = BitWidth - Amt;
2251
2252 // rotl: (Op0 << Amt) | (Op0 >> (BW - Amt))
2253 // rotr: (Op0 << (BW - Amt)) | (Op0 >> Amt)
2254 APInt Demanded0 = DemandedBits.rotr(IsROTL ? Amt : RevAmt);
2255 if (SimplifyDemandedBits(Op0, Demanded0, DemandedElts, Known2, TLO,
2256 Depth + 1))
2257 return true;
2258
2259 // rot*(x, 0) --> x
2260 if (Amt == 0)
2261 return TLO.CombineTo(Op, Op0);
2262
2263 // See if we don't demand either half of the rotated bits.
2264 if ((!TLO.LegalOperations() || isOperationLegal(ISD::SHL, VT)) &&
2265 DemandedBits.countr_zero() >= (IsROTL ? Amt : RevAmt)) {
2266 Op1 = TLO.DAG.getConstant(IsROTL ? Amt : RevAmt, dl, Op1.getValueType());
2267 return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SHL, dl, VT, Op0, Op1));
2268 }
2269 if ((!TLO.LegalOperations() || isOperationLegal(ISD::SRL, VT)) &&
2270 DemandedBits.countl_zero() >= (IsROTL ? RevAmt : Amt)) {
2271 Op1 = TLO.DAG.getConstant(IsROTL ? RevAmt : Amt, dl, Op1.getValueType());
2272 return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SRL, dl, VT, Op0, Op1));
2273 }
2274 }
2275
2276 // For pow-2 bitwidths we only demand the bottom modulo amt bits.
2277 if (isPowerOf2_32(BitWidth)) {
2278 APInt DemandedAmtBits(Op1.getScalarValueSizeInBits(), BitWidth - 1);
2279 if (SimplifyDemandedBits(Op1, DemandedAmtBits, DemandedElts, Known2, TLO,
2280 Depth + 1))
2281 return true;
2282 }
2283 break;
2284 }
2285 case ISD::SMIN:
2286 case ISD::SMAX:
2287 case ISD::UMIN:
2288 case ISD::UMAX: {
2289 unsigned Opc = Op.getOpcode();
2290 SDValue Op0 = Op.getOperand(0);
2291 SDValue Op1 = Op.getOperand(1);
2292
2293 // If we're only demanding signbits, then we can simplify to OR/AND node.
2294 unsigned BitOp =
2295 (Opc == ISD::SMIN || Opc == ISD::UMAX) ? ISD::OR : ISD::AND;
2296 unsigned NumSignBits =
2297 std::min(TLO.DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1),
2298 TLO.DAG.ComputeNumSignBits(Op1, DemandedElts, Depth + 1));
2299 unsigned NumDemandedUpperBits = BitWidth - DemandedBits.countr_zero();
2300 if (NumSignBits >= NumDemandedUpperBits)
2301 return TLO.CombineTo(Op, TLO.DAG.getNode(BitOp, SDLoc(Op), VT, Op0, Op1));
2302
2303 // Check if one arg is always less/greater than (or equal) to the other arg.
2304 KnownBits Known0 = TLO.DAG.computeKnownBits(Op0, DemandedElts, Depth + 1);
2305 KnownBits Known1 = TLO.DAG.computeKnownBits(Op1, DemandedElts, Depth + 1);
2306 switch (Opc) {
2307 case ISD::SMIN:
2308 if (std::optional<bool> IsSLE = KnownBits::sle(Known0, Known1))
2309 return TLO.CombineTo(Op, *IsSLE ? Op0 : Op1);
2310 if (std::optional<bool> IsSLT = KnownBits::slt(Known0, Known1))
2311 return TLO.CombineTo(Op, *IsSLT ? Op0 : Op1);
2312 Known = KnownBits::smin(Known0, Known1);
2313 break;
2314 case ISD::SMAX:
2315 if (std::optional<bool> IsSGE = KnownBits::sge(Known0, Known1))
2316 return TLO.CombineTo(Op, *IsSGE ? Op0 : Op1);
2317 if (std::optional<bool> IsSGT = KnownBits::sgt(Known0, Known1))
2318 return TLO.CombineTo(Op, *IsSGT ? Op0 : Op1);
2319 Known = KnownBits::smax(Known0, Known1);
2320 break;
2321 case ISD::UMIN:
2322 if (std::optional<bool> IsULE = KnownBits::ule(Known0, Known1))
2323 return TLO.CombineTo(Op, *IsULE ? Op0 : Op1);
2324 if (std::optional<bool> IsULT = KnownBits::ult(Known0, Known1))
2325 return TLO.CombineTo(Op, *IsULT ? Op0 : Op1);
2326 Known = KnownBits::umin(Known0, Known1);
2327 break;
2328 case ISD::UMAX:
2329 if (std::optional<bool> IsUGE = KnownBits::uge(Known0, Known1))
2330 return TLO.CombineTo(Op, *IsUGE ? Op0 : Op1);
2331 if (std::optional<bool> IsUGT = KnownBits::ugt(Known0, Known1))
2332 return TLO.CombineTo(Op, *IsUGT ? Op0 : Op1);
2333 Known = KnownBits::umax(Known0, Known1);
2334 break;
2335 }
2336 break;
2337 }
2338 case ISD::BITREVERSE: {
2339 SDValue Src = Op.getOperand(0);
2340 APInt DemandedSrcBits = DemandedBits.reverseBits();
2341 if (SimplifyDemandedBits(Src, DemandedSrcBits, DemandedElts, Known2, TLO,
2342 Depth + 1))
2343 return true;
2344 Known.One = Known2.One.reverseBits();
2345 Known.Zero = Known2.Zero.reverseBits();
2346 break;
2347 }
2348 case ISD::BSWAP: {
2349 SDValue Src = Op.getOperand(0);
2350
2351 // If the only bits demanded come from one byte of the bswap result,
2352 // just shift the input byte into position to eliminate the bswap.
2353 unsigned NLZ = DemandedBits.countl_zero();
2354 unsigned NTZ = DemandedBits.countr_zero();
2355
2356 // Round NTZ down to the next byte. If we have 11 trailing zeros, then
2357 // we need all the bits down to bit 8. Likewise, round NLZ. If we
2358 // have 14 leading zeros, round to 8.
2359 NLZ = alignDown(NLZ, 8);
2360 NTZ = alignDown(NTZ, 8);
2361 // If we need exactly one byte, we can do this transformation.
2362 if (BitWidth - NLZ - NTZ == 8) {
2363 // Replace this with either a left or right shift to get the byte into
2364 // the right place.
2365 unsigned ShiftOpcode = NLZ > NTZ ? ISD::SRL : ISD::SHL;
2366 if (!TLO.LegalOperations() || isOperationLegal(ShiftOpcode, VT)) {
2367 unsigned ShiftAmount = NLZ > NTZ ? NLZ - NTZ : NTZ - NLZ;
2368 SDValue ShAmt = TLO.DAG.getShiftAmountConstant(ShiftAmount, VT, dl);
2369 SDValue NewOp = TLO.DAG.getNode(ShiftOpcode, dl, VT, Src, ShAmt);
2370 return TLO.CombineTo(Op, NewOp);
2371 }
2372 }
2373
2374 APInt DemandedSrcBits = DemandedBits.byteSwap();
2375 if (SimplifyDemandedBits(Src, DemandedSrcBits, DemandedElts, Known2, TLO,
2376 Depth + 1))
2377 return true;
2378 Known.One = Known2.One.byteSwap();
2379 Known.Zero = Known2.Zero.byteSwap();
2380 break;
2381 }
2382 case ISD::CTPOP: {
2383 // If only 1 bit is demanded, replace with PARITY as long as we're before
2384 // op legalization.
2385 // FIXME: Limit to scalars for now.
2386 if (DemandedBits.isOne() && !TLO.LegalOps && !VT.isVector())
2387 return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::PARITY, dl, VT,
2388 Op.getOperand(0)));
2389
2390 Known = TLO.DAG.computeKnownBits(Op, DemandedElts, Depth);
2391 break;
2392 }
2394 SDValue Op0 = Op.getOperand(0);
2395 EVT ExVT = cast<VTSDNode>(Op.getOperand(1))->getVT();
2396 unsigned ExVTBits = ExVT.getScalarSizeInBits();
2397
2398 // If we only care about the highest bit, don't bother shifting right.
2399 if (DemandedBits.isSignMask()) {
2400 unsigned MinSignedBits =
2401 TLO.DAG.ComputeMaxSignificantBits(Op0, DemandedElts, Depth + 1);
2402 bool AlreadySignExtended = ExVTBits >= MinSignedBits;
2403 // However if the input is already sign extended we expect the sign
2404 // extension to be dropped altogether later and do not simplify.
2405 if (!AlreadySignExtended) {
2406 // Compute the correct shift amount type, which must be getShiftAmountTy
2407 // for scalar types after legalization.
2408 SDValue ShiftAmt =
2409 TLO.DAG.getShiftAmountConstant(BitWidth - ExVTBits, VT, dl);
2410 return TLO.CombineTo(Op,
2411 TLO.DAG.getNode(ISD::SHL, dl, VT, Op0, ShiftAmt));
2412 }
2413 }
2414
2415 // If none of the extended bits are demanded, eliminate the sextinreg.
2416 if (DemandedBits.getActiveBits() <= ExVTBits)
2417 return TLO.CombineTo(Op, Op0);
2418
2419 APInt InputDemandedBits = DemandedBits.getLoBits(ExVTBits);
2420
2421 // Since the sign extended bits are demanded, we know that the sign
2422 // bit is demanded.
2423 InputDemandedBits.setBit(ExVTBits - 1);
2424
2425 if (SimplifyDemandedBits(Op0, InputDemandedBits, DemandedElts, Known, TLO,
2426 Depth + 1))
2427 return true;
2428
2429 // If the sign bit of the input is known set or clear, then we know the
2430 // top bits of the result.
2431
2432 // If the input sign bit is known zero, convert this into a zero extension.
2433 if (Known.Zero[ExVTBits - 1])
2434 return TLO.CombineTo(Op, TLO.DAG.getZeroExtendInReg(Op0, dl, ExVT));
2435
2436 APInt Mask = APInt::getLowBitsSet(BitWidth, ExVTBits);
2437 if (Known.One[ExVTBits - 1]) { // Input sign bit known set
2438 Known.One.setBitsFrom(ExVTBits);
2439 Known.Zero &= Mask;
2440 } else { // Input sign bit unknown
2441 Known.Zero &= Mask;
2442 Known.One &= Mask;
2443 }
2444 break;
2445 }
2446 case ISD::BUILD_PAIR: {
2447 EVT HalfVT = Op.getOperand(0).getValueType();
2448 unsigned HalfBitWidth = HalfVT.getScalarSizeInBits();
2449
2450 APInt MaskLo = DemandedBits.getLoBits(HalfBitWidth).trunc(HalfBitWidth);
2451 APInt MaskHi = DemandedBits.getHiBits(HalfBitWidth).trunc(HalfBitWidth);
2452
2453 KnownBits KnownLo, KnownHi;
2454
2455 if (SimplifyDemandedBits(Op.getOperand(0), MaskLo, KnownLo, TLO, Depth + 1))
2456 return true;
2457
2458 if (SimplifyDemandedBits(Op.getOperand(1), MaskHi, KnownHi, TLO, Depth + 1))
2459 return true;
2460
2461 Known = KnownHi.concat(KnownLo);
2462 break;
2463 }
2465 if (VT.isScalableVector())
2466 return false;
2467 [[fallthrough]];
2468 case ISD::ZERO_EXTEND: {
2469 SDValue Src = Op.getOperand(0);
2470 EVT SrcVT = Src.getValueType();
2471 unsigned InBits = SrcVT.getScalarSizeInBits();
2472 unsigned InElts = SrcVT.isFixedLengthVector() ? SrcVT.getVectorNumElements() : 1;
2473 bool IsVecInReg = Op.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG;
2474
2475 // If none of the top bits are demanded, convert this into an any_extend.
2476 if (DemandedBits.getActiveBits() <= InBits) {
2477 // If we only need the non-extended bits of the bottom element
2478 // then we can just bitcast to the result.
2479 if (IsLE && IsVecInReg && DemandedElts == 1 &&
2480 VT.getSizeInBits() == SrcVT.getSizeInBits())
2481 return TLO.CombineTo(Op, TLO.DAG.getBitcast(VT, Src));
2482
2483 unsigned Opc =
2485 if (!TLO.LegalOperations() || isOperationLegal(Opc, VT))
2486 return TLO.CombineTo(Op, TLO.DAG.getNode(Opc, dl, VT, Src));
2487 }
2488
2489 APInt InDemandedBits = DemandedBits.trunc(InBits);
2490 APInt InDemandedElts = DemandedElts.zext(InElts);
2491 if (SimplifyDemandedBits(Src, InDemandedBits, InDemandedElts, Known, TLO,
2492 Depth + 1)) {
2493 Op->dropFlags(SDNodeFlags::NonNeg);
2494 return true;
2495 }
2496 assert(Known.getBitWidth() == InBits && "Src width has changed?");
2497 Known = Known.zext(BitWidth);
2498
2499 // Attempt to avoid multi-use ops if we don't need anything from them.
2500 if (SDValue NewSrc = SimplifyMultipleUseDemandedBits(
2501 Src, InDemandedBits, InDemandedElts, TLO.DAG, Depth + 1))
2502 return TLO.CombineTo(Op, TLO.DAG.getNode(Op.getOpcode(), dl, VT, NewSrc));
2503 break;
2504 }
2506 if (VT.isScalableVector())
2507 return false;
2508 [[fallthrough]];
2509 case ISD::SIGN_EXTEND: {
2510 SDValue Src = Op.getOperand(0);
2511 EVT SrcVT = Src.getValueType();
2512 unsigned InBits = SrcVT.getScalarSizeInBits();
2513 unsigned InElts = SrcVT.isFixedLengthVector() ? SrcVT.getVectorNumElements() : 1;
2514 bool IsVecInReg = Op.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG;
2515
2516 APInt InDemandedElts = DemandedElts.zext(InElts);
2517 APInt InDemandedBits = DemandedBits.trunc(InBits);
2518
2519 // Since some of the sign extended bits are demanded, we know that the sign
2520 // bit is demanded.
2521 InDemandedBits.setBit(InBits - 1);
2522
2523 // If none of the top bits are demanded, convert this into an any_extend.
2524 if (DemandedBits.getActiveBits() <= InBits) {
2525 // If we only need the non-extended bits of the bottom element
2526 // then we can just bitcast to the result.
2527 if (IsLE && IsVecInReg && DemandedElts == 1 &&
2528 VT.getSizeInBits() == SrcVT.getSizeInBits())
2529 return TLO.CombineTo(Op, TLO.DAG.getBitcast(VT, Src));
2530
2531 // Don't lose an all signbits 0/-1 splat on targets with 0/-1 booleans.
2533 TLO.DAG.ComputeNumSignBits(Src, InDemandedElts, Depth + 1) !=
2534 InBits) {
2535 unsigned Opc =
2537 if (!TLO.LegalOperations() || isOperationLegal(Opc, VT))
2538 return TLO.CombineTo(Op, TLO.DAG.getNode(Opc, dl, VT, Src));
2539 }
2540 }
2541
2542 if (SimplifyDemandedBits(Src, InDemandedBits, InDemandedElts, Known, TLO,
2543 Depth + 1))
2544 return true;
2545 assert(Known.getBitWidth() == InBits && "Src width has changed?");
2546
2547 // If the sign bit is known one, the top bits match.
2548 Known = Known.sext(BitWidth);
2549
2550 // If the sign bit is known zero, convert this to a zero extend.
2551 if (Known.isNonNegative()) {
2552 unsigned Opc =
2554 if (!TLO.LegalOperations() || isOperationLegal(Opc, VT)) {
2555 SDNodeFlags Flags;
2556 if (!IsVecInReg)
2557 Flags |= SDNodeFlags::NonNeg;
2558 return TLO.CombineTo(Op, TLO.DAG.getNode(Opc, dl, VT, Src, Flags));
2559 }
2560 }
2561
2562 // Attempt to avoid multi-use ops if we don't need anything from them.
2563 if (SDValue NewSrc = SimplifyMultipleUseDemandedBits(
2564 Src, InDemandedBits, InDemandedElts, TLO.DAG, Depth + 1))
2565 return TLO.CombineTo(Op, TLO.DAG.getNode(Op.getOpcode(), dl, VT, NewSrc));
2566 break;
2567 }
2569 if (VT.isScalableVector())
2570 return false;
2571 [[fallthrough]];
2572 case ISD::ANY_EXTEND: {
2573 SDValue Src = Op.getOperand(0);
2574 EVT SrcVT = Src.getValueType();
2575 unsigned InBits = SrcVT.getScalarSizeInBits();
2576 unsigned InElts = SrcVT.isFixedLengthVector() ? SrcVT.getVectorNumElements() : 1;
2577 bool IsVecInReg = Op.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG;
2578
2579 // If we only need the bottom element then we can just bitcast.
2580 // TODO: Handle ANY_EXTEND?
2581 if (IsLE && IsVecInReg && DemandedElts == 1 &&
2582 VT.getSizeInBits() == SrcVT.getSizeInBits())
2583 return TLO.CombineTo(Op, TLO.DAG.getBitcast(VT, Src));
2584
2585 APInt InDemandedBits = DemandedBits.trunc(InBits);
2586 APInt InDemandedElts = DemandedElts.zext(InElts);
2587 if (SimplifyDemandedBits(Src, InDemandedBits, InDemandedElts, Known, TLO,
2588 Depth + 1))
2589 return true;
2590 assert(Known.getBitWidth() == InBits && "Src width has changed?");
2591 Known = Known.anyext(BitWidth);
2592
2593 // Attempt to avoid multi-use ops if we don't need anything from them.
2594 if (SDValue NewSrc = SimplifyMultipleUseDemandedBits(
2595 Src, InDemandedBits, InDemandedElts, TLO.DAG, Depth + 1))
2596 return TLO.CombineTo(Op, TLO.DAG.getNode(Op.getOpcode(), dl, VT, NewSrc));
2597 break;
2598 }
2599 case ISD::TRUNCATE: {
2600 SDValue Src = Op.getOperand(0);
2601
2602 // Simplify the input, using demanded bit information, and compute the known
2603 // zero/one bits live out.
2604 unsigned OperandBitWidth = Src.getScalarValueSizeInBits();
2605 APInt TruncMask = DemandedBits.zext(OperandBitWidth);
2606 if (SimplifyDemandedBits(Src, TruncMask, DemandedElts, Known, TLO,
2607 Depth + 1)) {
2608 // Disable the nsw and nuw flags. We can no longer guarantee that we
2609 // won't wrap after simplification.
2610 Op->dropFlags(SDNodeFlags::NoWrap);
2611 return true;
2612 }
2613 Known = Known.trunc(BitWidth);
2614
2615 // Attempt to avoid multi-use ops if we don't need anything from them.
2616 if (SDValue NewSrc = SimplifyMultipleUseDemandedBits(
2617 Src, TruncMask, DemandedElts, TLO.DAG, Depth + 1))
2618 return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::TRUNCATE, dl, VT, NewSrc));
2619
2620 // If the input is only used by this truncate, see if we can shrink it based
2621 // on the known demanded bits.
2622 switch (Src.getOpcode()) {
2623 default:
2624 break;
2625 case ISD::SRL:
2626 // Shrink SRL by a constant if none of the high bits shifted in are
2627 // demanded.
2628 if (TLO.LegalTypes() && !isTypeDesirableForOp(ISD::SRL, VT))
2629 // Do not turn (vt1 truncate (vt2 srl)) into (vt1 srl) if vt1 is
2630 // undesirable.
2631 break;
2632
2633 if (Src.getNode()->hasOneUse()) {
2634 if (isTruncateFree(Src, VT) &&
2635 !isTruncateFree(Src.getValueType(), VT)) {
2636 // If truncate is only free at trunc(srl), do not turn it into
2637 // srl(trunc). The check is done by first check the truncate is free
2638 // at Src's opcode(srl), then check the truncate is not done by
2639 // referencing sub-register. In test, if both trunc(srl) and
2640 // srl(trunc)'s trunc are free, srl(trunc) performs better. If only
2641 // trunc(srl)'s trunc is free, trunc(srl) is better.
2642 break;
2643 }
2644
2645 std::optional<uint64_t> ShAmtC =
2646 TLO.DAG.getValidShiftAmount(Src, DemandedElts, Depth + 2);
2647 if (!ShAmtC || *ShAmtC >= BitWidth)
2648 break;
2649 uint64_t ShVal = *ShAmtC;
2650
2651 APInt HighBits =
2652 APInt::getHighBitsSet(OperandBitWidth, OperandBitWidth - BitWidth);
2653 HighBits.lshrInPlace(ShVal);
2654 HighBits = HighBits.trunc(BitWidth);
2655 if (!(HighBits & DemandedBits)) {
2656 // None of the shifted in bits are needed. Add a truncate of the
2657 // shift input, then shift it.
2658 SDValue NewShAmt = TLO.DAG.getShiftAmountConstant(ShVal, VT, dl);
2659 SDValue NewTrunc =
2660 TLO.DAG.getNode(ISD::TRUNCATE, dl, VT, Src.getOperand(0));
2661 return TLO.CombineTo(
2662 Op, TLO.DAG.getNode(ISD::SRL, dl, VT, NewTrunc, NewShAmt));
2663 }
2664 }
2665 break;
2666 }
2667
2668 break;
2669 }
2670 case ISD::AssertZext: {
2671 // AssertZext demands all of the high bits, plus any of the low bits
2672 // demanded by its users.
2673 EVT ZVT = cast<VTSDNode>(Op.getOperand(1))->getVT();
2675 if (SimplifyDemandedBits(Op.getOperand(0), ~InMask | DemandedBits, Known,
2676 TLO, Depth + 1))
2677 return true;
2678
2679 Known.Zero |= ~InMask;
2680 Known.One &= (~Known.Zero);
2681 break;
2682 }
2684 SDValue Src = Op.getOperand(0);
2685 SDValue Idx = Op.getOperand(1);
2686 ElementCount SrcEltCnt = Src.getValueType().getVectorElementCount();
2687 unsigned EltBitWidth = Src.getScalarValueSizeInBits();
2688
2689 if (SrcEltCnt.isScalable())
2690 return false;
2691
2692 // Demand the bits from every vector element without a constant index.
2693 unsigned NumSrcElts = SrcEltCnt.getFixedValue();
2694 APInt DemandedSrcElts = APInt::getAllOnes(NumSrcElts);
2695 if (auto *CIdx = dyn_cast<ConstantSDNode>(Idx))
2696 if (CIdx->getAPIntValue().ult(NumSrcElts))
2697 DemandedSrcElts = APInt::getOneBitSet(NumSrcElts, CIdx->getZExtValue());
2698
2699 // If BitWidth > EltBitWidth the value is anyext:ed. So we do not know
2700 // anything about the extended bits.
2701 APInt DemandedSrcBits = DemandedBits;
2702 if (BitWidth > EltBitWidth)
2703 DemandedSrcBits = DemandedSrcBits.trunc(EltBitWidth);
2704
2705 if (SimplifyDemandedBits(Src, DemandedSrcBits, DemandedSrcElts, Known2, TLO,
2706 Depth + 1))
2707 return true;
2708
2709 // Attempt to avoid multi-use ops if we don't need anything from them.
2710 if (!DemandedSrcBits.isAllOnes() || !DemandedSrcElts.isAllOnes()) {
2711 if (SDValue DemandedSrc = SimplifyMultipleUseDemandedBits(
2712 Src, DemandedSrcBits, DemandedSrcElts, TLO.DAG, Depth + 1)) {
2713 SDValue NewOp =
2714 TLO.DAG.getNode(Op.getOpcode(), dl, VT, DemandedSrc, Idx);
2715 return TLO.CombineTo(Op, NewOp);
2716 }
2717 }
2718
2719 Known = Known2;
2720 if (BitWidth > EltBitWidth)
2721 Known = Known.anyext(BitWidth);
2722 break;
2723 }
2724 case ISD::BITCAST: {
2725 if (VT.isScalableVector())
2726 return false;
2727 SDValue Src = Op.getOperand(0);
2728 EVT SrcVT = Src.getValueType();
2729 unsigned NumSrcEltBits = SrcVT.getScalarSizeInBits();
2730
2731 // If this is an FP->Int bitcast and if the sign bit is the only
2732 // thing demanded, turn this into a FGETSIGN.
2733 if (!TLO.LegalOperations() && !VT.isVector() && !SrcVT.isVector() &&
2734 DemandedBits == APInt::getSignMask(Op.getValueSizeInBits()) &&
2735 SrcVT.isFloatingPoint()) {
2736 bool OpVTLegal = isOperationLegalOrCustom(ISD::FGETSIGN, VT);
2737 bool i32Legal = isOperationLegalOrCustom(ISD::FGETSIGN, MVT::i32);
2738 if ((OpVTLegal || i32Legal) && VT.isSimple() && SrcVT != MVT::f16 &&
2739 SrcVT != MVT::f128) {
2740 // Cannot eliminate/lower SHL for f128 yet.
2741 EVT Ty = OpVTLegal ? VT : MVT::i32;
2742 // Make a FGETSIGN + SHL to move the sign bit into the appropriate
2743 // place. We expect the SHL to be eliminated by other optimizations.
2744 SDValue Sign = TLO.DAG.getNode(ISD::FGETSIGN, dl, Ty, Src);
2745 unsigned OpVTSizeInBits = Op.getValueSizeInBits();
2746 if (!OpVTLegal && OpVTSizeInBits > 32)
2747 Sign = TLO.DAG.getNode(ISD::ZERO_EXTEND, dl, VT, Sign);
2748 unsigned ShVal = Op.getValueSizeInBits() - 1;
2749 SDValue ShAmt = TLO.DAG.getConstant(ShVal, dl, VT);
2750 return TLO.CombineTo(Op,
2751 TLO.DAG.getNode(ISD::SHL, dl, VT, Sign, ShAmt));
2752 }
2753 }
2754
2755 // Bitcast from a vector using SimplifyDemanded Bits/VectorElts.
2756 // Demand the elt/bit if any of the original elts/bits are demanded.
2757 if (SrcVT.isVector() && (BitWidth % NumSrcEltBits) == 0) {
2758 unsigned Scale = BitWidth / NumSrcEltBits;
2759 unsigned NumSrcElts = SrcVT.getVectorNumElements();
2760 APInt DemandedSrcBits = APInt::getZero(NumSrcEltBits);
2761 APInt DemandedSrcElts = APInt::getZero(NumSrcElts);
2762 for (unsigned i = 0; i != Scale; ++i) {
2763 unsigned EltOffset = IsLE ? i : (Scale - 1 - i);
2764 unsigned BitOffset = EltOffset * NumSrcEltBits;
2765 APInt Sub = DemandedBits.extractBits(NumSrcEltBits, BitOffset);
2766 if (!Sub.isZero()) {
2767 DemandedSrcBits |= Sub;
2768 for (unsigned j = 0; j != NumElts; ++j)
2769 if (DemandedElts[j])
2770 DemandedSrcElts.setBit((j * Scale) + i);
2771 }
2772 }
2773
2774 APInt KnownSrcUndef, KnownSrcZero;
2775 if (SimplifyDemandedVectorElts(Src, DemandedSrcElts, KnownSrcUndef,
2776 KnownSrcZero, TLO, Depth + 1))
2777 return true;
2778
2779 KnownBits KnownSrcBits;
2780 if (SimplifyDemandedBits(Src, DemandedSrcBits, DemandedSrcElts,
2781 KnownSrcBits, TLO, Depth + 1))
2782 return true;
2783 } else if (IsLE && (NumSrcEltBits % BitWidth) == 0) {
2784 // TODO - bigendian once we have test coverage.
2785 unsigned Scale = NumSrcEltBits / BitWidth;
2786 unsigned NumSrcElts = SrcVT.isVector() ? SrcVT.getVectorNumElements() : 1;
2787 APInt DemandedSrcBits = APInt::getZero(NumSrcEltBits);
2788 APInt DemandedSrcElts = APInt::getZero(NumSrcElts);
2789 for (unsigned i = 0; i != NumElts; ++i)
2790 if (DemandedElts[i]) {
2791 unsigned Offset = (i % Scale) * BitWidth;
2792 DemandedSrcBits.insertBits(DemandedBits, Offset);
2793 DemandedSrcElts.setBit(i / Scale);
2794 }
2795
2796 if (SrcVT.isVector()) {
2797 APInt KnownSrcUndef, KnownSrcZero;
2798 if (SimplifyDemandedVectorElts(Src, DemandedSrcElts, KnownSrcUndef,
2799 KnownSrcZero, TLO, Depth + 1))
2800 return true;
2801 }
2802
2803 KnownBits KnownSrcBits;
2804 if (SimplifyDemandedBits(Src, DemandedSrcBits, DemandedSrcElts,
2805 KnownSrcBits, TLO, Depth + 1))
2806 return true;
2807
2808 // Attempt to avoid multi-use ops if we don't need anything from them.
2809 if (!DemandedSrcBits.isAllOnes() || !DemandedSrcElts.isAllOnes()) {
2810 if (SDValue DemandedSrc = SimplifyMultipleUseDemandedBits(
2811 Src, DemandedSrcBits, DemandedSrcElts, TLO.DAG, Depth + 1)) {
2812 SDValue NewOp = TLO.DAG.getBitcast(VT, DemandedSrc);
2813 return TLO.CombineTo(Op, NewOp);
2814 }
2815 }
2816 }
2817
2818 // If this is a bitcast, let computeKnownBits handle it. Only do this on a
2819 // recursive call where Known may be useful to the caller.
2820 if (Depth > 0) {
2821 Known = TLO.DAG.computeKnownBits(Op, DemandedElts, Depth);
2822 return false;
2823 }
2824 break;
2825 }
2826 case ISD::MUL:
2827 if (DemandedBits.isPowerOf2()) {
2828 // The LSB of X*Y is set only if (X & 1) == 1 and (Y & 1) == 1.
2829 // If we demand exactly one bit N and we have "X * (C' << N)" where C' is
2830 // odd (has LSB set), then the left-shifted low bit of X is the answer.
2831 unsigned CTZ = DemandedBits.countr_zero();
2832 ConstantSDNode *C = isConstOrConstSplat(Op.getOperand(1), DemandedElts);
2833 if (C && C->getAPIntValue().countr_zero() == CTZ) {
2834 SDValue AmtC = TLO.DAG.getShiftAmountConstant(CTZ, VT, dl);
2835 SDValue Shl = TLO.DAG.getNode(ISD::SHL, dl, VT, Op.getOperand(0), AmtC);
2836 return TLO.CombineTo(Op, Shl);
2837 }
2838 }
2839 // For a squared value "X * X", the bottom 2 bits are 0 and X[0] because:
2840 // X * X is odd iff X is odd.
2841 // 'Quadratic Reciprocity': X * X -> 0 for bit[1]
2842 if (Op.getOperand(0) == Op.getOperand(1) && DemandedBits.ult(4)) {
2843 SDValue One = TLO.DAG.getConstant(1, dl, VT);
2844 SDValue And1 = TLO.DAG.getNode(ISD::AND, dl, VT, Op.getOperand(0), One);
2845 return TLO.CombineTo(Op, And1);
2846 }
2847 [[fallthrough]];
2848 case ISD::ADD:
2849 case ISD::SUB: {
2850 // Add, Sub, and Mul don't demand any bits in positions beyond that
2851 // of the highest bit demanded of them.
2852 SDValue Op0 = Op.getOperand(0), Op1 = Op.getOperand(1);
2853 SDNodeFlags Flags = Op.getNode()->getFlags();
2854 unsigned DemandedBitsLZ = DemandedBits.countl_zero();
2855 APInt LoMask = APInt::getLowBitsSet(BitWidth, BitWidth - DemandedBitsLZ);
2856 KnownBits KnownOp0, KnownOp1;
2857 auto GetDemandedBitsLHSMask = [&](APInt Demanded,
2858 const KnownBits &KnownRHS) {
2859 if (Op.getOpcode() == ISD::MUL)
2860 Demanded.clearHighBits(KnownRHS.countMinTrailingZeros());
2861 return Demanded;
2862 };
2863 if (SimplifyDemandedBits(Op1, LoMask, DemandedElts, KnownOp1, TLO,
2864 Depth + 1) ||
2865 SimplifyDemandedBits(Op0, GetDemandedBitsLHSMask(LoMask, KnownOp1),
2866 DemandedElts, KnownOp0, TLO, Depth + 1) ||
2867 // See if the operation should be performed at a smaller bit width.
2868 ShrinkDemandedOp(Op, BitWidth, DemandedBits, TLO)) {
2869 // Disable the nsw and nuw flags. We can no longer guarantee that we
2870 // won't wrap after simplification.
2871 Op->dropFlags(SDNodeFlags::NoWrap);
2872 return true;
2873 }
2874
2875 // neg x with only low bit demanded is simply x.
2876 if (Op.getOpcode() == ISD::SUB && DemandedBits.isOne() &&
2877 isNullConstant(Op0))
2878 return TLO.CombineTo(Op, Op1);
2879
2880 // Attempt to avoid multi-use ops if we don't need anything from them.
2881 if (!LoMask.isAllOnes() || !DemandedElts.isAllOnes()) {
2882 SDValue DemandedOp0 = SimplifyMultipleUseDemandedBits(
2883 Op0, LoMask, DemandedElts, TLO.DAG, Depth + 1);
2884 SDValue DemandedOp1 = SimplifyMultipleUseDemandedBits(
2885 Op1, LoMask, DemandedElts, TLO.DAG, Depth + 1);
2886 if (DemandedOp0 || DemandedOp1) {
2887 Op0 = DemandedOp0 ? DemandedOp0 : Op0;
2888 Op1 = DemandedOp1 ? DemandedOp1 : Op1;
2889 SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), dl, VT, Op0, Op1,
2890 Flags & ~SDNodeFlags::NoWrap);
2891 return TLO.CombineTo(Op, NewOp);
2892 }
2893 }
2894
2895 // If we have a constant operand, we may be able to turn it into -1 if we
2896 // do not demand the high bits. This can make the constant smaller to
2897 // encode, allow more general folding, or match specialized instruction
2898 // patterns (eg, 'blsr' on x86). Don't bother changing 1 to -1 because that
2899 // is probably not useful (and could be detrimental).
2901 APInt HighMask = APInt::getHighBitsSet(BitWidth, DemandedBitsLZ);
2902 if (C && !C->isAllOnes() && !C->isOne() &&
2903 (C->getAPIntValue() | HighMask).isAllOnes()) {
2904 SDValue Neg1 = TLO.DAG.getAllOnesConstant(dl, VT);
2905 // Disable the nsw and nuw flags. We can no longer guarantee that we
2906 // won't wrap after simplification.
2907 SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), dl, VT, Op0, Neg1,
2908 Flags & ~SDNodeFlags::NoWrap);
2909 return TLO.CombineTo(Op, NewOp);
2910 }
2911
2912 // Match a multiply with a disguised negated-power-of-2 and convert to a
2913 // an equivalent shift-left amount.
2914 // Example: (X * MulC) + Op1 --> Op1 - (X << log2(-MulC))
2915 auto getShiftLeftAmt = [&HighMask](SDValue Mul) -> unsigned {
2916 if (Mul.getOpcode() != ISD::MUL || !Mul.hasOneUse())
2917 return 0;
2918
2919 // Don't touch opaque constants. Also, ignore zero and power-of-2
2920 // multiplies. Those will get folded later.
2921 ConstantSDNode *MulC = isConstOrConstSplat(Mul.getOperand(1));
2922 if (MulC && !MulC->isOpaque() && !MulC->isZero() &&
2923 !MulC->getAPIntValue().isPowerOf2()) {
2924 APInt UnmaskedC = MulC->getAPIntValue() | HighMask;
2925 if (UnmaskedC.isNegatedPowerOf2())
2926 return (-UnmaskedC).logBase2();
2927 }
2928 return 0;
2929 };
2930
2931 auto foldMul = [&](ISD::NodeType NT, SDValue X, SDValue Y,
2932 unsigned ShlAmt) {
2933 SDValue ShlAmtC = TLO.DAG.getShiftAmountConstant(ShlAmt, VT, dl);
2934 SDValue Shl = TLO.DAG.getNode(ISD::SHL, dl, VT, X, ShlAmtC);
2935 SDValue Res = TLO.DAG.getNode(NT, dl, VT, Y, Shl);
2936 return TLO.CombineTo(Op, Res);
2937 };
2938
2940 if (Op.getOpcode() == ISD::ADD) {
2941 // (X * MulC) + Op1 --> Op1 - (X << log2(-MulC))
2942 if (unsigned ShAmt = getShiftLeftAmt(Op0))
2943 return foldMul(ISD::SUB, Op0.getOperand(0), Op1, ShAmt);
2944 // Op0 + (X * MulC) --> Op0 - (X << log2(-MulC))
2945 if (unsigned ShAmt = getShiftLeftAmt(Op1))
2946 return foldMul(ISD::SUB, Op1.getOperand(0), Op0, ShAmt);
2947 }
2948 if (Op.getOpcode() == ISD::SUB) {
2949 // Op0 - (X * MulC) --> Op0 + (X << log2(-MulC))
2950 if (unsigned ShAmt = getShiftLeftAmt(Op1))
2951 return foldMul(ISD::ADD, Op1.getOperand(0), Op0, ShAmt);
2952 }
2953 }
2954
2955 if (Op.getOpcode() == ISD::MUL) {
2956 Known = KnownBits::mul(KnownOp0, KnownOp1);
2957 } else { // Op.getOpcode() is either ISD::ADD or ISD::SUB.
2959 Op.getOpcode() == ISD::ADD, Flags.hasNoSignedWrap(),
2960 Flags.hasNoUnsignedWrap(), KnownOp0, KnownOp1);
2961 }
2962 break;
2963 }
2964 default:
2965 // We also ask the target about intrinsics (which could be specific to it).
2966 if (Op.getOpcode() >= ISD::BUILTIN_OP_END ||
2967 Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN) {
2968 // TODO: Probably okay to remove after audit; here to reduce change size
2969 // in initial enablement patch for scalable vectors
2970 if (Op.getValueType().isScalableVector())
2971 break;
2972 if (SimplifyDemandedBitsForTargetNode(Op, DemandedBits, DemandedElts,
2973 Known, TLO, Depth))
2974 return true;
2975 break;
2976 }
2977
2978 // Just use computeKnownBits to compute output bits.
2979 Known = TLO.DAG.computeKnownBits(Op, DemandedElts, Depth);
2980 break;
2981 }
2982
2983 // If we know the value of all of the demanded bits, return this as a
2984 // constant.
2985 if (!isTargetCanonicalConstantNode(Op) &&
2986 DemandedBits.isSubsetOf(Known.Zero | Known.One)) {
2987 // Avoid folding to a constant if any OpaqueConstant is involved.
2988 const SDNode *N = Op.getNode();
2989 for (SDNode *Op :
2991 if (auto *C = dyn_cast<ConstantSDNode>(Op))
2992 if (C->isOpaque())
2993 return false;
2994 }
2995 if (VT.isInteger())
2996 return TLO.CombineTo(Op, TLO.DAG.getConstant(Known.One, dl, VT));
2997 if (VT.isFloatingPoint())
2998 return TLO.CombineTo(
2999 Op, TLO.DAG.getConstantFP(APFloat(VT.getFltSemantics(), Known.One),
3000 dl, VT));
3001 }
3002
3003 // A multi use 'all demanded elts' simplify failed to find any knownbits.
3004 // Try again just for the original demanded elts.
3005 // Ensure we do this AFTER constant folding above.
3006 if (HasMultiUse && Known.isUnknown() && !OriginalDemandedElts.isAllOnes())
3007 Known = TLO.DAG.computeKnownBits(Op, OriginalDemandedElts, Depth);
3008
3009 return false;
3010}
3011
3013 const APInt &DemandedElts,
3014 DAGCombinerInfo &DCI) const {
3015 SelectionDAG &DAG = DCI.DAG;
3016 TargetLoweringOpt TLO(DAG, !DCI.isBeforeLegalize(),
3017 !DCI.isBeforeLegalizeOps());
3018
3019 APInt KnownUndef, KnownZero;
3020 bool Simplified =
3021 SimplifyDemandedVectorElts(Op, DemandedElts, KnownUndef, KnownZero, TLO);
3022 if (Simplified) {
3023 DCI.AddToWorklist(Op.getNode());
3024 DCI.CommitTargetLoweringOpt(TLO);
3025 }
3026
3027 return Simplified;
3028}
3029
3030/// Given a vector binary operation and known undefined elements for each input
3031/// operand, compute whether each element of the output is undefined.
3033 const APInt &UndefOp0,
3034 const APInt &UndefOp1) {
3035 EVT VT = BO.getValueType();
3037 "Vector binop only");
3038
3039 EVT EltVT = VT.getVectorElementType();
3040 unsigned NumElts = VT.isFixedLengthVector() ? VT.getVectorNumElements() : 1;
3041 assert(UndefOp0.getBitWidth() == NumElts &&
3042 UndefOp1.getBitWidth() == NumElts && "Bad type for undef analysis");
3043
3044 auto getUndefOrConstantElt = [&](SDValue V, unsigned Index,
3045 const APInt &UndefVals) {
3046 if (UndefVals[Index])
3047 return DAG.getUNDEF(EltVT);
3048
3049 if (auto *BV = dyn_cast<BuildVectorSDNode>(V)) {
3050 // Try hard to make sure that the getNode() call is not creating temporary
3051 // nodes. Ignore opaque integers because they do not constant fold.
3052 SDValue Elt = BV->getOperand(Index);
3053 auto *C = dyn_cast<ConstantSDNode>(Elt);
3054 if (isa<ConstantFPSDNode>(Elt) || Elt.isUndef() || (C && !C->isOpaque()))
3055 return Elt;
3056 }
3057
3058 return SDValue();
3059 };
3060
3061 APInt KnownUndef = APInt::getZero(NumElts);
3062 for (unsigned i = 0; i != NumElts; ++i) {
3063 // If both inputs for this element are either constant or undef and match
3064 // the element type, compute the constant/undef result for this element of
3065 // the vector.
3066 // TODO: Ideally we would use FoldConstantArithmetic() here, but that does
3067 // not handle FP constants. The code within getNode() should be refactored
3068 // to avoid the danger of creating a bogus temporary node here.
3069 SDValue C0 = getUndefOrConstantElt(BO.getOperand(0), i, UndefOp0);
3070 SDValue C1 = getUndefOrConstantElt(BO.getOperand(1), i, UndefOp1);
3071 if (C0 && C1 && C0.getValueType() == EltVT && C1.getValueType() == EltVT)
3072 if (DAG.getNode(BO.getOpcode(), SDLoc(BO), EltVT, C0, C1).isUndef())
3073 KnownUndef.setBit(i);
3074 }
3075 return KnownUndef;
3076}
3077
3079 SDValue Op, const APInt &OriginalDemandedElts, APInt &KnownUndef,
3080 APInt &KnownZero, TargetLoweringOpt &TLO, unsigned Depth,
3081 bool AssumeSingleUse) const {
3082 EVT VT = Op.getValueType();
3083 unsigned Opcode = Op.getOpcode();
3084 APInt DemandedElts = OriginalDemandedElts;
3085 unsigned NumElts = DemandedElts.getBitWidth();
3086 assert(VT.isVector() && "Expected vector op");
3087
3088 KnownUndef = KnownZero = APInt::getZero(NumElts);
3089
3090 if (!shouldSimplifyDemandedVectorElts(Op, TLO))
3091 return false;
3092
3093 // TODO: For now we assume we know nothing about scalable vectors.
3094 if (VT.isScalableVector())
3095 return false;
3096
3097 assert(VT.getVectorNumElements() == NumElts &&
3098 "Mask size mismatches value type element count!");
3099
3100 // Undef operand.
3101 if (Op.isUndef()) {
3102 KnownUndef.setAllBits();
3103 return false;
3104 }
3105
3106 // If Op has other users, assume that all elements are needed.
3107 if (!AssumeSingleUse && !Op.getNode()->hasOneUse())
3108 DemandedElts.setAllBits();
3109
3110 // Not demanding any elements from Op.
3111 if (DemandedElts == 0) {
3112 KnownUndef.setAllBits();
3113 return TLO.CombineTo(Op, TLO.DAG.getUNDEF(VT));
3114 }
3115
3116 // Limit search depth.
3118 return false;
3119
3120 SDLoc DL(Op);
3121 unsigned EltSizeInBits = VT.getScalarSizeInBits();
3122 bool IsLE = TLO.DAG.getDataLayout().isLittleEndian();
3123
3124 // Helper for demanding the specified elements and all the bits of both binary
3125 // operands.
3126 auto SimplifyDemandedVectorEltsBinOp = [&](SDValue Op0, SDValue Op1) {
3127 SDValue NewOp0 = SimplifyMultipleUseDemandedVectorElts(Op0, DemandedElts,
3128 TLO.DAG, Depth + 1);
3129 SDValue NewOp1 = SimplifyMultipleUseDemandedVectorElts(Op1, DemandedElts,
3130 TLO.DAG, Depth + 1);
3131 if (NewOp0 || NewOp1) {
3132 SDValue NewOp =
3133 TLO.DAG.getNode(Opcode, SDLoc(Op), VT, NewOp0 ? NewOp0 : Op0,
3134 NewOp1 ? NewOp1 : Op1, Op->getFlags());
3135 return TLO.CombineTo(Op, NewOp);
3136 }
3137 return false;
3138 };
3139
3140 switch (Opcode) {
3141 case ISD::SCALAR_TO_VECTOR: {
3142 if (!DemandedElts[0]) {
3143 KnownUndef.setAllBits();
3144 return TLO.CombineTo(Op, TLO.DAG.getUNDEF(VT));
3145 }
3146 SDValue ScalarSrc = Op.getOperand(0);
3147 if (ScalarSrc.getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
3148 SDValue Src = ScalarSrc.getOperand(0);
3149 SDValue Idx = ScalarSrc.getOperand(1);
3150 EVT SrcVT = Src.getValueType();
3151
3152 ElementCount SrcEltCnt = SrcVT.getVectorElementCount();
3153
3154 if (SrcEltCnt.isScalable())
3155 return false;
3156
3157 unsigned NumSrcElts = SrcEltCnt.getFixedValue();
3158 if (isNullConstant(Idx)) {
3159 APInt SrcDemandedElts = APInt::getOneBitSet(NumSrcElts, 0);
3160 APInt SrcUndef = KnownUndef.zextOrTrunc(NumSrcElts);
3161 APInt SrcZero = KnownZero.zextOrTrunc(NumSrcElts);
3162 if (SimplifyDemandedVectorElts(Src, SrcDemandedElts, SrcUndef, SrcZero,
3163 TLO, Depth + 1))
3164 return true;
3165 }
3166 }
3167 KnownUndef.setHighBits(NumElts - 1);
3168 break;
3169 }
3170 case ISD::BITCAST: {
3171 SDValue Src = Op.getOperand(0);
3172 EVT SrcVT = Src.getValueType();
3173
3174 // We only handle vectors here.
3175 // TODO - investigate calling SimplifyDemandedBits/ComputeKnownBits?
3176 if (!SrcVT.isVector())
3177 break;
3178
3179 // Fast handling of 'identity' bitcasts.
3180 unsigned NumSrcElts = SrcVT.getVectorNumElements();
3181 if (NumSrcElts == NumElts)
3182 return SimplifyDemandedVectorElts(Src, DemandedElts, KnownUndef,
3183 KnownZero, TLO, Depth + 1);
3184
3185 APInt SrcDemandedElts, SrcZero, SrcUndef;
3186
3187 // Bitcast from 'large element' src vector to 'small element' vector, we
3188 // must demand a source element if any DemandedElt maps to it.
3189 if ((NumElts % NumSrcElts) == 0) {
3190 unsigned Scale = NumElts / NumSrcElts;
3191 SrcDemandedElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
3192 if (SimplifyDemandedVectorElts(Src, SrcDemandedElts, SrcUndef, SrcZero,
3193 TLO, Depth + 1))
3194 return true;
3195
3196 // Try calling SimplifyDemandedBits, converting demanded elts to the bits
3197 // of the large element.
3198 // TODO - bigendian once we have test coverage.
3199 if (IsLE) {
3200 unsigned SrcEltSizeInBits = SrcVT.getScalarSizeInBits();
3201 APInt SrcDemandedBits = APInt::getZero(SrcEltSizeInBits);
3202 for (unsigned i = 0; i != NumElts; ++i)
3203 if (DemandedElts[i]) {
3204 unsigned Ofs = (i % Scale) * EltSizeInBits;
3205 SrcDemandedBits.setBits(Ofs, Ofs + EltSizeInBits);
3206 }
3207
3208 KnownBits Known;
3209 if (SimplifyDemandedBits(Src, SrcDemandedBits, SrcDemandedElts, Known,
3210 TLO, Depth + 1))
3211 return true;
3212
3213 // The bitcast has split each wide element into a number of
3214 // narrow subelements. We have just computed the Known bits
3215 // for wide elements. See if element splitting results in
3216 // some subelements being zero. Only for demanded elements!
3217 for (unsigned SubElt = 0; SubElt != Scale; ++SubElt) {
3218 if (!Known.Zero.extractBits(EltSizeInBits, SubElt * EltSizeInBits)
3219 .isAllOnes())
3220 continue;
3221 for (unsigned SrcElt = 0; SrcElt != NumSrcElts; ++SrcElt) {
3222 unsigned Elt = Scale * SrcElt + SubElt;
3223 if (DemandedElts[Elt])
3224 KnownZero.setBit(Elt);
3225 }
3226 }
3227 }
3228
3229 // If the src element is zero/undef then all the output elements will be -
3230 // only demanded elements are guaranteed to be correct.
3231 for (unsigned i = 0; i != NumSrcElts; ++i) {
3232 if (SrcDemandedElts[i]) {
3233 if (SrcZero[i])
3234 KnownZero.setBits(i * Scale, (i + 1) * Scale);
3235 if (SrcUndef[i])
3236 KnownUndef.setBits(i * Scale, (i + 1) * Scale);
3237 }
3238 }
3239 }
3240
3241 // Bitcast from 'small element' src vector to 'large element' vector, we
3242 // demand all smaller source elements covered by the larger demanded element
3243 // of this vector.
3244 if ((NumSrcElts % NumElts) == 0) {
3245 unsigned Scale = NumSrcElts / NumElts;
3246 SrcDemandedElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
3247 if (SimplifyDemandedVectorElts(Src, SrcDemandedElts, SrcUndef, SrcZero,
3248 TLO, Depth + 1))
3249 return true;
3250
3251 // If all the src elements covering an output element are zero/undef, then
3252 // the output element will be as well, assuming it was demanded.
3253 for (unsigned i = 0; i != NumElts; ++i) {
3254 if (DemandedElts[i]) {
3255 if (SrcZero.extractBits(Scale, i * Scale).isAllOnes())
3256 KnownZero.setBit(i);
3257 if (SrcUndef.extractBits(Scale, i * Scale).isAllOnes())
3258 KnownUndef.setBit(i);
3259 }
3260 }
3261 }
3262 break;
3263 }
3264 case ISD::FREEZE: {
3265 SDValue N0 = Op.getOperand(0);
3266 if (TLO.DAG.isGuaranteedNotToBeUndefOrPoison(N0, DemandedElts,
3267 /*PoisonOnly=*/false))
3268 return TLO.CombineTo(Op, N0);
3269
3270 // TODO: Replace this with the general fold from DAGCombiner::visitFREEZE
3271 // freeze(op(x, ...)) -> op(freeze(x), ...).
3272 if (N0.getOpcode() == ISD::SCALAR_TO_VECTOR && DemandedElts == 1)
3273 return TLO.CombineTo(
3275 TLO.DAG.getFreeze(N0.getOperand(0))));
3276 break;
3277 }
3278 case ISD::BUILD_VECTOR: {
3279 // Check all elements and simplify any unused elements with UNDEF.
3280 if (!DemandedElts.isAllOnes()) {
3281 // Don't simplify BROADCASTS.
3282 if (llvm::any_of(Op->op_values(),
3283 [&](SDValue Elt) { return Op.getOperand(0) != Elt; })) {
3284 SmallVector<SDValue, 32> Ops(Op->ops());
3285 bool Updated = false;
3286 for (unsigned i = 0; i != NumElts; ++i) {
3287 if (!DemandedElts[i] && !Ops[i].isUndef()) {
3288 Ops[i] = TLO.DAG.getUNDEF(Ops[0].getValueType());
3289 KnownUndef.setBit(i);
3290 Updated = true;
3291 }
3292 }
3293 if (Updated)
3294 return TLO.CombineTo(Op, TLO.DAG.getBuildVector(VT, DL, Ops));
3295 }
3296 }
3297 for (unsigned i = 0; i != NumElts; ++i) {
3298 SDValue SrcOp = Op.getOperand(i);
3299 if (SrcOp.isUndef()) {
3300 KnownUndef.setBit(i);
3301 } else if (EltSizeInBits == SrcOp.getScalarValueSizeInBits() &&
3303 KnownZero.setBit(i);
3304 }
3305 }
3306 break;
3307 }
3308 case ISD::CONCAT_VECTORS: {
3309 EVT SubVT = Op.getOperand(0).getValueType();
3310 unsigned NumSubVecs = Op.getNumOperands();
3311 unsigned NumSubElts = SubVT.getVectorNumElements();
3312 for (unsigned i = 0; i != NumSubVecs; ++i) {
3313 SDValue SubOp = Op.getOperand(i);
3314 APInt SubElts = DemandedElts.extractBits(NumSubElts, i * NumSubElts);
3315 APInt SubUndef, SubZero;
3316 if (SimplifyDemandedVectorElts(SubOp, SubElts, SubUndef, SubZero, TLO,
3317 Depth + 1))
3318 return true;
3319 KnownUndef.insertBits(SubUndef, i * NumSubElts);
3320 KnownZero.insertBits(SubZero, i * NumSubElts);
3321 }
3322
3323 // Attempt to avoid multi-use ops if we don't need anything from them.
3324 if (!DemandedElts.isAllOnes()) {
3325 bool FoundNewSub = false;
3326 SmallVector<SDValue, 2> DemandedSubOps;
3327 for (unsigned i = 0; i != NumSubVecs; ++i) {
3328 SDValue SubOp = Op.getOperand(i);
3329 APInt SubElts = DemandedElts.extractBits(NumSubElts, i * NumSubElts);
3330 SDValue NewSubOp = SimplifyMultipleUseDemandedVectorElts(
3331 SubOp, SubElts, TLO.DAG, Depth + 1);
3332 DemandedSubOps.push_back(NewSubOp ? NewSubOp : SubOp);
3333 FoundNewSub = NewSubOp ? true : FoundNewSub;
3334 }
3335 if (FoundNewSub) {
3336 SDValue NewOp =
3337 TLO.DAG.getNode(Op.getOpcode(), SDLoc(Op), VT, DemandedSubOps);
3338 return TLO.CombineTo(Op, NewOp);
3339 }
3340 }
3341 break;
3342 }
3343 case ISD::INSERT_SUBVECTOR: {
3344 // Demand any elements from the subvector and the remainder from the src its
3345 // inserted into.
3346 SDValue Src = Op.getOperand(0);
3347 SDValue Sub = Op.getOperand(1);
3348 uint64_t Idx = Op.getConstantOperandVal(2);
3349 unsigned NumSubElts = Sub.getValueType().getVectorNumElements();
3350 APInt DemandedSubElts = DemandedElts.extractBits(NumSubElts, Idx);
3351 APInt DemandedSrcElts = DemandedElts;
3352 DemandedSrcElts.insertBits(APInt::getZero(NumSubElts), Idx);
3353
3354 APInt SubUndef, SubZero;
3355 if (SimplifyDemandedVectorElts(Sub, DemandedSubElts, SubUndef, SubZero, TLO,
3356 Depth + 1))
3357 return true;
3358
3359 // If none of the src operand elements are demanded, replace it with undef.
3360 if (!DemandedSrcElts && !Src.isUndef())
3361 return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT,
3362 TLO.DAG.getUNDEF(VT), Sub,
3363 Op.getOperand(2)));
3364
3365 if (SimplifyDemandedVectorElts(Src, DemandedSrcElts, KnownUndef, KnownZero,
3366 TLO, Depth + 1))
3367 return true;
3368 KnownUndef.insertBits(SubUndef, Idx);
3369 KnownZero.insertBits(SubZero, Idx);
3370
3371 // Attempt to avoid multi-use ops if we don't need anything from them.
3372 if (!DemandedSrcElts.isAllOnes() || !DemandedSubElts.isAllOnes()) {
3373 SDValue NewSrc = SimplifyMultipleUseDemandedVectorElts(
3374 Src, DemandedSrcElts, TLO.DAG, Depth + 1);
3375 SDValue NewSub = SimplifyMultipleUseDemandedVectorElts(
3376 Sub, DemandedSubElts, TLO.DAG, Depth + 1);
3377 if (NewSrc || NewSub) {
3378 NewSrc = NewSrc ? NewSrc : Src;
3379 NewSub = NewSub ? NewSub : Sub;
3380 SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), SDLoc(Op), VT, NewSrc,
3381 NewSub, Op.getOperand(2));
3382 return TLO.CombineTo(Op, NewOp);
3383 }
3384 }
3385 break;
3386 }
3388 // Offset the demanded elts by the subvector index.
3389 SDValue Src = Op.getOperand(0);
3390 if (Src.getValueType().isScalableVector())
3391 break;
3392 uint64_t Idx = Op.getConstantOperandVal(1);
3393 unsigned NumSrcElts = Src.getValueType().getVectorNumElements();
3394 APInt DemandedSrcElts = DemandedElts.zext(NumSrcElts).shl(Idx);
3395
3396 APInt SrcUndef, SrcZero;
3397 if (SimplifyDemandedVectorElts(Src, DemandedSrcElts, SrcUndef, SrcZero, TLO,
3398 Depth + 1))
3399 return true;
3400 KnownUndef = SrcUndef.extractBits(NumElts, Idx);
3401 KnownZero = SrcZero.extractBits(NumElts, Idx);
3402
3403 // Attempt to avoid multi-use ops if we don't need anything from them.
3404 if (!DemandedElts.isAllOnes()) {
3405 SDValue NewSrc = SimplifyMultipleUseDemandedVectorElts(
3406 Src, DemandedSrcElts, TLO.DAG, Depth + 1);
3407 if (NewSrc) {
3408 SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), SDLoc(Op), VT, NewSrc,
3409 Op.getOperand(1));
3410 return TLO.CombineTo(Op, NewOp);
3411 }
3412 }
3413 break;
3414 }
3416 SDValue Vec = Op.getOperand(0);
3417 SDValue Scl = Op.getOperand(1);
3418 auto *CIdx = dyn_cast<ConstantSDNode>(Op.getOperand(2));
3419
3420 // For a legal, constant insertion index, if we don't need this insertion
3421 // then strip it, else remove it from the demanded elts.
3422 if (CIdx && CIdx->getAPIntValue().ult(NumElts)) {
3423 unsigned Idx = CIdx->getZExtValue();
3424 if (!DemandedElts[Idx])
3425 return TLO.CombineTo(Op, Vec);
3426
3427 APInt DemandedVecElts(DemandedElts);
3428 DemandedVecElts.clearBit(Idx);
3429 if (SimplifyDemandedVectorElts(Vec, DemandedVecElts, KnownUndef,
3430 KnownZero, TLO, Depth + 1))
3431 return true;
3432
3433 KnownUndef.setBitVal(Idx, Scl.isUndef());
3434
3435 KnownZero.setBitVal(Idx, isNullConstant(Scl) || isNullFPConstant(Scl));
3436 break;
3437 }
3438
3439 APInt VecUndef, VecZero;
3440 if (SimplifyDemandedVectorElts(Vec, DemandedElts, VecUndef, VecZero, TLO,
3441 Depth + 1))
3442 return true;
3443 // Without knowing the insertion index we can't set KnownUndef/KnownZero.
3444 break;
3445 }
3446 case ISD::VSELECT: {
3447 SDValue Sel = Op.getOperand(0);
3448 SDValue LHS = Op.getOperand(1);
3449 SDValue RHS = Op.getOperand(2);
3450
3451 // Try to transform the select condition based on the current demanded
3452 // elements.
3453 APInt UndefSel, ZeroSel;
3454 if (SimplifyDemandedVectorElts(Sel, DemandedElts, UndefSel, ZeroSel, TLO,
3455 Depth + 1))
3456 return true;
3457
3458 // See if we can simplify either vselect operand.
3459 APInt DemandedLHS(DemandedElts);
3460 APInt DemandedRHS(DemandedElts);
3461 APInt UndefLHS, ZeroLHS;
3462 APInt UndefRHS, ZeroRHS;
3463 if (SimplifyDemandedVectorElts(LHS, DemandedLHS, UndefLHS, ZeroLHS, TLO,
3464 Depth + 1))
3465 return true;
3466 if (SimplifyDemandedVectorElts(RHS, DemandedRHS, UndefRHS, ZeroRHS, TLO,
3467 Depth + 1))
3468 return true;
3469
3470 KnownUndef = UndefLHS & UndefRHS;
3471 KnownZero = ZeroLHS & ZeroRHS;
3472
3473 // If we know that the selected element is always zero, we don't need the
3474 // select value element.
3475 APInt DemandedSel = DemandedElts & ~KnownZero;
3476 if (DemandedSel != DemandedElts)
3477 if (SimplifyDemandedVectorElts(Sel, DemandedSel, UndefSel, ZeroSel, TLO,
3478 Depth + 1))
3479 return true;
3480
3481 break;
3482 }
3483 case ISD::VECTOR_SHUFFLE: {
3484 SDValue LHS = Op.getOperand(0);
3485 SDValue RHS = Op.getOperand(1);
3486 ArrayRef<int> ShuffleMask = cast<ShuffleVectorSDNode>(Op)->getMask();
3487
3488 // Collect demanded elements from shuffle operands..
3489 APInt DemandedLHS(NumElts, 0);
3490 APInt DemandedRHS(NumElts, 0);
3491 for (unsigned i = 0; i != NumElts; ++i) {
3492 int M = ShuffleMask[i];
3493 if (M < 0 || !DemandedElts[i])
3494 continue;
3495 assert(0 <= M && M < (int)(2 * NumElts) && "Shuffle index out of range");
3496 if (M < (int)NumElts)
3497 DemandedLHS.setBit(M);
3498 else
3499 DemandedRHS.setBit(M - NumElts);
3500 }
3501
3502 // See if we can simplify either shuffle operand.
3503 APInt UndefLHS, ZeroLHS;
3504 APInt UndefRHS, ZeroRHS;
3505 if (SimplifyDemandedVectorElts(LHS, DemandedLHS, UndefLHS, ZeroLHS, TLO,
3506 Depth + 1))
3507 return true;
3508 if (SimplifyDemandedVectorElts(RHS, DemandedRHS, UndefRHS, ZeroRHS, TLO,
3509 Depth + 1))
3510 return true;
3511
3512 // Simplify mask using undef elements from LHS/RHS.
3513 bool Updated = false;
3514 bool IdentityLHS = true, IdentityRHS = true;
3515 SmallVector<int, 32> NewMask(ShuffleMask);
3516 for (unsigned i = 0; i != NumElts; ++i) {
3517 int &M = NewMask[i];
3518 if (M < 0)
3519 continue;
3520 if (!DemandedElts[i] || (M < (int)NumElts && UndefLHS[M]) ||
3521 (M >= (int)NumElts && UndefRHS[M - NumElts])) {
3522 Updated = true;
3523 M = -1;
3524 }
3525 IdentityLHS &= (M < 0) || (M == (int)i);
3526 IdentityRHS &= (M < 0) || ((M - NumElts) == i);
3527 }
3528
3529 // Update legal shuffle masks based on demanded elements if it won't reduce
3530 // to Identity which can cause premature removal of the shuffle mask.
3531 if (Updated && !IdentityLHS && !IdentityRHS && !TLO.LegalOps) {
3532 SDValue LegalShuffle =
3533 buildLegalVectorShuffle(VT, DL, LHS, RHS, NewMask, TLO.DAG);
3534 if (LegalShuffle)
3535 return TLO.CombineTo(Op, LegalShuffle);
3536 }
3537
3538 // Propagate undef/zero elements from LHS/RHS.
3539 for (unsigned i = 0; i != NumElts; ++i) {
3540 int M = ShuffleMask[i];
3541 if (M < 0) {
3542 KnownUndef.setBit(i);
3543 } else if (M < (int)NumElts) {
3544 if (UndefLHS[M])
3545 KnownUndef.setBit(i);
3546 if (ZeroLHS[M])
3547 KnownZero.setBit(i);
3548 } else {
3549 if (UndefRHS[M - NumElts])
3550 KnownUndef.setBit(i);
3551 if (ZeroRHS[M - NumElts])
3552 KnownZero.setBit(i);
3553 }
3554 }
3555 break;
3556 }
3560 APInt SrcUndef, SrcZero;
3561 SDValue Src = Op.getOperand(0);
3562 unsigned NumSrcElts = Src.getValueType().getVectorNumElements();
3563 APInt DemandedSrcElts = DemandedElts.zext(NumSrcElts);
3564 if (SimplifyDemandedVectorElts(Src, DemandedSrcElts, SrcUndef, SrcZero, TLO,
3565 Depth + 1))
3566 return true;
3567 KnownZero = SrcZero.zextOrTrunc(NumElts);
3568 KnownUndef = SrcUndef.zextOrTrunc(NumElts);
3569
3570 if (IsLE && Op.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG &&
3571 Op.getValueSizeInBits() == Src.getValueSizeInBits() &&
3572 DemandedSrcElts == 1) {
3573 // aext - if we just need the bottom element then we can bitcast.
3574 return TLO.CombineTo(Op, TLO.DAG.getBitcast(VT, Src));
3575 }
3576
3577 if (Op.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG) {
3578 // zext(undef) upper bits are guaranteed to be zero.
3579 if (DemandedElts.isSubsetOf(KnownUndef))
3580 return TLO.CombineTo(Op, TLO.DAG.getConstant(0, SDLoc(Op), VT));
3581 KnownUndef.clearAllBits();
3582
3583 // zext - if we just need the bottom element then we can mask:
3584 // zext(and(x,c)) -> and(x,c') iff the zext is the only user of the and.
3585 if (IsLE && DemandedSrcElts == 1 && Src.getOpcode() == ISD::AND &&
3586 Op->isOnlyUserOf(Src.getNode()) &&
3587 Op.getValueSizeInBits() == Src.getValueSizeInBits()) {
3588 SDLoc DL(Op);
3589 EVT SrcVT = Src.getValueType();
3590 EVT SrcSVT = SrcVT.getScalarType();
3591 SmallVector<SDValue> MaskElts;
3592 MaskElts.push_back(TLO.DAG.getAllOnesConstant(DL, SrcSVT));
3593 MaskElts.append(NumSrcElts - 1, TLO.DAG.getConstant(0, DL, SrcSVT));
3594 SDValue Mask = TLO.DAG.getBuildVector(SrcVT, DL, MaskElts);
3595 if (SDValue Fold = TLO.DAG.FoldConstantArithmetic(
3596 ISD::AND, DL, SrcVT, {Src.getOperand(1), Mask})) {
3597 Fold = TLO.DAG.getNode(ISD::AND, DL, SrcVT, Src.getOperand(0), Fold);
3598 return TLO.CombineTo(Op, TLO.DAG.getBitcast(VT, Fold));
3599 }
3600 }
3601 }
3602 break;
3603 }
3604
3605 // TODO: There are more binop opcodes that could be handled here - MIN,
3606 // MAX, saturated math, etc.
3607 case ISD::ADD: {
3608 SDValue Op0 = Op.getOperand(0);
3609 SDValue Op1 = Op.getOperand(1);
3610 if (Op0 == Op1 && Op->isOnlyUserOf(Op0.getNode())) {
3611 APInt UndefLHS, ZeroLHS;
3612 if (SimplifyDemandedVectorElts(Op0, DemandedElts, UndefLHS, ZeroLHS, TLO,
3613 Depth + 1, /*AssumeSingleUse*/ true))
3614 return true;
3615 }
3616 [[fallthrough]];
3617 }
3618 case ISD::AVGCEILS:
3619 case ISD::AVGCEILU:
3620 case ISD::AVGFLOORS:
3621 case ISD::AVGFLOORU:
3622 case ISD::OR:
3623 case ISD::XOR:
3624 case ISD::SUB:
3625 case ISD::FADD:
3626 case ISD::FSUB:
3627 case ISD::FMUL:
3628 case ISD::FDIV:
3629 case ISD::FREM: {
3630 SDValue Op0 = Op.getOperand(0);
3631 SDValue Op1 = Op.getOperand(1);
3632
3633 APInt UndefRHS, ZeroRHS;
3634 if (SimplifyDemandedVectorElts(Op1, DemandedElts, UndefRHS, ZeroRHS, TLO,
3635 Depth + 1))
3636 return true;
3637 APInt UndefLHS, ZeroLHS;
3638 if (SimplifyDemandedVectorElts(Op0, DemandedElts, UndefLHS, ZeroLHS, TLO,
3639 Depth + 1))
3640 return true;
3641
3642 KnownZero = ZeroLHS & ZeroRHS;
3643 KnownUndef = getKnownUndefForVectorBinop(Op, TLO.DAG, UndefLHS, UndefRHS);
3644
3645 // Attempt to avoid multi-use ops if we don't need anything from them.
3646 // TODO - use KnownUndef to relax the demandedelts?
3647 if (!DemandedElts.isAllOnes())
3648 if (SimplifyDemandedVectorEltsBinOp(Op0, Op1))
3649 return true;
3650 break;
3651 }
3652 case ISD::SHL:
3653 case ISD::SRL:
3654 case ISD::SRA:
3655 case ISD::ROTL:
3656 case ISD::ROTR: {
3657 SDValue Op0 = Op.getOperand(0);
3658 SDValue Op1 = Op.getOperand(1);
3659
3660 APInt UndefRHS, ZeroRHS;
3661 if (SimplifyDemandedVectorElts(Op1, DemandedElts, UndefRHS, ZeroRHS, TLO,
3662 Depth + 1))
3663 return true;
3664 APInt UndefLHS, ZeroLHS;
3665 if (SimplifyDemandedVectorElts(Op0, DemandedElts, UndefLHS, ZeroLHS, TLO,
3666 Depth + 1))
3667 return true;
3668
3669 KnownZero = ZeroLHS;
3670 KnownUndef = UndefLHS & UndefRHS; // TODO: use getKnownUndefForVectorBinop?
3671
3672 // Attempt to avoid multi-use ops if we don't need anything from them.
3673 // TODO - use KnownUndef to relax the demandedelts?
3674 if (!DemandedElts.isAllOnes())
3675 if (SimplifyDemandedVectorEltsBinOp(Op0, Op1))
3676 return true;
3677 break;
3678 }
3679 case ISD::MUL:
3680 case ISD::MULHU:
3681 case ISD::MULHS:
3682 case ISD::AND: {
3683 SDValue Op0 = Op.getOperand(0);
3684 SDValue Op1 = Op.getOperand(1);
3685
3686 APInt SrcUndef, SrcZero;
3687 if (SimplifyDemandedVectorElts(Op1, DemandedElts, SrcUndef, SrcZero, TLO,
3688 Depth + 1))
3689 return true;
3690 // If we know that a demanded element was zero in Op1 we don't need to
3691 // demand it in Op0 - its guaranteed to be zero.
3692 APInt DemandedElts0 = DemandedElts & ~SrcZero;
3693 if (SimplifyDemandedVectorElts(Op0, DemandedElts0, KnownUndef, KnownZero,
3694 TLO, Depth + 1))
3695 return true;
3696
3697 KnownUndef &= DemandedElts0;
3698 KnownZero &= DemandedElts0;
3699
3700 // If every element pair has a zero/undef then just fold to zero.
3701 // fold (and x, undef) -> 0 / (and x, 0) -> 0
3702 // fold (mul x, undef) -> 0 / (mul x, 0) -> 0
3703 if (DemandedElts.isSubsetOf(SrcZero | KnownZero | SrcUndef | KnownUndef))
3704 return TLO.CombineTo(Op, TLO.DAG.getConstant(0, SDLoc(Op), VT));
3705
3706 // If either side has a zero element, then the result element is zero, even
3707 // if the other is an UNDEF.
3708 // TODO: Extend getKnownUndefForVectorBinop to also deal with known zeros
3709 // and then handle 'and' nodes with the rest of the binop opcodes.
3710 KnownZero |= SrcZero;
3711 KnownUndef &= SrcUndef;
3712 KnownUndef &= ~KnownZero;
3713
3714 // Attempt to avoid multi-use ops if we don't need anything from them.
3715 if (!DemandedElts.isAllOnes())
3716 if (SimplifyDemandedVectorEltsBinOp(Op0, Op1))
3717 return true;
3718 break;
3719 }
3720 case ISD::TRUNCATE:
3721 case ISD::SIGN_EXTEND:
3722 case ISD::ZERO_EXTEND:
3723 if (SimplifyDemandedVectorElts(Op.getOperand(0), DemandedElts, KnownUndef,
3724 KnownZero, TLO, Depth + 1))
3725 return true;
3726
3727 if (!DemandedElts.isAllOnes())
3728 if (SDValue NewOp = SimplifyMultipleUseDemandedVectorElts(
3729 Op.getOperand(0), DemandedElts, TLO.DAG, Depth + 1))
3730 return TLO.CombineTo(Op, TLO.DAG.getNode(Opcode, SDLoc(Op), VT, NewOp));
3731
3732 if (Op.getOpcode() == ISD::ZERO_EXTEND) {
3733 // zext(undef) upper bits are guaranteed to be zero.
3734 if (DemandedElts.isSubsetOf(KnownUndef))
3735 return TLO.CombineTo(Op, TLO.DAG.getConstant(0, SDLoc(Op), VT));
3736 KnownUndef.clearAllBits();
3737 }
3738 break;
3739 case ISD::SINT_TO_FP:
3740 case ISD::UINT_TO_FP:
3741 case ISD::FP_TO_SINT:
3742 case ISD::FP_TO_UINT:
3743 if (SimplifyDemandedVectorElts(Op.getOperand(0), DemandedElts, KnownUndef,
3744 KnownZero, TLO, Depth + 1))
3745 return true;
3746 // Don't fall through to generic undef -> undef handling.
3747 return false;
3748 default: {
3749 if (Op.getOpcode() >= ISD::BUILTIN_OP_END) {
3750 if (SimplifyDemandedVectorEltsForTargetNode(Op, DemandedElts, KnownUndef,
3751 KnownZero, TLO, Depth))
3752 return true;
3753 } else {
3754 KnownBits Known;
3755 APInt DemandedBits = APInt::getAllOnes(EltSizeInBits);
3756 if (SimplifyDemandedBits(Op, DemandedBits, OriginalDemandedElts, Known,
3757 TLO, Depth, AssumeSingleUse))
3758 return true;
3759 }
3760 break;
3761 }
3762 }
3763 assert((KnownUndef & KnownZero) == 0 && "Elements flagged as undef AND zero");
3764
3765 // Constant fold all undef cases.
3766 // TODO: Handle zero cases as well.
3767 if (DemandedElts.isSubsetOf(KnownUndef))
3768 return TLO.CombineTo(Op, TLO.DAG.getUNDEF(VT));
3769
3770 return false;
3771}
3772
3773/// Determine which of the bits specified in Mask are known to be either zero or
3774/// one and return them in the Known.
3776 KnownBits &Known,
3777 const APInt &DemandedElts,
3778 const SelectionDAG &DAG,
3779 unsigned Depth) const {
3780 assert((Op.getOpcode() >= ISD::BUILTIN_OP_END ||
3781 Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
3782 Op.getOpcode() == ISD::INTRINSIC_W_CHAIN ||
3783 Op.getOpcode() == ISD::INTRINSIC_VOID) &&
3784 "Should use MaskedValueIsZero if you don't know whether Op"
3785 " is a target node!");
3786 Known.resetAll();
3787}
3788
3791 const APInt &DemandedElts, const MachineRegisterInfo &MRI,
3792 unsigned Depth) const {
3793 Known.resetAll();
3794}
3795
3797 const int FrameIdx, KnownBits &Known, const MachineFunction &MF) const {
3798 // The low bits are known zero if the pointer is aligned.
3799 Known.Zero.setLowBits(Log2(MF.getFrameInfo().getObjectAlign(FrameIdx)));
3800}
3801
3804 unsigned Depth) const {
3805 return Align(1);
3806}
3807
3808/// This method can be implemented by targets that want to expose additional
3809/// information about sign bits to the DAG Combiner.
3811 const APInt &,
3812 const SelectionDAG &,
3813 unsigned Depth) const {
3814 assert((Op.getOpcode() >= ISD::BUILTIN_OP_END ||
3815 Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
3816 Op.getOpcode() == ISD::INTRINSIC_W_CHAIN ||
3817 Op.getOpcode() == ISD::INTRINSIC_VOID) &&
3818 "Should use ComputeNumSignBits if you don't know whether Op"
3819 " is a target node!");
3820 return 1;
3821}
3822
3824 GISelKnownBits &Analysis, Register R, const APInt &DemandedElts,
3825 const MachineRegisterInfo &MRI, unsigned Depth) const {
3826 return 1;
3827}
3828
3830 SDValue Op, const APInt &DemandedElts, APInt &KnownUndef, APInt &KnownZero,
3831 TargetLoweringOpt &TLO, unsigned Depth) const {
3832 assert((Op.getOpcode() >= ISD::BUILTIN_OP_END ||
3833 Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
3834 Op.getOpcode() == ISD::INTRINSIC_W_CHAIN ||
3835 Op.getOpcode() == ISD::INTRINSIC_VOID) &&
3836 "Should use SimplifyDemandedVectorElts if you don't know whether Op"
3837 " is a target node!");
3838 return false;
3839}
3840
3842 SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
3843 KnownBits &Known, TargetLoweringOpt &TLO, unsigned Depth) const {
3844 assert((Op.getOpcode() >= ISD::BUILTIN_OP_END ||
3845 Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
3846 Op.getOpcode() == ISD::INTRINSIC_W_CHAIN ||
3847 Op.getOpcode() == ISD::INTRINSIC_VOID) &&
3848 "Should use SimplifyDemandedBits if you don't know whether Op"
3849 " is a target node!");
3850 computeKnownBitsForTargetNode(Op, Known, DemandedElts, TLO.DAG, Depth);
3851 return false;
3852}
3853
3855 SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
3856 SelectionDAG &DAG, unsigned Depth) const {
3857 assert(
3858 (Op.getOpcode() >= ISD::BUILTIN_OP_END ||
3859 Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
3860 Op.getOpcode() == ISD::INTRINSIC_W_CHAIN ||
3861 Op.getOpcode() == ISD::INTRINSIC_VOID) &&
3862 "Should use SimplifyMultipleUseDemandedBits if you don't know whether Op"
3863 " is a target node!");
3864 return SDValue();
3865}
3866
3867SDValue
3870 SelectionDAG &DAG) const {
3871 bool LegalMask = isShuffleMaskLegal(Mask, VT);
3872 if (!LegalMask) {
3873 std::swap(N0, N1);
3875 LegalMask = isShuffleMaskLegal(Mask, VT);
3876 }
3877
3878 if (!LegalMask)
3879 return SDValue();
3880
3881 return DAG.getVectorShuffle(VT, DL, N0, N1, Mask);
3882}
3883
3885 return nullptr;
3886}
3887
3889 SDValue Op, const APInt &DemandedElts, const SelectionDAG &DAG,
3890 bool PoisonOnly, unsigned Depth) const {
3891 assert(
3892 (Op.getOpcode() >= ISD::BUILTIN_OP_END ||
3893 Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
3894 Op.getOpcode() == ISD::INTRINSIC_W_CHAIN ||
3895 Op.getOpcode() == ISD::INTRINSIC_VOID) &&
3896 "Should use isGuaranteedNotToBeUndefOrPoison if you don't know whether Op"
3897 " is a target node!");
3898
3899 // If Op can't create undef/poison and none of its operands are undef/poison
3900 // then Op is never undef/poison.
3901 return !canCreateUndefOrPoisonForTargetNode(Op, DemandedElts, DAG, PoisonOnly,
3902 /*ConsiderFlags*/ true, Depth) &&
3903 all_of(Op->ops(), [&](SDValue V) {
3904 return DAG.isGuaranteedNotToBeUndefOrPoison(V, PoisonOnly,
3905 Depth + 1);
3906 });
3907}
3908
3910 SDValue Op, const APInt &DemandedElts, const SelectionDAG &DAG,
3911 bool PoisonOnly, bool ConsiderFlags, unsigned Depth) const {
3912 assert((Op.getOpcode() >= ISD::BUILTIN_OP_END ||
3913 Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
3914 Op.getOpcode() == ISD::INTRINSIC_W_CHAIN ||
3915 Op.getOpcode() == ISD::INTRINSIC_VOID) &&
3916 "Should use canCreateUndefOrPoison if you don't know whether Op"
3917 " is a target node!");
3918 // Be conservative and return true.
3919 return true;
3920}
3921
3923 const SelectionDAG &DAG,
3924 bool SNaN,
3925 unsigned Depth) const {
3926 assert((Op.getOpcode() >= ISD::BUILTIN_OP_END ||
3927 Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
3928 Op.getOpcode() == ISD::INTRINSIC_W_CHAIN ||
3929 Op.getOpcode() == ISD::INTRINSIC_VOID) &&
3930 "Should use isKnownNeverNaN if you don't know whether Op"
3931 " is a target node!");
3932 return false;
3933}
3934
3936 const APInt &DemandedElts,
3937 APInt &UndefElts,
3938 const SelectionDAG &DAG,
3939 unsigned Depth) const {
3940 assert((Op.getOpcode() >= ISD::BUILTIN_OP_END ||
3941 Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
3942 Op.getOpcode() == ISD::INTRINSIC_W_CHAIN ||
3943 Op.getOpcode() == ISD::INTRINSIC_VOID) &&
3944 "Should use isSplatValue if you don't know whether Op"
3945 " is a target node!");
3946 return false;
3947}
3948
3949// FIXME: Ideally, this would use ISD::isConstantSplatVector(), but that must
3950// work with truncating build vectors and vectors with elements of less than
3951// 8 bits.
3953 if (!N)
3954 return false;
3955
3956 unsigned EltWidth;
3957 APInt CVal;
3958 if (ConstantSDNode *CN = isConstOrConstSplat(N, /*AllowUndefs=*/false,
3959 /*AllowTruncation=*/true)) {
3960 CVal = CN->getAPIntValue();
3961 EltWidth = N.getValueType().getScalarSizeInBits();
3962 } else
3963 return false;
3964
3965 // If this is a truncating splat, truncate the splat value.
3966 // Otherwise, we may fail to match the expected values below.
3967 if (EltWidth < CVal.getBitWidth())
3968 CVal = CVal.trunc(EltWidth);
3969
3970 switch (getBooleanContents(N.getValueType())) {
3972 return CVal[0];
3974 return CVal.isOne();
3976 return CVal.isAllOnes();
3977 }
3978
3979 llvm_unreachable("Invalid boolean contents");
3980}
3981
3983 if (!N)
3984 return false;
3985
3986 const ConstantSDNode *CN = dyn_cast<ConstantSDNode>(N);
3987 if (!CN) {
3988 const BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(N);
3989 if (!BV)
3990 return false;
3991
3992 // Only interested in constant splats, we don't care about undef
3993 // elements in identifying boolean constants and getConstantSplatNode
3994 // returns NULL if all ops are undef;
3995 CN = BV->getConstantSplatNode();
3996 if (!CN)
3997 return false;
3998 }
3999
4000 if (getBooleanContents(N->getValueType(0)) == UndefinedBooleanContent)
4001 return !CN->getAPIntValue()[0];
4002
4003 return CN->isZero();
4004}
4005
4007 bool SExt) const {
4008 if (VT == MVT::i1)
4009 return N->isOne();
4010
4012 switch (Cnt) {
4014 // An extended value of 1 is always true, unless its original type is i1,
4015 // in which case it will be sign extended to -1.
4016 return (N->isOne() && !SExt) || (SExt && (N->getValueType(0) != MVT::i1));
4019 return N->isAllOnes() && SExt;
4020 }
4021 llvm_unreachable("Unexpected enumeration.");
4022}
4023
4024/// This helper function of SimplifySetCC tries to optimize the comparison when
4025/// either operand of the SetCC node is a bitwise-and instruction.
4026SDValue TargetLowering::foldSetCCWithAnd(EVT VT, SDValue N0, SDValue N1,
4027 ISD::CondCode Cond, const SDLoc &DL,
4028 DAGCombinerInfo &DCI) const {
4029 if (N1.getOpcode() == ISD::AND && N0.getOpcode() != ISD::AND)
4030 std::swap(N0, N1);
4031
4032 SelectionDAG &DAG = DCI.DAG;
4033 EVT OpVT = N0.getValueType();
4034 if (N0.getOpcode() != ISD::AND || !OpVT.isInteger() ||
4035 (Cond != ISD::SETEQ && Cond != ISD::SETNE))
4036 return SDValue();
4037
4038 // (X & Y) != 0 --> zextOrTrunc(X & Y)
4039 // iff everything but LSB is known zero:
4040 if (Cond == ISD::SETNE && isNullConstant(N1) &&
4043 unsigned NumEltBits = OpVT.getScalarSizeInBits();
4044 APInt UpperBits = APInt::getHighBitsSet(NumEltBits, NumEltBits - 1);
4045 if (DAG.MaskedValueIsZero(N0, UpperBits))
4046 return DAG.getBoolExtOrTrunc(N0, DL, VT, OpVT);
4047 }
4048
4049 // Try to eliminate a power-of-2 mask constant by converting to a signbit
4050 // test in a narrow type that we can truncate to with no cost. Examples:
4051 // (i32 X & 32768) == 0 --> (trunc X to i16) >= 0
4052 // (i32 X & 32768) != 0 --> (trunc X to i16) < 0
4053 // TODO: This conservatively checks for type legality on the source and
4054 // destination types. That may inhibit optimizations, but it also
4055 // allows setcc->shift transforms that may be more beneficial.
4056 auto *AndC = dyn_cast<ConstantSDNode>(N0.getOperand(1));
4057 if (AndC && isNullConstant(N1) && AndC->getAPIntValue().isPowerOf2() &&
4058 isTypeLegal(OpVT) && N0.hasOneUse()) {
4059 EVT NarrowVT = EVT::getIntegerVT(*DAG.getContext(),
4060 AndC->getAPIntValue().getActiveBits());
4061 if (isTruncateFree(OpVT, NarrowVT) && isTypeLegal(NarrowVT)) {
4062 SDValue Trunc = DAG.getZExtOrTrunc(N0.getOperand(0), DL, NarrowVT);
4063 SDValue Zero = DAG.getConstant(0, DL, NarrowVT);
4064 return DAG.getSetCC(DL, VT, Trunc, Zero,
4066 }
4067 }
4068
4069 // Match these patterns in any of their permutations:
4070 // (X & Y) == Y
4071 // (X & Y) != Y
4072 SDValue X, Y;
4073 if (N0.getOperand(0) == N1) {
4074 X = N0.getOperand(1);
4075 Y = N0.getOperand(0);
4076 } else if (N0.getOperand(1) == N1) {
4077 X = N0.getOperand(0);
4078 Y = N0.getOperand(1);
4079 } else {
4080 return SDValue();
4081 }
4082
4083 // TODO: We should invert (X & Y) eq/ne 0 -> (X & Y) ne/eq Y if
4084 // `isXAndYEqZeroPreferableToXAndYEqY` is false. This is a bit difficult as
4085 // its liable to create and infinite loop.
4086 SDValue Zero = DAG.getConstant(0, DL, OpVT);
4087 if (isXAndYEqZeroPreferableToXAndYEqY(Cond, OpVT) &&
4089 // Simplify X & Y == Y to X & Y != 0 if Y has exactly one bit set.
4090 // Note that where Y is variable and is known to have at most one bit set
4091 // (for example, if it is Z & 1) we cannot do this; the expressions are not
4092 // equivalent when Y == 0.
4093 assert(OpVT.isInteger());
4095 if (DCI.isBeforeLegalizeOps() ||
4097 return DAG.getSetCC(DL, VT, N0, Zero, Cond);
4098 } else if (N0.hasOneUse() && hasAndNotCompare(Y)) {
4099 // If the target supports an 'and-not' or 'and-complement' logic operation,
4100 // try to use that to make a comparison operation more efficient.
4101 // But don't do this transform if the mask is a single bit because there are
4102 // more efficient ways to deal with that case (for example, 'bt' on x86 or
4103 // 'rlwinm' on PPC).
4104
4105 // Bail out if the compare operand that we want to turn into a zero is
4106 // already a zero (otherwise, infinite loop).
4107 if (isNullConstant(Y))
4108 return SDValue();
4109
4110 // Transform this into: ~X & Y == 0.
4111 SDValue NotX = DAG.getNOT(SDLoc(X), X, OpVT);
4112 SDValue NewAnd = DAG.getNode(ISD::AND, SDLoc(N0), OpVT, NotX, Y);
4113 return DAG.getSetCC(DL, VT, NewAnd, Zero, Cond);
4114 }
4115
4116 return SDValue();
4117}
4118
4119/// There are multiple IR patterns that could be checking whether certain
4120/// truncation of a signed number would be lossy or not. The pattern which is
4121/// best at IR level, may not lower optimally. Thus, we want to unfold it.
4122/// We are looking for the following pattern: (KeptBits is a constant)
4123/// (add %x, (1 << (KeptBits-1))) srccond (1 << KeptBits)
4124/// KeptBits won't be bitwidth(x), that will be constant-folded to true/false.
4125/// KeptBits also can't be 1, that would have been folded to %x dstcond 0
4126/// We will unfold it into the natural trunc+sext pattern:
4127/// ((%x << C) a>> C) dstcond %x
4128/// Where C = bitwidth(x) - KeptBits and C u< bitwidth(x)
4129SDValue TargetLowering::optimizeSetCCOfSignedTruncationCheck(
4130 EVT SCCVT, SDValue N0, SDValue N1, ISD::CondCode Cond, DAGCombinerInfo &DCI,
4131 const SDLoc &DL) const {
4132 // We must be comparing with a constant.
4133 ConstantSDNode *C1;
4134 if (!(C1 = dyn_cast<ConstantSDNode>(N1)))
4135 return SDValue();
4136
4137 // N0 should be: add %x, (1 << (KeptBits-1))
4138 if (N0->getOpcode() != ISD::ADD)
4139 return SDValue();
4140
4141 // And we must be 'add'ing a constant.
4142 ConstantSDNode *C01;
4143 if (!(C01 = dyn_cast<ConstantSDNode>(N0->getOperand(1))))
4144 return SDValue();
4145
4146 SDValue X = N0->getOperand(0);
4147 EVT XVT = X.getValueType();
4148
4149 // Validate constants ...
4150
4151 APInt I1 = C1->getAPIntValue();
4152
4153 ISD::CondCode NewCond;
4154 if (Cond == ISD::CondCode::SETULT) {
4155 NewCond = ISD::CondCode::SETEQ;
4156 } else if (Cond == ISD::CondCode::SETULE) {
4157 NewCond = ISD::CondCode::SETEQ;
4158 // But need to 'canonicalize' the constant.
4159 I1 += 1;
4160 } else if (Cond == ISD::CondCode::SETUGT) {
4161 NewCond = ISD::CondCode::SETNE;
4162 // But need to 'canonicalize' the constant.
4163 I1 += 1;
4164 } else if (Cond == ISD::CondCode::SETUGE) {
4165 NewCond = ISD::CondCode::SETNE;
4166 } else
4167 return SDValue();
4168
4169 APInt I01 = C01->getAPIntValue();
4170
4171 auto checkConstants = [&I1, &I01]() -> bool {
4172 // Both of them must be power-of-two, and the constant from setcc is bigger.
4173 return I1.ugt(I01) && I1.isPowerOf2() && I01.isPowerOf2();
4174 };
4175
4176 if (checkConstants()) {
4177 // Great, e.g. got icmp ult i16 (add i16 %x, 128), 256
4178 } else {
4179 // What if we invert constants? (and the target predicate)
4180 I1.negate();
4181 I01.negate();
4182 assert(XVT.isInteger());
4183 NewCond = getSetCCInverse(NewCond, XVT);
4184 if (!checkConstants())
4185 return SDValue();
4186 // Great, e.g. got icmp uge i16 (add i16 %x, -128), -256
4187 }
4188
4189 // They are power-of-two, so which bit is set?
4190 const unsigned KeptBits = I1.logBase2();
4191 const unsigned KeptBitsMinusOne = I01.logBase2();
4192
4193 // Magic!
4194 if (KeptBits != (KeptBitsMinusOne + 1))
4195 return SDValue();
4196 assert(KeptBits > 0 && KeptBits < XVT.getSizeInBits() && "unreachable");
4197
4198 // We don't want to do this in every single case.
4199 SelectionDAG &DAG = DCI.DAG;
4200 if (!shouldTransformSignedTruncationCheck(XVT, KeptBits))
4201 return SDValue();
4202
4203 // Unfold into: sext_inreg(%x) cond %x
4204 // Where 'cond' will be either 'eq' or 'ne'.
4205 SDValue SExtInReg = DAG.getNode(
4207 DAG.getValueType(EVT::getIntegerVT(*DAG.getContext(), KeptBits)));
4208 return DAG.getSetCC(DL, SCCVT, SExtInReg, X, NewCond);
4209}
4210
4211// (X & (C l>>/<< Y)) ==/!= 0 --> ((X <</l>> Y) & C) ==/!= 0
4212SDValue TargetLowering::optimizeSetCCByHoistingAndByConstFromLogicalShift(
4213 EVT SCCVT, SDValue N0, SDValue N1C, ISD::CondCode Cond,
4214 DAGCombinerInfo &DCI, const SDLoc &DL) const {
4216 "Should be a comparison with 0.");
4217 assert((Cond == ISD::SETEQ || Cond == ISD::SETNE) &&
4218 "Valid only for [in]equality comparisons.");
4219
4220 unsigned NewShiftOpcode;
4221 SDValue X, C, Y;
4222
4223 SelectionDAG &DAG = DCI.DAG;
4224
4225 // Look for '(C l>>/<< Y)'.
4226 auto Match = [&NewShiftOpcode, &X, &C, &Y, &DAG, this](SDValue V) {
4227 // The shift should be one-use.
4228 if (!V.hasOneUse())
4229 return false;
4230 unsigned OldShiftOpcode = V.getOpcode();
4231 switch (OldShiftOpcode) {
4232 case ISD::SHL:
4233 NewShiftOpcode = ISD::SRL;
4234 break;
4235 case ISD::SRL:
4236 NewShiftOpcode = ISD::SHL;
4237 break;
4238 default:
4239 return false; // must be a logical shift.
4240 }
4241 // We should be shifting a constant.
4242 // FIXME: best to use isConstantOrConstantVector().
4243 C = V.getOperand(0);
4245 isConstOrConstSplat(C, /*AllowUndefs=*/true, /*AllowTruncation=*/true);
4246 if (!CC)
4247 return false;
4248 Y = V.getOperand(1);
4249
4251 isConstOrConstSplat(X, /*AllowUndefs=*/true, /*AllowTruncation=*/true);
4253 X, XC, CC, Y, OldShiftOpcode, NewShiftOpcode, DAG);
4254 };
4255
4256 // LHS of comparison should be an one-use 'and'.
4257 if (N0.getOpcode() != ISD::AND || !N0.hasOneUse())
4258 return SDValue();
4259
4260 X = N0.getOperand(0);
4261 SDValue Mask = N0.getOperand(1);
4262
4263 // 'and' is commutative!
4264 if (!Match(Mask)) {
4265 std::swap(X, Mask);
4266 if (!Match(Mask))
4267 return SDValue();
4268 }
4269
4270 EVT VT = X.getValueType();
4271
4272 // Produce:
4273 // ((X 'OppositeShiftOpcode' Y) & C) Cond 0
4274 SDValue T0 = DAG.getNode(NewShiftOpcode, DL, VT, X, Y);
4275 SDValue T1 = DAG.getNode(ISD::AND, DL, VT, T0, C);
4276 SDValue T2 = DAG.getSetCC(DL, SCCVT, T1, N1C, Cond);
4277 return T2;
4278}
4279
4280/// Try to fold an equality comparison with a {add/sub/xor} binary operation as
4281/// the 1st operand (N0). Callers are expected to swap the N0/N1 parameters to
4282/// handle the commuted versions of these patterns.
4283SDValue TargetLowering::foldSetCCWithBinOp(EVT VT, SDValue N0, SDValue N1,
4284 ISD::CondCode Cond, const SDLoc &DL,
4285 DAGCombinerInfo &DCI) const {
4286 unsigned BOpcode = N0.getOpcode();
4287 assert((BOpcode == ISD::ADD || BOpcode == ISD::SUB || BOpcode == ISD::XOR) &&
4288 "Unexpected binop");
4289 assert((Cond == ISD::SETEQ || Cond == ISD::SETNE) && "Unexpected condcode");
4290
4291 // (X + Y) == X --> Y == 0
4292 // (X - Y) == X --> Y == 0
4293 // (X ^ Y) == X --> Y == 0
4294 SelectionDAG &DAG = DCI.DAG;
4295 EVT OpVT = N0.getValueType();
4296 SDValue X = N0.getOperand(0);
4297 SDValue Y = N0.getOperand(1);
4298 if (X == N1)
4299 return DAG.getSetCC(DL, VT, Y, DAG.getConstant(0, DL, OpVT), Cond);
4300
4301 if (Y != N1)
4302 return SDValue();
4303
4304 // (X + Y) == Y --> X == 0
4305 // (X ^ Y) == Y --> X == 0
4306 if (BOpcode == ISD::ADD || BOpcode == ISD::XOR)
4307 return DAG.getSetCC(DL, VT, X, DAG.getConstant(0, DL, OpVT), Cond);
4308
4309 // The shift would not be valid if the operands are boolean (i1).
4310 if (!N0.hasOneUse() || OpVT.getScalarSizeInBits() == 1)
4311 return SDValue();
4312
4313 // (X - Y) == Y --> X == Y << 1
4314 SDValue One = DAG.getShiftAmountConstant(1, OpVT, DL);
4315 SDValue YShl1 = DAG.getNode(ISD::SHL, DL, N1.getValueType(), Y, One);
4316 if (!DCI.isCalledByLegalizer())
4317 DCI.AddToWorklist(YShl1.getNode());
4318 return DAG.getSetCC(DL, VT, X, YShl1, Cond);
4319}
4320
4322 SDValue N0, const APInt &C1,
4323 ISD::CondCode Cond, const SDLoc &dl,
4324 SelectionDAG &DAG) {
4325 // Look through truncs that don't change the value of a ctpop.
4326 // FIXME: Add vector support? Need to be careful with setcc result type below.
4327 SDValue CTPOP = N0;
4328 if (N0.getOpcode() == ISD::TRUNCATE && N0.hasOneUse() && !VT.isVector() &&
4330 CTPOP = N0.getOperand(0);
4331
4332 if (CTPOP.getOpcode() != ISD::CTPOP || !CTPOP.hasOneUse())
4333 return SDValue();
4334
4335 EVT CTVT = CTPOP.getValueType();
4336 SDValue CTOp = CTPOP.getOperand(0);
4337
4338 // Expand a power-of-2-or-zero comparison based on ctpop:
4339 // (ctpop x) u< 2 -> (x & x-1) == 0
4340 // (ctpop x) u> 1 -> (x & x-1) != 0
4341 if (Cond == ISD::SETULT || Cond == ISD::SETUGT) {
4342 // Keep the CTPOP if it is a cheap vector op.
4343 if (CTVT.isVector() && TLI.isCtpopFast(CTVT))
4344 return SDValue();
4345
4346 unsigned CostLimit = TLI.getCustomCtpopCost(CTVT, Cond);
4347 if (C1.ugt(CostLimit + (Cond == ISD::SETULT)))
4348 return SDValue();
4349 if (C1 == 0 && (Cond == ISD::SETULT))
4350 return SDValue(); // This is handled elsewhere.
4351
4352 unsigned Passes = C1.getLimitedValue() - (Cond == ISD::SETULT);
4353
4354 SDValue NegOne = DAG.getAllOnesConstant(dl, CTVT);
4355 SDValue Result = CTOp;
4356 for (unsigned i = 0; i < Passes; i++) {
4357 SDValue Add = DAG.getNode(ISD::ADD, dl, CTVT, Result, NegOne);
4358 Result = DAG.getNode(ISD::AND, dl, CTVT, Result, Add);
4359 }
4360 ISD::CondCode CC = Cond ==