LLVM 19.0.0git
TargetLowering.cpp
Go to the documentation of this file.
1//===-- TargetLowering.cpp - Implement the TargetLowering class -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This implements the TargetLowering class.
10//
11//===----------------------------------------------------------------------===//
12
14#include "llvm/ADT/STLExtras.h"
25#include "llvm/IR/DataLayout.h"
28#include "llvm/IR/LLVMContext.h"
29#include "llvm/MC/MCAsmInfo.h"
30#include "llvm/MC/MCExpr.h"
36#include <cctype>
37using namespace llvm;
38
39/// NOTE: The TargetMachine owns TLOF.
41 : TargetLoweringBase(tm) {}
42
43const char *TargetLowering::getTargetNodeName(unsigned Opcode) const {
44 return nullptr;
45}
46
49}
50
51/// Check whether a given call node is in tail position within its function. If
52/// so, it sets Chain to the input chain of the tail call.
54 SDValue &Chain) const {
56
57 // First, check if tail calls have been disabled in this function.
58 if (F.getFnAttribute("disable-tail-calls").getValueAsBool())
59 return false;
60
61 // Conservatively require the attributes of the call to match those of
62 // the return. Ignore following attributes because they don't affect the
63 // call sequence.
64 AttrBuilder CallerAttrs(F.getContext(), F.getAttributes().getRetAttrs());
65 for (const auto &Attr :
66 {Attribute::Alignment, Attribute::Dereferenceable,
67 Attribute::DereferenceableOrNull, Attribute::NoAlias,
68 Attribute::NonNull, Attribute::NoUndef, Attribute::Range})
69 CallerAttrs.removeAttribute(Attr);
70
71 if (CallerAttrs.hasAttributes())
72 return false;
73
74 // It's not safe to eliminate the sign / zero extension of the return value.
75 if (CallerAttrs.contains(Attribute::ZExt) ||
76 CallerAttrs.contains(Attribute::SExt))
77 return false;
78
79 // Check if the only use is a function return node.
80 return isUsedByReturnOnly(Node, Chain);
81}
82
84 const uint32_t *CallerPreservedMask,
85 const SmallVectorImpl<CCValAssign> &ArgLocs,
86 const SmallVectorImpl<SDValue> &OutVals) const {
87 for (unsigned I = 0, E = ArgLocs.size(); I != E; ++I) {
88 const CCValAssign &ArgLoc = ArgLocs[I];
89 if (!ArgLoc.isRegLoc())
90 continue;
91 MCRegister Reg = ArgLoc.getLocReg();
92 // Only look at callee saved registers.
93 if (MachineOperand::clobbersPhysReg(CallerPreservedMask, Reg))
94 continue;
95 // Check that we pass the value used for the caller.
96 // (We look for a CopyFromReg reading a virtual register that is used
97 // for the function live-in value of register Reg)
98 SDValue Value = OutVals[I];
99 if (Value->getOpcode() == ISD::AssertZext)
100 Value = Value.getOperand(0);
101 if (Value->getOpcode() != ISD::CopyFromReg)
102 return false;
103 Register ArgReg = cast<RegisterSDNode>(Value->getOperand(1))->getReg();
104 if (MRI.getLiveInPhysReg(ArgReg) != Reg)
105 return false;
106 }
107 return true;
108}
109
110/// Set CallLoweringInfo attribute flags based on a call instruction
111/// and called function attributes.
113 unsigned ArgIdx) {
114 IsSExt = Call->paramHasAttr(ArgIdx, Attribute::SExt);
115 IsZExt = Call->paramHasAttr(ArgIdx, Attribute::ZExt);
116 IsInReg = Call->paramHasAttr(ArgIdx, Attribute::InReg);
117 IsSRet = Call->paramHasAttr(ArgIdx, Attribute::StructRet);
118 IsNest = Call->paramHasAttr(ArgIdx, Attribute::Nest);
119 IsByVal = Call->paramHasAttr(ArgIdx, Attribute::ByVal);
120 IsPreallocated = Call->paramHasAttr(ArgIdx, Attribute::Preallocated);
121 IsInAlloca = Call->paramHasAttr(ArgIdx, Attribute::InAlloca);
122 IsReturned = Call->paramHasAttr(ArgIdx, Attribute::Returned);
123 IsSwiftSelf = Call->paramHasAttr(ArgIdx, Attribute::SwiftSelf);
124 IsSwiftAsync = Call->paramHasAttr(ArgIdx, Attribute::SwiftAsync);
125 IsSwiftError = Call->paramHasAttr(ArgIdx, Attribute::SwiftError);
126 Alignment = Call->getParamStackAlign(ArgIdx);
127 IndirectType = nullptr;
129 "multiple ABI attributes?");
130 if (IsByVal) {
131 IndirectType = Call->getParamByValType(ArgIdx);
132 if (!Alignment)
133 Alignment = Call->getParamAlign(ArgIdx);
134 }
135 if (IsPreallocated)
136 IndirectType = Call->getParamPreallocatedType(ArgIdx);
137 if (IsInAlloca)
138 IndirectType = Call->getParamInAllocaType(ArgIdx);
139 if (IsSRet)
140 IndirectType = Call->getParamStructRetType(ArgIdx);
141}
142
143/// Generate a libcall taking the given operands as arguments and returning a
144/// result of type RetVT.
145std::pair<SDValue, SDValue>
148 MakeLibCallOptions CallOptions,
149 const SDLoc &dl,
150 SDValue InChain) const {
151 if (!InChain)
152 InChain = DAG.getEntryNode();
153
155 Args.reserve(Ops.size());
156
158 for (unsigned i = 0; i < Ops.size(); ++i) {
159 SDValue NewOp = Ops[i];
160 Entry.Node = NewOp;
161 Entry.Ty = Entry.Node.getValueType().getTypeForEVT(*DAG.getContext());
162 Entry.IsSExt = shouldSignExtendTypeInLibCall(NewOp.getValueType(),
163 CallOptions.IsSExt);
164 Entry.IsZExt = !Entry.IsSExt;
165
166 if (CallOptions.IsSoften &&
168 Entry.IsSExt = Entry.IsZExt = false;
169 }
170 Args.push_back(Entry);
171 }
172
173 if (LC == RTLIB::UNKNOWN_LIBCALL)
174 report_fatal_error("Unsupported library call operation!");
177
178 Type *RetTy = RetVT.getTypeForEVT(*DAG.getContext());
180 bool signExtend = shouldSignExtendTypeInLibCall(RetVT, CallOptions.IsSExt);
181 bool zeroExtend = !signExtend;
182
183 if (CallOptions.IsSoften &&
185 signExtend = zeroExtend = false;
186 }
187
188 CLI.setDebugLoc(dl)
189 .setChain(InChain)
190 .setLibCallee(getLibcallCallingConv(LC), RetTy, Callee, std::move(Args))
191 .setNoReturn(CallOptions.DoesNotReturn)
194 .setSExtResult(signExtend)
195 .setZExtResult(zeroExtend);
196 return LowerCallTo(CLI);
197}
198
200 std::vector<EVT> &MemOps, unsigned Limit, const MemOp &Op, unsigned DstAS,
201 unsigned SrcAS, const AttributeList &FuncAttributes) const {
202 if (Limit != ~unsigned(0) && Op.isMemcpyWithFixedDstAlign() &&
203 Op.getSrcAlign() < Op.getDstAlign())
204 return false;
205
206 EVT VT = getOptimalMemOpType(Op, FuncAttributes);
207
208 if (VT == MVT::Other) {
209 // Use the largest integer type whose alignment constraints are satisfied.
210 // We only need to check DstAlign here as SrcAlign is always greater or
211 // equal to DstAlign (or zero).
212 VT = MVT::i64;
213 if (Op.isFixedDstAlign())
214 while (Op.getDstAlign() < (VT.getSizeInBits() / 8) &&
215 !allowsMisalignedMemoryAccesses(VT, DstAS, Op.getDstAlign()))
217 assert(VT.isInteger());
218
219 // Find the largest legal integer type.
220 MVT LVT = MVT::i64;
221 while (!isTypeLegal(LVT))
222 LVT = (MVT::SimpleValueType)(LVT.SimpleTy - 1);
223 assert(LVT.isInteger());
224
225 // If the type we've chosen is larger than the largest legal integer type
226 // then use that instead.
227 if (VT.bitsGT(LVT))
228 VT = LVT;
229 }
230
231 unsigned NumMemOps = 0;
232 uint64_t Size = Op.size();
233 while (Size) {
234 unsigned VTSize = VT.getSizeInBits() / 8;
235 while (VTSize > Size) {
236 // For now, only use non-vector load / store's for the left-over pieces.
237 EVT NewVT = VT;
238 unsigned NewVTSize;
239
240 bool Found = false;
241 if (VT.isVector() || VT.isFloatingPoint()) {
242 NewVT = (VT.getSizeInBits() > 64) ? MVT::i64 : MVT::i32;
245 Found = true;
246 else if (NewVT == MVT::i64 &&
248 isSafeMemOpType(MVT::f64)) {
249 // i64 is usually not legal on 32-bit targets, but f64 may be.
250 NewVT = MVT::f64;
251 Found = true;
252 }
253 }
254
255 if (!Found) {
256 do {
257 NewVT = (MVT::SimpleValueType)(NewVT.getSimpleVT().SimpleTy - 1);
258 if (NewVT == MVT::i8)
259 break;
260 } while (!isSafeMemOpType(NewVT.getSimpleVT()));
261 }
262 NewVTSize = NewVT.getSizeInBits() / 8;
263
264 // If the new VT cannot cover all of the remaining bits, then consider
265 // issuing a (or a pair of) unaligned and overlapping load / store.
266 unsigned Fast;
267 if (NumMemOps && Op.allowOverlap() && NewVTSize < Size &&
269 VT, DstAS, Op.isFixedDstAlign() ? Op.getDstAlign() : Align(1),
271 Fast)
272 VTSize = Size;
273 else {
274 VT = NewVT;
275 VTSize = NewVTSize;
276 }
277 }
278
279 if (++NumMemOps > Limit)
280 return false;
281
282 MemOps.push_back(VT);
283 Size -= VTSize;
284 }
285
286 return true;
287}
288
289/// Soften the operands of a comparison. This code is shared among BR_CC,
290/// SELECT_CC, and SETCC handlers.
292 SDValue &NewLHS, SDValue &NewRHS,
293 ISD::CondCode &CCCode,
294 const SDLoc &dl, const SDValue OldLHS,
295 const SDValue OldRHS) const {
296 SDValue Chain;
297 return softenSetCCOperands(DAG, VT, NewLHS, NewRHS, CCCode, dl, OldLHS,
298 OldRHS, Chain);
299}
300
302 SDValue &NewLHS, SDValue &NewRHS,
303 ISD::CondCode &CCCode,
304 const SDLoc &dl, const SDValue OldLHS,
305 const SDValue OldRHS,
306 SDValue &Chain,
307 bool IsSignaling) const {
308 // FIXME: Currently we cannot really respect all IEEE predicates due to libgcc
309 // not supporting it. We can update this code when libgcc provides such
310 // functions.
311
312 assert((VT == MVT::f32 || VT == MVT::f64 || VT == MVT::f128 || VT == MVT::ppcf128)
313 && "Unsupported setcc type!");
314
315 // Expand into one or more soft-fp libcall(s).
316 RTLIB::Libcall LC1 = RTLIB::UNKNOWN_LIBCALL, LC2 = RTLIB::UNKNOWN_LIBCALL;
317 bool ShouldInvertCC = false;
318 switch (CCCode) {
319 case ISD::SETEQ:
320 case ISD::SETOEQ:
321 LC1 = (VT == MVT::f32) ? RTLIB::OEQ_F32 :
322 (VT == MVT::f64) ? RTLIB::OEQ_F64 :
323 (VT == MVT::f128) ? RTLIB::OEQ_F128 : RTLIB::OEQ_PPCF128;
324 break;
325 case ISD::SETNE:
326 case ISD::SETUNE:
327 LC1 = (VT == MVT::f32) ? RTLIB::UNE_F32 :
328 (VT == MVT::f64) ? RTLIB::UNE_F64 :
329 (VT == MVT::f128) ? RTLIB::UNE_F128 : RTLIB::UNE_PPCF128;
330 break;
331 case ISD::SETGE:
332 case ISD::SETOGE:
333 LC1 = (VT == MVT::f32) ? RTLIB::OGE_F32 :
334 (VT == MVT::f64) ? RTLIB::OGE_F64 :
335 (VT == MVT::f128) ? RTLIB::OGE_F128 : RTLIB::OGE_PPCF128;
336 break;
337 case ISD::SETLT:
338 case ISD::SETOLT:
339 LC1 = (VT == MVT::f32) ? RTLIB::OLT_F32 :
340 (VT == MVT::f64) ? RTLIB::OLT_F64 :
341 (VT == MVT::f128) ? RTLIB::OLT_F128 : RTLIB::OLT_PPCF128;
342 break;
343 case ISD::SETLE:
344 case ISD::SETOLE:
345 LC1 = (VT == MVT::f32) ? RTLIB::OLE_F32 :
346 (VT == MVT::f64) ? RTLIB::OLE_F64 :
347 (VT == MVT::f128) ? RTLIB::OLE_F128 : RTLIB::OLE_PPCF128;
348 break;
349 case ISD::SETGT:
350 case ISD::SETOGT:
351 LC1 = (VT == MVT::f32) ? RTLIB::OGT_F32 :
352 (VT == MVT::f64) ? RTLIB::OGT_F64 :
353 (VT == MVT::f128) ? RTLIB::OGT_F128 : RTLIB::OGT_PPCF128;
354 break;
355 case ISD::SETO:
356 ShouldInvertCC = true;
357 [[fallthrough]];
358 case ISD::SETUO:
359 LC1 = (VT == MVT::f32) ? RTLIB::UO_F32 :
360 (VT == MVT::f64) ? RTLIB::UO_F64 :
361 (VT == MVT::f128) ? RTLIB::UO_F128 : RTLIB::UO_PPCF128;
362 break;
363 case ISD::SETONE:
364 // SETONE = O && UNE
365 ShouldInvertCC = true;
366 [[fallthrough]];
367 case ISD::SETUEQ:
368 LC1 = (VT == MVT::f32) ? RTLIB::UO_F32 :
369 (VT == MVT::f64) ? RTLIB::UO_F64 :
370 (VT == MVT::f128) ? RTLIB::UO_F128 : RTLIB::UO_PPCF128;
371 LC2 = (VT == MVT::f32) ? RTLIB::OEQ_F32 :
372 (VT == MVT::f64) ? RTLIB::OEQ_F64 :
373 (VT == MVT::f128) ? RTLIB::OEQ_F128 : RTLIB::OEQ_PPCF128;
374 break;
375 default:
376 // Invert CC for unordered comparisons
377 ShouldInvertCC = true;
378 switch (CCCode) {
379 case ISD::SETULT:
380 LC1 = (VT == MVT::f32) ? RTLIB::OGE_F32 :
381 (VT == MVT::f64) ? RTLIB::OGE_F64 :
382 (VT == MVT::f128) ? RTLIB::OGE_F128 : RTLIB::OGE_PPCF128;
383 break;
384 case ISD::SETULE:
385 LC1 = (VT == MVT::f32) ? RTLIB::OGT_F32 :
386 (VT == MVT::f64) ? RTLIB::OGT_F64 :
387 (VT == MVT::f128) ? RTLIB::OGT_F128 : RTLIB::OGT_PPCF128;
388 break;
389 case ISD::SETUGT:
390 LC1 = (VT == MVT::f32) ? RTLIB::OLE_F32 :
391 (VT == MVT::f64) ? RTLIB::OLE_F64 :
392 (VT == MVT::f128) ? RTLIB::OLE_F128 : RTLIB::OLE_PPCF128;
393 break;
394 case ISD::SETUGE:
395 LC1 = (VT == MVT::f32) ? RTLIB::OLT_F32 :
396 (VT == MVT::f64) ? RTLIB::OLT_F64 :
397 (VT == MVT::f128) ? RTLIB::OLT_F128 : RTLIB::OLT_PPCF128;
398 break;
399 default: llvm_unreachable("Do not know how to soften this setcc!");
400 }
401 }
402
403 // Use the target specific return value for comparison lib calls.
405 SDValue Ops[2] = {NewLHS, NewRHS};
407 EVT OpsVT[2] = { OldLHS.getValueType(),
408 OldRHS.getValueType() };
409 CallOptions.setTypeListBeforeSoften(OpsVT, RetVT, true);
410 auto Call = makeLibCall(DAG, LC1, RetVT, Ops, CallOptions, dl, Chain);
411 NewLHS = Call.first;
412 NewRHS = DAG.getConstant(0, dl, RetVT);
413
414 CCCode = getCmpLibcallCC(LC1);
415 if (ShouldInvertCC) {
416 assert(RetVT.isInteger());
417 CCCode = getSetCCInverse(CCCode, RetVT);
418 }
419
420 if (LC2 == RTLIB::UNKNOWN_LIBCALL) {
421 // Update Chain.
422 Chain = Call.second;
423 } else {
424 EVT SetCCVT =
425 getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), RetVT);
426 SDValue Tmp = DAG.getSetCC(dl, SetCCVT, NewLHS, NewRHS, CCCode);
427 auto Call2 = makeLibCall(DAG, LC2, RetVT, Ops, CallOptions, dl, Chain);
428 CCCode = getCmpLibcallCC(LC2);
429 if (ShouldInvertCC)
430 CCCode = getSetCCInverse(CCCode, RetVT);
431 NewLHS = DAG.getSetCC(dl, SetCCVT, Call2.first, NewRHS, CCCode);
432 if (Chain)
433 Chain = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, Call.second,
434 Call2.second);
435 NewLHS = DAG.getNode(ShouldInvertCC ? ISD::AND : ISD::OR, dl,
436 Tmp.getValueType(), Tmp, NewLHS);
437 NewRHS = SDValue();
438 }
439}
440
441/// Return the entry encoding for a jump table in the current function. The
442/// returned value is a member of the MachineJumpTableInfo::JTEntryKind enum.
444 // In non-pic modes, just use the address of a block.
445 if (!isPositionIndependent())
447
448 // In PIC mode, if the target supports a GPRel32 directive, use it.
449 if (getTargetMachine().getMCAsmInfo()->getGPRel32Directive() != nullptr)
451
452 // Otherwise, use a label difference.
454}
455
457 SelectionDAG &DAG) const {
458 // If our PIC model is GP relative, use the global offset table as the base.
459 unsigned JTEncoding = getJumpTableEncoding();
460
464
465 return Table;
466}
467
468/// This returns the relocation base for the given PIC jumptable, the same as
469/// getPICJumpTableRelocBase, but as an MCExpr.
470const MCExpr *
472 unsigned JTI,MCContext &Ctx) const{
473 // The normal PIC reloc base is the label at the start of the jump table.
474 return MCSymbolRefExpr::create(MF->getJTISymbol(JTI, Ctx), Ctx);
475}
476
478 SDValue Addr, int JTI,
479 SelectionDAG &DAG) const {
480 SDValue Chain = Value;
481 // Jump table debug info is only needed if CodeView is enabled.
483 Chain = DAG.getJumpTableDebugInfo(JTI, Chain, dl);
484 }
485 return DAG.getNode(ISD::BRIND, dl, MVT::Other, Chain, Addr);
486}
487
488bool
490 const TargetMachine &TM = getTargetMachine();
491 const GlobalValue *GV = GA->getGlobal();
492
493 // If the address is not even local to this DSO we will have to load it from
494 // a got and then add the offset.
495 if (!TM.shouldAssumeDSOLocal(GV))
496 return false;
497
498 // If the code is position independent we will have to add a base register.
499 if (isPositionIndependent())
500 return false;
501
502 // Otherwise we can do it.
503 return true;
504}
505
506//===----------------------------------------------------------------------===//
507// Optimization Methods
508//===----------------------------------------------------------------------===//
509
510/// If the specified instruction has a constant integer operand and there are
511/// bits set in that constant that are not demanded, then clear those bits and
512/// return true.
514 const APInt &DemandedBits,
515 const APInt &DemandedElts,
516 TargetLoweringOpt &TLO) const {
517 SDLoc DL(Op);
518 unsigned Opcode = Op.getOpcode();
519
520 // Early-out if we've ended up calling an undemanded node, leave this to
521 // constant folding.
522 if (DemandedBits.isZero() || DemandedElts.isZero())
523 return false;
524
525 // Do target-specific constant optimization.
526 if (targetShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO))
527 return TLO.New.getNode();
528
529 // FIXME: ISD::SELECT, ISD::SELECT_CC
530 switch (Opcode) {
531 default:
532 break;
533 case ISD::XOR:
534 case ISD::AND:
535 case ISD::OR: {
536 auto *Op1C = dyn_cast<ConstantSDNode>(Op.getOperand(1));
537 if (!Op1C || Op1C->isOpaque())
538 return false;
539
540 // If this is a 'not' op, don't touch it because that's a canonical form.
541 const APInt &C = Op1C->getAPIntValue();
542 if (Opcode == ISD::XOR && DemandedBits.isSubsetOf(C))
543 return false;
544
545 if (!C.isSubsetOf(DemandedBits)) {
546 EVT VT = Op.getValueType();
547 SDValue NewC = TLO.DAG.getConstant(DemandedBits & C, DL, VT);
548 SDValue NewOp = TLO.DAG.getNode(Opcode, DL, VT, Op.getOperand(0), NewC,
549 Op->getFlags());
550 return TLO.CombineTo(Op, NewOp);
551 }
552
553 break;
554 }
555 }
556
557 return false;
558}
559
561 const APInt &DemandedBits,
562 TargetLoweringOpt &TLO) const {
563 EVT VT = Op.getValueType();
564 APInt DemandedElts = VT.isVector()
566 : APInt(1, 1);
567 return ShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO);
568}
569
570/// Convert x+y to (VT)((SmallVT)x+(SmallVT)y) if the casts are free.
571/// This uses isTruncateFree/isZExtFree and ANY_EXTEND for the widening cast,
572/// but it could be generalized for targets with other types of implicit
573/// widening casts.
575 const APInt &DemandedBits,
576 TargetLoweringOpt &TLO) const {
577 assert(Op.getNumOperands() == 2 &&
578 "ShrinkDemandedOp only supports binary operators!");
579 assert(Op.getNode()->getNumValues() == 1 &&
580 "ShrinkDemandedOp only supports nodes with one result!");
581
582 EVT VT = Op.getValueType();
583 SelectionDAG &DAG = TLO.DAG;
584 SDLoc dl(Op);
585
586 // Early return, as this function cannot handle vector types.
587 if (VT.isVector())
588 return false;
589
590 assert(Op.getOperand(0).getValueType().getScalarSizeInBits() == BitWidth &&
591 Op.getOperand(1).getValueType().getScalarSizeInBits() == BitWidth &&
592 "ShrinkDemandedOp only supports operands that have the same size!");
593
594 // Don't do this if the node has another user, which may require the
595 // full value.
596 if (!Op.getNode()->hasOneUse())
597 return false;
598
599 // Search for the smallest integer type with free casts to and from
600 // Op's type. For expedience, just check power-of-2 integer types.
601 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
602 unsigned DemandedSize = DemandedBits.getActiveBits();
603 for (unsigned SmallVTBits = llvm::bit_ceil(DemandedSize);
604 SmallVTBits < BitWidth; SmallVTBits = NextPowerOf2(SmallVTBits)) {
605 EVT SmallVT = EVT::getIntegerVT(*DAG.getContext(), SmallVTBits);
606 if (TLI.isTruncateFree(VT, SmallVT) && TLI.isZExtFree(SmallVT, VT)) {
607 // We found a type with free casts.
608 SDValue X = DAG.getNode(
609 Op.getOpcode(), dl, SmallVT,
610 DAG.getNode(ISD::TRUNCATE, dl, SmallVT, Op.getOperand(0)),
611 DAG.getNode(ISD::TRUNCATE, dl, SmallVT, Op.getOperand(1)));
612 assert(DemandedSize <= SmallVTBits && "Narrowed below demanded bits?");
613 SDValue Z = DAG.getNode(ISD::ANY_EXTEND, dl, VT, X);
614 return TLO.CombineTo(Op, Z);
615 }
616 }
617 return false;
618}
619
621 DAGCombinerInfo &DCI) const {
622 SelectionDAG &DAG = DCI.DAG;
623 TargetLoweringOpt TLO(DAG, !DCI.isBeforeLegalize(),
624 !DCI.isBeforeLegalizeOps());
625 KnownBits Known;
626
627 bool Simplified = SimplifyDemandedBits(Op, DemandedBits, Known, TLO);
628 if (Simplified) {
629 DCI.AddToWorklist(Op.getNode());
631 }
632 return Simplified;
633}
634
636 const APInt &DemandedElts,
637 DAGCombinerInfo &DCI) const {
638 SelectionDAG &DAG = DCI.DAG;
639 TargetLoweringOpt TLO(DAG, !DCI.isBeforeLegalize(),
640 !DCI.isBeforeLegalizeOps());
641 KnownBits Known;
642
643 bool Simplified =
644 SimplifyDemandedBits(Op, DemandedBits, DemandedElts, Known, TLO);
645 if (Simplified) {
646 DCI.AddToWorklist(Op.getNode());
648 }
649 return Simplified;
650}
651
653 KnownBits &Known,
655 unsigned Depth,
656 bool AssumeSingleUse) const {
657 EVT VT = Op.getValueType();
658
659 // Since the number of lanes in a scalable vector is unknown at compile time,
660 // we track one bit which is implicitly broadcast to all lanes. This means
661 // that all lanes in a scalable vector are considered demanded.
662 APInt DemandedElts = VT.isFixedLengthVector()
664 : APInt(1, 1);
665 return SimplifyDemandedBits(Op, DemandedBits, DemandedElts, Known, TLO, Depth,
666 AssumeSingleUse);
667}
668
669// TODO: Under what circumstances can we create nodes? Constant folding?
671 SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
672 SelectionDAG &DAG, unsigned Depth) const {
673 EVT VT = Op.getValueType();
674
675 // Limit search depth.
677 return SDValue();
678
679 // Ignore UNDEFs.
680 if (Op.isUndef())
681 return SDValue();
682
683 // Not demanding any bits/elts from Op.
684 if (DemandedBits == 0 || DemandedElts == 0)
685 return DAG.getUNDEF(VT);
686
687 bool IsLE = DAG.getDataLayout().isLittleEndian();
688 unsigned NumElts = DemandedElts.getBitWidth();
689 unsigned BitWidth = DemandedBits.getBitWidth();
690 KnownBits LHSKnown, RHSKnown;
691 switch (Op.getOpcode()) {
692 case ISD::BITCAST: {
693 if (VT.isScalableVector())
694 return SDValue();
695
696 SDValue Src = peekThroughBitcasts(Op.getOperand(0));
697 EVT SrcVT = Src.getValueType();
698 EVT DstVT = Op.getValueType();
699 if (SrcVT == DstVT)
700 return Src;
701
702 unsigned NumSrcEltBits = SrcVT.getScalarSizeInBits();
703 unsigned NumDstEltBits = DstVT.getScalarSizeInBits();
704 if (NumSrcEltBits == NumDstEltBits)
705 if (SDValue V = SimplifyMultipleUseDemandedBits(
706 Src, DemandedBits, DemandedElts, DAG, Depth + 1))
707 return DAG.getBitcast(DstVT, V);
708
709 if (SrcVT.isVector() && (NumDstEltBits % NumSrcEltBits) == 0) {
710 unsigned Scale = NumDstEltBits / NumSrcEltBits;
711 unsigned NumSrcElts = SrcVT.getVectorNumElements();
712 APInt DemandedSrcBits = APInt::getZero(NumSrcEltBits);
713 APInt DemandedSrcElts = APInt::getZero(NumSrcElts);
714 for (unsigned i = 0; i != Scale; ++i) {
715 unsigned EltOffset = IsLE ? i : (Scale - 1 - i);
716 unsigned BitOffset = EltOffset * NumSrcEltBits;
717 APInt Sub = DemandedBits.extractBits(NumSrcEltBits, BitOffset);
718 if (!Sub.isZero()) {
719 DemandedSrcBits |= Sub;
720 for (unsigned j = 0; j != NumElts; ++j)
721 if (DemandedElts[j])
722 DemandedSrcElts.setBit((j * Scale) + i);
723 }
724 }
725
726 if (SDValue V = SimplifyMultipleUseDemandedBits(
727 Src, DemandedSrcBits, DemandedSrcElts, DAG, Depth + 1))
728 return DAG.getBitcast(DstVT, V);
729 }
730
731 // TODO - bigendian once we have test coverage.
732 if (IsLE && (NumSrcEltBits % NumDstEltBits) == 0) {
733 unsigned Scale = NumSrcEltBits / NumDstEltBits;
734 unsigned NumSrcElts = SrcVT.isVector() ? SrcVT.getVectorNumElements() : 1;
735 APInt DemandedSrcBits = APInt::getZero(NumSrcEltBits);
736 APInt DemandedSrcElts = APInt::getZero(NumSrcElts);
737 for (unsigned i = 0; i != NumElts; ++i)
738 if (DemandedElts[i]) {
739 unsigned Offset = (i % Scale) * NumDstEltBits;
740 DemandedSrcBits.insertBits(DemandedBits, Offset);
741 DemandedSrcElts.setBit(i / Scale);
742 }
743
744 if (SDValue V = SimplifyMultipleUseDemandedBits(
745 Src, DemandedSrcBits, DemandedSrcElts, DAG, Depth + 1))
746 return DAG.getBitcast(DstVT, V);
747 }
748
749 break;
750 }
751 case ISD::FREEZE: {
752 SDValue N0 = Op.getOperand(0);
753 if (DAG.isGuaranteedNotToBeUndefOrPoison(N0, DemandedElts,
754 /*PoisonOnly=*/false))
755 return N0;
756 break;
757 }
758 case ISD::AND: {
759 LHSKnown = DAG.computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
760 RHSKnown = DAG.computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
761
762 // If all of the demanded bits are known 1 on one side, return the other.
763 // These bits cannot contribute to the result of the 'and' in this
764 // context.
765 if (DemandedBits.isSubsetOf(LHSKnown.Zero | RHSKnown.One))
766 return Op.getOperand(0);
767 if (DemandedBits.isSubsetOf(RHSKnown.Zero | LHSKnown.One))
768 return Op.getOperand(1);
769 break;
770 }
771 case ISD::OR: {
772 LHSKnown = DAG.computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
773 RHSKnown = DAG.computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
774
775 // If all of the demanded bits are known zero on one side, return the
776 // other. These bits cannot contribute to the result of the 'or' in this
777 // context.
778 if (DemandedBits.isSubsetOf(LHSKnown.One | RHSKnown.Zero))
779 return Op.getOperand(0);
780 if (DemandedBits.isSubsetOf(RHSKnown.One | LHSKnown.Zero))
781 return Op.getOperand(1);
782 break;
783 }
784 case ISD::XOR: {
785 LHSKnown = DAG.computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
786 RHSKnown = DAG.computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
787
788 // If all of the demanded bits are known zero on one side, return the
789 // other.
790 if (DemandedBits.isSubsetOf(RHSKnown.Zero))
791 return Op.getOperand(0);
792 if (DemandedBits.isSubsetOf(LHSKnown.Zero))
793 return Op.getOperand(1);
794 break;
795 }
796 case ISD::SHL: {
797 // If we are only demanding sign bits then we can use the shift source
798 // directly.
799 if (std::optional<uint64_t> MaxSA =
800 DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) {
801 SDValue Op0 = Op.getOperand(0);
802 unsigned ShAmt = *MaxSA;
803 unsigned NumSignBits =
804 DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
805 unsigned UpperDemandedBits = BitWidth - DemandedBits.countr_zero();
806 if (NumSignBits > ShAmt && (NumSignBits - ShAmt) >= (UpperDemandedBits))
807 return Op0;
808 }
809 break;
810 }
811 case ISD::SETCC: {
812 SDValue Op0 = Op.getOperand(0);
813 SDValue Op1 = Op.getOperand(1);
814 ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(2))->get();
815 // If (1) we only need the sign-bit, (2) the setcc operands are the same
816 // width as the setcc result, and (3) the result of a setcc conforms to 0 or
817 // -1, we may be able to bypass the setcc.
818 if (DemandedBits.isSignMask() &&
822 // If we're testing X < 0, then this compare isn't needed - just use X!
823 // FIXME: We're limiting to integer types here, but this should also work
824 // if we don't care about FP signed-zero. The use of SETLT with FP means
825 // that we don't care about NaNs.
826 if (CC == ISD::SETLT && Op1.getValueType().isInteger() &&
828 return Op0;
829 }
830 break;
831 }
833 // If none of the extended bits are demanded, eliminate the sextinreg.
834 SDValue Op0 = Op.getOperand(0);
835 EVT ExVT = cast<VTSDNode>(Op.getOperand(1))->getVT();
836 unsigned ExBits = ExVT.getScalarSizeInBits();
837 if (DemandedBits.getActiveBits() <= ExBits &&
839 return Op0;
840 // If the input is already sign extended, just drop the extension.
841 unsigned NumSignBits = DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
842 if (NumSignBits >= (BitWidth - ExBits + 1))
843 return Op0;
844 break;
845 }
849 if (VT.isScalableVector())
850 return SDValue();
851
852 // If we only want the lowest element and none of extended bits, then we can
853 // return the bitcasted source vector.
854 SDValue Src = Op.getOperand(0);
855 EVT SrcVT = Src.getValueType();
856 EVT DstVT = Op.getValueType();
857 if (IsLE && DemandedElts == 1 &&
858 DstVT.getSizeInBits() == SrcVT.getSizeInBits() &&
859 DemandedBits.getActiveBits() <= SrcVT.getScalarSizeInBits()) {
860 return DAG.getBitcast(DstVT, Src);
861 }
862 break;
863 }
865 if (VT.isScalableVector())
866 return SDValue();
867
868 // If we don't demand the inserted element, return the base vector.
869 SDValue Vec = Op.getOperand(0);
870 auto *CIdx = dyn_cast<ConstantSDNode>(Op.getOperand(2));
871 EVT VecVT = Vec.getValueType();
872 if (CIdx && CIdx->getAPIntValue().ult(VecVT.getVectorNumElements()) &&
873 !DemandedElts[CIdx->getZExtValue()])
874 return Vec;
875 break;
876 }
878 if (VT.isScalableVector())
879 return SDValue();
880
881 SDValue Vec = Op.getOperand(0);
882 SDValue Sub = Op.getOperand(1);
883 uint64_t Idx = Op.getConstantOperandVal(2);
884 unsigned NumSubElts = Sub.getValueType().getVectorNumElements();
885 APInt DemandedSubElts = DemandedElts.extractBits(NumSubElts, Idx);
886 // If we don't demand the inserted subvector, return the base vector.
887 if (DemandedSubElts == 0)
888 return Vec;
889 break;
890 }
891 case ISD::VECTOR_SHUFFLE: {
893 ArrayRef<int> ShuffleMask = cast<ShuffleVectorSDNode>(Op)->getMask();
894
895 // If all the demanded elts are from one operand and are inline,
896 // then we can use the operand directly.
897 bool AllUndef = true, IdentityLHS = true, IdentityRHS = true;
898 for (unsigned i = 0; i != NumElts; ++i) {
899 int M = ShuffleMask[i];
900 if (M < 0 || !DemandedElts[i])
901 continue;
902 AllUndef = false;
903 IdentityLHS &= (M == (int)i);
904 IdentityRHS &= ((M - NumElts) == i);
905 }
906
907 if (AllUndef)
908 return DAG.getUNDEF(Op.getValueType());
909 if (IdentityLHS)
910 return Op.getOperand(0);
911 if (IdentityRHS)
912 return Op.getOperand(1);
913 break;
914 }
915 default:
916 // TODO: Probably okay to remove after audit; here to reduce change size
917 // in initial enablement patch for scalable vectors
918 if (VT.isScalableVector())
919 return SDValue();
920
921 if (Op.getOpcode() >= ISD::BUILTIN_OP_END)
922 if (SDValue V = SimplifyMultipleUseDemandedBitsForTargetNode(
923 Op, DemandedBits, DemandedElts, DAG, Depth))
924 return V;
925 break;
926 }
927 return SDValue();
928}
929
932 unsigned Depth) const {
933 EVT VT = Op.getValueType();
934 // Since the number of lanes in a scalable vector is unknown at compile time,
935 // we track one bit which is implicitly broadcast to all lanes. This means
936 // that all lanes in a scalable vector are considered demanded.
937 APInt DemandedElts = VT.isFixedLengthVector()
939 : APInt(1, 1);
940 return SimplifyMultipleUseDemandedBits(Op, DemandedBits, DemandedElts, DAG,
941 Depth);
942}
943
945 SDValue Op, const APInt &DemandedElts, SelectionDAG &DAG,
946 unsigned Depth) const {
947 APInt DemandedBits = APInt::getAllOnes(Op.getScalarValueSizeInBits());
948 return SimplifyMultipleUseDemandedBits(Op, DemandedBits, DemandedElts, DAG,
949 Depth);
950}
951
952// Attempt to form ext(avgfloor(A, B)) from shr(add(ext(A), ext(B)), 1).
953// or to form ext(avgceil(A, B)) from shr(add(ext(A), ext(B), 1), 1).
956 const TargetLowering &TLI,
957 const APInt &DemandedBits,
958 const APInt &DemandedElts, unsigned Depth) {
959 assert((Op.getOpcode() == ISD::SRL || Op.getOpcode() == ISD::SRA) &&
960 "SRL or SRA node is required here!");
961 // Is the right shift using an immediate value of 1?
962 ConstantSDNode *N1C = isConstOrConstSplat(Op.getOperand(1), DemandedElts);
963 if (!N1C || !N1C->isOne())
964 return SDValue();
965
966 // We are looking for an avgfloor
967 // add(ext, ext)
968 // or one of these as a avgceil
969 // add(add(ext, ext), 1)
970 // add(add(ext, 1), ext)
971 // add(ext, add(ext, 1))
972 SDValue Add = Op.getOperand(0);
973 if (Add.getOpcode() != ISD::ADD)
974 return SDValue();
975
976 SDValue ExtOpA = Add.getOperand(0);
977 SDValue ExtOpB = Add.getOperand(1);
978 SDValue Add2;
979 auto MatchOperands = [&](SDValue Op1, SDValue Op2, SDValue Op3, SDValue A) {
980 ConstantSDNode *ConstOp;
981 if ((ConstOp = isConstOrConstSplat(Op2, DemandedElts)) &&
982 ConstOp->isOne()) {
983 ExtOpA = Op1;
984 ExtOpB = Op3;
985 Add2 = A;
986 return true;
987 }
988 if ((ConstOp = isConstOrConstSplat(Op3, DemandedElts)) &&
989 ConstOp->isOne()) {
990 ExtOpA = Op1;
991 ExtOpB = Op2;
992 Add2 = A;
993 return true;
994 }
995 return false;
996 };
997 bool IsCeil =
998 (ExtOpA.getOpcode() == ISD::ADD &&
999 MatchOperands(ExtOpA.getOperand(0), ExtOpA.getOperand(1), ExtOpB, ExtOpA)) ||
1000 (ExtOpB.getOpcode() == ISD::ADD &&
1001 MatchOperands(ExtOpB.getOperand(0), ExtOpB.getOperand(1), ExtOpA, ExtOpB));
1002
1003 // If the shift is signed (sra):
1004 // - Needs >= 2 sign bit for both operands.
1005 // - Needs >= 2 zero bits.
1006 // If the shift is unsigned (srl):
1007 // - Needs >= 1 zero bit for both operands.
1008 // - Needs 1 demanded bit zero and >= 2 sign bits.
1009 SelectionDAG &DAG = TLO.DAG;
1010 unsigned ShiftOpc = Op.getOpcode();
1011 bool IsSigned = false;
1012 unsigned KnownBits;
1013 unsigned NumSignedA = DAG.ComputeNumSignBits(ExtOpA, DemandedElts, Depth);
1014 unsigned NumSignedB = DAG.ComputeNumSignBits(ExtOpB, DemandedElts, Depth);
1015 unsigned NumSigned = std::min(NumSignedA, NumSignedB) - 1;
1016 unsigned NumZeroA =
1017 DAG.computeKnownBits(ExtOpA, DemandedElts, Depth).countMinLeadingZeros();
1018 unsigned NumZeroB =
1019 DAG.computeKnownBits(ExtOpB, DemandedElts, Depth).countMinLeadingZeros();
1020 unsigned NumZero = std::min(NumZeroA, NumZeroB);
1021
1022 switch (ShiftOpc) {
1023 default:
1024 llvm_unreachable("Unexpected ShiftOpc in combineShiftToAVG");
1025 case ISD::SRA: {
1026 if (NumZero >= 2 && NumSigned < NumZero) {
1027 IsSigned = false;
1028 KnownBits = NumZero;
1029 break;
1030 }
1031 if (NumSigned >= 1) {
1032 IsSigned = true;
1033 KnownBits = NumSigned;
1034 break;
1035 }
1036 return SDValue();
1037 }
1038 case ISD::SRL: {
1039 if (NumZero >= 1 && NumSigned < NumZero) {
1040 IsSigned = false;
1041 KnownBits = NumZero;
1042 break;
1043 }
1044 if (NumSigned >= 1 && DemandedBits.isSignBitClear()) {
1045 IsSigned = true;
1046 KnownBits = NumSigned;
1047 break;
1048 }
1049 return SDValue();
1050 }
1051 }
1052
1053 unsigned AVGOpc = IsCeil ? (IsSigned ? ISD::AVGCEILS : ISD::AVGCEILU)
1054 : (IsSigned ? ISD::AVGFLOORS : ISD::AVGFLOORU);
1055
1056 // Find the smallest power-2 type that is legal for this vector size and
1057 // operation, given the original type size and the number of known sign/zero
1058 // bits.
1059 EVT VT = Op.getValueType();
1060 unsigned MinWidth =
1061 std::max<unsigned>(VT.getScalarSizeInBits() - KnownBits, 8);
1062 EVT NVT = EVT::getIntegerVT(*DAG.getContext(), llvm::bit_ceil(MinWidth));
1064 return SDValue();
1065 if (VT.isVector())
1066 NVT = EVT::getVectorVT(*DAG.getContext(), NVT, VT.getVectorElementCount());
1067 if (TLO.LegalTypes() && !TLI.isOperationLegal(AVGOpc, NVT)) {
1068 // If we could not transform, and (both) adds are nuw/nsw, we can use the
1069 // larger type size to do the transform.
1070 if (TLO.LegalOperations() && !TLI.isOperationLegal(AVGOpc, VT))
1071 return SDValue();
1072 if (DAG.willNotOverflowAdd(IsSigned, Add.getOperand(0),
1073 Add.getOperand(1)) &&
1074 (!Add2 || DAG.willNotOverflowAdd(IsSigned, Add2.getOperand(0),
1075 Add2.getOperand(1))))
1076 NVT = VT;
1077 else
1078 return SDValue();
1079 }
1080
1081 // Don't create a AVGFLOOR node with a scalar constant unless its legal as
1082 // this is likely to stop other folds (reassociation, value tracking etc.)
1083 if (!IsCeil && !TLI.isOperationLegal(AVGOpc, NVT) &&
1084 (isa<ConstantSDNode>(ExtOpA) || isa<ConstantSDNode>(ExtOpB)))
1085 return SDValue();
1086
1087 SDLoc DL(Op);
1088 SDValue ResultAVG =
1089 DAG.getNode(AVGOpc, DL, NVT, DAG.getExtOrTrunc(IsSigned, ExtOpA, DL, NVT),
1090 DAG.getExtOrTrunc(IsSigned, ExtOpB, DL, NVT));
1091 return DAG.getExtOrTrunc(IsSigned, ResultAVG, DL, VT);
1092}
1093
1094/// Look at Op. At this point, we know that only the OriginalDemandedBits of the
1095/// result of Op are ever used downstream. If we can use this information to
1096/// simplify Op, create a new simplified DAG node and return true, returning the
1097/// original and new nodes in Old and New. Otherwise, analyze the expression and
1098/// return a mask of Known bits for the expression (used to simplify the
1099/// caller). The Known bits may only be accurate for those bits in the
1100/// OriginalDemandedBits and OriginalDemandedElts.
1102 SDValue Op, const APInt &OriginalDemandedBits,
1103 const APInt &OriginalDemandedElts, KnownBits &Known, TargetLoweringOpt &TLO,
1104 unsigned Depth, bool AssumeSingleUse) const {
1105 unsigned BitWidth = OriginalDemandedBits.getBitWidth();
1106 assert(Op.getScalarValueSizeInBits() == BitWidth &&
1107 "Mask size mismatches value type size!");
1108
1109 // Don't know anything.
1110 Known = KnownBits(BitWidth);
1111
1112 EVT VT = Op.getValueType();
1113 bool IsLE = TLO.DAG.getDataLayout().isLittleEndian();
1114 unsigned NumElts = OriginalDemandedElts.getBitWidth();
1115 assert((!VT.isFixedLengthVector() || NumElts == VT.getVectorNumElements()) &&
1116 "Unexpected vector size");
1117
1118 APInt DemandedBits = OriginalDemandedBits;
1119 APInt DemandedElts = OriginalDemandedElts;
1120 SDLoc dl(Op);
1121
1122 // Undef operand.
1123 if (Op.isUndef())
1124 return false;
1125
1126 // We can't simplify target constants.
1127 if (Op.getOpcode() == ISD::TargetConstant)
1128 return false;
1129
1130 if (Op.getOpcode() == ISD::Constant) {
1131 // We know all of the bits for a constant!
1132 Known = KnownBits::makeConstant(Op->getAsAPIntVal());
1133 return false;
1134 }
1135
1136 if (Op.getOpcode() == ISD::ConstantFP) {
1137 // We know all of the bits for a floating point constant!
1139 cast<ConstantFPSDNode>(Op)->getValueAPF().bitcastToAPInt());
1140 return false;
1141 }
1142
1143 // Other users may use these bits.
1144 bool HasMultiUse = false;
1145 if (!AssumeSingleUse && !Op.getNode()->hasOneUse()) {
1147 // Limit search depth.
1148 return false;
1149 }
1150 // Allow multiple uses, just set the DemandedBits/Elts to all bits.
1152 DemandedElts = APInt::getAllOnes(NumElts);
1153 HasMultiUse = true;
1154 } else if (OriginalDemandedBits == 0 || OriginalDemandedElts == 0) {
1155 // Not demanding any bits/elts from Op.
1156 return TLO.CombineTo(Op, TLO.DAG.getUNDEF(VT));
1157 } else if (Depth >= SelectionDAG::MaxRecursionDepth) {
1158 // Limit search depth.
1159 return false;
1160 }
1161
1162 KnownBits Known2;
1163 switch (Op.getOpcode()) {
1164 case ISD::SCALAR_TO_VECTOR: {
1165 if (VT.isScalableVector())
1166 return false;
1167 if (!DemandedElts[0])
1168 return TLO.CombineTo(Op, TLO.DAG.getUNDEF(VT));
1169
1170 KnownBits SrcKnown;
1171 SDValue Src = Op.getOperand(0);
1172 unsigned SrcBitWidth = Src.getScalarValueSizeInBits();
1173 APInt SrcDemandedBits = DemandedBits.zext(SrcBitWidth);
1174 if (SimplifyDemandedBits(Src, SrcDemandedBits, SrcKnown, TLO, Depth + 1))
1175 return true;
1176
1177 // Upper elements are undef, so only get the knownbits if we just demand
1178 // the bottom element.
1179 if (DemandedElts == 1)
1180 Known = SrcKnown.anyextOrTrunc(BitWidth);
1181 break;
1182 }
1183 case ISD::BUILD_VECTOR:
1184 // Collect the known bits that are shared by every demanded element.
1185 // TODO: Call SimplifyDemandedBits for non-constant demanded elements.
1186 Known = TLO.DAG.computeKnownBits(Op, DemandedElts, Depth);
1187 return false; // Don't fall through, will infinitely loop.
1188 case ISD::SPLAT_VECTOR: {
1189 SDValue Scl = Op.getOperand(0);
1190 APInt DemandedSclBits = DemandedBits.zextOrTrunc(Scl.getValueSizeInBits());
1191 KnownBits KnownScl;
1192 if (SimplifyDemandedBits(Scl, DemandedSclBits, KnownScl, TLO, Depth + 1))
1193 return true;
1194
1195 // Implicitly truncate the bits to match the official semantics of
1196 // SPLAT_VECTOR.
1197 Known = KnownScl.trunc(BitWidth);
1198 break;
1199 }
1200 case ISD::LOAD: {
1201 auto *LD = cast<LoadSDNode>(Op);
1202 if (getTargetConstantFromLoad(LD)) {
1203 Known = TLO.DAG.computeKnownBits(Op, DemandedElts, Depth);
1204 return false; // Don't fall through, will infinitely loop.
1205 }
1206 if (ISD::isZEXTLoad(Op.getNode()) && Op.getResNo() == 0) {
1207 // If this is a ZEXTLoad and we are looking at the loaded value.
1208 EVT MemVT = LD->getMemoryVT();
1209 unsigned MemBits = MemVT.getScalarSizeInBits();
1210 Known.Zero.setBitsFrom(MemBits);
1211 return false; // Don't fall through, will infinitely loop.
1212 }
1213 break;
1214 }
1216 if (VT.isScalableVector())
1217 return false;
1218 SDValue Vec = Op.getOperand(0);
1219 SDValue Scl = Op.getOperand(1);
1220 auto *CIdx = dyn_cast<ConstantSDNode>(Op.getOperand(2));
1221 EVT VecVT = Vec.getValueType();
1222
1223 // If index isn't constant, assume we need all vector elements AND the
1224 // inserted element.
1225 APInt DemandedVecElts(DemandedElts);
1226 if (CIdx && CIdx->getAPIntValue().ult(VecVT.getVectorNumElements())) {
1227 unsigned Idx = CIdx->getZExtValue();
1228 DemandedVecElts.clearBit(Idx);
1229
1230 // Inserted element is not required.
1231 if (!DemandedElts[Idx])
1232 return TLO.CombineTo(Op, Vec);
1233 }
1234
1235 KnownBits KnownScl;
1236 unsigned NumSclBits = Scl.getScalarValueSizeInBits();
1237 APInt DemandedSclBits = DemandedBits.zextOrTrunc(NumSclBits);
1238 if (SimplifyDemandedBits(Scl, DemandedSclBits, KnownScl, TLO, Depth + 1))
1239 return true;
1240
1241 Known = KnownScl.anyextOrTrunc(BitWidth);
1242
1243 KnownBits KnownVec;
1244 if (SimplifyDemandedBits(Vec, DemandedBits, DemandedVecElts, KnownVec, TLO,
1245 Depth + 1))
1246 return true;
1247
1248 if (!!DemandedVecElts)
1249 Known = Known.intersectWith(KnownVec);
1250
1251 return false;
1252 }
1253 case ISD::INSERT_SUBVECTOR: {
1254 if (VT.isScalableVector())
1255 return false;
1256 // Demand any elements from the subvector and the remainder from the src its
1257 // inserted into.
1258 SDValue Src = Op.getOperand(0);
1259 SDValue Sub = Op.getOperand(1);
1260 uint64_t Idx = Op.getConstantOperandVal(2);
1261 unsigned NumSubElts = Sub.getValueType().getVectorNumElements();
1262 APInt DemandedSubElts = DemandedElts.extractBits(NumSubElts, Idx);
1263 APInt DemandedSrcElts = DemandedElts;
1264 DemandedSrcElts.insertBits(APInt::getZero(NumSubElts), Idx);
1265
1266 KnownBits KnownSub, KnownSrc;
1267 if (SimplifyDemandedBits(Sub, DemandedBits, DemandedSubElts, KnownSub, TLO,
1268 Depth + 1))
1269 return true;
1270 if (SimplifyDemandedBits(Src, DemandedBits, DemandedSrcElts, KnownSrc, TLO,
1271 Depth + 1))
1272 return true;
1273
1274 Known.Zero.setAllBits();
1275 Known.One.setAllBits();
1276 if (!!DemandedSubElts)
1277 Known = Known.intersectWith(KnownSub);
1278 if (!!DemandedSrcElts)
1279 Known = Known.intersectWith(KnownSrc);
1280
1281 // Attempt to avoid multi-use src if we don't need anything from it.
1282 if (!DemandedBits.isAllOnes() || !DemandedSubElts.isAllOnes() ||
1283 !DemandedSrcElts.isAllOnes()) {
1284 SDValue NewSub = SimplifyMultipleUseDemandedBits(
1285 Sub, DemandedBits, DemandedSubElts, TLO.DAG, Depth + 1);
1286 SDValue NewSrc = SimplifyMultipleUseDemandedBits(
1287 Src, DemandedBits, DemandedSrcElts, TLO.DAG, Depth + 1);
1288 if (NewSub || NewSrc) {
1289 NewSub = NewSub ? NewSub : Sub;
1290 NewSrc = NewSrc ? NewSrc : Src;
1291 SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), dl, VT, NewSrc, NewSub,
1292 Op.getOperand(2));
1293 return TLO.CombineTo(Op, NewOp);
1294 }
1295 }
1296 break;
1297 }
1299 if (VT.isScalableVector())
1300 return false;
1301 // Offset the demanded elts by the subvector index.
1302 SDValue Src = Op.getOperand(0);
1303 if (Src.getValueType().isScalableVector())
1304 break;
1305 uint64_t Idx = Op.getConstantOperandVal(1);
1306 unsigned NumSrcElts = Src.getValueType().getVectorNumElements();
1307 APInt DemandedSrcElts = DemandedElts.zext(NumSrcElts).shl(Idx);
1308
1309 if (SimplifyDemandedBits(Src, DemandedBits, DemandedSrcElts, Known, TLO,
1310 Depth + 1))
1311 return true;
1312
1313 // Attempt to avoid multi-use src if we don't need anything from it.
1314 if (!DemandedBits.isAllOnes() || !DemandedSrcElts.isAllOnes()) {
1315 SDValue DemandedSrc = SimplifyMultipleUseDemandedBits(
1316 Src, DemandedBits, DemandedSrcElts, TLO.DAG, Depth + 1);
1317 if (DemandedSrc) {
1318 SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), dl, VT, DemandedSrc,
1319 Op.getOperand(1));
1320 return TLO.CombineTo(Op, NewOp);
1321 }
1322 }
1323 break;
1324 }
1325 case ISD::CONCAT_VECTORS: {
1326 if (VT.isScalableVector())
1327 return false;
1328 Known.Zero.setAllBits();
1329 Known.One.setAllBits();
1330 EVT SubVT = Op.getOperand(0).getValueType();
1331 unsigned NumSubVecs = Op.getNumOperands();
1332 unsigned NumSubElts = SubVT.getVectorNumElements();
1333 for (unsigned i = 0; i != NumSubVecs; ++i) {
1334 APInt DemandedSubElts =
1335 DemandedElts.extractBits(NumSubElts, i * NumSubElts);
1336 if (SimplifyDemandedBits(Op.getOperand(i), DemandedBits, DemandedSubElts,
1337 Known2, TLO, Depth + 1))
1338 return true;
1339 // Known bits are shared by every demanded subvector element.
1340 if (!!DemandedSubElts)
1341 Known = Known.intersectWith(Known2);
1342 }
1343 break;
1344 }
1345 case ISD::VECTOR_SHUFFLE: {
1346 assert(!VT.isScalableVector());
1347 ArrayRef<int> ShuffleMask = cast<ShuffleVectorSDNode>(Op)->getMask();
1348
1349 // Collect demanded elements from shuffle operands..
1350 APInt DemandedLHS, DemandedRHS;
1351 if (!getShuffleDemandedElts(NumElts, ShuffleMask, DemandedElts, DemandedLHS,
1352 DemandedRHS))
1353 break;
1354
1355 if (!!DemandedLHS || !!DemandedRHS) {
1356 SDValue Op0 = Op.getOperand(0);
1357 SDValue Op1 = Op.getOperand(1);
1358
1359 Known.Zero.setAllBits();
1360 Known.One.setAllBits();
1361 if (!!DemandedLHS) {
1362 if (SimplifyDemandedBits(Op0, DemandedBits, DemandedLHS, Known2, TLO,
1363 Depth + 1))
1364 return true;
1365 Known = Known.intersectWith(Known2);
1366 }
1367 if (!!DemandedRHS) {
1368 if (SimplifyDemandedBits(Op1, DemandedBits, DemandedRHS, Known2, TLO,
1369 Depth + 1))
1370 return true;
1371 Known = Known.intersectWith(Known2);
1372 }
1373
1374 // Attempt to avoid multi-use ops if we don't need anything from them.
1375 SDValue DemandedOp0 = SimplifyMultipleUseDemandedBits(
1376 Op0, DemandedBits, DemandedLHS, TLO.DAG, Depth + 1);
1377 SDValue DemandedOp1 = SimplifyMultipleUseDemandedBits(
1378 Op1, DemandedBits, DemandedRHS, TLO.DAG, Depth + 1);
1379 if (DemandedOp0 || DemandedOp1) {
1380 Op0 = DemandedOp0 ? DemandedOp0 : Op0;
1381 Op1 = DemandedOp1 ? DemandedOp1 : Op1;
1382 SDValue NewOp = TLO.DAG.getVectorShuffle(VT, dl, Op0, Op1, ShuffleMask);
1383 return TLO.CombineTo(Op, NewOp);
1384 }
1385 }
1386 break;
1387 }
1388 case ISD::AND: {
1389 SDValue Op0 = Op.getOperand(0);
1390 SDValue Op1 = Op.getOperand(1);
1391
1392 // If the RHS is a constant, check to see if the LHS would be zero without
1393 // using the bits from the RHS. Below, we use knowledge about the RHS to
1394 // simplify the LHS, here we're using information from the LHS to simplify
1395 // the RHS.
1396 if (ConstantSDNode *RHSC = isConstOrConstSplat(Op1, DemandedElts)) {
1397 // Do not increment Depth here; that can cause an infinite loop.
1398 KnownBits LHSKnown = TLO.DAG.computeKnownBits(Op0, DemandedElts, Depth);
1399 // If the LHS already has zeros where RHSC does, this 'and' is dead.
1400 if ((LHSKnown.Zero & DemandedBits) ==
1401 (~RHSC->getAPIntValue() & DemandedBits))
1402 return TLO.CombineTo(Op, Op0);
1403
1404 // If any of the set bits in the RHS are known zero on the LHS, shrink
1405 // the constant.
1406 if (ShrinkDemandedConstant(Op, ~LHSKnown.Zero & DemandedBits,
1407 DemandedElts, TLO))
1408 return true;
1409
1410 // Bitwise-not (xor X, -1) is a special case: we don't usually shrink its
1411 // constant, but if this 'and' is only clearing bits that were just set by
1412 // the xor, then this 'and' can be eliminated by shrinking the mask of
1413 // the xor. For example, for a 32-bit X:
1414 // and (xor (srl X, 31), -1), 1 --> xor (srl X, 31), 1
1415 if (isBitwiseNot(Op0) && Op0.hasOneUse() &&
1416 LHSKnown.One == ~RHSC->getAPIntValue()) {
1417 SDValue Xor = TLO.DAG.getNode(ISD::XOR, dl, VT, Op0.getOperand(0), Op1);
1418 return TLO.CombineTo(Op, Xor);
1419 }
1420 }
1421
1422 // AND(INSERT_SUBVECTOR(C,X,I),M) -> INSERT_SUBVECTOR(AND(C,M),X,I)
1423 // iff 'C' is Undef/Constant and AND(X,M) == X (for DemandedBits).
1424 if (Op0.getOpcode() == ISD::INSERT_SUBVECTOR && !VT.isScalableVector() &&
1425 (Op0.getOperand(0).isUndef() ||
1427 Op0->hasOneUse()) {
1428 unsigned NumSubElts =
1430 unsigned SubIdx = Op0.getConstantOperandVal(2);
1431 APInt DemandedSub =
1432 APInt::getBitsSet(NumElts, SubIdx, SubIdx + NumSubElts);
1433 KnownBits KnownSubMask =
1434 TLO.DAG.computeKnownBits(Op1, DemandedSub & DemandedElts, Depth + 1);
1435 if (DemandedBits.isSubsetOf(KnownSubMask.One)) {
1436 SDValue NewAnd =
1437 TLO.DAG.getNode(ISD::AND, dl, VT, Op0.getOperand(0), Op1);
1438 SDValue NewInsert =
1439 TLO.DAG.getNode(ISD::INSERT_SUBVECTOR, dl, VT, NewAnd,
1440 Op0.getOperand(1), Op0.getOperand(2));
1441 return TLO.CombineTo(Op, NewInsert);
1442 }
1443 }
1444
1445 if (SimplifyDemandedBits(Op1, DemandedBits, DemandedElts, Known, TLO,
1446 Depth + 1))
1447 return true;
1448 if (SimplifyDemandedBits(Op0, ~Known.Zero & DemandedBits, DemandedElts,
1449 Known2, TLO, Depth + 1))
1450 return true;
1451
1452 // If all of the demanded bits are known one on one side, return the other.
1453 // These bits cannot contribute to the result of the 'and'.
1454 if (DemandedBits.isSubsetOf(Known2.Zero | Known.One))
1455 return TLO.CombineTo(Op, Op0);
1456 if (DemandedBits.isSubsetOf(Known.Zero | Known2.One))
1457 return TLO.CombineTo(Op, Op1);
1458 // If all of the demanded bits in the inputs are known zeros, return zero.
1459 if (DemandedBits.isSubsetOf(Known.Zero | Known2.Zero))
1460 return TLO.CombineTo(Op, TLO.DAG.getConstant(0, dl, VT));
1461 // If the RHS is a constant, see if we can simplify it.
1462 if (ShrinkDemandedConstant(Op, ~Known2.Zero & DemandedBits, DemandedElts,
1463 TLO))
1464 return true;
1465 // If the operation can be done in a smaller type, do so.
1466 if (ShrinkDemandedOp(Op, BitWidth, DemandedBits, TLO))
1467 return true;
1468
1469 // Attempt to avoid multi-use ops if we don't need anything from them.
1470 if (!DemandedBits.isAllOnes() || !DemandedElts.isAllOnes()) {
1471 SDValue DemandedOp0 = SimplifyMultipleUseDemandedBits(
1472 Op0, DemandedBits, DemandedElts, TLO.DAG, Depth + 1);
1473 SDValue DemandedOp1 = SimplifyMultipleUseDemandedBits(
1474 Op1, DemandedBits, DemandedElts, TLO.DAG, Depth + 1);
1475 if (DemandedOp0 || DemandedOp1) {
1476 Op0 = DemandedOp0 ? DemandedOp0 : Op0;
1477 Op1 = DemandedOp1 ? DemandedOp1 : Op1;
1478 SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), dl, VT, Op0, Op1);
1479 return TLO.CombineTo(Op, NewOp);
1480 }
1481 }
1482
1483 Known &= Known2;
1484 break;
1485 }
1486 case ISD::OR: {
1487 SDValue Op0 = Op.getOperand(0);
1488 SDValue Op1 = Op.getOperand(1);
1489 SDNodeFlags Flags = Op.getNode()->getFlags();
1490 if (SimplifyDemandedBits(Op1, DemandedBits, DemandedElts, Known, TLO,
1491 Depth + 1)) {
1492 if (Flags.hasDisjoint()) {
1493 Flags.setDisjoint(false);
1494 Op->setFlags(Flags);
1495 }
1496 return true;
1497 }
1498
1499 if (SimplifyDemandedBits(Op0, ~Known.One & DemandedBits, DemandedElts,
1500 Known2, TLO, Depth + 1)) {
1501 if (Flags.hasDisjoint()) {
1502 Flags.setDisjoint(false);
1503 Op->setFlags(Flags);
1504 }
1505 return true;
1506 }
1507
1508 // If all of the demanded bits are known zero on one side, return the other.
1509 // These bits cannot contribute to the result of the 'or'.
1510 if (DemandedBits.isSubsetOf(Known2.One | Known.Zero))
1511 return TLO.CombineTo(Op, Op0);
1512 if (DemandedBits.isSubsetOf(Known.One | Known2.Zero))
1513 return TLO.CombineTo(Op, Op1);
1514 // If the RHS is a constant, see if we can simplify it.
1515 if (ShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO))
1516 return true;
1517 // If the operation can be done in a smaller type, do so.
1518 if (ShrinkDemandedOp(Op, BitWidth, DemandedBits, TLO))
1519 return true;
1520
1521 // Attempt to avoid multi-use ops if we don't need anything from them.
1522 if (!DemandedBits.isAllOnes() || !DemandedElts.isAllOnes()) {
1523 SDValue DemandedOp0 = SimplifyMultipleUseDemandedBits(
1524 Op0, DemandedBits, DemandedElts, TLO.DAG, Depth + 1);
1525 SDValue DemandedOp1 = SimplifyMultipleUseDemandedBits(
1526 Op1, DemandedBits, DemandedElts, TLO.DAG, Depth + 1);
1527 if (DemandedOp0 || DemandedOp1) {
1528 Op0 = DemandedOp0 ? DemandedOp0 : Op0;
1529 Op1 = DemandedOp1 ? DemandedOp1 : Op1;
1530 SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), dl, VT, Op0, Op1);
1531 return TLO.CombineTo(Op, NewOp);
1532 }
1533 }
1534
1535 // (or (and X, C1), (and (or X, Y), C2)) -> (or (and X, C1|C2), (and Y, C2))
1536 // TODO: Use SimplifyMultipleUseDemandedBits to peek through masks.
1537 if (Op0.getOpcode() == ISD::AND && Op1.getOpcode() == ISD::AND &&
1538 Op0->hasOneUse() && Op1->hasOneUse()) {
1539 // Attempt to match all commutations - m_c_Or would've been useful!
1540 for (int I = 0; I != 2; ++I) {
1541 SDValue X = Op.getOperand(I).getOperand(0);
1542 SDValue C1 = Op.getOperand(I).getOperand(1);
1543 SDValue Alt = Op.getOperand(1 - I).getOperand(0);
1544 SDValue C2 = Op.getOperand(1 - I).getOperand(1);
1545 if (Alt.getOpcode() == ISD::OR) {
1546 for (int J = 0; J != 2; ++J) {
1547 if (X == Alt.getOperand(J)) {
1548 SDValue Y = Alt.getOperand(1 - J);
1549 if (SDValue C12 = TLO.DAG.FoldConstantArithmetic(ISD::OR, dl, VT,
1550 {C1, C2})) {
1551 SDValue MaskX = TLO.DAG.getNode(ISD::AND, dl, VT, X, C12);
1552 SDValue MaskY = TLO.DAG.getNode(ISD::AND, dl, VT, Y, C2);
1553 return TLO.CombineTo(
1554 Op, TLO.DAG.getNode(ISD::OR, dl, VT, MaskX, MaskY));
1555 }
1556 }
1557 }
1558 }
1559 }
1560 }
1561
1562 Known |= Known2;
1563 break;
1564 }
1565 case ISD::XOR: {
1566 SDValue Op0 = Op.getOperand(0);
1567 SDValue Op1 = Op.getOperand(1);
1568
1569 if (SimplifyDemandedBits(Op1, DemandedBits, DemandedElts, Known, TLO,
1570 Depth + 1))
1571 return true;
1572 if (SimplifyDemandedBits(Op0, DemandedBits, DemandedElts, Known2, TLO,
1573 Depth + 1))
1574 return true;
1575
1576 // If all of the demanded bits are known zero on one side, return the other.
1577 // These bits cannot contribute to the result of the 'xor'.
1578 if (DemandedBits.isSubsetOf(Known.Zero))
1579 return TLO.CombineTo(Op, Op0);
1580 if (DemandedBits.isSubsetOf(Known2.Zero))
1581 return TLO.CombineTo(Op, Op1);
1582 // If the operation can be done in a smaller type, do so.
1583 if (ShrinkDemandedOp(Op, BitWidth, DemandedBits, TLO))
1584 return true;
1585
1586 // If all of the unknown bits are known to be zero on one side or the other
1587 // turn this into an *inclusive* or.
1588 // e.g. (A & C1)^(B & C2) -> (A & C1)|(B & C2) iff C1&C2 == 0
1589 if (DemandedBits.isSubsetOf(Known.Zero | Known2.Zero))
1590 return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::OR, dl, VT, Op0, Op1));
1591
1592 ConstantSDNode *C = isConstOrConstSplat(Op1, DemandedElts);
1593 if (C) {
1594 // If one side is a constant, and all of the set bits in the constant are
1595 // also known set on the other side, turn this into an AND, as we know
1596 // the bits will be cleared.
1597 // e.g. (X | C1) ^ C2 --> (X | C1) & ~C2 iff (C1&C2) == C2
1598 // NB: it is okay if more bits are known than are requested
1599 if (C->getAPIntValue() == Known2.One) {
1600 SDValue ANDC =
1601 TLO.DAG.getConstant(~C->getAPIntValue() & DemandedBits, dl, VT);
1602 return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::AND, dl, VT, Op0, ANDC));
1603 }
1604
1605 // If the RHS is a constant, see if we can change it. Don't alter a -1
1606 // constant because that's a 'not' op, and that is better for combining
1607 // and codegen.
1608 if (!C->isAllOnes() && DemandedBits.isSubsetOf(C->getAPIntValue())) {
1609 // We're flipping all demanded bits. Flip the undemanded bits too.
1610 SDValue New = TLO.DAG.getNOT(dl, Op0, VT);
1611 return TLO.CombineTo(Op, New);
1612 }
1613
1614 unsigned Op0Opcode = Op0.getOpcode();
1615 if ((Op0Opcode == ISD::SRL || Op0Opcode == ISD::SHL) && Op0.hasOneUse()) {
1616 if (ConstantSDNode *ShiftC =
1617 isConstOrConstSplat(Op0.getOperand(1), DemandedElts)) {
1618 // Don't crash on an oversized shift. We can not guarantee that a
1619 // bogus shift has been simplified to undef.
1620 if (ShiftC->getAPIntValue().ult(BitWidth)) {
1621 uint64_t ShiftAmt = ShiftC->getZExtValue();
1623 Ones = Op0Opcode == ISD::SHL ? Ones.shl(ShiftAmt)
1624 : Ones.lshr(ShiftAmt);
1625 const TargetLowering &TLI = TLO.DAG.getTargetLoweringInfo();
1626 if ((DemandedBits & C->getAPIntValue()) == (DemandedBits & Ones) &&
1627 TLI.isDesirableToCommuteXorWithShift(Op.getNode())) {
1628 // If the xor constant is a demanded mask, do a 'not' before the
1629 // shift:
1630 // xor (X << ShiftC), XorC --> (not X) << ShiftC
1631 // xor (X >> ShiftC), XorC --> (not X) >> ShiftC
1632 SDValue Not = TLO.DAG.getNOT(dl, Op0.getOperand(0), VT);
1633 return TLO.CombineTo(Op, TLO.DAG.getNode(Op0Opcode, dl, VT, Not,
1634 Op0.getOperand(1)));
1635 }
1636 }
1637 }
1638 }
1639 }
1640
1641 // If we can't turn this into a 'not', try to shrink the constant.
1642 if (!C || !C->isAllOnes())
1643 if (ShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO))
1644 return true;
1645
1646 // Attempt to avoid multi-use ops if we don't need anything from them.
1647 if (!DemandedBits.isAllOnes() || !DemandedElts.isAllOnes()) {
1648 SDValue DemandedOp0 = SimplifyMultipleUseDemandedBits(
1649 Op0, DemandedBits, DemandedElts, TLO.DAG, Depth + 1);
1650 SDValue DemandedOp1 = SimplifyMultipleUseDemandedBits(
1651 Op1, DemandedBits, DemandedElts, TLO.DAG, Depth + 1);
1652 if (DemandedOp0 || DemandedOp1) {
1653 Op0 = DemandedOp0 ? DemandedOp0 : Op0;
1654 Op1 = DemandedOp1 ? DemandedOp1 : Op1;
1655 SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), dl, VT, Op0, Op1);
1656 return TLO.CombineTo(Op, NewOp);
1657 }
1658 }
1659
1660 Known ^= Known2;
1661 break;
1662 }
1663 case ISD::SELECT:
1664 if (SimplifyDemandedBits(Op.getOperand(2), DemandedBits, DemandedElts,
1665 Known, TLO, Depth + 1))
1666 return true;
1667 if (SimplifyDemandedBits(Op.getOperand(1), DemandedBits, DemandedElts,
1668 Known2, TLO, Depth + 1))
1669 return true;
1670
1671 // If the operands are constants, see if we can simplify them.
1672 if (ShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO))
1673 return true;
1674
1675 // Only known if known in both the LHS and RHS.
1676 Known = Known.intersectWith(Known2);
1677 break;
1678 case ISD::VSELECT:
1679 if (SimplifyDemandedBits(Op.getOperand(2), DemandedBits, DemandedElts,
1680 Known, TLO, Depth + 1))
1681 return true;
1682 if (SimplifyDemandedBits(Op.getOperand(1), DemandedBits, DemandedElts,
1683 Known2, TLO, Depth + 1))
1684 return true;
1685
1686 // Only known if known in both the LHS and RHS.
1687 Known = Known.intersectWith(Known2);
1688 break;
1689 case ISD::SELECT_CC:
1690 if (SimplifyDemandedBits(Op.getOperand(3), DemandedBits, DemandedElts,
1691 Known, TLO, Depth + 1))
1692 return true;
1693 if (SimplifyDemandedBits(Op.getOperand(2), DemandedBits, DemandedElts,
1694 Known2, TLO, Depth + 1))
1695 return true;
1696
1697 // If the operands are constants, see if we can simplify them.
1698 if (ShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO))
1699 return true;
1700
1701 // Only known if known in both the LHS and RHS.
1702 Known = Known.intersectWith(Known2);
1703 break;
1704 case ISD::SETCC: {
1705 SDValue Op0 = Op.getOperand(0);
1706 SDValue Op1 = Op.getOperand(1);
1707 ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(2))->get();
1708 // If (1) we only need the sign-bit, (2) the setcc operands are the same
1709 // width as the setcc result, and (3) the result of a setcc conforms to 0 or
1710 // -1, we may be able to bypass the setcc.
1711 if (DemandedBits.isSignMask() &&
1715 // If we're testing X < 0, then this compare isn't needed - just use X!
1716 // FIXME: We're limiting to integer types here, but this should also work
1717 // if we don't care about FP signed-zero. The use of SETLT with FP means
1718 // that we don't care about NaNs.
1719 if (CC == ISD::SETLT && Op1.getValueType().isInteger() &&
1721 return TLO.CombineTo(Op, Op0);
1722
1723 // TODO: Should we check for other forms of sign-bit comparisons?
1724 // Examples: X <= -1, X >= 0
1725 }
1726 if (getBooleanContents(Op0.getValueType()) ==
1728 BitWidth > 1)
1729 Known.Zero.setBitsFrom(1);
1730 break;
1731 }
1732 case ISD::SHL: {
1733 SDValue Op0 = Op.getOperand(0);
1734 SDValue Op1 = Op.getOperand(1);
1735 EVT ShiftVT = Op1.getValueType();
1736
1737 if (std::optional<uint64_t> KnownSA =
1738 TLO.DAG.getValidShiftAmount(Op, DemandedElts, Depth + 1)) {
1739 unsigned ShAmt = *KnownSA;
1740 if (ShAmt == 0)
1741 return TLO.CombineTo(Op, Op0);
1742
1743 // If this is ((X >>u C1) << ShAmt), see if we can simplify this into a
1744 // single shift. We can do this if the bottom bits (which are shifted
1745 // out) are never demanded.
1746 // TODO - support non-uniform vector amounts.
1747 if (Op0.getOpcode() == ISD::SRL) {
1748 if (!DemandedBits.intersects(APInt::getLowBitsSet(BitWidth, ShAmt))) {
1749 if (std::optional<uint64_t> InnerSA =
1750 TLO.DAG.getValidShiftAmount(Op0, DemandedElts, Depth + 2)) {
1751 unsigned C1 = *InnerSA;
1752 unsigned Opc = ISD::SHL;
1753 int Diff = ShAmt - C1;
1754 if (Diff < 0) {
1755 Diff = -Diff;
1756 Opc = ISD::SRL;
1757 }
1758 SDValue NewSA = TLO.DAG.getConstant(Diff, dl, ShiftVT);
1759 return TLO.CombineTo(
1760 Op, TLO.DAG.getNode(Opc, dl, VT, Op0.getOperand(0), NewSA));
1761 }
1762 }
1763 }
1764
1765 // Convert (shl (anyext x, c)) to (anyext (shl x, c)) if the high bits
1766 // are not demanded. This will likely allow the anyext to be folded away.
1767 // TODO - support non-uniform vector amounts.
1768 if (Op0.getOpcode() == ISD::ANY_EXTEND) {
1769 SDValue InnerOp = Op0.getOperand(0);
1770 EVT InnerVT = InnerOp.getValueType();
1771 unsigned InnerBits = InnerVT.getScalarSizeInBits();
1772 if (ShAmt < InnerBits && DemandedBits.getActiveBits() <= InnerBits &&
1773 isTypeDesirableForOp(ISD::SHL, InnerVT)) {
1774 SDValue NarrowShl = TLO.DAG.getNode(
1775 ISD::SHL, dl, InnerVT, InnerOp,
1776 TLO.DAG.getShiftAmountConstant(ShAmt, InnerVT, dl));
1777 return TLO.CombineTo(
1778 Op, TLO.DAG.getNode(ISD::ANY_EXTEND, dl, VT, NarrowShl));
1779 }
1780
1781 // Repeat the SHL optimization above in cases where an extension
1782 // intervenes: (shl (anyext (shr x, c1)), c2) to
1783 // (shl (anyext x), c2-c1). This requires that the bottom c1 bits
1784 // aren't demanded (as above) and that the shifted upper c1 bits of
1785 // x aren't demanded.
1786 // TODO - support non-uniform vector amounts.
1787 if (InnerOp.getOpcode() == ISD::SRL && Op0.hasOneUse() &&
1788 InnerOp.hasOneUse()) {
1789 if (std::optional<uint64_t> SA2 = TLO.DAG.getValidShiftAmount(
1790 InnerOp, DemandedElts, Depth + 2)) {
1791 unsigned InnerShAmt = *SA2;
1792 if (InnerShAmt < ShAmt && InnerShAmt < InnerBits &&
1793 DemandedBits.getActiveBits() <=
1794 (InnerBits - InnerShAmt + ShAmt) &&
1795 DemandedBits.countr_zero() >= ShAmt) {
1796 SDValue NewSA =
1797 TLO.DAG.getConstant(ShAmt - InnerShAmt, dl, ShiftVT);
1798 SDValue NewExt = TLO.DAG.getNode(ISD::ANY_EXTEND, dl, VT,
1799 InnerOp.getOperand(0));
1800 return TLO.CombineTo(
1801 Op, TLO.DAG.getNode(ISD::SHL, dl, VT, NewExt, NewSA));
1802 }
1803 }
1804 }
1805 }
1806
1807 APInt InDemandedMask = DemandedBits.lshr(ShAmt);
1808 if (SimplifyDemandedBits(Op0, InDemandedMask, DemandedElts, Known, TLO,
1809 Depth + 1)) {
1810 SDNodeFlags Flags = Op.getNode()->getFlags();
1811 if (Flags.hasNoSignedWrap() || Flags.hasNoUnsignedWrap()) {
1812 // Disable the nsw and nuw flags. We can no longer guarantee that we
1813 // won't wrap after simplification.
1814 Flags.setNoSignedWrap(false);
1815 Flags.setNoUnsignedWrap(false);
1816 Op->setFlags(Flags);
1817 }
1818 return true;
1819 }
1820 Known.Zero <<= ShAmt;
1821 Known.One <<= ShAmt;
1822 // low bits known zero.
1823 Known.Zero.setLowBits(ShAmt);
1824
1825 // Attempt to avoid multi-use ops if we don't need anything from them.
1826 if (!InDemandedMask.isAllOnes() || !DemandedElts.isAllOnes()) {
1827 SDValue DemandedOp0 = SimplifyMultipleUseDemandedBits(
1828 Op0, InDemandedMask, DemandedElts, TLO.DAG, Depth + 1);
1829 if (DemandedOp0) {
1830 SDValue NewOp = TLO.DAG.getNode(ISD::SHL, dl, VT, DemandedOp0, Op1);
1831 return TLO.CombineTo(Op, NewOp);
1832 }
1833 }
1834
1835 // TODO: Can we merge this fold with the one below?
1836 // Try shrinking the operation as long as the shift amount will still be
1837 // in range.
1838 if (ShAmt < DemandedBits.getActiveBits() && !VT.isVector() &&
1839 Op.getNode()->hasOneUse()) {
1840 // Search for the smallest integer type with free casts to and from
1841 // Op's type. For expedience, just check power-of-2 integer types.
1842 unsigned DemandedSize = DemandedBits.getActiveBits();
1843 for (unsigned SmallVTBits = llvm::bit_ceil(DemandedSize);
1844 SmallVTBits < BitWidth; SmallVTBits = NextPowerOf2(SmallVTBits)) {
1845 EVT SmallVT = EVT::getIntegerVT(*TLO.DAG.getContext(), SmallVTBits);
1846 if (isNarrowingProfitable(VT, SmallVT) &&
1847 isTypeDesirableForOp(ISD::SHL, SmallVT) &&
1848 isTruncateFree(VT, SmallVT) && isZExtFree(SmallVT, VT) &&
1849 (!TLO.LegalOperations() || isOperationLegal(ISD::SHL, SmallVT))) {
1850 assert(DemandedSize <= SmallVTBits &&
1851 "Narrowed below demanded bits?");
1852 // We found a type with free casts.
1853 SDValue NarrowShl = TLO.DAG.getNode(
1854 ISD::SHL, dl, SmallVT,
1855 TLO.DAG.getNode(ISD::TRUNCATE, dl, SmallVT, Op.getOperand(0)),
1856 TLO.DAG.getShiftAmountConstant(ShAmt, SmallVT, dl));
1857 return TLO.CombineTo(
1858 Op, TLO.DAG.getNode(ISD::ANY_EXTEND, dl, VT, NarrowShl));
1859 }
1860 }
1861 }
1862
1863 // Narrow shift to lower half - similar to ShrinkDemandedOp.
1864 // (shl i64:x, K) -> (i64 zero_extend (shl (i32 (trunc i64:x)), K))
1865 // Only do this if we demand the upper half so the knownbits are correct.
1866 unsigned HalfWidth = BitWidth / 2;
1867 if ((BitWidth % 2) == 0 && !VT.isVector() && ShAmt < HalfWidth &&
1868 DemandedBits.countLeadingOnes() >= HalfWidth) {
1869 EVT HalfVT = EVT::getIntegerVT(*TLO.DAG.getContext(), HalfWidth);
1870 if (isNarrowingProfitable(VT, HalfVT) &&
1871 isTypeDesirableForOp(ISD::SHL, HalfVT) &&
1872 isTruncateFree(VT, HalfVT) && isZExtFree(HalfVT, VT) &&
1873 (!TLO.LegalOperations() || isOperationLegal(ISD::SHL, HalfVT))) {
1874 // If we're demanding the upper bits at all, we must ensure
1875 // that the upper bits of the shift result are known to be zero,
1876 // which is equivalent to the narrow shift being NUW.
1877 if (bool IsNUW = (Known.countMinLeadingZeros() >= HalfWidth)) {
1878 bool IsNSW = Known.countMinSignBits() > HalfWidth;
1879 SDNodeFlags Flags;
1880 Flags.setNoSignedWrap(IsNSW);
1881 Flags.setNoUnsignedWrap(IsNUW);
1882 SDValue NewOp = TLO.DAG.getNode(ISD::TRUNCATE, dl, HalfVT, Op0);
1883 SDValue NewShiftAmt = TLO.DAG.getShiftAmountConstant(
1884 ShAmt, HalfVT, dl, TLO.LegalTypes());
1885 SDValue NewShift = TLO.DAG.getNode(ISD::SHL, dl, HalfVT, NewOp,
1886 NewShiftAmt, Flags);
1887 SDValue NewExt =
1888 TLO.DAG.getNode(ISD::ZERO_EXTEND, dl, VT, NewShift);
1889 return TLO.CombineTo(Op, NewExt);
1890 }
1891 }
1892 }
1893 } else {
1894 // This is a variable shift, so we can't shift the demand mask by a known
1895 // amount. But if we are not demanding high bits, then we are not
1896 // demanding those bits from the pre-shifted operand either.
1897 if (unsigned CTLZ = DemandedBits.countl_zero()) {
1898 APInt DemandedFromOp(APInt::getLowBitsSet(BitWidth, BitWidth - CTLZ));
1899 if (SimplifyDemandedBits(Op0, DemandedFromOp, DemandedElts, Known, TLO,
1900 Depth + 1)) {
1901 SDNodeFlags Flags = Op.getNode()->getFlags();
1902 if (Flags.hasNoSignedWrap() || Flags.hasNoUnsignedWrap()) {
1903 // Disable the nsw and nuw flags. We can no longer guarantee that we
1904 // won't wrap after simplification.
1905 Flags.setNoSignedWrap(false);
1906 Flags.setNoUnsignedWrap(false);
1907 Op->setFlags(Flags);
1908 }
1909 return true;
1910 }
1911 Known.resetAll();
1912 }
1913 }
1914
1915 // If we are only demanding sign bits then we can use the shift source
1916 // directly.
1917 if (std::optional<uint64_t> MaxSA =
1918 TLO.DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) {
1919 unsigned ShAmt = *MaxSA;
1920 unsigned NumSignBits =
1921 TLO.DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
1922 unsigned UpperDemandedBits = BitWidth - DemandedBits.countr_zero();
1923 if (NumSignBits > ShAmt && (NumSignBits - ShAmt) >= (UpperDemandedBits))
1924 return TLO.CombineTo(Op, Op0);
1925 }
1926 break;
1927 }
1928 case ISD::SRL: {
1929 SDValue Op0 = Op.getOperand(0);
1930 SDValue Op1 = Op.getOperand(1);
1931 EVT ShiftVT = Op1.getValueType();
1932
1933 if (std::optional<uint64_t> KnownSA =
1934 TLO.DAG.getValidShiftAmount(Op, DemandedElts, Depth + 1)) {
1935 unsigned ShAmt = *KnownSA;
1936 if (ShAmt == 0)
1937 return TLO.CombineTo(Op, Op0);
1938
1939 // If this is ((X << C1) >>u ShAmt), see if we can simplify this into a
1940 // single shift. We can do this if the top bits (which are shifted out)
1941 // are never demanded.
1942 // TODO - support non-uniform vector amounts.
1943 if (Op0.getOpcode() == ISD::SHL) {
1944 if (!DemandedBits.intersects(APInt::getHighBitsSet(BitWidth, ShAmt))) {
1945 if (std::optional<uint64_t> InnerSA =
1946 TLO.DAG.getValidShiftAmount(Op0, DemandedElts, Depth + 2)) {
1947 unsigned C1 = *InnerSA;
1948 unsigned Opc = ISD::SRL;
1949 int Diff = ShAmt - C1;
1950 if (Diff < 0) {
1951 Diff = -Diff;
1952 Opc = ISD::SHL;
1953 }
1954 SDValue NewSA = TLO.DAG.getConstant(Diff, dl, ShiftVT);
1955 return TLO.CombineTo(
1956 Op, TLO.DAG.getNode(Opc, dl, VT, Op0.getOperand(0), NewSA));
1957 }
1958 }
1959 }
1960
1961 APInt InDemandedMask = (DemandedBits << ShAmt);
1962
1963 // If the shift is exact, then it does demand the low bits (and knows that
1964 // they are zero).
1965 if (Op->getFlags().hasExact())
1966 InDemandedMask.setLowBits(ShAmt);
1967
1968 // Narrow shift to lower half - similar to ShrinkDemandedOp.
1969 // (srl i64:x, K) -> (i64 zero_extend (srl (i32 (trunc i64:x)), K))
1970 if ((BitWidth % 2) == 0 && !VT.isVector()) {
1972 EVT HalfVT = EVT::getIntegerVT(*TLO.DAG.getContext(), BitWidth / 2);
1973 if (isNarrowingProfitable(VT, HalfVT) &&
1974 isTypeDesirableForOp(ISD::SRL, HalfVT) &&
1975 isTruncateFree(VT, HalfVT) && isZExtFree(HalfVT, VT) &&
1976 (!TLO.LegalOperations() || isOperationLegal(ISD::SRL, HalfVT)) &&
1977 ((InDemandedMask.countLeadingZeros() >= (BitWidth / 2)) ||
1978 TLO.DAG.MaskedValueIsZero(Op0, HiBits))) {
1979 SDValue NewOp = TLO.DAG.getNode(ISD::TRUNCATE, dl, HalfVT, Op0);
1980 SDValue NewShiftAmt = TLO.DAG.getShiftAmountConstant(
1981 ShAmt, HalfVT, dl, TLO.LegalTypes());
1982 SDValue NewShift =
1983 TLO.DAG.getNode(ISD::SRL, dl, HalfVT, NewOp, NewShiftAmt);
1984 return TLO.CombineTo(
1985 Op, TLO.DAG.getNode(ISD::ZERO_EXTEND, dl, VT, NewShift));
1986 }
1987 }
1988
1989 // Compute the new bits that are at the top now.
1990 if (SimplifyDemandedBits(Op0, InDemandedMask, DemandedElts, Known, TLO,
1991 Depth + 1))
1992 return true;
1993 Known.Zero.lshrInPlace(ShAmt);
1994 Known.One.lshrInPlace(ShAmt);
1995 // High bits known zero.
1996 Known.Zero.setHighBits(ShAmt);
1997
1998 // Attempt to avoid multi-use ops if we don't need anything from them.
1999 if (!InDemandedMask.isAllOnes() || !DemandedElts.isAllOnes()) {
2000 SDValue DemandedOp0 = SimplifyMultipleUseDemandedBits(
2001 Op0, InDemandedMask, DemandedElts, TLO.DAG, Depth + 1);
2002 if (DemandedOp0) {
2003 SDValue NewOp = TLO.DAG.getNode(ISD::SRL, dl, VT, DemandedOp0, Op1);
2004 return TLO.CombineTo(Op, NewOp);
2005 }
2006 }
2007 } else {
2008 // Use generic knownbits computation as it has support for non-uniform
2009 // shift amounts.
2010 Known = TLO.DAG.computeKnownBits(Op, DemandedElts, Depth);
2011 }
2012
2013 // Try to match AVG patterns (after shift simplification).
2014 if (SDValue AVG = combineShiftToAVG(Op, TLO, *this, DemandedBits,
2015 DemandedElts, Depth + 1))
2016 return TLO.CombineTo(Op, AVG);
2017
2018 break;
2019 }
2020 case ISD::SRA: {
2021 SDValue Op0 = Op.getOperand(0);
2022 SDValue Op1 = Op.getOperand(1);
2023 EVT ShiftVT = Op1.getValueType();
2024
2025 // If we only want bits that already match the signbit then we don't need
2026 // to shift.
2027 unsigned NumHiDemandedBits = BitWidth - DemandedBits.countr_zero();
2028 if (TLO.DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1) >=
2029 NumHiDemandedBits)
2030 return TLO.CombineTo(Op, Op0);
2031
2032 // If this is an arithmetic shift right and only the low-bit is set, we can
2033 // always convert this into a logical shr, even if the shift amount is
2034 // variable. The low bit of the shift cannot be an input sign bit unless
2035 // the shift amount is >= the size of the datatype, which is undefined.
2036 if (DemandedBits.isOne())
2037 return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SRL, dl, VT, Op0, Op1));
2038
2039 if (std::optional<uint64_t> KnownSA =
2040 TLO.DAG.getValidShiftAmount(Op, DemandedElts, Depth + 1)) {
2041 unsigned ShAmt = *KnownSA;
2042 if (ShAmt == 0)
2043 return TLO.CombineTo(Op, Op0);
2044
2045 // fold (sra (shl x, c1), c1) -> sext_inreg for some c1 and target
2046 // supports sext_inreg.
2047 if (Op0.getOpcode() == ISD::SHL) {
2048 if (std::optional<uint64_t> InnerSA =
2049 TLO.DAG.getValidShiftAmount(Op0, DemandedElts, Depth + 2)) {
2050 unsigned LowBits = BitWidth - ShAmt;
2051 EVT ExtVT = EVT::getIntegerVT(*TLO.DAG.getContext(), LowBits);
2052 if (VT.isVector())
2053 ExtVT = EVT::getVectorVT(*TLO.DAG.getContext(), ExtVT,
2055
2056 if (*InnerSA == ShAmt) {
2057 if (!TLO.LegalOperations() ||
2059 return TLO.CombineTo(
2060 Op, TLO.DAG.getNode(ISD::SIGN_EXTEND_INREG, dl, VT,
2061 Op0.getOperand(0),
2062 TLO.DAG.getValueType(ExtVT)));
2063
2064 // Even if we can't convert to sext_inreg, we might be able to
2065 // remove this shift pair if the input is already sign extended.
2066 unsigned NumSignBits =
2067 TLO.DAG.ComputeNumSignBits(Op0.getOperand(0), DemandedElts);
2068 if (NumSignBits > ShAmt)
2069 return TLO.CombineTo(Op, Op0.getOperand(0));
2070 }
2071 }
2072 }
2073
2074 APInt InDemandedMask = (DemandedBits << ShAmt);
2075
2076 // If the shift is exact, then it does demand the low bits (and knows that
2077 // they are zero).
2078 if (Op->getFlags().hasExact())
2079 InDemandedMask.setLowBits(ShAmt);
2080
2081 // If any of the demanded bits are produced by the sign extension, we also
2082 // demand the input sign bit.
2083 if (DemandedBits.countl_zero() < ShAmt)
2084 InDemandedMask.setSignBit();
2085
2086 if (SimplifyDemandedBits(Op0, InDemandedMask, DemandedElts, Known, TLO,
2087 Depth + 1))
2088 return true;
2089 Known.Zero.lshrInPlace(ShAmt);
2090 Known.One.lshrInPlace(ShAmt);
2091
2092 // If the input sign bit is known to be zero, or if none of the top bits
2093 // are demanded, turn this into an unsigned shift right.
2094 if (Known.Zero[BitWidth - ShAmt - 1] ||
2095 DemandedBits.countl_zero() >= ShAmt) {
2096 SDNodeFlags Flags;
2097 Flags.setExact(Op->getFlags().hasExact());
2098 return TLO.CombineTo(
2099 Op, TLO.DAG.getNode(ISD::SRL, dl, VT, Op0, Op1, Flags));
2100 }
2101
2102 int Log2 = DemandedBits.exactLogBase2();
2103 if (Log2 >= 0) {
2104 // The bit must come from the sign.
2105 SDValue NewSA = TLO.DAG.getConstant(BitWidth - 1 - Log2, dl, ShiftVT);
2106 return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SRL, dl, VT, Op0, NewSA));
2107 }
2108
2109 if (Known.One[BitWidth - ShAmt - 1])
2110 // New bits are known one.
2111 Known.One.setHighBits(ShAmt);
2112
2113 // Attempt to avoid multi-use ops if we don't need anything from them.
2114 if (!InDemandedMask.isAllOnes() || !DemandedElts.isAllOnes()) {
2115 SDValue DemandedOp0 = SimplifyMultipleUseDemandedBits(
2116 Op0, InDemandedMask, DemandedElts, TLO.DAG, Depth + 1);
2117 if (DemandedOp0) {
2118 SDValue NewOp = TLO.DAG.getNode(ISD::SRA, dl, VT, DemandedOp0, Op1);
2119 return TLO.CombineTo(Op, NewOp);
2120 }
2121 }
2122 }
2123
2124 // Try to match AVG patterns (after shift simplification).
2125 if (SDValue AVG = combineShiftToAVG(Op, TLO, *this, DemandedBits,
2126 DemandedElts, Depth + 1))
2127 return TLO.CombineTo(Op, AVG);
2128
2129 break;
2130 }
2131 case ISD::FSHL:
2132 case ISD::FSHR: {
2133 SDValue Op0 = Op.getOperand(0);
2134 SDValue Op1 = Op.getOperand(1);
2135 SDValue Op2 = Op.getOperand(2);
2136 bool IsFSHL = (Op.getOpcode() == ISD::FSHL);
2137
2138 if (ConstantSDNode *SA = isConstOrConstSplat(Op2, DemandedElts)) {
2139 unsigned Amt = SA->getAPIntValue().urem(BitWidth);
2140
2141 // For fshl, 0-shift returns the 1st arg.
2142 // For fshr, 0-shift returns the 2nd arg.
2143 if (Amt == 0) {
2144 if (SimplifyDemandedBits(IsFSHL ? Op0 : Op1, DemandedBits, DemandedElts,
2145 Known, TLO, Depth + 1))
2146 return true;
2147 break;
2148 }
2149
2150 // fshl: (Op0 << Amt) | (Op1 >> (BW - Amt))
2151 // fshr: (Op0 << (BW - Amt)) | (Op1 >> Amt)
2152 APInt Demanded0 = DemandedBits.lshr(IsFSHL ? Amt : (BitWidth - Amt));
2153 APInt Demanded1 = DemandedBits << (IsFSHL ? (BitWidth - Amt) : Amt);
2154 if (SimplifyDemandedBits(Op0, Demanded0, DemandedElts, Known2, TLO,
2155 Depth + 1))
2156 return true;
2157 if (SimplifyDemandedBits(Op1, Demanded1, DemandedElts, Known, TLO,
2158 Depth + 1))
2159 return true;
2160
2161 Known2.One <<= (IsFSHL ? Amt : (BitWidth - Amt));
2162 Known2.Zero <<= (IsFSHL ? Amt : (BitWidth - Amt));
2163 Known.One.lshrInPlace(IsFSHL ? (BitWidth - Amt) : Amt);
2164 Known.Zero.lshrInPlace(IsFSHL ? (BitWidth - Amt) : Amt);
2165 Known = Known.unionWith(Known2);
2166
2167 // Attempt to avoid multi-use ops if we don't need anything from them.
2168 if (!Demanded0.isAllOnes() || !Demanded1.isAllOnes() ||
2169 !DemandedElts.isAllOnes()) {
2170 SDValue DemandedOp0 = SimplifyMultipleUseDemandedBits(
2171 Op0, Demanded0, DemandedElts, TLO.DAG, Depth + 1);
2172 SDValue DemandedOp1 = SimplifyMultipleUseDemandedBits(
2173 Op1, Demanded1, DemandedElts, TLO.DAG, Depth + 1);
2174 if (DemandedOp0 || DemandedOp1) {
2175 DemandedOp0 = DemandedOp0 ? DemandedOp0 : Op0;
2176 DemandedOp1 = DemandedOp1 ? DemandedOp1 : Op1;
2177 SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), dl, VT, DemandedOp0,
2178 DemandedOp1, Op2);
2179 return TLO.CombineTo(Op, NewOp);
2180 }
2181 }
2182 }
2183
2184 // For pow-2 bitwidths we only demand the bottom modulo amt bits.
2185 if (isPowerOf2_32(BitWidth)) {
2186 APInt DemandedAmtBits(Op2.getScalarValueSizeInBits(), BitWidth - 1);
2187 if (SimplifyDemandedBits(Op2, DemandedAmtBits, DemandedElts,
2188 Known2, TLO, Depth + 1))
2189 return true;
2190 }
2191 break;
2192 }
2193 case ISD::ROTL:
2194 case ISD::ROTR: {
2195 SDValue Op0 = Op.getOperand(0);
2196 SDValue Op1 = Op.getOperand(1);
2197 bool IsROTL = (Op.getOpcode() == ISD::ROTL);
2198
2199 // If we're rotating an 0/-1 value, then it stays an 0/-1 value.
2200 if (BitWidth == TLO.DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1))
2201 return TLO.CombineTo(Op, Op0);
2202
2203 if (ConstantSDNode *SA = isConstOrConstSplat(Op1, DemandedElts)) {
2204 unsigned Amt = SA->getAPIntValue().urem(BitWidth);
2205 unsigned RevAmt = BitWidth - Amt;
2206
2207 // rotl: (Op0 << Amt) | (Op0 >> (BW - Amt))
2208 // rotr: (Op0 << (BW - Amt)) | (Op0 >> Amt)
2209 APInt Demanded0 = DemandedBits.rotr(IsROTL ? Amt : RevAmt);
2210 if (SimplifyDemandedBits(Op0, Demanded0, DemandedElts, Known2, TLO,
2211 Depth + 1))
2212 return true;
2213
2214 // rot*(x, 0) --> x
2215 if (Amt == 0)
2216 return TLO.CombineTo(Op, Op0);
2217
2218 // See if we don't demand either half of the rotated bits.
2219 if ((!TLO.LegalOperations() || isOperationLegal(ISD::SHL, VT)) &&
2220 DemandedBits.countr_zero() >= (IsROTL ? Amt : RevAmt)) {
2221 Op1 = TLO.DAG.getConstant(IsROTL ? Amt : RevAmt, dl, Op1.getValueType());
2222 return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SHL, dl, VT, Op0, Op1));
2223 }
2224 if ((!TLO.LegalOperations() || isOperationLegal(ISD::SRL, VT)) &&
2225 DemandedBits.countl_zero() >= (IsROTL ? RevAmt : Amt)) {
2226 Op1 = TLO.DAG.getConstant(IsROTL ? RevAmt : Amt, dl, Op1.getValueType());
2227 return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SRL, dl, VT, Op0, Op1));
2228 }
2229 }
2230
2231 // For pow-2 bitwidths we only demand the bottom modulo amt bits.
2232 if (isPowerOf2_32(BitWidth)) {
2233 APInt DemandedAmtBits(Op1.getScalarValueSizeInBits(), BitWidth - 1);
2234 if (SimplifyDemandedBits(Op1, DemandedAmtBits, DemandedElts, Known2, TLO,
2235 Depth + 1))
2236 return true;
2237 }
2238 break;
2239 }
2240 case ISD::SMIN:
2241 case ISD::SMAX:
2242 case ISD::UMIN:
2243 case ISD::UMAX: {
2244 unsigned Opc = Op.getOpcode();
2245 SDValue Op0 = Op.getOperand(0);
2246 SDValue Op1 = Op.getOperand(1);
2247
2248 // If we're only demanding signbits, then we can simplify to OR/AND node.
2249 unsigned BitOp =
2250 (Opc == ISD::SMIN || Opc == ISD::UMAX) ? ISD::OR : ISD::AND;
2251 unsigned NumSignBits =
2252 std::min(TLO.DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1),
2253 TLO.DAG.ComputeNumSignBits(Op1, DemandedElts, Depth + 1));
2254 unsigned NumDemandedUpperBits = BitWidth - DemandedBits.countr_zero();
2255 if (NumSignBits >= NumDemandedUpperBits)
2256 return TLO.CombineTo(Op, TLO.DAG.getNode(BitOp, SDLoc(Op), VT, Op0, Op1));
2257
2258 // Check if one arg is always less/greater than (or equal) to the other arg.
2259 KnownBits Known0 = TLO.DAG.computeKnownBits(Op0, DemandedElts, Depth + 1);
2260 KnownBits Known1 = TLO.DAG.computeKnownBits(Op1, DemandedElts, Depth + 1);
2261 switch (Opc) {
2262 case ISD::SMIN:
2263 if (std::optional<bool> IsSLE = KnownBits::sle(Known0, Known1))
2264 return TLO.CombineTo(Op, *IsSLE ? Op0 : Op1);
2265 if (std::optional<bool> IsSLT = KnownBits::slt(Known0, Known1))
2266 return TLO.CombineTo(Op, *IsSLT ? Op0 : Op1);
2267 Known = KnownBits::smin(Known0, Known1);
2268 break;
2269 case ISD::SMAX:
2270 if (std::optional<bool> IsSGE = KnownBits::sge(Known0, Known1))
2271 return TLO.CombineTo(Op, *IsSGE ? Op0 : Op1);
2272 if (std::optional<bool> IsSGT = KnownBits::sgt(Known0, Known1))
2273 return TLO.CombineTo(Op, *IsSGT ? Op0 : Op1);
2274 Known = KnownBits::smax(Known0, Known1);
2275 break;
2276 case ISD::UMIN:
2277 if (std::optional<bool> IsULE = KnownBits::ule(Known0, Known1))
2278 return TLO.CombineTo(Op, *IsULE ? Op0 : Op1);
2279 if (std::optional<bool> IsULT = KnownBits::ult(Known0, Known1))
2280 return TLO.CombineTo(Op, *IsULT ? Op0 : Op1);
2281 Known = KnownBits::umin(Known0, Known1);
2282 break;
2283 case ISD::UMAX:
2284 if (std::optional<bool> IsUGE = KnownBits::uge(Known0, Known1))
2285 return TLO.CombineTo(Op, *IsUGE ? Op0 : Op1);
2286 if (std::optional<bool> IsUGT = KnownBits::ugt(Known0, Known1))
2287 return TLO.CombineTo(Op, *IsUGT ? Op0 : Op1);
2288 Known = KnownBits::umax(Known0, Known1);
2289 break;
2290 }
2291 break;
2292 }
2293 case ISD::BITREVERSE: {
2294 SDValue Src = Op.getOperand(0);
2295 APInt DemandedSrcBits = DemandedBits.reverseBits();
2296 if (SimplifyDemandedBits(Src, DemandedSrcBits, DemandedElts, Known2, TLO,
2297 Depth + 1))
2298 return true;
2299 Known.One = Known2.One.reverseBits();
2300 Known.Zero = Known2.Zero.reverseBits();
2301 break;
2302 }
2303 case ISD::BSWAP: {
2304 SDValue Src = Op.getOperand(0);
2305
2306 // If the only bits demanded come from one byte of the bswap result,
2307 // just shift the input byte into position to eliminate the bswap.
2308 unsigned NLZ = DemandedBits.countl_zero();
2309 unsigned NTZ = DemandedBits.countr_zero();
2310
2311 // Round NTZ down to the next byte. If we have 11 trailing zeros, then
2312 // we need all the bits down to bit 8. Likewise, round NLZ. If we
2313 // have 14 leading zeros, round to 8.
2314 NLZ = alignDown(NLZ, 8);
2315 NTZ = alignDown(NTZ, 8);
2316 // If we need exactly one byte, we can do this transformation.
2317 if (BitWidth - NLZ - NTZ == 8) {
2318 // Replace this with either a left or right shift to get the byte into
2319 // the right place.
2320 unsigned ShiftOpcode = NLZ > NTZ ? ISD::SRL : ISD::SHL;
2321 if (!TLO.LegalOperations() || isOperationLegal(ShiftOpcode, VT)) {
2322 unsigned ShiftAmount = NLZ > NTZ ? NLZ - NTZ : NTZ - NLZ;
2323 SDValue ShAmt = TLO.DAG.getShiftAmountConstant(ShiftAmount, VT, dl);
2324 SDValue NewOp = TLO.DAG.getNode(ShiftOpcode, dl, VT, Src, ShAmt);
2325 return TLO.CombineTo(Op, NewOp);
2326 }
2327 }
2328
2329 APInt DemandedSrcBits = DemandedBits.byteSwap();
2330 if (SimplifyDemandedBits(Src, DemandedSrcBits, DemandedElts, Known2, TLO,
2331 Depth + 1))
2332 return true;
2333 Known.One = Known2.One.byteSwap();
2334 Known.Zero = Known2.Zero.byteSwap();
2335 break;
2336 }
2337 case ISD::CTPOP: {
2338 // If only 1 bit is demanded, replace with PARITY as long as we're before
2339 // op legalization.
2340 // FIXME: Limit to scalars for now.
2341 if (DemandedBits.isOne() && !TLO.LegalOps && !VT.isVector())
2342 return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::PARITY, dl, VT,
2343 Op.getOperand(0)));
2344
2345 Known = TLO.DAG.computeKnownBits(Op, DemandedElts, Depth);
2346 break;
2347 }
2349 SDValue Op0 = Op.getOperand(0);
2350 EVT ExVT = cast<VTSDNode>(Op.getOperand(1))->getVT();
2351 unsigned ExVTBits = ExVT.getScalarSizeInBits();
2352
2353 // If we only care about the highest bit, don't bother shifting right.
2354 if (DemandedBits.isSignMask()) {
2355 unsigned MinSignedBits =
2356 TLO.DAG.ComputeMaxSignificantBits(Op0, DemandedElts, Depth + 1);
2357 bool AlreadySignExtended = ExVTBits >= MinSignedBits;
2358 // However if the input is already sign extended we expect the sign
2359 // extension to be dropped altogether later and do not simplify.
2360 if (!AlreadySignExtended) {
2361 // Compute the correct shift amount type, which must be getShiftAmountTy
2362 // for scalar types after legalization.
2363 SDValue ShiftAmt =
2364 TLO.DAG.getShiftAmountConstant(BitWidth - ExVTBits, VT, dl);
2365 return TLO.CombineTo(Op,
2366 TLO.DAG.getNode(ISD::SHL, dl, VT, Op0, ShiftAmt));
2367 }
2368 }
2369
2370 // If none of the extended bits are demanded, eliminate the sextinreg.
2371 if (DemandedBits.getActiveBits() <= ExVTBits)
2372 return TLO.CombineTo(Op, Op0);
2373
2374 APInt InputDemandedBits = DemandedBits.getLoBits(ExVTBits);
2375
2376 // Since the sign extended bits are demanded, we know that the sign
2377 // bit is demanded.
2378 InputDemandedBits.setBit(ExVTBits - 1);
2379
2380 if (SimplifyDemandedBits(Op0, InputDemandedBits, DemandedElts, Known, TLO,
2381 Depth + 1))
2382 return true;
2383
2384 // If the sign bit of the input is known set or clear, then we know the
2385 // top bits of the result.
2386
2387 // If the input sign bit is known zero, convert this into a zero extension.
2388 if (Known.Zero[ExVTBits - 1])
2389 return TLO.CombineTo(Op, TLO.DAG.getZeroExtendInReg(Op0, dl, ExVT));
2390
2391 APInt Mask = APInt::getLowBitsSet(BitWidth, ExVTBits);
2392 if (Known.One[ExVTBits - 1]) { // Input sign bit known set
2393 Known.One.setBitsFrom(ExVTBits);
2394 Known.Zero &= Mask;
2395 } else { // Input sign bit unknown
2396 Known.Zero &= Mask;
2397 Known.One &= Mask;
2398 }
2399 break;
2400 }
2401 case ISD::BUILD_PAIR: {
2402 EVT HalfVT = Op.getOperand(0).getValueType();
2403 unsigned HalfBitWidth = HalfVT.getScalarSizeInBits();
2404
2405 APInt MaskLo = DemandedBits.getLoBits(HalfBitWidth).trunc(HalfBitWidth);
2406 APInt MaskHi = DemandedBits.getHiBits(HalfBitWidth).trunc(HalfBitWidth);
2407
2408 KnownBits KnownLo, KnownHi;
2409
2410 if (SimplifyDemandedBits(Op.getOperand(0), MaskLo, KnownLo, TLO, Depth + 1))
2411 return true;
2412
2413 if (SimplifyDemandedBits(Op.getOperand(1), MaskHi, KnownHi, TLO, Depth + 1))
2414 return true;
2415
2416 Known = KnownHi.concat(KnownLo);
2417 break;
2418 }
2420 if (VT.isScalableVector())
2421 return false;
2422 [[fallthrough]];
2423 case ISD::ZERO_EXTEND: {
2424 SDValue Src = Op.getOperand(0);
2425 EVT SrcVT = Src.getValueType();
2426 unsigned InBits = SrcVT.getScalarSizeInBits();
2427 unsigned InElts = SrcVT.isFixedLengthVector() ? SrcVT.getVectorNumElements() : 1;
2428 bool IsVecInReg = Op.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG;
2429
2430 // If none of the top bits are demanded, convert this into an any_extend.
2431 if (DemandedBits.getActiveBits() <= InBits) {
2432 // If we only need the non-extended bits of the bottom element
2433 // then we can just bitcast to the result.
2434 if (IsLE && IsVecInReg && DemandedElts == 1 &&
2435 VT.getSizeInBits() == SrcVT.getSizeInBits())
2436 return TLO.CombineTo(Op, TLO.DAG.getBitcast(VT, Src));
2437
2438 unsigned Opc =
2440 if (!TLO.LegalOperations() || isOperationLegal(Opc, VT))
2441 return TLO.CombineTo(Op, TLO.DAG.getNode(Opc, dl, VT, Src));
2442 }
2443
2444 SDNodeFlags Flags = Op->getFlags();
2445 APInt InDemandedBits = DemandedBits.trunc(InBits);
2446 APInt InDemandedElts = DemandedElts.zext(InElts);
2447 if (SimplifyDemandedBits(Src, InDemandedBits, InDemandedElts, Known, TLO,
2448 Depth + 1)) {
2449 if (Flags.hasNonNeg()) {
2450 Flags.setNonNeg(false);
2451 Op->setFlags(Flags);
2452 }
2453 return true;
2454 }
2455 assert(Known.getBitWidth() == InBits && "Src width has changed?");
2456 Known = Known.zext(BitWidth);
2457
2458 // Attempt to avoid multi-use ops if we don't need anything from them.
2459 if (SDValue NewSrc = SimplifyMultipleUseDemandedBits(
2460 Src, InDemandedBits, InDemandedElts, TLO.DAG, Depth + 1))
2461 return TLO.CombineTo(Op, TLO.DAG.getNode(Op.getOpcode(), dl, VT, NewSrc));
2462 break;
2463 }
2465 if (VT.isScalableVector())
2466 return false;
2467 [[fallthrough]];
2468 case ISD::SIGN_EXTEND: {
2469 SDValue Src = Op.getOperand(0);
2470 EVT SrcVT = Src.getValueType();
2471 unsigned InBits = SrcVT.getScalarSizeInBits();
2472 unsigned InElts = SrcVT.isFixedLengthVector() ? SrcVT.getVectorNumElements() : 1;
2473 bool IsVecInReg = Op.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG;
2474
2475 APInt InDemandedElts = DemandedElts.zext(InElts);
2476 APInt InDemandedBits = DemandedBits.trunc(InBits);
2477
2478 // Since some of the sign extended bits are demanded, we know that the sign
2479 // bit is demanded.
2480 InDemandedBits.setBit(InBits - 1);
2481
2482 // If none of the top bits are demanded, convert this into an any_extend.
2483 if (DemandedBits.getActiveBits() <= InBits) {
2484 // If we only need the non-extended bits of the bottom element
2485 // then we can just bitcast to the result.
2486 if (IsLE && IsVecInReg && DemandedElts == 1 &&
2487 VT.getSizeInBits() == SrcVT.getSizeInBits())
2488 return TLO.CombineTo(Op, TLO.DAG.getBitcast(VT, Src));
2489
2490 // Don't lose an all signbits 0/-1 splat on targets with 0/-1 booleans.
2492 TLO.DAG.ComputeNumSignBits(Src, InDemandedElts, Depth + 1) !=
2493 InBits) {
2494 unsigned Opc =
2496 if (!TLO.LegalOperations() || isOperationLegal(Opc, VT))
2497 return TLO.CombineTo(Op, TLO.DAG.getNode(Opc, dl, VT, Src));
2498 }
2499 }
2500
2501 if (SimplifyDemandedBits(Src, InDemandedBits, InDemandedElts, Known, TLO,
2502 Depth + 1))
2503 return true;
2504 assert(Known.getBitWidth() == InBits && "Src width has changed?");
2505
2506 // If the sign bit is known one, the top bits match.
2507 Known = Known.sext(BitWidth);
2508
2509 // If the sign bit is known zero, convert this to a zero extend.
2510 if (Known.isNonNegative()) {
2511 unsigned Opc =
2513 if (!TLO.LegalOperations() || isOperationLegal(Opc, VT)) {
2514 SDNodeFlags Flags;
2515 if (!IsVecInReg)
2516 Flags.setNonNeg(true);
2517 return TLO.CombineTo(Op, TLO.DAG.getNode(Opc, dl, VT, Src, Flags));
2518 }
2519 }
2520
2521 // Attempt to avoid multi-use ops if we don't need anything from them.
2522 if (SDValue NewSrc = SimplifyMultipleUseDemandedBits(
2523 Src, InDemandedBits, InDemandedElts, TLO.DAG, Depth + 1))
2524 return TLO.CombineTo(Op, TLO.DAG.getNode(Op.getOpcode(), dl, VT, NewSrc));
2525 break;
2526 }
2528 if (VT.isScalableVector())
2529 return false;
2530 [[fallthrough]];
2531 case ISD::ANY_EXTEND: {
2532 SDValue Src = Op.getOperand(0);
2533 EVT SrcVT = Src.getValueType();
2534 unsigned InBits = SrcVT.getScalarSizeInBits();
2535 unsigned InElts = SrcVT.isFixedLengthVector() ? SrcVT.getVectorNumElements() : 1;
2536 bool IsVecInReg = Op.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG;
2537
2538 // If we only need the bottom element then we can just bitcast.
2539 // TODO: Handle ANY_EXTEND?
2540 if (IsLE && IsVecInReg && DemandedElts == 1 &&
2541 VT.getSizeInBits() == SrcVT.getSizeInBits())
2542 return TLO.CombineTo(Op, TLO.DAG.getBitcast(VT, Src));
2543
2544 APInt InDemandedBits = DemandedBits.trunc(InBits);
2545 APInt InDemandedElts = DemandedElts.zext(InElts);
2546 if (SimplifyDemandedBits(Src, InDemandedBits, InDemandedElts, Known, TLO,
2547 Depth + 1))
2548 return true;
2549 assert(Known.getBitWidth() == InBits && "Src width has changed?");
2550 Known = Known.anyext(BitWidth);
2551
2552 // Attempt to avoid multi-use ops if we don't need anything from them.
2553 if (SDValue NewSrc = SimplifyMultipleUseDemandedBits(
2554 Src, InDemandedBits, InDemandedElts, TLO.DAG, Depth + 1))
2555 return TLO.CombineTo(Op, TLO.DAG.getNode(Op.getOpcode(), dl, VT, NewSrc));
2556 break;
2557 }
2558 case ISD::TRUNCATE: {
2559 SDValue Src = Op.getOperand(0);
2560
2561 // Simplify the input, using demanded bit information, and compute the known
2562 // zero/one bits live out.
2563 unsigned OperandBitWidth = Src.getScalarValueSizeInBits();
2564 APInt TruncMask = DemandedBits.zext(OperandBitWidth);
2565 if (SimplifyDemandedBits(Src, TruncMask, DemandedElts, Known, TLO,
2566 Depth + 1))
2567 return true;
2568 Known = Known.trunc(BitWidth);
2569
2570 // Attempt to avoid multi-use ops if we don't need anything from them.
2571 if (SDValue NewSrc = SimplifyMultipleUseDemandedBits(
2572 Src, TruncMask, DemandedElts, TLO.DAG, Depth + 1))
2573 return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::TRUNCATE, dl, VT, NewSrc));
2574
2575 // If the input is only used by this truncate, see if we can shrink it based
2576 // on the known demanded bits.
2577 switch (Src.getOpcode()) {
2578 default:
2579 break;
2580 case ISD::SRL:
2581 // Shrink SRL by a constant if none of the high bits shifted in are
2582 // demanded.
2583 if (TLO.LegalTypes() && !isTypeDesirableForOp(ISD::SRL, VT))
2584 // Do not turn (vt1 truncate (vt2 srl)) into (vt1 srl) if vt1 is
2585 // undesirable.
2586 break;
2587
2588 if (Src.getNode()->hasOneUse()) {
2589 std::optional<uint64_t> ShAmtC =
2590 TLO.DAG.getValidShiftAmount(Src, DemandedElts, Depth + 2);
2591 if (!ShAmtC || *ShAmtC >= BitWidth)
2592 break;
2593 uint64_t ShVal = *ShAmtC;
2594
2595 APInt HighBits =
2596 APInt::getHighBitsSet(OperandBitWidth, OperandBitWidth - BitWidth);
2597 HighBits.lshrInPlace(ShVal);
2598 HighBits = HighBits.trunc(BitWidth);
2599
2600 if (!(HighBits & DemandedBits)) {
2601 // None of the shifted in bits are needed. Add a truncate of the
2602 // shift input, then shift it.
2603 SDValue NewShAmt =
2604 TLO.DAG.getShiftAmountConstant(ShVal, VT, dl, TLO.LegalTypes());
2605 SDValue NewTrunc =
2606 TLO.DAG.getNode(ISD::TRUNCATE, dl, VT, Src.getOperand(0));
2607 return TLO.CombineTo(
2608 Op, TLO.DAG.getNode(ISD::SRL, dl, VT, NewTrunc, NewShAmt));
2609 }
2610 }
2611 break;
2612 }
2613
2614 break;
2615 }
2616 case ISD::AssertZext: {
2617 // AssertZext demands all of the high bits, plus any of the low bits
2618 // demanded by its users.
2619 EVT ZVT = cast<VTSDNode>(Op.getOperand(1))->getVT();
2621 if (SimplifyDemandedBits(Op.getOperand(0), ~InMask | DemandedBits, Known,
2622 TLO, Depth + 1))
2623 return true;
2624
2625 Known.Zero |= ~InMask;
2626 Known.One &= (~Known.Zero);
2627 break;
2628 }
2630 SDValue Src = Op.getOperand(0);
2631 SDValue Idx = Op.getOperand(1);
2632 ElementCount SrcEltCnt = Src.getValueType().getVectorElementCount();
2633 unsigned EltBitWidth = Src.getScalarValueSizeInBits();
2634
2635 if (SrcEltCnt.isScalable())
2636 return false;
2637
2638 // Demand the bits from every vector element without a constant index.
2639 unsigned NumSrcElts = SrcEltCnt.getFixedValue();
2640 APInt DemandedSrcElts = APInt::getAllOnes(NumSrcElts);
2641 if (auto *CIdx = dyn_cast<ConstantSDNode>(Idx))
2642 if (CIdx->getAPIntValue().ult(NumSrcElts))
2643 DemandedSrcElts = APInt::getOneBitSet(NumSrcElts, CIdx->getZExtValue());
2644
2645 // If BitWidth > EltBitWidth the value is anyext:ed. So we do not know
2646 // anything about the extended bits.
2647 APInt DemandedSrcBits = DemandedBits;
2648 if (BitWidth > EltBitWidth)
2649 DemandedSrcBits = DemandedSrcBits.trunc(EltBitWidth);
2650
2651 if (SimplifyDemandedBits(Src, DemandedSrcBits, DemandedSrcElts, Known2, TLO,
2652 Depth + 1))
2653 return true;
2654
2655 // Attempt to avoid multi-use ops if we don't need anything from them.
2656 if (!DemandedSrcBits.isAllOnes() || !DemandedSrcElts.isAllOnes()) {
2657 if (SDValue DemandedSrc = SimplifyMultipleUseDemandedBits(
2658 Src, DemandedSrcBits, DemandedSrcElts, TLO.DAG, Depth + 1)) {
2659 SDValue NewOp =
2660 TLO.DAG.getNode(Op.getOpcode(), dl, VT, DemandedSrc, Idx);
2661 return TLO.CombineTo(Op, NewOp);
2662 }
2663 }
2664
2665 Known = Known2;
2666 if (BitWidth > EltBitWidth)
2667 Known = Known.anyext(BitWidth);
2668 break;
2669 }
2670 case ISD::BITCAST: {
2671 if (VT.isScalableVector())
2672 return false;
2673 SDValue Src = Op.getOperand(0);
2674 EVT SrcVT = Src.getValueType();
2675 unsigned NumSrcEltBits = SrcVT.getScalarSizeInBits();
2676
2677 // If this is an FP->Int bitcast and if the sign bit is the only
2678 // thing demanded, turn this into a FGETSIGN.
2679 if (!TLO.LegalOperations() && !VT.isVector() && !SrcVT.isVector() &&
2680 DemandedBits == APInt::getSignMask(Op.getValueSizeInBits()) &&
2681 SrcVT.isFloatingPoint()) {
2682 bool OpVTLegal = isOperationLegalOrCustom(ISD::FGETSIGN, VT);
2683 bool i32Legal = isOperationLegalOrCustom(ISD::FGETSIGN, MVT::i32);
2684 if ((OpVTLegal || i32Legal) && VT.isSimple() && SrcVT != MVT::f16 &&
2685 SrcVT != MVT::f128) {
2686 // Cannot eliminate/lower SHL for f128 yet.
2687 EVT Ty = OpVTLegal ? VT : MVT::i32;
2688 // Make a FGETSIGN + SHL to move the sign bit into the appropriate
2689 // place. We expect the SHL to be eliminated by other optimizations.
2690 SDValue Sign = TLO.DAG.getNode(ISD::FGETSIGN, dl, Ty, Src);
2691 unsigned OpVTSizeInBits = Op.getValueSizeInBits();
2692 if (!OpVTLegal && OpVTSizeInBits > 32)
2693 Sign = TLO.DAG.getNode(ISD::ZERO_EXTEND, dl, VT, Sign);
2694 unsigned ShVal = Op.getValueSizeInBits() - 1;
2695 SDValue ShAmt = TLO.DAG.getConstant(ShVal, dl, VT);
2696 return TLO.CombineTo(Op,
2697 TLO.DAG.getNode(ISD::SHL, dl, VT, Sign, ShAmt));
2698 }
2699 }
2700
2701 // Bitcast from a vector using SimplifyDemanded Bits/VectorElts.
2702 // Demand the elt/bit if any of the original elts/bits are demanded.
2703 if (SrcVT.isVector() && (BitWidth % NumSrcEltBits) == 0) {
2704 unsigned Scale = BitWidth / NumSrcEltBits;
2705 unsigned NumSrcElts = SrcVT.getVectorNumElements();
2706 APInt DemandedSrcBits = APInt::getZero(NumSrcEltBits);
2707 APInt DemandedSrcElts = APInt::getZero(NumSrcElts);
2708 for (unsigned i = 0; i != Scale; ++i) {
2709 unsigned EltOffset = IsLE ? i : (Scale - 1 - i);
2710 unsigned BitOffset = EltOffset * NumSrcEltBits;
2711 APInt Sub = DemandedBits.extractBits(NumSrcEltBits, BitOffset);
2712 if (!Sub.isZero()) {
2713 DemandedSrcBits |= Sub;
2714 for (unsigned j = 0; j != NumElts; ++j)
2715 if (DemandedElts[j])
2716 DemandedSrcElts.setBit((j * Scale) + i);
2717 }
2718 }
2719
2720 APInt KnownSrcUndef, KnownSrcZero;
2721 if (SimplifyDemandedVectorElts(Src, DemandedSrcElts, KnownSrcUndef,
2722 KnownSrcZero, TLO, Depth + 1))
2723 return true;
2724
2725 KnownBits KnownSrcBits;
2726 if (SimplifyDemandedBits(Src, DemandedSrcBits, DemandedSrcElts,
2727 KnownSrcBits, TLO, Depth + 1))
2728 return true;
2729 } else if (IsLE && (NumSrcEltBits % BitWidth) == 0) {
2730 // TODO - bigendian once we have test coverage.
2731 unsigned Scale = NumSrcEltBits / BitWidth;
2732 unsigned NumSrcElts = SrcVT.isVector() ? SrcVT.getVectorNumElements() : 1;
2733 APInt DemandedSrcBits = APInt::getZero(NumSrcEltBits);
2734 APInt DemandedSrcElts = APInt::getZero(NumSrcElts);
2735 for (unsigned i = 0; i != NumElts; ++i)
2736 if (DemandedElts[i]) {
2737 unsigned Offset = (i % Scale) * BitWidth;
2738 DemandedSrcBits.insertBits(DemandedBits, Offset);
2739 DemandedSrcElts.setBit(i / Scale);
2740 }
2741
2742 if (SrcVT.isVector()) {
2743 APInt KnownSrcUndef, KnownSrcZero;
2744 if (SimplifyDemandedVectorElts(Src, DemandedSrcElts, KnownSrcUndef,
2745 KnownSrcZero, TLO, Depth + 1))
2746 return true;
2747 }
2748
2749 KnownBits KnownSrcBits;
2750 if (SimplifyDemandedBits(Src, DemandedSrcBits, DemandedSrcElts,
2751 KnownSrcBits, TLO, Depth + 1))
2752 return true;
2753
2754 // Attempt to avoid multi-use ops if we don't need anything from them.
2755 if (!DemandedSrcBits.isAllOnes() || !DemandedSrcElts.isAllOnes()) {
2756 if (SDValue DemandedSrc = SimplifyMultipleUseDemandedBits(
2757 Src, DemandedSrcBits, DemandedSrcElts, TLO.DAG, Depth + 1)) {
2758 SDValue NewOp = TLO.DAG.getBitcast(VT, DemandedSrc);
2759 return TLO.CombineTo(Op, NewOp);
2760 }
2761 }
2762 }
2763
2764 // If this is a bitcast, let computeKnownBits handle it. Only do this on a
2765 // recursive call where Known may be useful to the caller.
2766 if (Depth > 0) {
2767 Known = TLO.DAG.computeKnownBits(Op, DemandedElts, Depth);
2768 return false;
2769 }
2770 break;
2771 }
2772 case ISD::MUL:
2773 if (DemandedBits.isPowerOf2()) {
2774 // The LSB of X*Y is set only if (X & 1) == 1 and (Y & 1) == 1.
2775 // If we demand exactly one bit N and we have "X * (C' << N)" where C' is
2776 // odd (has LSB set), then the left-shifted low bit of X is the answer.
2777 unsigned CTZ = DemandedBits.countr_zero();
2778 ConstantSDNode *C = isConstOrConstSplat(Op.getOperand(1), DemandedElts);
2779 if (C && C->getAPIntValue().countr_zero() == CTZ) {
2780 SDValue AmtC = TLO.DAG.getShiftAmountConstant(CTZ, VT, dl);
2781 SDValue Shl = TLO.DAG.getNode(ISD::SHL, dl, VT, Op.getOperand(0), AmtC);
2782 return TLO.CombineTo(Op, Shl);
2783 }
2784 }
2785 // For a squared value "X * X", the bottom 2 bits are 0 and X[0] because:
2786 // X * X is odd iff X is odd.
2787 // 'Quadratic Reciprocity': X * X -> 0 for bit[1]
2788 if (Op.getOperand(0) == Op.getOperand(1) && DemandedBits.ult(4)) {
2789 SDValue One = TLO.DAG.getConstant(1, dl, VT);
2790 SDValue And1 = TLO.DAG.getNode(ISD::AND, dl, VT, Op.getOperand(0), One);
2791 return TLO.CombineTo(Op, And1);
2792 }
2793 [[fallthrough]];
2794 case ISD::ADD:
2795 case ISD::SUB: {
2796 // Add, Sub, and Mul don't demand any bits in positions beyond that
2797 // of the highest bit demanded of them.
2798 SDValue Op0 = Op.getOperand(0), Op1 = Op.getOperand(1);
2799 SDNodeFlags Flags = Op.getNode()->getFlags();
2800 unsigned DemandedBitsLZ = DemandedBits.countl_zero();
2801 APInt LoMask = APInt::getLowBitsSet(BitWidth, BitWidth - DemandedBitsLZ);
2802 KnownBits KnownOp0, KnownOp1;
2803 auto GetDemandedBitsLHSMask = [&](APInt Demanded,
2804 const KnownBits &KnownRHS) {
2805 if (Op.getOpcode() == ISD::MUL)
2806 Demanded.clearHighBits(KnownRHS.countMinTrailingZeros());
2807 return Demanded;
2808 };
2809 if (SimplifyDemandedBits(Op1, LoMask, DemandedElts, KnownOp1, TLO,
2810 Depth + 1) ||
2811 SimplifyDemandedBits(Op0, GetDemandedBitsLHSMask(LoMask, KnownOp1),
2812 DemandedElts, KnownOp0, TLO, Depth + 1) ||
2813 // See if the operation should be performed at a smaller bit width.
2814 ShrinkDemandedOp(Op, BitWidth, DemandedBits, TLO)) {
2815 if (Flags.hasNoSignedWrap() || Flags.hasNoUnsignedWrap()) {
2816 // Disable the nsw and nuw flags. We can no longer guarantee that we
2817 // won't wrap after simplification.
2818 Flags.setNoSignedWrap(false);
2819 Flags.setNoUnsignedWrap(false);
2820 Op->setFlags(Flags);
2821 }
2822 return true;
2823 }
2824
2825 // neg x with only low bit demanded is simply x.
2826 if (Op.getOpcode() == ISD::SUB && DemandedBits.isOne() &&
2827 isNullConstant(Op0))
2828 return TLO.CombineTo(Op, Op1);
2829
2830 // Attempt to avoid multi-use ops if we don't need anything from them.
2831 if (!LoMask.isAllOnes() || !DemandedElts.isAllOnes()) {
2832 SDValue DemandedOp0 = SimplifyMultipleUseDemandedBits(
2833 Op0, LoMask, DemandedElts, TLO.DAG, Depth + 1);
2834 SDValue DemandedOp1 = SimplifyMultipleUseDemandedBits(
2835 Op1, LoMask, DemandedElts, TLO.DAG, Depth + 1);
2836 if (DemandedOp0 || DemandedOp1) {
2837 Flags.setNoSignedWrap(false);
2838 Flags.setNoUnsignedWrap(false);
2839 Op0 = DemandedOp0 ? DemandedOp0 : Op0;
2840 Op1 = DemandedOp1 ? DemandedOp1 : Op1;
2841 SDValue NewOp =
2842 TLO.DAG.getNode(Op.getOpcode(), dl, VT, Op0, Op1, Flags);
2843 return TLO.CombineTo(Op, NewOp);
2844 }
2845 }
2846
2847 // If we have a constant operand, we may be able to turn it into -1 if we
2848 // do not demand the high bits. This can make the constant smaller to
2849 // encode, allow more general folding, or match specialized instruction
2850 // patterns (eg, 'blsr' on x86). Don't bother changing 1 to -1 because that
2851 // is probably not useful (and could be detrimental).
2853 APInt HighMask = APInt::getHighBitsSet(BitWidth, DemandedBitsLZ);
2854 if (C && !C->isAllOnes() && !C->isOne() &&
2855 (C->getAPIntValue() | HighMask).isAllOnes()) {
2856 SDValue Neg1 = TLO.DAG.getAllOnesConstant(dl, VT);
2857 // Disable the nsw and nuw flags. We can no longer guarantee that we
2858 // won't wrap after simplification.
2859 Flags.setNoSignedWrap(false);
2860 Flags.setNoUnsignedWrap(false);
2861 SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), dl, VT, Op0, Neg1, Flags);
2862 return TLO.CombineTo(Op, NewOp);
2863 }
2864
2865 // Match a multiply with a disguised negated-power-of-2 and convert to a
2866 // an equivalent shift-left amount.
2867 // Example: (X * MulC) + Op1 --> Op1 - (X << log2(-MulC))
2868 auto getShiftLeftAmt = [&HighMask](SDValue Mul) -> unsigned {
2869 if (Mul.getOpcode() != ISD::MUL || !Mul.hasOneUse())
2870 return 0;
2871
2872 // Don't touch opaque constants. Also, ignore zero and power-of-2
2873 // multiplies. Those will get folded later.
2874 ConstantSDNode *MulC = isConstOrConstSplat(Mul.getOperand(1));
2875 if (MulC && !MulC->isOpaque() && !MulC->isZero() &&
2876 !MulC->getAPIntValue().isPowerOf2()) {
2877 APInt UnmaskedC = MulC->getAPIntValue() | HighMask;
2878 if (UnmaskedC.isNegatedPowerOf2())
2879 return (-UnmaskedC).logBase2();
2880 }
2881 return 0;
2882 };
2883
2884 auto foldMul = [&](ISD::NodeType NT, SDValue X, SDValue Y,
2885 unsigned ShlAmt) {
2886 SDValue ShlAmtC = TLO.DAG.getShiftAmountConstant(ShlAmt, VT, dl);
2887 SDValue Shl = TLO.DAG.getNode(ISD::SHL, dl, VT, X, ShlAmtC);
2888 SDValue Res = TLO.DAG.getNode(NT, dl, VT, Y, Shl);
2889 return TLO.CombineTo(Op, Res);
2890 };
2891
2893 if (Op.getOpcode() == ISD::ADD) {
2894 // (X * MulC) + Op1 --> Op1 - (X << log2(-MulC))
2895 if (unsigned ShAmt = getShiftLeftAmt(Op0))
2896 return foldMul(ISD::SUB, Op0.getOperand(0), Op1, ShAmt);
2897 // Op0 + (X * MulC) --> Op0 - (X << log2(-MulC))
2898 if (unsigned ShAmt = getShiftLeftAmt(Op1))
2899 return foldMul(ISD::SUB, Op1.getOperand(0), Op0, ShAmt);
2900 }
2901 if (Op.getOpcode() == ISD::SUB) {
2902 // Op0 - (X * MulC) --> Op0 + (X << log2(-MulC))
2903 if (unsigned ShAmt = getShiftLeftAmt(Op1))
2904 return foldMul(ISD::ADD, Op1.getOperand(0), Op0, ShAmt);
2905 }
2906 }
2907
2908 if (Op.getOpcode() == ISD::MUL) {
2909 Known = KnownBits::mul(KnownOp0, KnownOp1);
2910 } else { // Op.getOpcode() is either ISD::ADD or ISD::SUB.
2912 Op.getOpcode() == ISD::ADD, Flags.hasNoSignedWrap(),
2913 Flags.hasNoUnsignedWrap(), KnownOp0, KnownOp1);
2914 }
2915 break;
2916 }
2917 default:
2918 // We also ask the target about intrinsics (which could be specific to it).
2919 if (Op.getOpcode() >= ISD::BUILTIN_OP_END ||
2920 Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN) {
2921 // TODO: Probably okay to remove after audit; here to reduce change size
2922 // in initial enablement patch for scalable vectors
2923 if (Op.getValueType().isScalableVector())
2924 break;
2925 if (SimplifyDemandedBitsForTargetNode(Op, DemandedBits, DemandedElts,
2926 Known, TLO, Depth))
2927 return true;
2928 break;
2929 }
2930
2931 // Just use computeKnownBits to compute output bits.
2932 Known = TLO.DAG.computeKnownBits(Op, DemandedElts, Depth);
2933 break;
2934 }
2935
2936 // If we know the value of all of the demanded bits, return this as a
2937 // constant.
2938 if (!isTargetCanonicalConstantNode(Op) &&
2939 DemandedBits.isSubsetOf(Known.Zero | Known.One)) {
2940 // Avoid folding to a constant if any OpaqueConstant is involved.
2941 const SDNode *N = Op.getNode();
2942 for (SDNode *Op :
2944 if (auto *C = dyn_cast<ConstantSDNode>(Op))
2945 if (C->isOpaque())
2946 return false;
2947 }
2948 if (VT.isInteger())
2949 return TLO.CombineTo(Op, TLO.DAG.getConstant(Known.One, dl, VT));
2950 if (VT.isFloatingPoint())
2951 return TLO.CombineTo(
2952 Op,
2953 TLO.DAG.getConstantFP(
2954 APFloat(TLO.DAG.EVTToAPFloatSemantics(VT), Known.One), dl, VT));
2955 }
2956
2957 // A multi use 'all demanded elts' simplify failed to find any knownbits.
2958 // Try again just for the original demanded elts.
2959 // Ensure we do this AFTER constant folding above.
2960 if (HasMultiUse && Known.isUnknown() && !OriginalDemandedElts.isAllOnes())
2961 Known = TLO.DAG.computeKnownBits(Op, OriginalDemandedElts, Depth);
2962
2963 return false;
2964}
2965
2967 const APInt &DemandedElts,
2968 DAGCombinerInfo &DCI) const {
2969 SelectionDAG &DAG = DCI.DAG;
2970 TargetLoweringOpt TLO(DAG, !DCI.isBeforeLegalize(),
2971 !DCI.isBeforeLegalizeOps());
2972
2973 APInt KnownUndef, KnownZero;
2974 bool Simplified =
2975 SimplifyDemandedVectorElts(Op, DemandedElts, KnownUndef, KnownZero, TLO);
2976 if (Simplified) {
2977 DCI.AddToWorklist(Op.getNode());
2978 DCI.CommitTargetLoweringOpt(TLO);
2979 }
2980
2981 return Simplified;
2982}
2983
2984/// Given a vector binary operation and known undefined elements for each input
2985/// operand, compute whether each element of the output is undefined.
2987 const APInt &UndefOp0,
2988 const APInt &UndefOp1) {
2989 EVT VT = BO.getValueType();
2991 "Vector binop only");
2992
2993 EVT EltVT = VT.getVectorElementType();
2994 unsigned NumElts = VT.isFixedLengthVector() ? VT.getVectorNumElements() : 1;
2995 assert(UndefOp0.getBitWidth() == NumElts &&
2996 UndefOp1.getBitWidth() == NumElts && "Bad type for undef analysis");
2997
2998 auto getUndefOrConstantElt = [&](SDValue V, unsigned Index,
2999 const APInt &UndefVals) {
3000 if (UndefVals[Index])
3001 return DAG.getUNDEF(EltVT);
3002
3003 if (auto *BV = dyn_cast<BuildVectorSDNode>(V)) {
3004 // Try hard to make sure that the getNode() call is not creating temporary
3005 // nodes. Ignore opaque integers because they do not constant fold.
3006 SDValue Elt = BV->getOperand(Index);
3007 auto *C = dyn_cast<ConstantSDNode>(Elt);
3008 if (isa<ConstantFPSDNode>(Elt) || Elt.isUndef() || (C && !C->isOpaque()))
3009 return Elt;
3010 }
3011
3012 return SDValue();
3013 };
3014
3015 APInt KnownUndef = APInt::getZero(NumElts);
3016 for (unsigned i = 0; i != NumElts; ++i) {
3017 // If both inputs for this element are either constant or undef and match
3018 // the element type, compute the constant/undef result for this element of
3019 // the vector.
3020 // TODO: Ideally we would use FoldConstantArithmetic() here, but that does
3021 // not handle FP constants. The code within getNode() should be refactored
3022 // to avoid the danger of creating a bogus temporary node here.
3023 SDValue C0 = getUndefOrConstantElt(BO.getOperand(0), i, UndefOp0);
3024 SDValue C1 = getUndefOrConstantElt(BO.getOperand(1), i, UndefOp1);
3025 if (C0 && C1 && C0.getValueType() == EltVT && C1.getValueType() == EltVT)
3026 if (DAG.getNode(BO.getOpcode(), SDLoc(BO), EltVT, C0, C1).isUndef())
3027 KnownUndef.setBit(i);
3028 }
3029 return KnownUndef;
3030}
3031
3033 SDValue Op, const APInt &OriginalDemandedElts, APInt &KnownUndef,
3034 APInt &KnownZero, TargetLoweringOpt &TLO, unsigned Depth,
3035 bool AssumeSingleUse) const {
3036 EVT VT = Op.getValueType();
3037 unsigned Opcode = Op.getOpcode();
3038 APInt DemandedElts = OriginalDemandedElts;
3039 unsigned NumElts = DemandedElts.getBitWidth();
3040 assert(VT.isVector() && "Expected vector op");
3041
3042 KnownUndef = KnownZero = APInt::getZero(NumElts);
3043
3044 const TargetLowering &TLI = TLO.DAG.getTargetLoweringInfo();
3045 if (!TLI.shouldSimplifyDemandedVectorElts(Op, TLO))
3046 return false;
3047
3048 // TODO: For now we assume we know nothing about scalable vectors.
3049 if (VT.isScalableVector())
3050 return false;
3051
3052 assert(VT.getVectorNumElements() == NumElts &&
3053 "Mask size mismatches value type element count!");
3054
3055 // Undef operand.
3056 if (Op.isUndef()) {
3057 KnownUndef.setAllBits();
3058 return false;
3059 }
3060
3061 // If Op has other users, assume that all elements are needed.
3062 if (!AssumeSingleUse && !Op.getNode()->hasOneUse())
3063 DemandedElts.setAllBits();
3064
3065 // Not demanding any elements from Op.
3066 if (DemandedElts == 0) {
3067 KnownUndef.setAllBits();
3068 return TLO.CombineTo(Op, TLO.DAG.getUNDEF(VT));
3069 }
3070
3071 // Limit search depth.
3073 return false;
3074
3075 SDLoc DL(Op);
3076 unsigned EltSizeInBits = VT.getScalarSizeInBits();
3077 bool IsLE = TLO.DAG.getDataLayout().isLittleEndian();
3078
3079 // Helper for demanding the specified elements and all the bits of both binary
3080 // operands.
3081 auto SimplifyDemandedVectorEltsBinOp = [&](SDValue Op0, SDValue Op1) {
3082 SDValue NewOp0 = SimplifyMultipleUseDemandedVectorElts(Op0, DemandedElts,
3083 TLO.DAG, Depth + 1);
3084 SDValue NewOp1 = SimplifyMultipleUseDemandedVectorElts(Op1, DemandedElts,
3085 TLO.DAG, Depth + 1);
3086 if (NewOp0 || NewOp1) {
3087 SDValue NewOp =
3088 TLO.DAG.getNode(Opcode, SDLoc(Op), VT, NewOp0 ? NewOp0 : Op0,
3089 NewOp1 ? NewOp1 : Op1, Op->getFlags());
3090 return TLO.CombineTo(Op, NewOp);
3091 }
3092 return false;
3093 };
3094
3095 switch (Opcode) {
3096 case ISD::SCALAR_TO_VECTOR: {
3097 if (!DemandedElts[0]) {
3098 KnownUndef.setAllBits();
3099 return TLO.CombineTo(Op, TLO.DAG.getUNDEF(VT));
3100 }
3101 SDValue ScalarSrc = Op.getOperand(0);
3102 if (ScalarSrc.getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
3103 SDValue Src = ScalarSrc.getOperand(0);
3104 SDValue Idx = ScalarSrc.getOperand(1);
3105 EVT SrcVT = Src.getValueType();
3106
3107 ElementCount SrcEltCnt = SrcVT.getVectorElementCount();
3108
3109 if (SrcEltCnt.isScalable())
3110 return false;
3111
3112 unsigned NumSrcElts = SrcEltCnt.getFixedValue();
3113 if (isNullConstant(Idx)) {
3114 APInt SrcDemandedElts = APInt::getOneBitSet(NumSrcElts, 0);
3115 APInt SrcUndef = KnownUndef.zextOrTrunc(NumSrcElts);
3116 APInt SrcZero = KnownZero.zextOrTrunc(NumSrcElts);
3117 if (SimplifyDemandedVectorElts(Src, SrcDemandedElts, SrcUndef, SrcZero,
3118 TLO, Depth + 1))
3119 return true;
3120 }
3121 }
3122 KnownUndef.setHighBits(NumElts - 1);
3123 break;
3124 }
3125 case ISD::BITCAST: {
3126 SDValue Src = Op.getOperand(0);
3127 EVT SrcVT = Src.getValueType();
3128
3129 // We only handle vectors here.
3130 // TODO - investigate calling SimplifyDemandedBits/ComputeKnownBits?
3131 if (!SrcVT.isVector())
3132 break;
3133
3134 // Fast handling of 'identity' bitcasts.
3135 unsigned NumSrcElts = SrcVT.getVectorNumElements();
3136 if (NumSrcElts == NumElts)
3137 return SimplifyDemandedVectorElts(Src, DemandedElts, KnownUndef,
3138 KnownZero, TLO, Depth + 1);
3139
3140 APInt SrcDemandedElts, SrcZero, SrcUndef;
3141
3142 // Bitcast from 'large element' src vector to 'small element' vector, we
3143 // must demand a source element if any DemandedElt maps to it.
3144 if ((NumElts % NumSrcElts) == 0) {
3145 unsigned Scale = NumElts / NumSrcElts;
3146 SrcDemandedElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
3147 if (SimplifyDemandedVectorElts(Src, SrcDemandedElts, SrcUndef, SrcZero,
3148 TLO, Depth + 1))
3149 return true;
3150
3151 // Try calling SimplifyDemandedBits, converting demanded elts to the bits
3152 // of the large element.
3153 // TODO - bigendian once we have test coverage.
3154 if (IsLE) {
3155 unsigned SrcEltSizeInBits = SrcVT.getScalarSizeInBits();
3156 APInt SrcDemandedBits = APInt::getZero(SrcEltSizeInBits);
3157 for (unsigned i = 0; i != NumElts; ++i)
3158 if (DemandedElts[i]) {
3159 unsigned Ofs = (i % Scale) * EltSizeInBits;
3160 SrcDemandedBits.setBits(Ofs, Ofs + EltSizeInBits);
3161 }
3162
3163 KnownBits Known;
3164 if (SimplifyDemandedBits(Src, SrcDemandedBits, SrcDemandedElts, Known,
3165 TLO, Depth + 1))
3166 return true;
3167
3168 // The bitcast has split each wide element into a number of
3169 // narrow subelements. We have just computed the Known bits
3170 // for wide elements. See if element splitting results in
3171 // some subelements being zero. Only for demanded elements!
3172 for (unsigned SubElt = 0; SubElt != Scale; ++SubElt) {
3173 if (!Known.Zero.extractBits(EltSizeInBits, SubElt * EltSizeInBits)
3174 .isAllOnes())
3175 continue;
3176 for (unsigned SrcElt = 0; SrcElt != NumSrcElts; ++SrcElt) {
3177 unsigned Elt = Scale * SrcElt + SubElt;
3178 if (DemandedElts[Elt])
3179 KnownZero.setBit(Elt);
3180 }
3181 }
3182 }
3183
3184 // If the src element is zero/undef then all the output elements will be -
3185 // only demanded elements are guaranteed to be correct.
3186 for (unsigned i = 0; i != NumSrcElts; ++i) {
3187 if (SrcDemandedElts[i]) {
3188 if (SrcZero[i])
3189 KnownZero.setBits(i * Scale, (i + 1) * Scale);
3190 if (SrcUndef[i])
3191 KnownUndef.setBits(i * Scale, (i + 1) * Scale);
3192 }
3193 }
3194 }
3195
3196 // Bitcast from 'small element' src vector to 'large element' vector, we
3197 // demand all smaller source elements covered by the larger demanded element
3198 // of this vector.
3199 if ((NumSrcElts % NumElts) == 0) {
3200 unsigned Scale = NumSrcElts / NumElts;
3201 SrcDemandedElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
3202 if (SimplifyDemandedVectorElts(Src, SrcDemandedElts, SrcUndef, SrcZero,
3203 TLO, Depth + 1))
3204 return true;
3205
3206 // If all the src elements covering an output element are zero/undef, then
3207 // the output element will be as well, assuming it was demanded.
3208 for (unsigned i = 0; i != NumElts; ++i) {
3209 if (DemandedElts[i]) {
3210 if (SrcZero.extractBits(Scale, i * Scale).isAllOnes())
3211 KnownZero.setBit(i);
3212 if (SrcUndef.extractBits(Scale, i * Scale).isAllOnes())
3213 KnownUndef.setBit(i);
3214 }
3215 }
3216 }
3217 break;
3218 }
3219 case ISD::FREEZE: {
3220 SDValue N0 = Op.getOperand(0);
3221 if (TLO.DAG.isGuaranteedNotToBeUndefOrPoison(N0, DemandedElts,
3222 /*PoisonOnly=*/false))
3223 return TLO.CombineTo(Op, N0);
3224
3225 // TODO: Replace this with the general fold from DAGCombiner::visitFREEZE
3226 // freeze(op(x, ...)) -> op(freeze(x), ...).
3227 if (N0.getOpcode() == ISD::SCALAR_TO_VECTOR && DemandedElts == 1)
3228 return TLO.CombineTo(
3230 TLO.DAG.getFreeze(N0.getOperand(0))));
3231 break;
3232 }
3233 case ISD::BUILD_VECTOR: {
3234 // Check all elements and simplify any unused elements with UNDEF.
3235 if (!DemandedElts.isAllOnes()) {
3236 // Don't simplify BROADCASTS.
3237 if (llvm::any_of(Op->op_values(),
3238 [&](SDValue Elt) { return Op.getOperand(0) != Elt; })) {
3239 SmallVector<SDValue, 32> Ops(Op->op_begin(), Op->op_end());
3240 bool Updated = false;
3241 for (unsigned i = 0; i != NumElts; ++i) {
3242 if (!DemandedElts[i] && !Ops[i].isUndef()) {
3243 Ops[i] = TLO.DAG.getUNDEF(Ops[0].getValueType());
3244 KnownUndef.setBit(i);
3245 Updated = true;
3246 }
3247 }
3248 if (Updated)
3249 return TLO.CombineTo(Op, TLO.DAG.getBuildVector(VT, DL, Ops));
3250 }
3251 }
3252 for (unsigned i = 0; i != NumElts; ++i) {
3253 SDValue SrcOp = Op.getOperand(i);
3254 if (SrcOp.isUndef()) {
3255 KnownUndef.setBit(i);
3256 } else if (EltSizeInBits == SrcOp.getScalarValueSizeInBits() &&
3258 KnownZero.setBit(i);
3259 }
3260 }
3261 break;
3262 }
3263 case ISD::CONCAT_VECTORS: {
3264 EVT SubVT = Op.getOperand(0).getValueType();
3265 unsigned NumSubVecs = Op.getNumOperands();
3266 unsigned NumSubElts = SubVT.getVectorNumElements();
3267 for (unsigned i = 0; i != NumSubVecs; ++i) {
3268 SDValue SubOp = Op.getOperand(i);
3269 APInt SubElts = DemandedElts.extractBits(NumSubElts, i * NumSubElts);
3270 APInt SubUndef, SubZero;
3271 if (SimplifyDemandedVectorElts(SubOp, SubElts, SubUndef, SubZero, TLO,
3272 Depth + 1))
3273 return true;
3274 KnownUndef.insertBits(SubUndef, i * NumSubElts);
3275 KnownZero.insertBits(SubZero, i * NumSubElts);
3276 }
3277
3278 // Attempt to avoid multi-use ops if we don't need anything from them.
3279 if (!DemandedElts.isAllOnes()) {
3280 bool FoundNewSub = false;
3281 SmallVector<SDValue, 2> DemandedSubOps;
3282 for (unsigned i = 0; i != NumSubVecs; ++i) {
3283 SDValue SubOp = Op.getOperand(i);
3284 APInt SubElts = DemandedElts.extractBits(NumSubElts, i * NumSubElts);
3285 SDValue NewSubOp = SimplifyMultipleUseDemandedVectorElts(
3286 SubOp, SubElts, TLO.DAG, Depth + 1);
3287 DemandedSubOps.push_back(NewSubOp ? NewSubOp : SubOp);
3288 FoundNewSub = NewSubOp ? true : FoundNewSub;
3289 }
3290 if (FoundNewSub) {
3291 SDValue NewOp =
3292 TLO.DAG.getNode(Op.getOpcode(), SDLoc(Op), VT, DemandedSubOps);
3293 return TLO.CombineTo(Op, NewOp);
3294 }
3295 }
3296 break;
3297 }
3298 case ISD::INSERT_SUBVECTOR: {
3299 // Demand any elements from the subvector and the remainder from the src its
3300 // inserted into.
3301 SDValue Src = Op.getOperand(0);
3302 SDValue Sub = Op.getOperand(1);
3303 uint64_t Idx = Op.getConstantOperandVal(2);
3304 unsigned NumSubElts = Sub.getValueType().getVectorNumElements();
3305 APInt DemandedSubElts = DemandedElts.extractBits(NumSubElts, Idx);
3306 APInt DemandedSrcElts = DemandedElts;
3307 DemandedSrcElts.insertBits(APInt::getZero(NumSubElts), Idx);
3308
3309 APInt SubUndef, SubZero;
3310 if (SimplifyDemandedVectorElts(Sub, DemandedSubElts, SubUndef, SubZero, TLO,
3311 Depth + 1))
3312 return true;
3313
3314 // If none of the src operand elements are demanded, replace it with undef.
3315 if (!DemandedSrcElts && !Src.isUndef())
3316 return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT,
3317 TLO.DAG.getUNDEF(VT), Sub,
3318 Op.getOperand(2)));
3319
3320 if (SimplifyDemandedVectorElts(Src, DemandedSrcElts, KnownUndef, KnownZero,
3321 TLO, Depth + 1))
3322 return true;
3323 KnownUndef.insertBits(SubUndef, Idx);
3324 KnownZero.insertBits(SubZero, Idx);
3325
3326 // Attempt to avoid multi-use ops if we don't need anything from them.
3327 if (!DemandedSrcElts.isAllOnes() || !DemandedSubElts.isAllOnes()) {
3328 SDValue NewSrc = SimplifyMultipleUseDemandedVectorElts(
3329 Src, DemandedSrcElts, TLO.DAG, Depth + 1);
3330 SDValue NewSub = SimplifyMultipleUseDemandedVectorElts(
3331 Sub, DemandedSubElts, TLO.DAG, Depth + 1);
3332 if (NewSrc || NewSub) {
3333 NewSrc = NewSrc ? NewSrc : Src;
3334 NewSub = NewSub ? NewSub : Sub;
3335 SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), SDLoc(Op), VT, NewSrc,
3336 NewSub, Op.getOperand(2));
3337 return TLO.CombineTo(Op, NewOp);
3338 }
3339 }
3340 break;
3341 }
3343 // Offset the demanded elts by the subvector index.
3344 SDValue Src = Op.getOperand(0);
3345 if (Src.getValueType().isScalableVector())
3346 break;
3347 uint64_t Idx = Op.getConstantOperandVal(1);
3348 unsigned NumSrcElts = Src.getValueType().getVectorNumElements();
3349 APInt DemandedSrcElts = DemandedElts.zext(NumSrcElts).shl(Idx);
3350
3351 APInt SrcUndef, SrcZero;
3352 if (SimplifyDemandedVectorElts(Src, DemandedSrcElts, SrcUndef, SrcZero, TLO,
3353 Depth + 1))
3354 return true;
3355 KnownUndef = SrcUndef.extractBits(NumElts, Idx);
3356 KnownZero = SrcZero.extractBits(NumElts, Idx);
3357
3358 // Attempt to avoid multi-use ops if we don't need anything from them.
3359 if (!DemandedElts.isAllOnes()) {
3360 SDValue NewSrc = SimplifyMultipleUseDemandedVectorElts(
3361 Src, DemandedSrcElts, TLO.DAG, Depth + 1);
3362 if (NewSrc) {
3363 SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), SDLoc(Op), VT, NewSrc,
3364 Op.getOperand(1));
3365 return TLO.CombineTo(Op, NewOp);
3366 }
3367 }
3368 break;
3369 }
3371 SDValue Vec = Op.getOperand(0);
3372 SDValue Scl = Op.getOperand(1);
3373 auto *CIdx = dyn_cast<ConstantSDNode>(Op.getOperand(2));
3374
3375 // For a legal, constant insertion index, if we don't need this insertion
3376 // then strip it, else remove it from the demanded elts.
3377 if (CIdx && CIdx->getAPIntValue().ult(NumElts)) {
3378 unsigned Idx = CIdx->getZExtValue();
3379 if (!DemandedElts[Idx])
3380 return TLO.CombineTo(Op, Vec);
3381
3382 APInt DemandedVecElts(DemandedElts);
3383 DemandedVecElts.clearBit(Idx);
3384 if (SimplifyDemandedVectorElts(Vec, DemandedVecElts, KnownUndef,
3385 KnownZero, TLO, Depth + 1))
3386 return true;
3387
3388 KnownUndef.setBitVal(Idx, Scl.isUndef());
3389
3390 KnownZero.setBitVal(Idx, isNullConstant(Scl) || isNullFPConstant(Scl));
3391 break;
3392 }
3393
3394 APInt VecUndef, VecZero;
3395 if (SimplifyDemandedVectorElts(Vec, DemandedElts, VecUndef, VecZero, TLO,
3396 Depth + 1))
3397 return true;
3398 // Without knowing the insertion index we can't set KnownUndef/KnownZero.
3399 break;
3400 }
3401 case ISD::VSELECT: {
3402 SDValue Sel = Op.getOperand(0);
3403 SDValue LHS = Op.getOperand(1);
3404 SDValue RHS = Op.getOperand(2);
3405
3406 // Try to transform the select condition based on the current demanded
3407 // elements.
3408 APInt UndefSel, ZeroSel;
3409 if (SimplifyDemandedVectorElts(Sel, DemandedElts, UndefSel, ZeroSel, TLO,
3410 Depth + 1))
3411 return true;
3412
3413 // See if we can simplify either vselect operand.
3414 APInt DemandedLHS(DemandedElts);
3415 APInt DemandedRHS(DemandedElts);
3416 APInt UndefLHS, ZeroLHS;
3417 APInt UndefRHS, ZeroRHS;
3418 if (SimplifyDemandedVectorElts(LHS, DemandedLHS, UndefLHS, ZeroLHS, TLO,
3419 Depth + 1))
3420 return true;
3421 if (SimplifyDemandedVectorElts(RHS, DemandedRHS, UndefRHS, ZeroRHS, TLO,
3422 Depth + 1))
3423 return true;
3424
3425 KnownUndef = UndefLHS & UndefRHS;
3426 KnownZero = ZeroLHS & ZeroRHS;
3427
3428 // If we know that the selected element is always zero, we don't need the
3429 // select value element.
3430 APInt DemandedSel = DemandedElts & ~KnownZero;
3431 if (DemandedSel != DemandedElts)
3432 if (SimplifyDemandedVectorElts(Sel, DemandedSel, UndefSel, ZeroSel, TLO,
3433 Depth + 1))
3434 return true;
3435
3436 break;
3437 }
3438 case ISD::VECTOR_SHUFFLE: {
3439 SDValue LHS = Op.getOperand(0);
3440 SDValue RHS = Op.getOperand(1);
3441 ArrayRef<int> ShuffleMask = cast<ShuffleVectorSDNode>(Op)->getMask();
3442
3443 // Collect demanded elements from shuffle operands..
3444 APInt DemandedLHS(NumElts, 0);
3445 APInt DemandedRHS(NumElts, 0);
3446 for (unsigned i = 0; i != NumElts; ++i) {
3447 int M = ShuffleMask[i];
3448 if (M < 0 || !DemandedElts[i])
3449 continue;
3450 assert(0 <= M && M < (int)(2 * NumElts) && "Shuffle index out of range");
3451 if (M < (int)NumElts)
3452 DemandedLHS.setBit(M);
3453 else
3454 DemandedRHS.setBit(M - NumElts);
3455 }
3456
3457 // See if we can simplify either shuffle operand.
3458 APInt UndefLHS, ZeroLHS;
3459 APInt UndefRHS, ZeroRHS;
3460 if (SimplifyDemandedVectorElts(LHS, DemandedLHS, UndefLHS, ZeroLHS, TLO,
3461 Depth + 1))
3462 return true;
3463 if (SimplifyDemandedVectorElts(RHS, DemandedRHS, UndefRHS, ZeroRHS, TLO,
3464 Depth + 1))
3465 return true;
3466
3467 // Simplify mask using undef elements from LHS/RHS.
3468 bool Updated = false;
3469 bool IdentityLHS = true, IdentityRHS = true;
3470 SmallVector<int, 32> NewMask(ShuffleMask);
3471 for (unsigned i = 0; i != NumElts; ++i) {
3472 int &M = NewMask[i];
3473 if (M < 0)
3474 continue;
3475 if (!DemandedElts[i] || (M < (int)NumElts && UndefLHS[M]) ||
3476 (M >= (int)NumElts && UndefRHS[M - NumElts])) {
3477 Updated = true;
3478 M = -1;
3479 }
3480 IdentityLHS &= (M < 0) || (M == (int)i);
3481 IdentityRHS &= (M < 0) || ((M - NumElts) == i);
3482 }
3483
3484 // Update legal shuffle masks based on demanded elements if it won't reduce
3485 // to Identity which can cause premature removal of the shuffle mask.
3486 if (Updated && !IdentityLHS && !IdentityRHS && !TLO.LegalOps) {
3487 SDValue LegalShuffle =
3488 buildLegalVectorShuffle(VT, DL, LHS, RHS, NewMask, TLO.DAG);
3489 if (LegalShuffle)
3490 return TLO.CombineTo(Op, LegalShuffle);
3491 }
3492
3493 // Propagate undef/zero elements from LHS/RHS.
3494 for (unsigned i = 0; i != NumElts; ++i) {
3495 int M = ShuffleMask[i];
3496 if (M < 0) {
3497 KnownUndef.setBit(i);
3498 } else if (M < (int)NumElts) {
3499 if (UndefLHS[M])
3500 KnownUndef.setBit(i);
3501 if (ZeroLHS[M])
3502 KnownZero.setBit(i);
3503 } else {
3504 if (UndefRHS[M - NumElts])
3505 KnownUndef.setBit(i);
3506 if (ZeroRHS[M - NumElts])
3507 KnownZero.setBit(i);
3508 }
3509 }
3510 break;
3511 }
3515 APInt SrcUndef, SrcZero;
3516 SDValue Src = Op.getOperand(0);
3517 unsigned NumSrcElts = Src.getValueType().getVectorNumElements();
3518 APInt DemandedSrcElts = DemandedElts.zext(NumSrcElts);
3519 if (SimplifyDemandedVectorElts(Src, DemandedSrcElts, SrcUndef, SrcZero, TLO,
3520 Depth + 1))
3521 return true;
3522 KnownZero = SrcZero.zextOrTrunc(NumElts);
3523 KnownUndef = SrcUndef.zextOrTrunc(NumElts);
3524
3525 if (IsLE && Op.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG &&
3526 Op.getValueSizeInBits() == Src.getValueSizeInBits() &&
3527 DemandedSrcElts == 1) {
3528 // aext - if we just need the bottom element then we can bitcast.
3529 return TLO.CombineTo(Op, TLO.DAG.getBitcast(VT, Src));
3530 }
3531
3532 if (Op.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG) {
3533 // zext(undef) upper bits are guaranteed to be zero.
3534 if (DemandedElts.isSubsetOf(KnownUndef))
3535 return TLO.CombineTo(Op, TLO.DAG.getConstant(0, SDLoc(Op), VT));
3536 KnownUndef.clearAllBits();
3537
3538 // zext - if we just need the bottom element then we can mask:
3539 // zext(and(x,c)) -> and(x,c') iff the zext is the only user of the and.
3540 if (IsLE && DemandedSrcElts == 1 && Src.getOpcode() == ISD::AND &&
3541 Op->isOnlyUserOf(Src.getNode()) &&
3542 Op.getValueSizeInBits() == Src.getValueSizeInBits()) {
3543 SDLoc DL(Op);
3544 EVT SrcVT = Src.getValueType();
3545 EVT SrcSVT = SrcVT.getScalarType();
3546 SmallVector<SDValue> MaskElts;
3547 MaskElts.push_back(TLO.DAG.getAllOnesConstant(DL, SrcSVT));
3548 MaskElts.append(NumSrcElts - 1, TLO.DAG.getConstant(0, DL, SrcSVT));
3549 SDValue Mask = TLO.DAG.getBuildVector(SrcVT, DL, MaskElts);
3550 if (SDValue Fold = TLO.DAG.FoldConstantArithmetic(
3551 ISD::AND, DL, SrcVT, {Src.getOperand(1), Mask})) {
3552 Fold = TLO.DAG.getNode(ISD::AND, DL, SrcVT, Src.getOperand(0), Fold);
3553 return TLO.CombineTo(Op, TLO.DAG.getBitcast(VT, Fold));
3554 }
3555 }
3556 }
3557 break;
3558 }
3559
3560 // TODO: There are more binop opcodes that could be handled here - MIN,
3561 // MAX, saturated math, etc.
3562 case ISD::ADD: {
3563 SDValue Op0 = Op.getOperand(0);
3564 SDValue Op1 = Op.getOperand(1);
3565 if (Op0 == Op1 && Op->isOnlyUserOf(Op0.getNode())) {
3566 APInt UndefLHS, ZeroLHS;
3567 if (SimplifyDemandedVectorElts(Op0, DemandedElts, UndefLHS, ZeroLHS, TLO,
3568 Depth + 1, /*AssumeSingleUse*/ true))
3569 return true;
3570 }
3571 [[fallthrough]];
3572 }
3573 case ISD::AVGCEILS:
3574 case ISD::AVGCEILU:
3575 case ISD::AVGFLOORS:
3576 case ISD::AVGFLOORU:
3577 case ISD::OR:
3578 case ISD::XOR:
3579 case ISD::SUB:
3580 case ISD::FADD:
3581 case ISD::FSUB:
3582 case ISD::FMUL:
3583 case ISD::FDIV:
3584 case ISD::FREM: {
3585 SDValue Op0 = Op.getOperand(0);
3586 SDValue Op1 = Op.getOperand(1);
3587
3588 APInt UndefRHS, ZeroRHS;
3589 if (SimplifyDemandedVectorElts(Op1, DemandedElts, UndefRHS, ZeroRHS, TLO,
3590 Depth + 1))
3591 return true;
3592 APInt UndefLHS, ZeroLHS;
3593 if (SimplifyDemandedVectorElts(Op0, DemandedElts, UndefLHS, ZeroLHS, TLO,
3594 Depth + 1))
3595 return true;
3596
3597 KnownZero = ZeroLHS & ZeroRHS;
3598 KnownUndef = getKnownUndefForVectorBinop(Op, TLO.DAG, UndefLHS, UndefRHS);
3599
3600 // Attempt to avoid multi-use ops if we don't need anything from them.
3601 // TODO - use KnownUndef to relax the demandedelts?
3602 if (!DemandedElts.isAllOnes())
3603 if (SimplifyDemandedVectorEltsBinOp(Op0, Op1))
3604 return true;
3605 break;
3606 }
3607 case ISD::SHL:
3608 case ISD::SRL:
3609 case ISD::SRA:
3610 case ISD::ROTL:
3611 case ISD::ROTR: {
3612 SDValue Op0 = Op.getOperand(0);
3613 SDValue Op1 = Op.getOperand(1);
3614
3615 APInt UndefRHS, ZeroRHS;
3616 if (SimplifyDemandedVectorElts(Op1, DemandedElts, UndefRHS, ZeroRHS, TLO,
3617 Depth + 1))
3618 return true;
3619 APInt UndefLHS, ZeroLHS;
3620 if (SimplifyDemandedVectorElts(Op0, DemandedElts, UndefLHS, ZeroLHS, TLO,
3621 Depth + 1))
3622 return true;
3623
3624 KnownZero = ZeroLHS;
3625 KnownUndef = UndefLHS & UndefRHS; // TODO: use getKnownUndefForVectorBinop?
3626
3627 // Attempt to avoid multi-use ops if we don't need anything from them.
3628 // TODO - use KnownUndef to relax the demandedelts?
3629 if (!DemandedElts.isAllOnes())
3630 if (SimplifyDemandedVectorEltsBinOp(Op0, Op1))
3631 return true;
3632 break;
3633 }
3634 case ISD::MUL:
3635 case ISD::MULHU:
3636 case ISD::MULHS:
3637 case ISD::AND: {
3638 SDValue Op0 = Op.getOperand(0);
3639 SDValue Op1 = Op.getOperand(1);
3640
3641 APInt SrcUndef, SrcZero;
3642 if (SimplifyDemandedVectorElts(Op1, DemandedElts, SrcUndef, SrcZero, TLO,
3643 Depth + 1))
3644 return true;
3645 // If we know that a demanded element was zero in Op1 we don't need to
3646 // demand it in Op0 - its guaranteed to be zero.
3647 APInt DemandedElts0 = DemandedElts & ~SrcZero;
3648 if (SimplifyDemandedVectorElts(Op0, DemandedElts0, KnownUndef, KnownZero,
3649 TLO, Depth + 1))
3650 return true;
3651
3652 KnownUndef &= DemandedElts0;
3653 KnownZero &= DemandedElts0;
3654
3655 // If every element pair has a zero/undef then just fold to zero.
3656 // fold (and x, undef) -> 0 / (and x, 0) -> 0
3657 // fold (mul x, undef) -> 0 / (mul x, 0) -> 0
3658 if (DemandedElts.isSubsetOf(SrcZero | KnownZero | SrcUndef | KnownUndef))
3659 return TLO.CombineTo(Op, TLO.DAG.getConstant(0, SDLoc(Op), VT));
3660
3661 // If either side has a zero element, then the result element is zero, even
3662 // if the other is an UNDEF.
3663 // TODO: Extend getKnownUndefForVectorBinop to also deal with known zeros
3664 // and then handle 'and' nodes with the rest of the binop opcodes.
3665 KnownZero |= SrcZero;
3666 KnownUndef &= SrcUndef;
3667 KnownUndef &= ~KnownZero;
3668
3669 // Attempt to avoid multi-use ops if we don't need anything from them.
3670 if (!DemandedElts.isAllOnes())
3671 if (SimplifyDemandedVectorEltsBinOp(Op0, Op1))
3672 return true;
3673 break;
3674 }
3675 case ISD::TRUNCATE:
3676 case ISD::SIGN_EXTEND:
3677 case ISD::ZERO_EXTEND:
3678 if (SimplifyDemandedVectorElts(Op.getOperand(0), DemandedElts, KnownUndef,
3679 KnownZero, TLO, Depth + 1))
3680 return true;
3681
3682 if (Op.getOpcode() == ISD::ZERO_EXTEND) {
3683 // zext(undef) upper bits are guaranteed to be zero.
3684 if (DemandedElts.isSubsetOf(KnownUndef))
3685 return TLO.CombineTo(Op, TLO.DAG.getConstant(0, SDLoc(Op), VT));
3686 KnownUndef.clearAllBits();
3687 }
3688 break;
3689 default: {
3690 if (Op.getOpcode() >= ISD::BUILTIN_OP_END) {
3691 if (SimplifyDemandedVectorEltsForTargetNode(Op, DemandedElts, KnownUndef,
3692 KnownZero, TLO, Depth))
3693 return true;
3694 } else {
3695 KnownBits Known;
3696 APInt DemandedBits = APInt::getAllOnes(EltSizeInBits);
3697 if (SimplifyDemandedBits(Op, DemandedBits, OriginalDemandedElts, Known,
3698 TLO, Depth, AssumeSingleUse))
3699 return true;
3700 }
3701 break;
3702 }
3703 }
3704 assert((KnownUndef & KnownZero) == 0 && "Elements flagged as undef AND zero");
3705
3706 // Constant fold all undef cases.
3707 // TODO: Handle zero cases as well.
3708 if (DemandedElts.isSubsetOf(KnownUndef))
3709 return TLO.CombineTo(Op, TLO.DAG.getUNDEF(VT));
3710
3711 return false;
3712}
3713
3714/// Determine which of the bits specified in Mask are known to be either zero or
3715/// one and return them in the Known.
3717 KnownBits &Known,
3718 const APInt &DemandedElts,
3719 const SelectionDAG &DAG,
3720 unsigned Depth) const {
3721 assert((Op.getOpcode() >= ISD::BUILTIN_OP_END ||
3722 Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
3723 Op.getOpcode() == ISD::INTRINSIC_W_CHAIN ||
3724 Op.getOpcode() == ISD::INTRINSIC_VOID) &&
3725 "Should use MaskedValueIsZero if you don't know whether Op"
3726 " is a target node!");
3727 Known.resetAll();
3728}
3729
3732 const APInt &DemandedElts, const MachineRegisterInfo &MRI,
3733 unsigned Depth) const {
3734 Known.resetAll();
3735}
3736
3738 const int FrameIdx, KnownBits &Known, const MachineFunction &MF) const {
3739 // The low bits are known zero if the pointer is aligned.
3740 Known.Zero.setLowBits(Log2(MF.getFrameInfo().getObjectAlign(FrameIdx)));
3741}
3742
3745 unsigned Depth) const {
3746 return Align(1);
3747}
3748
3749/// This method can be implemented by targets that want to expose additional
3750/// information about sign bits to the DAG Combiner.
3752 const APInt &,
3753 const SelectionDAG &,
3754 unsigned Depth) const {
3755 assert((Op.getOpcode() >= ISD::BUILTIN_OP_END ||
3756 Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
3757 Op.getOpcode() == ISD::INTRINSIC_W_CHAIN ||
3758 Op.getOpcode() == ISD::INTRINSIC_VOID) &&
3759 "Should use ComputeNumSignBits if you don't know whether Op"
3760 " is a target node!");
3761 return 1;
3762}
3763
3765 GISelKnownBits &Analysis, Register R, const APInt &DemandedElts,
3766 const MachineRegisterInfo &MRI, unsigned Depth) const {
3767 return 1;
3768}
3769
3771 SDValue Op, const APInt &DemandedElts, APInt &KnownUndef, APInt &KnownZero,
3772 TargetLoweringOpt &TLO, unsigned Depth) const {
3773 assert((Op.getOpcode() >= ISD::BUILTIN_OP_END ||
3774 Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
3775 Op.getOpcode() == ISD::INTRINSIC_W_CHAIN ||
3776 Op.getOpcode() == ISD::INTRINSIC_VOID) &&
3777 "Should use SimplifyDemandedVectorElts if you don't know whether Op"
3778 " is a target node!");
3779 return false;
3780}
3781
3783 SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
3784 KnownBits &Known, TargetLoweringOpt &TLO, unsigned Depth) const {
3785 assert((Op.getOpcode() >= ISD::BUILTIN_OP_END ||
3786 Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
3787 Op.getOpcode() == ISD::INTRINSIC_W_CHAIN ||
3788 Op.getOpcode() == ISD::INTRINSIC_VOID) &&
3789 "Should use SimplifyDemandedBits if you don't know whether Op"
3790 " is a target node!");
3791 computeKnownBitsForTargetNode(Op, Known, DemandedElts, TLO.DAG, Depth);
3792 return false;
3793}
3794
3796 SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
3797 SelectionDAG &DAG, unsigned Depth) const {
3798 assert(
3799 (Op.getOpcode() >= ISD::BUILTIN_OP_END ||
3800 Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
3801 Op.getOpcode() == ISD::INTRINSIC_W_CHAIN ||
3802 Op.getOpcode() == ISD::INTRINSIC_VOID) &&
3803 "Should use SimplifyMultipleUseDemandedBits if you don't know whether Op"
3804 " is a target node!");
3805 return SDValue();
3806}
3807
3808SDValue
3811 SelectionDAG &DAG) const {
3812 bool LegalMask = isShuffleMaskLegal(Mask, VT);
3813 if (!LegalMask) {
3814 std::swap(N0, N1);
3816 LegalMask = isShuffleMaskLegal(Mask, VT);
3817 }
3818
3819 if (!LegalMask)
3820 return SDValue();
3821
3822 return DAG.getVectorShuffle(VT, DL, N0, N1, Mask);
3823}
3824
3826 return nullptr;
3827}
3828
3830 SDValue Op, const APInt &DemandedElts, const SelectionDAG &DAG,
3831 bool PoisonOnly, unsigned Depth) const {
3832 assert(
3833 (Op.getOpcode() >= ISD::BUILTIN_OP_END ||
3834 Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
3835 Op.getOpcode() == ISD::INTRINSIC_W_CHAIN ||
3836 Op.getOpcode() == ISD::INTRINSIC_VOID) &&
3837 "Should use isGuaranteedNotToBeUndefOrPoison if you don't know whether Op"
3838 " is a target node!");
3839
3840 // If Op can't create undef/poison and none of its operands are undef/poison
3841 // then Op is never undef/poison.
3842 return !canCreateUndefOrPoisonForTargetNode(Op, DemandedElts, DAG, PoisonOnly,
3843 /*ConsiderFlags*/ true, Depth) &&
3844 all_of(Op->ops(), [&](SDValue V) {
3845 return DAG.isGuaranteedNotToBeUndefOrPoison(V, PoisonOnly,
3846 Depth + 1);
3847 });
3848}
3849
3851 SDValue Op, const APInt &DemandedElts, const SelectionDAG &DAG,
3852 bool PoisonOnly, bool ConsiderFlags, unsigned Depth) const {
3853 assert((Op.getOpcode() >= ISD::BUILTIN_OP_END ||
3854 Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
3855 Op.getOpcode() == ISD::INTRINSIC_W_CHAIN ||
3856 Op.getOpcode() == ISD::INTRINSIC_VOID) &&
3857 "Should use canCreateUndefOrPoison if you don't know whether Op"
3858 " is a target node!");
3859 // Be conservative and return true.
3860 return true;
3861}
3862
3864 const SelectionDAG &DAG,
3865 bool SNaN,
3866 unsigned Depth) const {
3867 assert((Op.getOpcode() >= ISD::BUILTIN_OP_END ||
3868 Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
3869 Op.getOpcode() == ISD::INTRINSIC_W_CHAIN ||
3870 Op.getOpcode() == ISD::INTRINSIC_VOID) &&
3871 "Should use isKnownNeverNaN if you don't know whether Op"
3872 " is a target node!");
3873 return false;
3874}
3875
3877 const APInt &DemandedElts,
3878 APInt &UndefElts,
3879 const SelectionDAG &DAG,
3880 unsigned Depth) const {
3881 assert((Op.getOpcode() >= ISD::BUILTIN_OP_END ||
3882 Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
3883 Op.getOpcode() == ISD::INTRINSIC_W_CHAIN ||
3884 Op.getOpcode() == ISD::INTRINSIC_VOID) &&
3885 "Should use isSplatValue if you don't know whether Op"
3886 " is a target node!");
3887 return false;
3888}
3889
3890// FIXME: Ideally, this would use ISD::isConstantSplatVector(), but that must
3891// work with truncating build vectors and vectors with elements of less than
3892// 8 bits.
3894 if (!N)
3895 return false;
3896
3897 unsigned EltWidth;
3898 APInt CVal;
3899 if (ConstantSDNode *CN = isConstOrConstSplat(N, /*AllowUndefs=*/false,
3900 /*AllowTruncation=*/true)) {
3901 CVal = CN->getAPIntValue();
3902 EltWidth = N.getValueType().getScalarSizeInBits();
3903 } else
3904 return false;
3905
3906 // If this is a truncating splat, truncate the splat value.
3907 // Otherwise, we may fail to match the expected values below.
3908 if (EltWidth < CVal.getBitWidth())
3909 CVal = CVal.trunc(EltWidth);
3910
3911 switch (getBooleanContents(N.getValueType())) {
3913 return CVal[0];
3915 return CVal.isOne();
3917 return CVal.isAllOnes();
3918 }
3919
3920 llvm_unreachable("Invalid boolean contents");
3921}
3922
3924 if (!N)
3925 return false;
3926
3927 const ConstantSDNode *CN = dyn_cast<ConstantSDNode>(N);
3928 if (!CN) {
3929 const BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(N);
3930 if (!BV)
3931 return false;
3932
3933 // Only interested in constant splats, we don't care about undef
3934 // elements in identifying boolean constants and getConstantSplatNode
3935 // returns NULL if all ops are undef;
3936 CN = BV->getConstantSplatNode();
3937 if (!CN)
3938 return false;
3939 }
3940
3941 if (getBooleanContents(N->getValueType(0)) == UndefinedBooleanContent)
3942 return !CN->getAPIntValue()[0];
3943
3944 return CN->isZero();
3945}
3946
3948 bool SExt) const {
3949 if (VT == MVT::i1)
3950 return N->isOne();
3951
3953 switch (Cnt) {
3955 // An extended value of 1 is always true, unless its original type is i1,
3956 // in which case it will be sign extended to -1.
3957 return (N->isOne() && !SExt) || (SExt && (N->getValueType(0) != MVT::i1));
3960 return N->isAllOnes() && SExt;
3961 }
3962 llvm_unreachable("Unexpected enumeration.");
3963}
3964
3965/// This helper function of SimplifySetCC tries to optimize the comparison when
3966/// either operand of the SetCC node is a bitwise-and instruction.
3967SDValue TargetLowering::foldSetCCWithAnd(EVT VT, SDValue N0, SDValue N1,
3968 ISD::CondCode Cond, const SDLoc &DL,
3969 DAGCombinerInfo &DCI) const {
3970 if (N1.getOpcode() == ISD::AND && N0.getOpcode() != ISD::AND)
3971 std::swap(N0, N1);
3972
3973 SelectionDAG &DAG = DCI.DAG;
3974 EVT OpVT = N0.getValueType();
3975 if (N0.getOpcode() != ISD::AND || !OpVT.isInteger() ||
3976 (Cond != ISD::SETEQ && Cond != ISD::SETNE))
3977 return SDValue();
3978
3979 // (X & Y) != 0 --> zextOrTrunc(X & Y)
3980 // iff everything but LSB is known zero:
3981 if (Cond == ISD::SETNE && isNullConstant(N1) &&
3984 unsigned NumEltBits = OpVT.getScalarSizeInBits();
3985 APInt UpperBits = APInt::getHighBitsSet(NumEltBits, NumEltBits - 1);
3986 if (DAG.MaskedValueIsZero(N0, UpperBits))
3987 return DAG.getBoolExtOrTrunc(N0, DL, VT, OpVT);
3988 }
3989
3990 // Try to eliminate a power-of-2 mask constant by converting to a signbit
3991 // test in a narrow type that we can truncate to with no cost. Examples:
3992 // (i32 X & 32768) == 0 --> (trunc X to i16) >= 0
3993 // (i32 X & 32768) != 0 --> (trunc X to i16) < 0
3994 // TODO: This conservatively checks for type legality on the source and
3995 // destination types. That may inhibit optimizations, but it also
3996 // allows setcc->shift transforms that may be more beneficial.
3997 auto *AndC = dyn_cast<ConstantSDNode>(N0.getOperand(1));
3998 if (AndC && isNullConstant(N1) && AndC->getAPIntValue().isPowerOf2() &&
3999 isTypeLegal(OpVT) && N0.hasOneUse()) {
4000 EVT NarrowVT = EVT::getIntegerVT(*DAG.getContext(),
4001 AndC->getAPIntValue().getActiveBits());
4002 if (isTruncateFree(OpVT, NarrowVT) && isTypeLegal(NarrowVT)) {
4003 SDValue Trunc = DAG.getZExtOrTrunc(N0.getOperand(0), DL, NarrowVT);
4004 SDValue Zero = DAG.getConstant(0, DL, NarrowVT);
4005 return DAG.getSetCC(DL, VT, Trunc, Zero,
4007 }
4008 }
4009
4010 // Match these patterns in any of their permutations:
4011 // (X & Y) == Y
4012 // (X & Y) != Y
4013 SDValue X, Y;
4014 if (N0.getOperand(0) == N1) {
4015 X = N0.getOperand(1);
4016 Y = N0.getOperand(0);
4017 } else if (N0.getOperand(1) == N1) {
4018 X = N0.getOperand(0);
4019 Y = N0.getOperand(1);
4020 } else {
4021 return SDValue();
4022 }
4023
4024 // TODO: We should invert (X & Y) eq/ne 0 -> (X & Y) ne/eq Y if
4025 // `isXAndYEqZeroPreferableToXAndYEqY` is false. This is a bit difficult as
4026 // its liable to create and infinite loop.
4027 SDValue Zero = DAG.getConstant(0, DL, OpVT);
4028 if (isXAndYEqZeroPreferableToXAndYEqY(Cond, OpVT) &&
4030 // Simplify X & Y == Y to X & Y != 0 if Y has exactly one bit set.
4031 // Note that where Y is variable and is known to have at most one bit set
4032 // (for example, if it is Z & 1) we cannot do this; the expressions are not
4033 // equivalent when Y == 0.
4034 assert(OpVT.isInteger());
4036 if (DCI.isBeforeLegalizeOps() ||
4038 return DAG.getSetCC(DL, VT, N0, Zero, Cond);
4039 } else if (N0.hasOneUse() && hasAndNotCompare(Y)) {
4040 // If the target supports an 'and-not' or 'and-complement' logic operation,
4041 // try to use that to make a comparison operation more efficient.
4042 // But don't do this transform if the mask is a single bit because there are
4043 // more efficient ways to deal with that case (for example, 'bt' on x86 or
4044 // 'rlwinm' on PPC).
4045
4046 // Bail out if the compare operand that we want to turn into a zero is
4047 // already a zero (otherwise, infinite loop).
4048 if (isNullConstant(Y))
4049 return SDValue();
4050
4051 // Transform this into: ~X & Y == 0.
4052 SDValue NotX = DAG.getNOT(SDLoc(X), X, OpVT);
4053 SDValue NewAnd = DAG.getNode(ISD::AND, SDLoc(N0), OpVT, NotX, Y);
4054 return DAG.getSetCC(DL, VT, NewAnd, Zero, Cond);
4055 }
4056
4057 return SDValue();
4058}
4059
4060/// There are multiple IR patterns that could be checking whether certain
4061/// truncation of a signed number would be lossy or not. The pattern which is
4062/// best at IR level, may not lower optimally. Thus, we want to unfold it.
4063/// We are looking for the following pattern: (KeptBits is a constant)
4064/// (add %x, (1 << (KeptBits-1))) srccond (1 << KeptBits)
4065/// KeptBits won't be bitwidth(x), that will be constant-folded to true/false.
4066/// KeptBits also can't be 1, that would have been folded to %x dstcond 0
4067/// We will unfold it into the natural trunc+sext pattern:
4068/// ((%x << C) a>> C) dstcond %x
4069/// Where C = bitwidth(x) - KeptBits and C u< bitwidth(x)
4070SDValue TargetLowering::optimizeSetCCOfSignedTruncationCheck(
4071 EVT SCCVT, SDValue N0, SDValue N1, ISD::CondCode Cond, DAGCombinerInfo &DCI,
4072 const SDLoc &DL) const {
4073 // We must be comparing with a constant.
4074 ConstantSDNode *C1;
4075 if (!(C1 = dyn_cast<ConstantSDNode>(N1)))
4076 return SDValue();
4077
4078 // N0 should be: add %x, (1 << (KeptBits-1))
4079 if (N0->getOpcode() != ISD::ADD)
4080 return SDValue();
4081
4082 // And we must be 'add'ing a constant.
4083 ConstantSDNode *C01;
4084 if (!(C01 = dyn_cast<ConstantSDNode>(N0->getOperand(1))))
4085 return SDValue();
4086
4087 SDValue X = N0->getOperand(0);
4088 EVT XVT = X.getValueType();
4089
4090 // Validate constants ...
4091
4092 APInt I1 = C1->getAPIntValue();
4093
4094 ISD::CondCode NewCond;
4095 if (Cond == ISD::CondCode::SETULT) {
4096 NewCond = ISD::CondCode::SETEQ;
4097 } else if (Cond == ISD::CondCode::SETULE) {
4098 NewCond = ISD::CondCode::SETEQ;
4099 // But need to 'canonicalize' the constant.
4100 I1 += 1;
4101 } else if (Cond == ISD::CondCode::SETUGT) {
4102 NewCond = ISD::CondCode::SETNE;
4103 // But need to 'canonicalize' the constant.
4104 I1 += 1;
4105 } else if (Cond == ISD::CondCode::SETUGE) {
4106 NewCond = ISD::CondCode::SETNE;
4107 } else
4108 return SDValue();
4109
4110 APInt I01 = C01->getAPIntValue();
4111
4112 auto checkConstants = [&I1, &I01]() -> bool {
4113 // Both of them must be power-of-two, and the constant from setcc is bigger.
4114 return I1.ugt(I01) && I1.isPowerOf2() && I01.isPowerOf2();
4115 };
4116
4117 if (checkConstants()) {
4118 // Great, e.g. got icmp ult i16 (add i16 %x, 128), 256
4119 } else {
4120 // What if we invert constants? (and the target predicate)
4121 I1.negate();
4122 I01.negate();
4123 assert(XVT.isInteger());
4124 NewCond = getSetCCInverse(NewCond, XVT);
4125 if (!checkConstants())
4126 return SDValue();
4127 // Great, e.g. got icmp uge i16 (add i16 %x, -128), -256
4128 }
4129
4130 // They are power-of-two, so which bit is set?
4131 const unsigned KeptBits = I1.logBase2();
4132 const unsigned KeptBitsMinusOne = I01.logBase2();
4133
4134 // Magic!
4135 if (KeptBits != (KeptBitsMinusOne + 1))
4136 return SDValue();
4137 assert(KeptBits > 0 && KeptBits < XVT.getSizeInBits() && "unreachable");
4138
4139 // We don't want to do this in every single case.
4140 SelectionDAG &DAG = DCI.DAG;
4142 XVT, KeptBits))
4143 return SDValue();
4144
4145 // Unfold into: sext_inreg(%x) cond %x
4146 // Where 'cond' will be either 'eq' or 'ne'.
4147 SDValue SExtInReg = DAG.getNode(
4149 DAG.getValueType(EVT::getIntegerVT(*DAG.getContext(), KeptBits)));
4150 return DAG.getSetCC(DL, SCCVT, SExtInReg, X, NewCond);
4151}
4152
4153// (X & (C l>>/<< Y)) ==/!= 0 --> ((X <</l>> Y) & C) ==/!= 0
4154SDValue TargetLowering::optimizeSetCCByHoistingAndByConstFromLogicalShift(
4155 EVT SCCVT, SDValue N0, SDValue N1C, ISD::CondCode Cond,
4156 DAGCombinerInfo &DCI, const SDLoc &DL) const {
4158 "Should be a comparison with 0.");
4159 assert((Cond == ISD::SETEQ || Cond == ISD::SETNE) &&
4160 "Valid only for [in]equality comparisons.");
4161
4162 unsigned NewShiftOpcode;
4163 SDValue X, C, Y;
4164
4165 SelectionDAG &DAG = DCI.DAG;
4166 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
4167
4168 // Look for '(C l>>/<< Y)'.
4169 auto Match = [&NewShiftOpcode, &X, &C, &Y, &TLI, &DAG](SDValue V) {
4170 // The shift should be one-use.
4171 if (!V.hasOneUse())
4172 return false;
4173 unsigned OldShiftOpcode = V.getOpcode();
4174 switch (OldShiftOpcode) {
4175 case ISD::SHL:
4176 NewShiftOpcode = ISD::SRL;
4177 break;
4178 case ISD::SRL:
4179 NewShiftOpcode = ISD::SHL;
4180 break;
4181 default:
4182 return false; // must be a logical shift.
4183 }
4184 // We should be shifting a constant.
4185 // FIXME: best to use isConstantOrConstantVector().
4186 C = V.getOperand(0);
4188 isConstOrConstSplat(C, /*AllowUndefs=*/true, /*AllowTruncation=*/true);
4189 if (!CC)
4190 return false;
4191 Y = V.getOperand(1);
4192
4194 isConstOrConstSplat(X, /*AllowUndefs=*/true, /*AllowTruncation=*/true);
4195 return TLI.shouldProduceAndByConstByHoistingConstFromShiftsLHSOfAnd(
4196 X, XC, CC, Y, OldShiftOpcode, NewShiftOpcode, DAG);
4197 };
4198
4199 // LHS of comparison should be an one-use 'and'.
4200 if (N0.getOpcode() != ISD::AND || !N0.hasOneUse())
4201 return SDValue();
4202
4203 X = N0.getOperand(0);
4204 SDValue Mask = N0.getOperand(1);
4205
4206 // 'and' is commutative!
4207 if (!Match(Mask)) {
4208 std::swap(X, Mask);
4209 if (!Match(Mask))
4210 return SDValue();
4211 }
4212
4213 EVT VT = X.getValueType();
4214
4215 // Produce:
4216 // ((X 'OppositeShiftOpcode' Y) & C) Cond 0
4217 SDValue T0 = DAG.getNode(NewShiftOpcode, DL, VT, X, Y);
4218 SDValue T1 = DAG.getNode(ISD::AND, DL, VT, T0, C);
4219 SDValue T2 = DAG.getSetCC(DL, SCCVT, T1, N1C, Cond);
4220 return T2;
4221}
4222
4223/// Try to fold an equality comparison with a {add/sub/xor} binary operation as
4224/// the 1st operand (N0). Callers are expected to swap the N0/N1 parameters to
4225/// handle the commuted versions of these patterns.
4226SDValue TargetLowering::foldSetCCWithBinOp(EVT VT, SDValue N0, SDValue N1,
4227 ISD::CondCode Cond, const SDLoc &DL,
4228 DAGCombinerInfo &DCI) const {
4229 unsigned BOpcode = N0.getOpcode();
4230 assert((BOpcode == ISD::ADD || BOpcode == ISD::SUB || BOpcode == ISD::XOR) &&
4231 "Unexpected binop");
4232 assert((Cond == ISD::SETEQ || Cond == ISD::SETNE) && "Unexpected condcode");
4233
4234 // (X + Y) == X --> Y == 0
4235 // (X - Y) == X --> Y == 0
4236 // (X ^ Y) == X --> Y == 0
4237 SelectionDAG &DAG = DCI.DAG;
4238 EVT OpVT = N0.getValueType();
4239 SDValue X = N0.getOperand(0);
4240 SDValue Y = N0.getOperand(1);
4241 if (X == N1)
4242 return DAG.getSetCC(DL, VT, Y, DAG.getConstant(0, DL, OpVT), Cond);
4243
4244 if (Y != N1)
4245 return SDValue();
4246
4247 // (X + Y) == Y --> X == 0
4248 // (X ^ Y) == Y --> X == 0
4249 if (BOpcode == ISD::ADD || BOpcode == ISD::XOR)
4250 return DAG.getSetCC(DL, VT, X, DAG.getConstant(0, DL, OpVT), Cond);
4251
4252 // The shift would not be valid if the operands are boolean (i1).
4253 if (!N0.hasOneUse() || OpVT.getScalarSizeInBits() == 1)
4254 return SDValue();
4255
4256 // (X - Y) == Y --> X == Y << 1
4257 SDValue One =
4258 DAG.getShiftAmountConstant(1, OpVT, DL, !DCI.isBeforeLegalize());
4259 SDValue YShl1 = DAG.getNode(ISD::SHL, DL, N1.getValueType(), Y, One);
4260 if (!DCI.isCalledByLegalizer())
4261 DCI.AddToWorklist(YShl1.getNode());
4262 return DAG.getSetCC(DL, VT, X, YShl1, Cond);
4263}
4264
4266 SDValue N0, const APInt &C1,
4267 ISD::CondCode Cond, const SDLoc &dl,
4268 SelectionDAG &DAG) {
4269 // Look through truncs that don't change the value of a ctpop.
4270 // FIXME: Add vector support? Need to be careful with setcc result type below.
4271 SDValue CTPOP = N0;
4272 if (N0.getOpcode() == ISD::TRUNCATE && N0.hasOneUse() && !VT.isVector() &&
4274 CTPOP = N0.getOperand(0);
4275
4276 if (CTPOP.getOpcode() != ISD::CTPOP || !CTPOP.hasOneUse())
4277 return SDValue();
4278
4279 EVT CTVT = CTPOP.getValueType();
4280 SDValue CTOp = CTPOP.getOperand(0);
4281
4282 // Expand a power-of-2-or-zero comparison based on ctpop:
4283 // (ctpop x) u< 2 -> (x & x-1) == 0
4284 // (ctpop x) u> 1 -> (x & x-1) != 0
4285 if (Cond == ISD::SETULT || Cond == ISD::SETUGT) {
4286 // Keep the CTPOP if it is a cheap vector op.
4287 if (CTVT.isVector() && TLI.isCtpopFast(CTVT))
4288 return SDValue();
4289
4290 unsigned CostLimit = TLI.getCustomCtpopCost(CTVT, Cond);
4291 if (C1.ugt(CostLimit + (Cond == ISD::SETULT)))
4292 return SDValue();
4293 if (C1 == 0 && (Cond == ISD::SETULT))
4294 return SDValue(); // This is handled elsewhere.
4295
4296 unsigned Passes = C1.getLimitedValue() - (Cond == ISD::SETULT);
4297
4298 SDValue NegOne = DAG.getAllOnesConstant(dl, CTVT);
4299 SDValue Result = CTOp;
4300 for (unsigned i = 0; i < Passes; i++) {
4301 SDValue Add = DAG.getNode(ISD::ADD, dl, CTVT, Result, NegOne);
4302 Result = DAG.getNode(ISD::AND, dl, CTVT, Result, Add);
4303 }
4305 return DAG.getSetCC(dl, VT, Result, DAG.getConstant(0, dl, CTVT), CC);
4306 }
4307
4308 // Expand a power-of-2 comparison based on ctpop
4309 if ((Cond == ISD::SETEQ || Cond == ISD::SETNE) && C1 == 1) {
4310 // Keep the CTPOP if it is cheap.
4311 if (TLI.isCtpopFast(CTVT))
4312 return SDValue();
4313
4314 SDValue Zero = DAG.getConstant(0, dl, CTVT);
4315 SDValue NegOne = DAG.getAllOnesConstant(dl, CTVT);
4316 assert(CTVT.isInteger());
4317 SDValue Add = DAG.getNode(ISD::ADD, dl, CTVT, CTOp, NegOne);
4318
4319 // Its not uncommon for known-never-zero X to exist in (ctpop X) eq/ne 1, so
4320 // check before emitting a potentially unnecessary op.
4321 if (DAG.isKnownNeverZero(CTOp)) {
4322 // (ctpop x) == 1 --> (x & x-1) == 0
4323 // (ctpop x) != 1 --> (x & x-1) != 0
4324 SDValue And = DAG.getNode(ISD::AND, dl, CTVT, CTOp, Add);
4325 SDValue RHS = DAG.getSetCC(dl, VT, And, Zero, Cond);
4326 return RHS;
4327 }
4328
4329 // (ctpop x) == 1 --> (x ^ x-1) > x-1
4330 // (ctpop x) != 1 --> (x ^ x-1) <= x-1
4331 SDValue Xor = DAG.getNode(ISD::XOR, dl, CTVT, CTOp, Add);
4333 return DAG.getSetCC(dl, VT, Xor, Add, CmpCond);
4334 }
4335
4336 return SDValue();
4337}
4338
4340 ISD::CondCode Cond, const SDLoc &dl,
4341 SelectionDAG &DAG) {
4342 if (Cond != ISD::SETEQ && Cond != ISD::SETNE)
4343 return SDValue();
4344
4345 auto *C1 = isConstOrConstSplat(N1, /* AllowUndefs */ true);
4346 if (!C1 || !(C1->isZero() || C1->isAllOnes()))
4347 return SDValue();
4348
4349 auto getRotateSource = [](SDValue X) {
4350 if (X.getOpcode() == ISD::ROTL || X.getOpcode() == ISD::ROTR)
4351 return X.getOperand(0);
4352 return SDValue();
4353 };
4354
4355 // Peek through a rotated value compared against 0 or -1:
4356 // (rot X, Y) == 0/-1 --> X == 0/-1
4357 // (rot X, Y) != 0/-1 --> X != 0/-1
4358 if (SDValue R = getRotateSource(N0))
4359 return DAG.getSetCC(dl, VT, R, N1,