LLVM 20.0.0git
TargetLowering.cpp
Go to the documentation of this file.
1//===-- TargetLowering.cpp - Implement the TargetLowering class -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This implements the TargetLowering class.
10//
11//===----------------------------------------------------------------------===//
12
14#include "llvm/ADT/STLExtras.h"
25#include "llvm/IR/DataLayout.h"
28#include "llvm/IR/LLVMContext.h"
29#include "llvm/MC/MCAsmInfo.h"
30#include "llvm/MC/MCExpr.h"
36#include <cctype>
37using namespace llvm;
38
39/// NOTE: The TargetMachine owns TLOF.
41 : TargetLoweringBase(tm) {}
42
43const char *TargetLowering::getTargetNodeName(unsigned Opcode) const {
44 return nullptr;
45}
46
49}
50
51/// Check whether a given call node is in tail position within its function. If
52/// so, it sets Chain to the input chain of the tail call.
54 SDValue &Chain) const {
56
57 // First, check if tail calls have been disabled in this function.
58 if (F.getFnAttribute("disable-tail-calls").getValueAsBool())
59 return false;
60
61 // Conservatively require the attributes of the call to match those of
62 // the return. Ignore following attributes because they don't affect the
63 // call sequence.
64 AttrBuilder CallerAttrs(F.getContext(), F.getAttributes().getRetAttrs());
65 for (const auto &Attr :
66 {Attribute::Alignment, Attribute::Dereferenceable,
67 Attribute::DereferenceableOrNull, Attribute::NoAlias,
68 Attribute::NonNull, Attribute::NoUndef, Attribute::Range})
69 CallerAttrs.removeAttribute(Attr);
70
71 if (CallerAttrs.hasAttributes())
72 return false;
73
74 // It's not safe to eliminate the sign / zero extension of the return value.
75 if (CallerAttrs.contains(Attribute::ZExt) ||
76 CallerAttrs.contains(Attribute::SExt))
77 return false;
78
79 // Check if the only use is a function return node.
80 return isUsedByReturnOnly(Node, Chain);
81}
82
84 const uint32_t *CallerPreservedMask,
85 const SmallVectorImpl<CCValAssign> &ArgLocs,
86 const SmallVectorImpl<SDValue> &OutVals) const {
87 for (unsigned I = 0, E = ArgLocs.size(); I != E; ++I) {
88 const CCValAssign &ArgLoc = ArgLocs[I];
89 if (!ArgLoc.isRegLoc())
90 continue;
91 MCRegister Reg = ArgLoc.getLocReg();
92 // Only look at callee saved registers.
93 if (MachineOperand::clobbersPhysReg(CallerPreservedMask, Reg))
94 continue;
95 // Check that we pass the value used for the caller.
96 // (We look for a CopyFromReg reading a virtual register that is used
97 // for the function live-in value of register Reg)
98 SDValue Value = OutVals[I];
99 if (Value->getOpcode() == ISD::AssertZext)
100 Value = Value.getOperand(0);
101 if (Value->getOpcode() != ISD::CopyFromReg)
102 return false;
103 Register ArgReg = cast<RegisterSDNode>(Value->getOperand(1))->getReg();
104 if (MRI.getLiveInPhysReg(ArgReg) != Reg)
105 return false;
106 }
107 return true;
108}
109
110/// Set CallLoweringInfo attribute flags based on a call instruction
111/// and called function attributes.
113 unsigned ArgIdx) {
114 IsSExt = Call->paramHasAttr(ArgIdx, Attribute::SExt);
115 IsZExt = Call->paramHasAttr(ArgIdx, Attribute::ZExt);
116 IsInReg = Call->paramHasAttr(ArgIdx, Attribute::InReg);
117 IsSRet = Call->paramHasAttr(ArgIdx, Attribute::StructRet);
118 IsNest = Call->paramHasAttr(ArgIdx, Attribute::Nest);
119 IsByVal = Call->paramHasAttr(ArgIdx, Attribute::ByVal);
120 IsPreallocated = Call->paramHasAttr(ArgIdx, Attribute::Preallocated);
121 IsInAlloca = Call->paramHasAttr(ArgIdx, Attribute::InAlloca);
122 IsReturned = Call->paramHasAttr(ArgIdx, Attribute::Returned);
123 IsSwiftSelf = Call->paramHasAttr(ArgIdx, Attribute::SwiftSelf);
124 IsSwiftAsync = Call->paramHasAttr(ArgIdx, Attribute::SwiftAsync);
125 IsSwiftError = Call->paramHasAttr(ArgIdx, Attribute::SwiftError);
126 Alignment = Call->getParamStackAlign(ArgIdx);
127 IndirectType = nullptr;
129 "multiple ABI attributes?");
130 if (IsByVal) {
131 IndirectType = Call->getParamByValType(ArgIdx);
132 if (!Alignment)
133 Alignment = Call->getParamAlign(ArgIdx);
134 }
135 if (IsPreallocated)
136 IndirectType = Call->getParamPreallocatedType(ArgIdx);
137 if (IsInAlloca)
138 IndirectType = Call->getParamInAllocaType(ArgIdx);
139 if (IsSRet)
140 IndirectType = Call->getParamStructRetType(ArgIdx);
141}
142
143/// Generate a libcall taking the given operands as arguments and returning a
144/// result of type RetVT.
145std::pair<SDValue, SDValue>
148 MakeLibCallOptions CallOptions,
149 const SDLoc &dl,
150 SDValue InChain) const {
151 if (!InChain)
152 InChain = DAG.getEntryNode();
153
155 Args.reserve(Ops.size());
156
158 for (unsigned i = 0; i < Ops.size(); ++i) {
159 SDValue NewOp = Ops[i];
160 Entry.Node = NewOp;
161 Entry.Ty = Entry.Node.getValueType().getTypeForEVT(*DAG.getContext());
162 Entry.IsSExt = shouldSignExtendTypeInLibCall(NewOp.getValueType(),
163 CallOptions.IsSExt);
164 Entry.IsZExt = !Entry.IsSExt;
165
166 if (CallOptions.IsSoften &&
168 Entry.IsSExt = Entry.IsZExt = false;
169 }
170 Args.push_back(Entry);
171 }
172
173 if (LC == RTLIB::UNKNOWN_LIBCALL)
174 report_fatal_error("Unsupported library call operation!");
177
178 Type *RetTy = RetVT.getTypeForEVT(*DAG.getContext());
180 bool signExtend = shouldSignExtendTypeInLibCall(RetVT, CallOptions.IsSExt);
181 bool zeroExtend = !signExtend;
182
183 if (CallOptions.IsSoften &&
185 signExtend = zeroExtend = false;
186 }
187
188 CLI.setDebugLoc(dl)
189 .setChain(InChain)
190 .setLibCallee(getLibcallCallingConv(LC), RetTy, Callee, std::move(Args))
191 .setNoReturn(CallOptions.DoesNotReturn)
194 .setSExtResult(signExtend)
195 .setZExtResult(zeroExtend);
196 return LowerCallTo(CLI);
197}
198
200 std::vector<EVT> &MemOps, unsigned Limit, const MemOp &Op, unsigned DstAS,
201 unsigned SrcAS, const AttributeList &FuncAttributes) const {
202 if (Limit != ~unsigned(0) && Op.isMemcpyWithFixedDstAlign() &&
203 Op.getSrcAlign() < Op.getDstAlign())
204 return false;
205
206 EVT VT = getOptimalMemOpType(Op, FuncAttributes);
207
208 if (VT == MVT::Other) {
209 // Use the largest integer type whose alignment constraints are satisfied.
210 // We only need to check DstAlign here as SrcAlign is always greater or
211 // equal to DstAlign (or zero).
212 VT = MVT::LAST_INTEGER_VALUETYPE;
213 if (Op.isFixedDstAlign())
214 while (Op.getDstAlign() < (VT.getSizeInBits() / 8) &&
215 !allowsMisalignedMemoryAccesses(VT, DstAS, Op.getDstAlign()))
217 assert(VT.isInteger());
218
219 // Find the largest legal integer type.
220 MVT LVT = MVT::LAST_INTEGER_VALUETYPE;
221 while (!isTypeLegal(LVT))
222 LVT = (MVT::SimpleValueType)(LVT.SimpleTy - 1);
223 assert(LVT.isInteger());
224
225 // If the type we've chosen is larger than the largest legal integer type
226 // then use that instead.
227 if (VT.bitsGT(LVT))
228 VT = LVT;
229 }
230
231 unsigned NumMemOps = 0;
232 uint64_t Size = Op.size();
233 while (Size) {
234 unsigned VTSize = VT.getSizeInBits() / 8;
235 while (VTSize > Size) {
236 // For now, only use non-vector load / store's for the left-over pieces.
237 EVT NewVT = VT;
238 unsigned NewVTSize;
239
240 bool Found = false;
241 if (VT.isVector() || VT.isFloatingPoint()) {
242 NewVT = (VT.getSizeInBits() > 64) ? MVT::i64 : MVT::i32;
245 Found = true;
246 else if (NewVT == MVT::i64 &&
248 isSafeMemOpType(MVT::f64)) {
249 // i64 is usually not legal on 32-bit targets, but f64 may be.
250 NewVT = MVT::f64;
251 Found = true;
252 }
253 }
254
255 if (!Found) {
256 do {
257 NewVT = (MVT::SimpleValueType)(NewVT.getSimpleVT().SimpleTy - 1);
258 if (NewVT == MVT::i8)
259 break;
260 } while (!isSafeMemOpType(NewVT.getSimpleVT()));
261 }
262 NewVTSize = NewVT.getSizeInBits() / 8;
263
264 // If the new VT cannot cover all of the remaining bits, then consider
265 // issuing a (or a pair of) unaligned and overlapping load / store.
266 unsigned Fast;
267 if (NumMemOps && Op.allowOverlap() && NewVTSize < Size &&
269 VT, DstAS, Op.isFixedDstAlign() ? Op.getDstAlign() : Align(1),
271 Fast)
272 VTSize = Size;
273 else {
274 VT = NewVT;
275 VTSize = NewVTSize;
276 }
277 }
278
279 if (++NumMemOps > Limit)
280 return false;
281
282 MemOps.push_back(VT);
283 Size -= VTSize;
284 }
285
286 return true;
287}
288
289/// Soften the operands of a comparison. This code is shared among BR_CC,
290/// SELECT_CC, and SETCC handlers.
292 SDValue &NewLHS, SDValue &NewRHS,
293 ISD::CondCode &CCCode,
294 const SDLoc &dl, const SDValue OldLHS,
295 const SDValue OldRHS) const {
296 SDValue Chain;
297 return softenSetCCOperands(DAG, VT, NewLHS, NewRHS, CCCode, dl, OldLHS,
298 OldRHS, Chain);
299}
300
302 SDValue &NewLHS, SDValue &NewRHS,
303 ISD::CondCode &CCCode,
304 const SDLoc &dl, const SDValue OldLHS,
305 const SDValue OldRHS,
306 SDValue &Chain,
307 bool IsSignaling) const {
308 // FIXME: Currently we cannot really respect all IEEE predicates due to libgcc
309 // not supporting it. We can update this code when libgcc provides such
310 // functions.
311
312 assert((VT == MVT::f32 || VT == MVT::f64 || VT == MVT::f128 || VT == MVT::ppcf128)
313 && "Unsupported setcc type!");
314
315 // Expand into one or more soft-fp libcall(s).
316 RTLIB::Libcall LC1 = RTLIB::UNKNOWN_LIBCALL, LC2 = RTLIB::UNKNOWN_LIBCALL;
317 bool ShouldInvertCC = false;
318 switch (CCCode) {
319 case ISD::SETEQ:
320 case ISD::SETOEQ:
321 LC1 = (VT == MVT::f32) ? RTLIB::OEQ_F32 :
322 (VT == MVT::f64) ? RTLIB::OEQ_F64 :
323 (VT == MVT::f128) ? RTLIB::OEQ_F128 : RTLIB::OEQ_PPCF128;
324 break;
325 case ISD::SETNE:
326 case ISD::SETUNE:
327 LC1 = (VT == MVT::f32) ? RTLIB::UNE_F32 :
328 (VT == MVT::f64) ? RTLIB::UNE_F64 :
329 (VT == MVT::f128) ? RTLIB::UNE_F128 : RTLIB::UNE_PPCF128;
330 break;
331 case ISD::SETGE:
332 case ISD::SETOGE:
333 LC1 = (VT == MVT::f32) ? RTLIB::OGE_F32 :
334 (VT == MVT::f64) ? RTLIB::OGE_F64 :
335 (VT == MVT::f128) ? RTLIB::OGE_F128 : RTLIB::OGE_PPCF128;
336 break;
337 case ISD::SETLT:
338 case ISD::SETOLT:
339 LC1 = (VT == MVT::f32) ? RTLIB::OLT_F32 :
340 (VT == MVT::f64) ? RTLIB::OLT_F64 :
341 (VT == MVT::f128) ? RTLIB::OLT_F128 : RTLIB::OLT_PPCF128;
342 break;
343 case ISD::SETLE:
344 case ISD::SETOLE:
345 LC1 = (VT == MVT::f32) ? RTLIB::OLE_F32 :
346 (VT == MVT::f64) ? RTLIB::OLE_F64 :
347 (VT == MVT::f128) ? RTLIB::OLE_F128 : RTLIB::OLE_PPCF128;
348 break;
349 case ISD::SETGT:
350 case ISD::SETOGT:
351 LC1 = (VT == MVT::f32) ? RTLIB::OGT_F32 :
352 (VT == MVT::f64) ? RTLIB::OGT_F64 :
353 (VT == MVT::f128) ? RTLIB::OGT_F128 : RTLIB::OGT_PPCF128;
354 break;
355 case ISD::SETO:
356 ShouldInvertCC = true;
357 [[fallthrough]];
358 case ISD::SETUO:
359 LC1 = (VT == MVT::f32) ? RTLIB::UO_F32 :
360 (VT == MVT::f64) ? RTLIB::UO_F64 :
361 (VT == MVT::f128) ? RTLIB::UO_F128 : RTLIB::UO_PPCF128;
362 break;
363 case ISD::SETONE:
364 // SETONE = O && UNE
365 ShouldInvertCC = true;
366 [[fallthrough]];
367 case ISD::SETUEQ:
368 LC1 = (VT == MVT::f32) ? RTLIB::UO_F32 :
369 (VT == MVT::f64) ? RTLIB::UO_F64 :
370 (VT == MVT::f128) ? RTLIB::UO_F128 : RTLIB::UO_PPCF128;
371 LC2 = (VT == MVT::f32) ? RTLIB::OEQ_F32 :
372 (VT == MVT::f64) ? RTLIB::OEQ_F64 :
373 (VT == MVT::f128) ? RTLIB::OEQ_F128 : RTLIB::OEQ_PPCF128;
374 break;
375 default:
376 // Invert CC for unordered comparisons
377 ShouldInvertCC = true;
378 switch (CCCode) {
379 case ISD::SETULT:
380 LC1 = (VT == MVT::f32) ? RTLIB::OGE_F32 :
381 (VT == MVT::f64) ? RTLIB::OGE_F64 :
382 (VT == MVT::f128) ? RTLIB::OGE_F128 : RTLIB::OGE_PPCF128;
383 break;
384 case ISD::SETULE:
385 LC1 = (VT == MVT::f32) ? RTLIB::OGT_F32 :
386 (VT == MVT::f64) ? RTLIB::OGT_F64 :
387 (VT == MVT::f128) ? RTLIB::OGT_F128 : RTLIB::OGT_PPCF128;
388 break;
389 case ISD::SETUGT:
390 LC1 = (VT == MVT::f32) ? RTLIB::OLE_F32 :
391 (VT == MVT::f64) ? RTLIB::OLE_F64 :
392 (VT == MVT::f128) ? RTLIB::OLE_F128 : RTLIB::OLE_PPCF128;
393 break;
394 case ISD::SETUGE:
395 LC1 = (VT == MVT::f32) ? RTLIB::OLT_F32 :
396 (VT == MVT::f64) ? RTLIB::OLT_F64 :
397 (VT == MVT::f128) ? RTLIB::OLT_F128 : RTLIB::OLT_PPCF128;
398 break;
399 default: llvm_unreachable("Do not know how to soften this setcc!");
400 }
401 }
402
403 // Use the target specific return value for comparison lib calls.
405 SDValue Ops[2] = {NewLHS, NewRHS};
407 EVT OpsVT[2] = { OldLHS.getValueType(),
408 OldRHS.getValueType() };
409 CallOptions.setTypeListBeforeSoften(OpsVT, RetVT, true);
410 auto Call = makeLibCall(DAG, LC1, RetVT, Ops, CallOptions, dl, Chain);
411 NewLHS = Call.first;
412 NewRHS = DAG.getConstant(0, dl, RetVT);
413
414 CCCode = getCmpLibcallCC(LC1);
415 if (ShouldInvertCC) {
416 assert(RetVT.isInteger());
417 CCCode = getSetCCInverse(CCCode, RetVT);
418 }
419
420 if (LC2 == RTLIB::UNKNOWN_LIBCALL) {
421 // Update Chain.
422 Chain = Call.second;
423 } else {
424 EVT SetCCVT =
425 getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), RetVT);
426 SDValue Tmp = DAG.getSetCC(dl, SetCCVT, NewLHS, NewRHS, CCCode);
427 auto Call2 = makeLibCall(DAG, LC2, RetVT, Ops, CallOptions, dl, Chain);
428 CCCode = getCmpLibcallCC(LC2);
429 if (ShouldInvertCC)
430 CCCode = getSetCCInverse(CCCode, RetVT);
431 NewLHS = DAG.getSetCC(dl, SetCCVT, Call2.first, NewRHS, CCCode);
432 if (Chain)
433 Chain = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, Call.second,
434 Call2.second);
435 NewLHS = DAG.getNode(ShouldInvertCC ? ISD::AND : ISD::OR, dl,
436 Tmp.getValueType(), Tmp, NewLHS);
437 NewRHS = SDValue();
438 }
439}
440
441/// Return the entry encoding for a jump table in the current function. The
442/// returned value is a member of the MachineJumpTableInfo::JTEntryKind enum.
444 // In non-pic modes, just use the address of a block.
445 if (!isPositionIndependent())
447
448 // In PIC mode, if the target supports a GPRel32 directive, use it.
449 if (getTargetMachine().getMCAsmInfo()->getGPRel32Directive() != nullptr)
451
452 // Otherwise, use a label difference.
454}
455
457 SelectionDAG &DAG) const {
458 // If our PIC model is GP relative, use the global offset table as the base.
459 unsigned JTEncoding = getJumpTableEncoding();
460
464
465 return Table;
466}
467
468/// This returns the relocation base for the given PIC jumptable, the same as
469/// getPICJumpTableRelocBase, but as an MCExpr.
470const MCExpr *
472 unsigned JTI,MCContext &Ctx) const{
473 // The normal PIC reloc base is the label at the start of the jump table.
474 return MCSymbolRefExpr::create(MF->getJTISymbol(JTI, Ctx), Ctx);
475}
476
478 SDValue Addr, int JTI,
479 SelectionDAG &DAG) const {
480 SDValue Chain = Value;
481 // Jump table debug info is only needed if CodeView is enabled.
483 Chain = DAG.getJumpTableDebugInfo(JTI, Chain, dl);
484 }
485 return DAG.getNode(ISD::BRIND, dl, MVT::Other, Chain, Addr);
486}
487
488bool
490 const TargetMachine &TM = getTargetMachine();
491 const GlobalValue *GV = GA->getGlobal();
492
493 // If the address is not even local to this DSO we will have to load it from
494 // a got and then add the offset.
495 if (!TM.shouldAssumeDSOLocal(GV))
496 return false;
497
498 // If the code is position independent we will have to add a base register.
499 if (isPositionIndependent())
500 return false;
501
502 // Otherwise we can do it.
503 return true;
504}
505
506//===----------------------------------------------------------------------===//
507// Optimization Methods
508//===----------------------------------------------------------------------===//
509
510/// If the specified instruction has a constant integer operand and there are
511/// bits set in that constant that are not demanded, then clear those bits and
512/// return true.
514 const APInt &DemandedBits,
515 const APInt &DemandedElts,
516 TargetLoweringOpt &TLO) const {
517 SDLoc DL(Op);
518 unsigned Opcode = Op.getOpcode();
519
520 // Early-out if we've ended up calling an undemanded node, leave this to
521 // constant folding.
522 if (DemandedBits.isZero() || DemandedElts.isZero())
523 return false;
524
525 // Do target-specific constant optimization.
526 if (targetShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO))
527 return TLO.New.getNode();
528
529 // FIXME: ISD::SELECT, ISD::SELECT_CC
530 switch (Opcode) {
531 default:
532 break;
533 case ISD::XOR:
534 case ISD::AND:
535 case ISD::OR: {
536 auto *Op1C = dyn_cast<ConstantSDNode>(Op.getOperand(1));
537 if (!Op1C || Op1C->isOpaque())
538 return false;
539
540 // If this is a 'not' op, don't touch it because that's a canonical form.
541 const APInt &C = Op1C->getAPIntValue();
542 if (Opcode == ISD::XOR && DemandedBits.isSubsetOf(C))
543 return false;
544
545 if (!C.isSubsetOf(DemandedBits)) {
546 EVT VT = Op.getValueType();
547 SDValue NewC = TLO.DAG.getConstant(DemandedBits & C, DL, VT);
548 SDValue NewOp = TLO.DAG.getNode(Opcode, DL, VT, Op.getOperand(0), NewC,
549 Op->getFlags());
550 return TLO.CombineTo(Op, NewOp);
551 }
552
553 break;
554 }
555 }
556
557 return false;
558}
559
561 const APInt &DemandedBits,
562 TargetLoweringOpt &TLO) const {
563 EVT VT = Op.getValueType();
564 APInt DemandedElts = VT.isVector()
566 : APInt(1, 1);
567 return ShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO);
568}
569
570/// Convert x+y to (VT)((SmallVT)x+(SmallVT)y) if the casts are free.
571/// This uses isTruncateFree/isZExtFree and ANY_EXTEND for the widening cast,
572/// but it could be generalized for targets with other types of implicit
573/// widening casts.
575 const APInt &DemandedBits,
576 TargetLoweringOpt &TLO) const {
577 assert(Op.getNumOperands() == 2 &&
578 "ShrinkDemandedOp only supports binary operators!");
579 assert(Op.getNode()->getNumValues() == 1 &&
580 "ShrinkDemandedOp only supports nodes with one result!");
581
582 EVT VT = Op.getValueType();
583 SelectionDAG &DAG = TLO.DAG;
584 SDLoc dl(Op);
585
586 // Early return, as this function cannot handle vector types.
587 if (VT.isVector())
588 return false;
589
590 assert(Op.getOperand(0).getValueType().getScalarSizeInBits() == BitWidth &&
591 Op.getOperand(1).getValueType().getScalarSizeInBits() == BitWidth &&
592 "ShrinkDemandedOp only supports operands that have the same size!");
593
594 // Don't do this if the node has another user, which may require the
595 // full value.
596 if (!Op.getNode()->hasOneUse())
597 return false;
598
599 // Search for the smallest integer type with free casts to and from
600 // Op's type. For expedience, just check power-of-2 integer types.
601 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
602 unsigned DemandedSize = DemandedBits.getActiveBits();
603 for (unsigned SmallVTBits = llvm::bit_ceil(DemandedSize);
604 SmallVTBits < BitWidth; SmallVTBits = NextPowerOf2(SmallVTBits)) {
605 EVT SmallVT = EVT::getIntegerVT(*DAG.getContext(), SmallVTBits);
606 if (TLI.isTruncateFree(VT, SmallVT) && TLI.isZExtFree(SmallVT, VT)) {
607 // We found a type with free casts.
608 SDValue X = DAG.getNode(
609 Op.getOpcode(), dl, SmallVT,
610 DAG.getNode(ISD::TRUNCATE, dl, SmallVT, Op.getOperand(0)),
611 DAG.getNode(ISD::TRUNCATE, dl, SmallVT, Op.getOperand(1)));
612 assert(DemandedSize <= SmallVTBits && "Narrowed below demanded bits?");
613 SDValue Z = DAG.getNode(ISD::ANY_EXTEND, dl, VT, X);
614 return TLO.CombineTo(Op, Z);
615 }
616 }
617 return false;
618}
619
621 DAGCombinerInfo &DCI) const {
622 SelectionDAG &DAG = DCI.DAG;
623 TargetLoweringOpt TLO(DAG, !DCI.isBeforeLegalize(),
624 !DCI.isBeforeLegalizeOps());
625 KnownBits Known;
626
627 bool Simplified = SimplifyDemandedBits(Op, DemandedBits, Known, TLO);
628 if (Simplified) {
629 DCI.AddToWorklist(Op.getNode());
631 }
632 return Simplified;
633}
634
636 const APInt &DemandedElts,
637 DAGCombinerInfo &DCI) const {
638 SelectionDAG &DAG = DCI.DAG;
639 TargetLoweringOpt TLO(DAG, !DCI.isBeforeLegalize(),
640 !DCI.isBeforeLegalizeOps());
641 KnownBits Known;
642
643 bool Simplified =
644 SimplifyDemandedBits(Op, DemandedBits, DemandedElts, Known, TLO);
645 if (Simplified) {
646 DCI.AddToWorklist(Op.getNode());
648 }
649 return Simplified;
650}
651
653 KnownBits &Known,
655 unsigned Depth,
656 bool AssumeSingleUse) const {
657 EVT VT = Op.getValueType();
658
659 // Since the number of lanes in a scalable vector is unknown at compile time,
660 // we track one bit which is implicitly broadcast to all lanes. This means
661 // that all lanes in a scalable vector are considered demanded.
662 APInt DemandedElts = VT.isFixedLengthVector()
664 : APInt(1, 1);
665 return SimplifyDemandedBits(Op, DemandedBits, DemandedElts, Known, TLO, Depth,
666 AssumeSingleUse);
667}
668
669// TODO: Under what circumstances can we create nodes? Constant folding?
671 SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
672 SelectionDAG &DAG, unsigned Depth) const {
673 EVT VT = Op.getValueType();
674
675 // Limit search depth.
677 return SDValue();
678
679 // Ignore UNDEFs.
680 if (Op.isUndef())
681 return SDValue();
682
683 // Not demanding any bits/elts from Op.
684 if (DemandedBits == 0 || DemandedElts == 0)
685 return DAG.getUNDEF(VT);
686
687 bool IsLE = DAG.getDataLayout().isLittleEndian();
688 unsigned NumElts = DemandedElts.getBitWidth();
689 unsigned BitWidth = DemandedBits.getBitWidth();
690 KnownBits LHSKnown, RHSKnown;
691 switch (Op.getOpcode()) {
692 case ISD::BITCAST: {
693 if (VT.isScalableVector())
694 return SDValue();
695
696 SDValue Src = peekThroughBitcasts(Op.getOperand(0));
697 EVT SrcVT = Src.getValueType();
698 EVT DstVT = Op.getValueType();
699 if (SrcVT == DstVT)
700 return Src;
701
702 unsigned NumSrcEltBits = SrcVT.getScalarSizeInBits();
703 unsigned NumDstEltBits = DstVT.getScalarSizeInBits();
704 if (NumSrcEltBits == NumDstEltBits)
705 if (SDValue V = SimplifyMultipleUseDemandedBits(
706 Src, DemandedBits, DemandedElts, DAG, Depth + 1))
707 return DAG.getBitcast(DstVT, V);
708
709 if (SrcVT.isVector() && (NumDstEltBits % NumSrcEltBits) == 0) {
710 unsigned Scale = NumDstEltBits / NumSrcEltBits;
711 unsigned NumSrcElts = SrcVT.getVectorNumElements();
712 APInt DemandedSrcBits = APInt::getZero(NumSrcEltBits);
713 APInt DemandedSrcElts = APInt::getZero(NumSrcElts);
714 for (unsigned i = 0; i != Scale; ++i) {
715 unsigned EltOffset = IsLE ? i : (Scale - 1 - i);
716 unsigned BitOffset = EltOffset * NumSrcEltBits;
717 APInt Sub = DemandedBits.extractBits(NumSrcEltBits, BitOffset);
718 if (!Sub.isZero()) {
719 DemandedSrcBits |= Sub;
720 for (unsigned j = 0; j != NumElts; ++j)
721 if (DemandedElts[j])
722 DemandedSrcElts.setBit((j * Scale) + i);
723 }
724 }
725
726 if (SDValue V = SimplifyMultipleUseDemandedBits(
727 Src, DemandedSrcBits, DemandedSrcElts, DAG, Depth + 1))
728 return DAG.getBitcast(DstVT, V);
729 }
730
731 // TODO - bigendian once we have test coverage.
732 if (IsLE && (NumSrcEltBits % NumDstEltBits) == 0) {
733 unsigned Scale = NumSrcEltBits / NumDstEltBits;
734 unsigned NumSrcElts = SrcVT.isVector() ? SrcVT.getVectorNumElements() : 1;
735 APInt DemandedSrcBits = APInt::getZero(NumSrcEltBits);
736 APInt DemandedSrcElts = APInt::getZero(NumSrcElts);
737 for (unsigned i = 0; i != NumElts; ++i)
738 if (DemandedElts[i]) {
739 unsigned Offset = (i % Scale) * NumDstEltBits;
740 DemandedSrcBits.insertBits(DemandedBits, Offset);
741 DemandedSrcElts.setBit(i / Scale);
742 }
743
744 if (SDValue V = SimplifyMultipleUseDemandedBits(
745 Src, DemandedSrcBits, DemandedSrcElts, DAG, Depth + 1))
746 return DAG.getBitcast(DstVT, V);
747 }
748
749 break;
750 }
751 case ISD::FREEZE: {
752 SDValue N0 = Op.getOperand(0);
753 if (DAG.isGuaranteedNotToBeUndefOrPoison(N0, DemandedElts,
754 /*PoisonOnly=*/false))
755 return N0;
756 break;
757 }
758 case ISD::AND: {
759 LHSKnown = DAG.computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
760 RHSKnown = DAG.computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
761
762 // If all of the demanded bits are known 1 on one side, return the other.
763 // These bits cannot contribute to the result of the 'and' in this
764 // context.
765 if (DemandedBits.isSubsetOf(LHSKnown.Zero | RHSKnown.One))
766 return Op.getOperand(0);
767 if (DemandedBits.isSubsetOf(RHSKnown.Zero | LHSKnown.One))
768 return Op.getOperand(1);
769 break;
770 }
771 case ISD::OR: {
772 LHSKnown = DAG.computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
773 RHSKnown = DAG.computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
774
775 // If all of the demanded bits are known zero on one side, return the
776 // other. These bits cannot contribute to the result of the 'or' in this
777 // context.
778 if (DemandedBits.isSubsetOf(LHSKnown.One | RHSKnown.Zero))
779 return Op.getOperand(0);
780 if (DemandedBits.isSubsetOf(RHSKnown.One | LHSKnown.Zero))
781 return Op.getOperand(1);
782 break;
783 }
784 case ISD::XOR: {
785 LHSKnown = DAG.computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
786 RHSKnown = DAG.computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
787
788 // If all of the demanded bits are known zero on one side, return the
789 // other.
790 if (DemandedBits.isSubsetOf(RHSKnown.Zero))
791 return Op.getOperand(0);
792 if (DemandedBits.isSubsetOf(LHSKnown.Zero))
793 return Op.getOperand(1);
794 break;
795 }
796 case ISD::SHL: {
797 // If we are only demanding sign bits then we can use the shift source
798 // directly.
799 if (std::optional<uint64_t> MaxSA =
800 DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) {
801 SDValue Op0 = Op.getOperand(0);
802 unsigned ShAmt = *MaxSA;
803 unsigned NumSignBits =
804 DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
805 unsigned UpperDemandedBits = BitWidth - DemandedBits.countr_zero();
806 if (NumSignBits > ShAmt && (NumSignBits - ShAmt) >= (UpperDemandedBits))
807 return Op0;
808 }
809 break;
810 }
811 case ISD::SETCC: {
812 SDValue Op0 = Op.getOperand(0);
813 SDValue Op1 = Op.getOperand(1);
814 ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(2))->get();
815 // If (1) we only need the sign-bit, (2) the setcc operands are the same
816 // width as the setcc result, and (3) the result of a setcc conforms to 0 or
817 // -1, we may be able to bypass the setcc.
818 if (DemandedBits.isSignMask() &&
822 // If we're testing X < 0, then this compare isn't needed - just use X!
823 // FIXME: We're limiting to integer types here, but this should also work
824 // if we don't care about FP signed-zero. The use of SETLT with FP means
825 // that we don't care about NaNs.
826 if (CC == ISD::SETLT && Op1.getValueType().isInteger() &&
828 return Op0;
829 }
830 break;
831 }
833 // If none of the extended bits are demanded, eliminate the sextinreg.
834 SDValue Op0 = Op.getOperand(0);
835 EVT ExVT = cast<VTSDNode>(Op.getOperand(1))->getVT();
836 unsigned ExBits = ExVT.getScalarSizeInBits();
837 if (DemandedBits.getActiveBits() <= ExBits &&
839 return Op0;
840 // If the input is already sign extended, just drop the extension.
841 unsigned NumSignBits = DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
842 if (NumSignBits >= (BitWidth - ExBits + 1))
843 return Op0;
844 break;
845 }
849 if (VT.isScalableVector())
850 return SDValue();
851
852 // If we only want the lowest element and none of extended bits, then we can
853 // return the bitcasted source vector.
854 SDValue Src = Op.getOperand(0);
855 EVT SrcVT = Src.getValueType();
856 EVT DstVT = Op.getValueType();
857 if (IsLE && DemandedElts == 1 &&
858 DstVT.getSizeInBits() == SrcVT.getSizeInBits() &&
859 DemandedBits.getActiveBits() <= SrcVT.getScalarSizeInBits()) {
860 return DAG.getBitcast(DstVT, Src);
861 }
862 break;
863 }
865 if (VT.isScalableVector())
866 return SDValue();
867
868 // If we don't demand the inserted element, return the base vector.
869 SDValue Vec = Op.getOperand(0);
870 auto *CIdx = dyn_cast<ConstantSDNode>(Op.getOperand(2));
871 EVT VecVT = Vec.getValueType();
872 if (CIdx && CIdx->getAPIntValue().ult(VecVT.getVectorNumElements()) &&
873 !DemandedElts[CIdx->getZExtValue()])
874 return Vec;
875 break;
876 }
878 if (VT.isScalableVector())
879 return SDValue();
880
881 SDValue Vec = Op.getOperand(0);
882 SDValue Sub = Op.getOperand(1);
883 uint64_t Idx = Op.getConstantOperandVal(2);
884 unsigned NumSubElts = Sub.getValueType().getVectorNumElements();
885 APInt DemandedSubElts = DemandedElts.extractBits(NumSubElts, Idx);
886 // If we don't demand the inserted subvector, return the base vector.
887 if (DemandedSubElts == 0)
888 return Vec;
889 break;
890 }
891 case ISD::VECTOR_SHUFFLE: {
893 ArrayRef<int> ShuffleMask = cast<ShuffleVectorSDNode>(Op)->getMask();
894
895 // If all the demanded elts are from one operand and are inline,
896 // then we can use the operand directly.
897 bool AllUndef = true, IdentityLHS = true, IdentityRHS = true;
898 for (unsigned i = 0; i != NumElts; ++i) {
899 int M = ShuffleMask[i];
900 if (M < 0 || !DemandedElts[i])
901 continue;
902 AllUndef = false;
903 IdentityLHS &= (M == (int)i);
904 IdentityRHS &= ((M - NumElts) == i);
905 }
906
907 if (AllUndef)
908 return DAG.getUNDEF(Op.getValueType());
909 if (IdentityLHS)
910 return Op.getOperand(0);
911 if (IdentityRHS)
912 return Op.getOperand(1);
913 break;
914 }
915 default:
916 // TODO: Probably okay to remove after audit; here to reduce change size
917 // in initial enablement patch for scalable vectors
918 if (VT.isScalableVector())
919 return SDValue();
920
921 if (Op.getOpcode() >= ISD::BUILTIN_OP_END)
922 if (SDValue V = SimplifyMultipleUseDemandedBitsForTargetNode(
923 Op, DemandedBits, DemandedElts, DAG, Depth))
924 return V;
925 break;
926 }
927 return SDValue();
928}
929
932 unsigned Depth) const {
933 EVT VT = Op.getValueType();
934 // Since the number of lanes in a scalable vector is unknown at compile time,
935 // we track one bit which is implicitly broadcast to all lanes. This means
936 // that all lanes in a scalable vector are considered demanded.
937 APInt DemandedElts = VT.isFixedLengthVector()
939 : APInt(1, 1);
940 return SimplifyMultipleUseDemandedBits(Op, DemandedBits, DemandedElts, DAG,
941 Depth);
942}
943
945 SDValue Op, const APInt &DemandedElts, SelectionDAG &DAG,
946 unsigned Depth) const {
947 APInt DemandedBits = APInt::getAllOnes(Op.getScalarValueSizeInBits());
948 return SimplifyMultipleUseDemandedBits(Op, DemandedBits, DemandedElts, DAG,
949 Depth);
950}
951
952// Attempt to form ext(avgfloor(A, B)) from shr(add(ext(A), ext(B)), 1).
953// or to form ext(avgceil(A, B)) from shr(add(ext(A), ext(B), 1), 1).
956 const TargetLowering &TLI,
957 const APInt &DemandedBits,
958 const APInt &DemandedElts, unsigned Depth) {
959 assert((Op.getOpcode() == ISD::SRL || Op.getOpcode() == ISD::SRA) &&
960 "SRL or SRA node is required here!");
961 // Is the right shift using an immediate value of 1?
962 ConstantSDNode *N1C = isConstOrConstSplat(Op.getOperand(1), DemandedElts);
963 if (!N1C || !N1C->isOne())
964 return SDValue();
965
966 // We are looking for an avgfloor
967 // add(ext, ext)
968 // or one of these as a avgceil
969 // add(add(ext, ext), 1)
970 // add(add(ext, 1), ext)
971 // add(ext, add(ext, 1))
972 SDValue Add = Op.getOperand(0);
973 if (Add.getOpcode() != ISD::ADD)
974 return SDValue();
975
976 SDValue ExtOpA = Add.getOperand(0);
977 SDValue ExtOpB = Add.getOperand(1);
978 SDValue Add2;
979 auto MatchOperands = [&](SDValue Op1, SDValue Op2, SDValue Op3, SDValue A) {
980 ConstantSDNode *ConstOp;
981 if ((ConstOp = isConstOrConstSplat(Op2, DemandedElts)) &&
982 ConstOp->isOne()) {
983 ExtOpA = Op1;
984 ExtOpB = Op3;
985 Add2 = A;
986 return true;
987 }
988 if ((ConstOp = isConstOrConstSplat(Op3, DemandedElts)) &&
989 ConstOp->isOne()) {
990 ExtOpA = Op1;
991 ExtOpB = Op2;
992 Add2 = A;
993 return true;
994 }
995 return false;
996 };
997 bool IsCeil =
998 (ExtOpA.getOpcode() == ISD::ADD &&
999 MatchOperands(ExtOpA.getOperand(0), ExtOpA.getOperand(1), ExtOpB, ExtOpA)) ||
1000 (ExtOpB.getOpcode() == ISD::ADD &&
1001 MatchOperands(ExtOpB.getOperand(0), ExtOpB.getOperand(1), ExtOpA, ExtOpB));
1002
1003 // If the shift is signed (sra):
1004 // - Needs >= 2 sign bit for both operands.
1005 // - Needs >= 2 zero bits.
1006 // If the shift is unsigned (srl):
1007 // - Needs >= 1 zero bit for both operands.
1008 // - Needs 1 demanded bit zero and >= 2 sign bits.
1009 SelectionDAG &DAG = TLO.DAG;
1010 unsigned ShiftOpc = Op.getOpcode();
1011 bool IsSigned = false;
1012 unsigned KnownBits;
1013 unsigned NumSignedA = DAG.ComputeNumSignBits(ExtOpA, DemandedElts, Depth);
1014 unsigned NumSignedB = DAG.ComputeNumSignBits(ExtOpB, DemandedElts, Depth);
1015 unsigned NumSigned = std::min(NumSignedA, NumSignedB) - 1;
1016 unsigned NumZeroA =
1017 DAG.computeKnownBits(ExtOpA, DemandedElts, Depth).countMinLeadingZeros();
1018 unsigned NumZeroB =
1019 DAG.computeKnownBits(ExtOpB, DemandedElts, Depth).countMinLeadingZeros();
1020 unsigned NumZero = std::min(NumZeroA, NumZeroB);
1021
1022 switch (ShiftOpc) {
1023 default:
1024 llvm_unreachable("Unexpected ShiftOpc in combineShiftToAVG");
1025 case ISD::SRA: {
1026 if (NumZero >= 2 && NumSigned < NumZero) {
1027 IsSigned = false;
1028 KnownBits = NumZero;
1029 break;
1030 }
1031 if (NumSigned >= 1) {
1032 IsSigned = true;
1033 KnownBits = NumSigned;
1034 break;
1035 }
1036 return SDValue();
1037 }
1038 case ISD::SRL: {
1039 if (NumZero >= 1 && NumSigned < NumZero) {
1040 IsSigned = false;
1041 KnownBits = NumZero;
1042 break;
1043 }
1044 if (NumSigned >= 1 && DemandedBits.isSignBitClear()) {
1045 IsSigned = true;
1046 KnownBits = NumSigned;
1047 break;
1048 }
1049 return SDValue();
1050 }
1051 }
1052
1053 unsigned AVGOpc = IsCeil ? (IsSigned ? ISD::AVGCEILS : ISD::AVGCEILU)
1054 : (IsSigned ? ISD::AVGFLOORS : ISD::AVGFLOORU);
1055
1056 // Find the smallest power-2 type that is legal for this vector size and
1057 // operation, given the original type size and the number of known sign/zero
1058 // bits.
1059 EVT VT = Op.getValueType();
1060 unsigned MinWidth =
1061 std::max<unsigned>(VT.getScalarSizeInBits() - KnownBits, 8);
1062 EVT NVT = EVT::getIntegerVT(*DAG.getContext(), llvm::bit_ceil(MinWidth));
1064 return SDValue();
1065 if (VT.isVector())
1066 NVT = EVT::getVectorVT(*DAG.getContext(), NVT, VT.getVectorElementCount());
1067 if (TLO.LegalTypes() && !TLI.isOperationLegal(AVGOpc, NVT)) {
1068 // If we could not transform, and (both) adds are nuw/nsw, we can use the
1069 // larger type size to do the transform.
1070 if (TLO.LegalOperations() && !TLI.isOperationLegal(AVGOpc, VT))
1071 return SDValue();
1072 if (DAG.willNotOverflowAdd(IsSigned, Add.getOperand(0),
1073 Add.getOperand(1)) &&
1074 (!Add2 || DAG.willNotOverflowAdd(IsSigned, Add2.getOperand(0),
1075 Add2.getOperand(1))))
1076 NVT = VT;
1077 else
1078 return SDValue();
1079 }
1080
1081 // Don't create a AVGFLOOR node with a scalar constant unless its legal as
1082 // this is likely to stop other folds (reassociation, value tracking etc.)
1083 if (!IsCeil && !TLI.isOperationLegal(AVGOpc, NVT) &&
1084 (isa<ConstantSDNode>(ExtOpA) || isa<ConstantSDNode>(ExtOpB)))
1085 return SDValue();
1086
1087 SDLoc DL(Op);
1088 SDValue ResultAVG =
1089 DAG.getNode(AVGOpc, DL, NVT, DAG.getExtOrTrunc(IsSigned, ExtOpA, DL, NVT),
1090 DAG.getExtOrTrunc(IsSigned, ExtOpB, DL, NVT));
1091 return DAG.getExtOrTrunc(IsSigned, ResultAVG, DL, VT);
1092}
1093
1094/// Look at Op. At this point, we know that only the OriginalDemandedBits of the
1095/// result of Op are ever used downstream. If we can use this information to
1096/// simplify Op, create a new simplified DAG node and return true, returning the
1097/// original and new nodes in Old and New. Otherwise, analyze the expression and
1098/// return a mask of Known bits for the expression (used to simplify the
1099/// caller). The Known bits may only be accurate for those bits in the
1100/// OriginalDemandedBits and OriginalDemandedElts.
1102 SDValue Op, const APInt &OriginalDemandedBits,
1103 const APInt &OriginalDemandedElts, KnownBits &Known, TargetLoweringOpt &TLO,
1104 unsigned Depth, bool AssumeSingleUse) const {
1105 unsigned BitWidth = OriginalDemandedBits.getBitWidth();
1106 assert(Op.getScalarValueSizeInBits() == BitWidth &&
1107 "Mask size mismatches value type size!");
1108
1109 // Don't know anything.
1110 Known = KnownBits(BitWidth);
1111
1112 EVT VT = Op.getValueType();
1113 bool IsLE = TLO.DAG.getDataLayout().isLittleEndian();
1114 unsigned NumElts = OriginalDemandedElts.getBitWidth();
1115 assert((!VT.isFixedLengthVector() || NumElts == VT.getVectorNumElements()) &&
1116 "Unexpected vector size");
1117
1118 APInt DemandedBits = OriginalDemandedBits;
1119 APInt DemandedElts = OriginalDemandedElts;
1120 SDLoc dl(Op);
1121
1122 // Undef operand.
1123 if (Op.isUndef())
1124 return false;
1125
1126 // We can't simplify target constants.
1127 if (Op.getOpcode() == ISD::TargetConstant)
1128 return false;
1129
1130 if (Op.getOpcode() == ISD::Constant) {
1131 // We know all of the bits for a constant!
1132 Known = KnownBits::makeConstant(Op->getAsAPIntVal());
1133 return false;
1134 }
1135
1136 if (Op.getOpcode() == ISD::ConstantFP) {
1137 // We know all of the bits for a floating point constant!
1139 cast<ConstantFPSDNode>(Op)->getValueAPF().bitcastToAPInt());
1140 return false;
1141 }
1142
1143 // Other users may use these bits.
1144 bool HasMultiUse = false;
1145 if (!AssumeSingleUse && !Op.getNode()->hasOneUse()) {
1147 // Limit search depth.
1148 return false;
1149 }
1150 // Allow multiple uses, just set the DemandedBits/Elts to all bits.
1152 DemandedElts = APInt::getAllOnes(NumElts);
1153 HasMultiUse = true;
1154 } else if (OriginalDemandedBits == 0 || OriginalDemandedElts == 0) {
1155 // Not demanding any bits/elts from Op.
1156 return TLO.CombineTo(Op, TLO.DAG.getUNDEF(VT));
1157 } else if (Depth >= SelectionDAG::MaxRecursionDepth) {
1158 // Limit search depth.
1159 return false;
1160 }
1161
1162 KnownBits Known2;
1163 switch (Op.getOpcode()) {
1164 case ISD::SCALAR_TO_VECTOR: {
1165 if (VT.isScalableVector())
1166 return false;
1167 if (!DemandedElts[0])
1168 return TLO.CombineTo(Op, TLO.DAG.getUNDEF(VT));
1169
1170 KnownBits SrcKnown;
1171 SDValue Src = Op.getOperand(0);
1172 unsigned SrcBitWidth = Src.getScalarValueSizeInBits();
1173 APInt SrcDemandedBits = DemandedBits.zext(SrcBitWidth);
1174 if (SimplifyDemandedBits(Src, SrcDemandedBits, SrcKnown, TLO, Depth + 1))
1175 return true;
1176
1177 // Upper elements are undef, so only get the knownbits if we just demand
1178 // the bottom element.
1179 if (DemandedElts == 1)
1180 Known = SrcKnown.anyextOrTrunc(BitWidth);
1181 break;
1182 }
1183 case ISD::BUILD_VECTOR:
1184 // Collect the known bits that are shared by every demanded element.
1185 // TODO: Call SimplifyDemandedBits for non-constant demanded elements.
1186 Known = TLO.DAG.computeKnownBits(Op, DemandedElts, Depth);
1187 return false; // Don't fall through, will infinitely loop.
1188 case ISD::SPLAT_VECTOR: {
1189 SDValue Scl = Op.getOperand(0);
1190 APInt DemandedSclBits = DemandedBits.zextOrTrunc(Scl.getValueSizeInBits());
1191 KnownBits KnownScl;
1192 if (SimplifyDemandedBits(Scl, DemandedSclBits, KnownScl, TLO, Depth + 1))
1193 return true;
1194
1195 // Implicitly truncate the bits to match the official semantics of
1196 // SPLAT_VECTOR.
1197 Known = KnownScl.trunc(BitWidth);
1198 break;
1199 }
1200 case ISD::LOAD: {
1201 auto *LD = cast<LoadSDNode>(Op);
1202 if (getTargetConstantFromLoad(LD)) {
1203 Known = TLO.DAG.computeKnownBits(Op, DemandedElts, Depth);
1204 return false; // Don't fall through, will infinitely loop.
1205 }
1206 if (ISD::isZEXTLoad(Op.getNode()) && Op.getResNo() == 0) {
1207 // If this is a ZEXTLoad and we are looking at the loaded value.
1208 EVT MemVT = LD->getMemoryVT();
1209 unsigned MemBits = MemVT.getScalarSizeInBits();
1210 Known.Zero.setBitsFrom(MemBits);
1211 return false; // Don't fall through, will infinitely loop.
1212 }
1213 break;
1214 }
1216 if (VT.isScalableVector())
1217 return false;
1218 SDValue Vec = Op.getOperand(0);
1219 SDValue Scl = Op.getOperand(1);
1220 auto *CIdx = dyn_cast<ConstantSDNode>(Op.getOperand(2));
1221 EVT VecVT = Vec.getValueType();
1222
1223 // If index isn't constant, assume we need all vector elements AND the
1224 // inserted element.
1225 APInt DemandedVecElts(DemandedElts);
1226 if (CIdx && CIdx->getAPIntValue().ult(VecVT.getVectorNumElements())) {
1227 unsigned Idx = CIdx->getZExtValue();
1228 DemandedVecElts.clearBit(Idx);
1229
1230 // Inserted element is not required.
1231 if (!DemandedElts[Idx])
1232 return TLO.CombineTo(Op, Vec);
1233 }
1234
1235 KnownBits KnownScl;
1236 unsigned NumSclBits = Scl.getScalarValueSizeInBits();
1237 APInt DemandedSclBits = DemandedBits.zextOrTrunc(NumSclBits);
1238 if (SimplifyDemandedBits(Scl, DemandedSclBits, KnownScl, TLO, Depth + 1))
1239 return true;
1240
1241 Known = KnownScl.anyextOrTrunc(BitWidth);
1242
1243 KnownBits KnownVec;
1244 if (SimplifyDemandedBits(Vec, DemandedBits, DemandedVecElts, KnownVec, TLO,
1245 Depth + 1))
1246 return true;
1247
1248 if (!!DemandedVecElts)
1249 Known = Known.intersectWith(KnownVec);
1250
1251 return false;
1252 }
1253 case ISD::INSERT_SUBVECTOR: {
1254 if (VT.isScalableVector())
1255 return false;
1256 // Demand any elements from the subvector and the remainder from the src its
1257 // inserted into.
1258 SDValue Src = Op.getOperand(0);
1259 SDValue Sub = Op.getOperand(1);
1260 uint64_t Idx = Op.getConstantOperandVal(2);
1261 unsigned NumSubElts = Sub.getValueType().getVectorNumElements();
1262 APInt DemandedSubElts = DemandedElts.extractBits(NumSubElts, Idx);
1263 APInt DemandedSrcElts = DemandedElts;
1264 DemandedSrcElts.insertBits(APInt::getZero(NumSubElts), Idx);
1265
1266 KnownBits KnownSub, KnownSrc;
1267 if (SimplifyDemandedBits(Sub, DemandedBits, DemandedSubElts, KnownSub, TLO,
1268 Depth + 1))
1269 return true;
1270 if (SimplifyDemandedBits(Src, DemandedBits, DemandedSrcElts, KnownSrc, TLO,
1271 Depth + 1))
1272 return true;
1273
1274 Known.Zero.setAllBits();
1275 Known.One.setAllBits();
1276 if (!!DemandedSubElts)
1277 Known = Known.intersectWith(KnownSub);
1278 if (!!DemandedSrcElts)
1279 Known = Known.intersectWith(KnownSrc);
1280
1281 // Attempt to avoid multi-use src if we don't need anything from it.
1282 if (!DemandedBits.isAllOnes() || !DemandedSubElts.isAllOnes() ||
1283 !DemandedSrcElts.isAllOnes()) {
1284 SDValue NewSub = SimplifyMultipleUseDemandedBits(
1285 Sub, DemandedBits, DemandedSubElts, TLO.DAG, Depth + 1);
1286 SDValue NewSrc = SimplifyMultipleUseDemandedBits(
1287 Src, DemandedBits, DemandedSrcElts, TLO.DAG, Depth + 1);
1288 if (NewSub || NewSrc) {
1289 NewSub = NewSub ? NewSub : Sub;
1290 NewSrc = NewSrc ? NewSrc : Src;
1291 SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), dl, VT, NewSrc, NewSub,
1292 Op.getOperand(2));
1293 return TLO.CombineTo(Op, NewOp);
1294 }
1295 }
1296 break;
1297 }
1299 if (VT.isScalableVector())
1300 return false;
1301 // Offset the demanded elts by the subvector index.
1302 SDValue Src = Op.getOperand(0);
1303 if (Src.getValueType().isScalableVector())
1304 break;
1305 uint64_t Idx = Op.getConstantOperandVal(1);
1306 unsigned NumSrcElts = Src.getValueType().getVectorNumElements();
1307 APInt DemandedSrcElts = DemandedElts.zext(NumSrcElts).shl(Idx);
1308
1309 if (SimplifyDemandedBits(Src, DemandedBits, DemandedSrcElts, Known, TLO,
1310 Depth + 1))
1311 return true;
1312
1313 // Attempt to avoid multi-use src if we don't need anything from it.
1314 if (!DemandedBits.isAllOnes() || !DemandedSrcElts.isAllOnes()) {
1315 SDValue DemandedSrc = SimplifyMultipleUseDemandedBits(
1316 Src, DemandedBits, DemandedSrcElts, TLO.DAG, Depth + 1);
1317 if (DemandedSrc) {
1318 SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), dl, VT, DemandedSrc,
1319 Op.getOperand(1));
1320 return TLO.CombineTo(Op, NewOp);
1321 }
1322 }
1323 break;
1324 }
1325 case ISD::CONCAT_VECTORS: {
1326 if (VT.isScalableVector())
1327 return false;
1328 Known.Zero.setAllBits();
1329 Known.One.setAllBits();
1330 EVT SubVT = Op.getOperand(0).getValueType();
1331 unsigned NumSubVecs = Op.getNumOperands();
1332 unsigned NumSubElts = SubVT.getVectorNumElements();
1333 for (unsigned i = 0; i != NumSubVecs; ++i) {
1334 APInt DemandedSubElts =
1335 DemandedElts.extractBits(NumSubElts, i * NumSubElts);
1336 if (SimplifyDemandedBits(Op.getOperand(i), DemandedBits, DemandedSubElts,
1337 Known2, TLO, Depth + 1))
1338 return true;
1339 // Known bits are shared by every demanded subvector element.
1340 if (!!DemandedSubElts)
1341 Known = Known.intersectWith(Known2);
1342 }
1343 break;
1344 }
1345 case ISD::VECTOR_SHUFFLE: {
1346 assert(!VT.isScalableVector());
1347 ArrayRef<int> ShuffleMask = cast<ShuffleVectorSDNode>(Op)->getMask();
1348
1349 // Collect demanded elements from shuffle operands..
1350 APInt DemandedLHS, DemandedRHS;
1351 if (!getShuffleDemandedElts(NumElts, ShuffleMask, DemandedElts, DemandedLHS,
1352 DemandedRHS))
1353 break;
1354
1355 if (!!DemandedLHS || !!DemandedRHS) {
1356 SDValue Op0 = Op.getOperand(0);
1357 SDValue Op1 = Op.getOperand(1);
1358
1359 Known.Zero.setAllBits();
1360 Known.One.setAllBits();
1361 if (!!DemandedLHS) {
1362 if (SimplifyDemandedBits(Op0, DemandedBits, DemandedLHS, Known2, TLO,
1363 Depth + 1))
1364 return true;
1365 Known = Known.intersectWith(Known2);
1366 }
1367 if (!!DemandedRHS) {
1368 if (SimplifyDemandedBits(Op1, DemandedBits, DemandedRHS, Known2, TLO,
1369 Depth + 1))
1370 return true;
1371 Known = Known.intersectWith(Known2);
1372 }
1373
1374 // Attempt to avoid multi-use ops if we don't need anything from them.
1375 SDValue DemandedOp0 = SimplifyMultipleUseDemandedBits(
1376 Op0, DemandedBits, DemandedLHS, TLO.DAG, Depth + 1);
1377 SDValue DemandedOp1 = SimplifyMultipleUseDemandedBits(
1378 Op1, DemandedBits, DemandedRHS, TLO.DAG, Depth + 1);
1379 if (DemandedOp0 || DemandedOp1) {
1380 Op0 = DemandedOp0 ? DemandedOp0 : Op0;
1381 Op1 = DemandedOp1 ? DemandedOp1 : Op1;
1382 SDValue NewOp = TLO.DAG.getVectorShuffle(VT, dl, Op0, Op1, ShuffleMask);
1383 return TLO.CombineTo(Op, NewOp);
1384 }
1385 }
1386 break;
1387 }
1388 case ISD::AND: {
1389 SDValue Op0 = Op.getOperand(0);
1390 SDValue Op1 = Op.getOperand(1);
1391
1392 // If the RHS is a constant, check to see if the LHS would be zero without
1393 // using the bits from the RHS. Below, we use knowledge about the RHS to
1394 // simplify the LHS, here we're using information from the LHS to simplify
1395 // the RHS.
1396 if (ConstantSDNode *RHSC = isConstOrConstSplat(Op1, DemandedElts)) {
1397 // Do not increment Depth here; that can cause an infinite loop.
1398 KnownBits LHSKnown = TLO.DAG.computeKnownBits(Op0, DemandedElts, Depth);
1399 // If the LHS already has zeros where RHSC does, this 'and' is dead.
1400 if ((LHSKnown.Zero & DemandedBits) ==
1401 (~RHSC->getAPIntValue() & DemandedBits))
1402 return TLO.CombineTo(Op, Op0);
1403
1404 // If any of the set bits in the RHS are known zero on the LHS, shrink
1405 // the constant.
1406 if (ShrinkDemandedConstant(Op, ~LHSKnown.Zero & DemandedBits,
1407 DemandedElts, TLO))
1408 return true;
1409
1410 // Bitwise-not (xor X, -1) is a special case: we don't usually shrink its
1411 // constant, but if this 'and' is only clearing bits that were just set by
1412 // the xor, then this 'and' can be eliminated by shrinking the mask of
1413 // the xor. For example, for a 32-bit X:
1414 // and (xor (srl X, 31), -1), 1 --> xor (srl X, 31), 1
1415 if (isBitwiseNot(Op0) && Op0.hasOneUse() &&
1416 LHSKnown.One == ~RHSC->getAPIntValue()) {
1417 SDValue Xor = TLO.DAG.getNode(ISD::XOR, dl, VT, Op0.getOperand(0), Op1);
1418 return TLO.CombineTo(Op, Xor);
1419 }
1420 }
1421
1422 // AND(INSERT_SUBVECTOR(C,X,I),M) -> INSERT_SUBVECTOR(AND(C,M),X,I)
1423 // iff 'C' is Undef/Constant and AND(X,M) == X (for DemandedBits).
1424 if (Op0.getOpcode() == ISD::INSERT_SUBVECTOR && !VT.isScalableVector() &&
1425 (Op0.getOperand(0).isUndef() ||
1427 Op0->hasOneUse()) {
1428 unsigned NumSubElts =
1430 unsigned SubIdx = Op0.getConstantOperandVal(2);
1431 APInt DemandedSub =
1432 APInt::getBitsSet(NumElts, SubIdx, SubIdx + NumSubElts);
1433 KnownBits KnownSubMask =
1434 TLO.DAG.computeKnownBits(Op1, DemandedSub & DemandedElts, Depth + 1);
1435 if (DemandedBits.isSubsetOf(KnownSubMask.One)) {
1436 SDValue NewAnd =
1437 TLO.DAG.getNode(ISD::AND, dl, VT, Op0.getOperand(0), Op1);
1438 SDValue NewInsert =
1439 TLO.DAG.getNode(ISD::INSERT_SUBVECTOR, dl, VT, NewAnd,
1440 Op0.getOperand(1), Op0.getOperand(2));
1441 return TLO.CombineTo(Op, NewInsert);
1442 }
1443 }
1444
1445 if (SimplifyDemandedBits(Op1, DemandedBits, DemandedElts, Known, TLO,
1446 Depth + 1))
1447 return true;
1448 if (SimplifyDemandedBits(Op0, ~Known.Zero & DemandedBits, DemandedElts,
1449 Known2, TLO, Depth + 1))
1450 return true;
1451
1452 // If all of the demanded bits are known one on one side, return the other.
1453 // These bits cannot contribute to the result of the 'and'.
1454 if (DemandedBits.isSubsetOf(Known2.Zero | Known.One))
1455 return TLO.CombineTo(Op, Op0);
1456 if (DemandedBits.isSubsetOf(Known.Zero | Known2.One))
1457 return TLO.CombineTo(Op, Op1);
1458 // If all of the demanded bits in the inputs are known zeros, return zero.
1459 if (DemandedBits.isSubsetOf(Known.Zero | Known2.Zero))
1460 return TLO.CombineTo(Op, TLO.DAG.getConstant(0, dl, VT));
1461 // If the RHS is a constant, see if we can simplify it.
1462 if (ShrinkDemandedConstant(Op, ~Known2.Zero & DemandedBits, DemandedElts,
1463 TLO))
1464 return true;
1465 // If the operation can be done in a smaller type, do so.
1466 if (ShrinkDemandedOp(Op, BitWidth, DemandedBits, TLO))
1467 return true;
1468
1469 // Attempt to avoid multi-use ops if we don't need anything from them.
1470 if (!DemandedBits.isAllOnes() || !DemandedElts.isAllOnes()) {
1471 SDValue DemandedOp0 = SimplifyMultipleUseDemandedBits(
1472 Op0, DemandedBits, DemandedElts, TLO.DAG, Depth + 1);
1473 SDValue DemandedOp1 = SimplifyMultipleUseDemandedBits(
1474 Op1, DemandedBits, DemandedElts, TLO.DAG, Depth + 1);
1475 if (DemandedOp0 || DemandedOp1) {
1476 Op0 = DemandedOp0 ? DemandedOp0 : Op0;
1477 Op1 = DemandedOp1 ? DemandedOp1 : Op1;
1478 SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), dl, VT, Op0, Op1);
1479 return TLO.CombineTo(Op, NewOp);
1480 }
1481 }
1482
1483 Known &= Known2;
1484 break;
1485 }
1486 case ISD::OR: {
1487 SDValue Op0 = Op.getOperand(0);
1488 SDValue Op1 = Op.getOperand(1);
1489 SDNodeFlags Flags = Op.getNode()->getFlags();
1490 if (SimplifyDemandedBits(Op1, DemandedBits, DemandedElts, Known, TLO,
1491 Depth + 1)) {
1492 if (Flags.hasDisjoint()) {
1493 Flags.setDisjoint(false);
1494 Op->setFlags(Flags);
1495 }
1496 return true;
1497 }
1498
1499 if (SimplifyDemandedBits(Op0, ~Known.One & DemandedBits, DemandedElts,
1500 Known2, TLO, Depth + 1)) {
1501 if (Flags.hasDisjoint()) {
1502 Flags.setDisjoint(false);
1503 Op->setFlags(Flags);
1504 }
1505 return true;
1506 }
1507
1508 // If all of the demanded bits are known zero on one side, return the other.
1509 // These bits cannot contribute to the result of the 'or'.
1510 if (DemandedBits.isSubsetOf(Known2.One | Known.Zero))
1511 return TLO.CombineTo(Op, Op0);
1512 if (DemandedBits.isSubsetOf(Known.One | Known2.Zero))
1513 return TLO.CombineTo(Op, Op1);
1514 // If the RHS is a constant, see if we can simplify it.
1515 if (ShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO))
1516 return true;
1517 // If the operation can be done in a smaller type, do so.
1518 if (ShrinkDemandedOp(Op, BitWidth, DemandedBits, TLO))
1519 return true;
1520
1521 // Attempt to avoid multi-use ops if we don't need anything from them.
1522 if (!DemandedBits.isAllOnes() || !DemandedElts.isAllOnes()) {
1523 SDValue DemandedOp0 = SimplifyMultipleUseDemandedBits(
1524 Op0, DemandedBits, DemandedElts, TLO.DAG, Depth + 1);
1525 SDValue DemandedOp1 = SimplifyMultipleUseDemandedBits(
1526 Op1, DemandedBits, DemandedElts, TLO.DAG, Depth + 1);
1527 if (DemandedOp0 || DemandedOp1) {
1528 Op0 = DemandedOp0 ? DemandedOp0 : Op0;
1529 Op1 = DemandedOp1 ? DemandedOp1 : Op1;
1530 SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), dl, VT, Op0, Op1);
1531 return TLO.CombineTo(Op, NewOp);
1532 }
1533 }
1534
1535 // (or (and X, C1), (and (or X, Y), C2)) -> (or (and X, C1|C2), (and Y, C2))
1536 // TODO: Use SimplifyMultipleUseDemandedBits to peek through masks.
1537 if (Op0.getOpcode() == ISD::AND && Op1.getOpcode() == ISD::AND &&
1538 Op0->hasOneUse() && Op1->hasOneUse()) {
1539 // Attempt to match all commutations - m_c_Or would've been useful!
1540 for (int I = 0; I != 2; ++I) {
1541 SDValue X = Op.getOperand(I).getOperand(0);
1542 SDValue C1 = Op.getOperand(I).getOperand(1);
1543 SDValue Alt = Op.getOperand(1 - I).getOperand(0);
1544 SDValue C2 = Op.getOperand(1 - I).getOperand(1);
1545 if (Alt.getOpcode() == ISD::OR) {
1546 for (int J = 0; J != 2; ++J) {
1547 if (X == Alt.getOperand(J)) {
1548 SDValue Y = Alt.getOperand(1 - J);
1549 if (SDValue C12 = TLO.DAG.FoldConstantArithmetic(ISD::OR, dl, VT,
1550 {C1, C2})) {
1551 SDValue MaskX = TLO.DAG.getNode(ISD::AND, dl, VT, X, C12);
1552 SDValue MaskY = TLO.DAG.getNode(ISD::AND, dl, VT, Y, C2);
1553 return TLO.CombineTo(
1554 Op, TLO.DAG.getNode(ISD::OR, dl, VT, MaskX, MaskY));
1555 }
1556 }
1557 }
1558 }
1559 }
1560 }
1561
1562 Known |= Known2;
1563 break;
1564 }
1565 case ISD::XOR: {
1566 SDValue Op0 = Op.getOperand(0);
1567 SDValue Op1 = Op.getOperand(1);
1568
1569 if (SimplifyDemandedBits(Op1, DemandedBits, DemandedElts, Known, TLO,
1570 Depth + 1))
1571 return true;
1572 if (SimplifyDemandedBits(Op0, DemandedBits, DemandedElts, Known2, TLO,
1573 Depth + 1))
1574 return true;
1575
1576 // If all of the demanded bits are known zero on one side, return the other.
1577 // These bits cannot contribute to the result of the 'xor'.
1578 if (DemandedBits.isSubsetOf(Known.Zero))
1579 return TLO.CombineTo(Op, Op0);
1580 if (DemandedBits.isSubsetOf(Known2.Zero))
1581 return TLO.CombineTo(Op, Op1);
1582 // If the operation can be done in a smaller type, do so.
1583 if (ShrinkDemandedOp(Op, BitWidth, DemandedBits, TLO))
1584 return true;
1585
1586 // If all of the unknown bits are known to be zero on one side or the other
1587 // turn this into an *inclusive* or.
1588 // e.g. (A & C1)^(B & C2) -> (A & C1)|(B & C2) iff C1&C2 == 0
1589 if (DemandedBits.isSubsetOf(Known.Zero | Known2.Zero))
1590 return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::OR, dl, VT, Op0, Op1));
1591
1592 ConstantSDNode *C = isConstOrConstSplat(Op1, DemandedElts);
1593 if (C) {
1594 // If one side is a constant, and all of the set bits in the constant are
1595 // also known set on the other side, turn this into an AND, as we know
1596 // the bits will be cleared.
1597 // e.g. (X | C1) ^ C2 --> (X | C1) & ~C2 iff (C1&C2) == C2
1598 // NB: it is okay if more bits are known than are requested
1599 if (C->getAPIntValue() == Known2.One) {
1600 SDValue ANDC =
1601 TLO.DAG.getConstant(~C->getAPIntValue() & DemandedBits, dl, VT);
1602 return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::AND, dl, VT, Op0, ANDC));
1603 }
1604
1605 // If the RHS is a constant, see if we can change it. Don't alter a -1
1606 // constant because that's a 'not' op, and that is better for combining
1607 // and codegen.
1608 if (!C->isAllOnes() && DemandedBits.isSubsetOf(C->getAPIntValue())) {
1609 // We're flipping all demanded bits. Flip the undemanded bits too.
1610 SDValue New = TLO.DAG.getNOT(dl, Op0, VT);
1611 return TLO.CombineTo(Op, New);
1612 }
1613
1614 unsigned Op0Opcode = Op0.getOpcode();
1615 if ((Op0Opcode == ISD::SRL || Op0Opcode == ISD::SHL) && Op0.hasOneUse()) {
1616 if (ConstantSDNode *ShiftC =
1617 isConstOrConstSplat(Op0.getOperand(1), DemandedElts)) {
1618 // Don't crash on an oversized shift. We can not guarantee that a
1619 // bogus shift has been simplified to undef.
1620 if (ShiftC->getAPIntValue().ult(BitWidth)) {
1621 uint64_t ShiftAmt = ShiftC->getZExtValue();
1623 Ones = Op0Opcode == ISD::SHL ? Ones.shl(ShiftAmt)
1624 : Ones.lshr(ShiftAmt);
1625 const TargetLowering &TLI = TLO.DAG.getTargetLoweringInfo();
1626 if ((DemandedBits & C->getAPIntValue()) == (DemandedBits & Ones) &&
1627 TLI.isDesirableToCommuteXorWithShift(Op.getNode())) {
1628 // If the xor constant is a demanded mask, do a 'not' before the
1629 // shift:
1630 // xor (X << ShiftC), XorC --> (not X) << ShiftC
1631 // xor (X >> ShiftC), XorC --> (not X) >> ShiftC
1632 SDValue Not = TLO.DAG.getNOT(dl, Op0.getOperand(0), VT);
1633 return TLO.CombineTo(Op, TLO.DAG.getNode(Op0Opcode, dl, VT, Not,
1634 Op0.getOperand(1)));
1635 }
1636 }
1637 }
1638 }
1639 }
1640
1641 // If we can't turn this into a 'not', try to shrink the constant.
1642 if (!C || !C->isAllOnes())
1643 if (ShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO))
1644 return true;
1645
1646 // Attempt to avoid multi-use ops if we don't need anything from them.
1647 if (!DemandedBits.isAllOnes() || !DemandedElts.isAllOnes()) {
1648 SDValue DemandedOp0 = SimplifyMultipleUseDemandedBits(
1649 Op0, DemandedBits, DemandedElts, TLO.DAG, Depth + 1);
1650 SDValue DemandedOp1 = SimplifyMultipleUseDemandedBits(
1651 Op1, DemandedBits, DemandedElts, TLO.DAG, Depth + 1);
1652 if (DemandedOp0 || DemandedOp1) {
1653 Op0 = DemandedOp0 ? DemandedOp0 : Op0;
1654 Op1 = DemandedOp1 ? DemandedOp1 : Op1;
1655 SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), dl, VT, Op0, Op1);
1656 return TLO.CombineTo(Op, NewOp);
1657 }
1658 }
1659
1660 Known ^= Known2;
1661 break;
1662 }
1663 case ISD::SELECT:
1664 if (SimplifyDemandedBits(Op.getOperand(2), DemandedBits, DemandedElts,
1665 Known, TLO, Depth + 1))
1666 return true;
1667 if (SimplifyDemandedBits(Op.getOperand(1), DemandedBits, DemandedElts,
1668 Known2, TLO, Depth + 1))
1669 return true;
1670
1671 // If the operands are constants, see if we can simplify them.
1672 if (ShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO))
1673 return true;
1674
1675 // Only known if known in both the LHS and RHS.
1676 Known = Known.intersectWith(Known2);
1677 break;
1678 case ISD::VSELECT:
1679 if (SimplifyDemandedBits(Op.getOperand(2), DemandedBits, DemandedElts,
1680 Known, TLO, Depth + 1))
1681 return true;
1682 if (SimplifyDemandedBits(Op.getOperand(1), DemandedBits, DemandedElts,
1683 Known2, TLO, Depth + 1))
1684 return true;
1685
1686 // Only known if known in both the LHS and RHS.
1687 Known = Known.intersectWith(Known2);
1688 break;
1689 case ISD::SELECT_CC:
1690 if (SimplifyDemandedBits(Op.getOperand(3), DemandedBits, DemandedElts,
1691 Known, TLO, Depth + 1))
1692 return true;
1693 if (SimplifyDemandedBits(Op.getOperand(2), DemandedBits, DemandedElts,
1694 Known2, TLO, Depth + 1))
1695 return true;
1696
1697 // If the operands are constants, see if we can simplify them.
1698 if (ShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO))
1699 return true;
1700
1701 // Only known if known in both the LHS and RHS.
1702 Known = Known.intersectWith(Known2);
1703 break;
1704 case ISD::SETCC: {
1705 SDValue Op0 = Op.getOperand(0);
1706 SDValue Op1 = Op.getOperand(1);
1707 ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(2))->get();
1708 // If (1) we only need the sign-bit, (2) the setcc operands are the same
1709 // width as the setcc result, and (3) the result of a setcc conforms to 0 or
1710 // -1, we may be able to bypass the setcc.
1711 if (DemandedBits.isSignMask() &&
1715 // If we're testing X < 0, then this compare isn't needed - just use X!
1716 // FIXME: We're limiting to integer types here, but this should also work
1717 // if we don't care about FP signed-zero. The use of SETLT with FP means
1718 // that we don't care about NaNs.
1719 if (CC == ISD::SETLT && Op1.getValueType().isInteger() &&
1721 return TLO.CombineTo(Op, Op0);
1722
1723 // TODO: Should we check for other forms of sign-bit comparisons?
1724 // Examples: X <= -1, X >= 0
1725 }
1726 if (getBooleanContents(Op0.getValueType()) ==
1728 BitWidth > 1)
1729 Known.Zero.setBitsFrom(1);
1730 break;
1731 }
1732 case ISD::SHL: {
1733 SDValue Op0 = Op.getOperand(0);
1734 SDValue Op1 = Op.getOperand(1);
1735 EVT ShiftVT = Op1.getValueType();
1736
1737 if (std::optional<uint64_t> KnownSA =
1738 TLO.DAG.getValidShiftAmount(Op, DemandedElts, Depth + 1)) {
1739 unsigned ShAmt = *KnownSA;
1740 if (ShAmt == 0)
1741 return TLO.CombineTo(Op, Op0);
1742
1743 // If this is ((X >>u C1) << ShAmt), see if we can simplify this into a
1744 // single shift. We can do this if the bottom bits (which are shifted
1745 // out) are never demanded.
1746 // TODO - support non-uniform vector amounts.
1747 if (Op0.getOpcode() == ISD::SRL) {
1748 if (!DemandedBits.intersects(APInt::getLowBitsSet(BitWidth, ShAmt))) {
1749 if (std::optional<uint64_t> InnerSA =
1750 TLO.DAG.getValidShiftAmount(Op0, DemandedElts, Depth + 2)) {
1751 unsigned C1 = *InnerSA;
1752 unsigned Opc = ISD::SHL;
1753 int Diff = ShAmt - C1;
1754 if (Diff < 0) {
1755 Diff = -Diff;
1756 Opc = ISD::SRL;
1757 }
1758 SDValue NewSA = TLO.DAG.getConstant(Diff, dl, ShiftVT);
1759 return TLO.CombineTo(
1760 Op, TLO.DAG.getNode(Opc, dl, VT, Op0.getOperand(0), NewSA));
1761 }
1762 }
1763 }
1764
1765 // Convert (shl (anyext x, c)) to (anyext (shl x, c)) if the high bits
1766 // are not demanded. This will likely allow the anyext to be folded away.
1767 // TODO - support non-uniform vector amounts.
1768 if (Op0.getOpcode() == ISD::ANY_EXTEND) {
1769 SDValue InnerOp = Op0.getOperand(0);
1770 EVT InnerVT = InnerOp.getValueType();
1771 unsigned InnerBits = InnerVT.getScalarSizeInBits();
1772 if (ShAmt < InnerBits && DemandedBits.getActiveBits() <= InnerBits &&
1773 isTypeDesirableForOp(ISD::SHL, InnerVT)) {
1774 SDValue NarrowShl = TLO.DAG.getNode(
1775 ISD::SHL, dl, InnerVT, InnerOp,
1776 TLO.DAG.getShiftAmountConstant(ShAmt, InnerVT, dl));
1777 return TLO.CombineTo(
1778 Op, TLO.DAG.getNode(ISD::ANY_EXTEND, dl, VT, NarrowShl));
1779 }
1780
1781 // Repeat the SHL optimization above in cases where an extension
1782 // intervenes: (shl (anyext (shr x, c1)), c2) to
1783 // (shl (anyext x), c2-c1). This requires that the bottom c1 bits
1784 // aren't demanded (as above) and that the shifted upper c1 bits of
1785 // x aren't demanded.
1786 // TODO - support non-uniform vector amounts.
1787 if (InnerOp.getOpcode() == ISD::SRL && Op0.hasOneUse() &&
1788 InnerOp.hasOneUse()) {
1789 if (std::optional<uint64_t> SA2 = TLO.DAG.getValidShiftAmount(
1790 InnerOp, DemandedElts, Depth + 2)) {
1791 unsigned InnerShAmt = *SA2;
1792 if (InnerShAmt < ShAmt && InnerShAmt < InnerBits &&
1793 DemandedBits.getActiveBits() <=
1794 (InnerBits - InnerShAmt + ShAmt) &&
1795 DemandedBits.countr_zero() >= ShAmt) {
1796 SDValue NewSA =
1797 TLO.DAG.getConstant(ShAmt - InnerShAmt, dl, ShiftVT);
1798 SDValue NewExt = TLO.DAG.getNode(ISD::ANY_EXTEND, dl, VT,
1799 InnerOp.getOperand(0));
1800 return TLO.CombineTo(
1801 Op, TLO.DAG.getNode(ISD::SHL, dl, VT, NewExt, NewSA));
1802 }
1803 }
1804 }
1805 }
1806
1807 APInt InDemandedMask = DemandedBits.lshr(ShAmt);
1808 if (SimplifyDemandedBits(Op0, InDemandedMask, DemandedElts, Known, TLO,
1809 Depth + 1)) {
1810 SDNodeFlags Flags = Op.getNode()->getFlags();
1811 if (Flags.hasNoSignedWrap() || Flags.hasNoUnsignedWrap()) {
1812 // Disable the nsw and nuw flags. We can no longer guarantee that we
1813 // won't wrap after simplification.
1814 Flags.setNoSignedWrap(false);
1815 Flags.setNoUnsignedWrap(false);
1816 Op->setFlags(Flags);
1817 }
1818 return true;
1819 }
1820 Known.Zero <<= ShAmt;
1821 Known.One <<= ShAmt;
1822 // low bits known zero.
1823 Known.Zero.setLowBits(ShAmt);
1824
1825 // Attempt to avoid multi-use ops if we don't need anything from them.
1826 if (!InDemandedMask.isAllOnes() || !DemandedElts.isAllOnes()) {
1827 SDValue DemandedOp0 = SimplifyMultipleUseDemandedBits(
1828 Op0, InDemandedMask, DemandedElts, TLO.DAG, Depth + 1);
1829 if (DemandedOp0) {
1830 SDValue NewOp = TLO.DAG.getNode(ISD::SHL, dl, VT, DemandedOp0, Op1);
1831 return TLO.CombineTo(Op, NewOp);
1832 }
1833 }
1834
1835 // TODO: Can we merge this fold with the one below?
1836 // Try shrinking the operation as long as the shift amount will still be
1837 // in range.
1838 if (ShAmt < DemandedBits.getActiveBits() && !VT.isVector() &&
1839 Op.getNode()->hasOneUse()) {
1840 // Search for the smallest integer type with free casts to and from
1841 // Op's type. For expedience, just check power-of-2 integer types.
1842 unsigned DemandedSize = DemandedBits.getActiveBits();
1843 for (unsigned SmallVTBits = llvm::bit_ceil(DemandedSize);
1844 SmallVTBits < BitWidth; SmallVTBits = NextPowerOf2(SmallVTBits)) {
1845 EVT SmallVT = EVT::getIntegerVT(*TLO.DAG.getContext(), SmallVTBits);
1846 if (isNarrowingProfitable(VT, SmallVT) &&
1847 isTypeDesirableForOp(ISD::SHL, SmallVT) &&
1848 isTruncateFree(VT, SmallVT) && isZExtFree(SmallVT, VT) &&
1849 (!TLO.LegalOperations() || isOperationLegal(ISD::SHL, SmallVT))) {
1850 assert(DemandedSize <= SmallVTBits &&
1851 "Narrowed below demanded bits?");
1852 // We found a type with free casts.
1853 SDValue NarrowShl = TLO.DAG.getNode(
1854 ISD::SHL, dl, SmallVT,
1855 TLO.DAG.getNode(ISD::TRUNCATE, dl, SmallVT, Op.getOperand(0)),
1856 TLO.DAG.getShiftAmountConstant(ShAmt, SmallVT, dl));
1857 return TLO.CombineTo(
1858 Op, TLO.DAG.getNode(ISD::ANY_EXTEND, dl, VT, NarrowShl));
1859 }
1860 }
1861 }
1862
1863 // Narrow shift to lower half - similar to ShrinkDemandedOp.
1864 // (shl i64:x, K) -> (i64 zero_extend (shl (i32 (trunc i64:x)), K))
1865 // Only do this if we demand the upper half so the knownbits are correct.
1866 unsigned HalfWidth = BitWidth / 2;
1867 if ((BitWidth % 2) == 0 && !VT.isVector() && ShAmt < HalfWidth &&
1868 DemandedBits.countLeadingOnes() >= HalfWidth) {
1869 EVT HalfVT = EVT::getIntegerVT(*TLO.DAG.getContext(), HalfWidth);
1870 if (isNarrowingProfitable(VT, HalfVT) &&
1871 isTypeDesirableForOp(ISD::SHL, HalfVT) &&
1872 isTruncateFree(VT, HalfVT) && isZExtFree(HalfVT, VT) &&
1873 (!TLO.LegalOperations() || isOperationLegal(ISD::SHL, HalfVT))) {
1874 // If we're demanding the upper bits at all, we must ensure
1875 // that the upper bits of the shift result are known to be zero,
1876 // which is equivalent to the narrow shift being NUW.
1877 if (bool IsNUW = (Known.countMinLeadingZeros() >= HalfWidth)) {
1878 bool IsNSW = Known.countMinSignBits() > HalfWidth;
1879 SDNodeFlags Flags;
1880 Flags.setNoSignedWrap(IsNSW);
1881 Flags.setNoUnsignedWrap(IsNUW);
1882 SDValue NewOp = TLO.DAG.getNode(ISD::TRUNCATE, dl, HalfVT, Op0);
1883 SDValue NewShiftAmt =
1884 TLO.DAG.getShiftAmountConstant(ShAmt, HalfVT, dl);
1885 SDValue NewShift = TLO.DAG.getNode(ISD::SHL, dl, HalfVT, NewOp,
1886 NewShiftAmt, Flags);
1887 SDValue NewExt =
1888 TLO.DAG.getNode(ISD::ZERO_EXTEND, dl, VT, NewShift);
1889 return TLO.CombineTo(Op, NewExt);
1890 }
1891 }
1892 }
1893 } else {
1894 // This is a variable shift, so we can't shift the demand mask by a known
1895 // amount. But if we are not demanding high bits, then we are not
1896 // demanding those bits from the pre-shifted operand either.
1897 if (unsigned CTLZ = DemandedBits.countl_zero()) {
1898 APInt DemandedFromOp(APInt::getLowBitsSet(BitWidth, BitWidth - CTLZ));
1899 if (SimplifyDemandedBits(Op0, DemandedFromOp, DemandedElts, Known, TLO,
1900 Depth + 1)) {
1901 SDNodeFlags Flags = Op.getNode()->getFlags();
1902 if (Flags.hasNoSignedWrap() || Flags.hasNoUnsignedWrap()) {
1903 // Disable the nsw and nuw flags. We can no longer guarantee that we
1904 // won't wrap after simplification.
1905 Flags.setNoSignedWrap(false);
1906 Flags.setNoUnsignedWrap(false);
1907 Op->setFlags(Flags);
1908 }
1909 return true;
1910 }
1911 Known.resetAll();
1912 }
1913 }
1914
1915 // If we are only demanding sign bits then we can use the shift source
1916 // directly.
1917 if (std::optional<uint64_t> MaxSA =
1918 TLO.DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) {
1919 unsigned ShAmt = *MaxSA;
1920 unsigned NumSignBits =
1921 TLO.DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
1922 unsigned UpperDemandedBits = BitWidth - DemandedBits.countr_zero();
1923 if (NumSignBits > ShAmt && (NumSignBits - ShAmt) >= (UpperDemandedBits))
1924 return TLO.CombineTo(Op, Op0);
1925 }
1926 break;
1927 }
1928 case ISD::SRL: {
1929 SDValue Op0 = Op.getOperand(0);
1930 SDValue Op1 = Op.getOperand(1);
1931 EVT ShiftVT = Op1.getValueType();
1932
1933 if (std::optional<uint64_t> KnownSA =
1934 TLO.DAG.getValidShiftAmount(Op, DemandedElts, Depth + 1)) {
1935 unsigned ShAmt = *KnownSA;
1936 if (ShAmt == 0)
1937 return TLO.CombineTo(Op, Op0);
1938
1939 // If this is ((X << C1) >>u ShAmt), see if we can simplify this into a
1940 // single shift. We can do this if the top bits (which are shifted out)
1941 // are never demanded.
1942 // TODO - support non-uniform vector amounts.
1943 if (Op0.getOpcode() == ISD::SHL) {
1944 if (!DemandedBits.intersects(APInt::getHighBitsSet(BitWidth, ShAmt))) {
1945 if (std::optional<uint64_t> InnerSA =
1946 TLO.DAG.getValidShiftAmount(Op0, DemandedElts, Depth + 2)) {
1947 unsigned C1 = *InnerSA;
1948 unsigned Opc = ISD::SRL;
1949 int Diff = ShAmt - C1;
1950 if (Diff < 0) {
1951 Diff = -Diff;
1952 Opc = ISD::SHL;
1953 }
1954 SDValue NewSA = TLO.DAG.getConstant(Diff, dl, ShiftVT);
1955 return TLO.CombineTo(
1956 Op, TLO.DAG.getNode(Opc, dl, VT, Op0.getOperand(0), NewSA));
1957 }
1958 }
1959 }
1960
1961 APInt InDemandedMask = (DemandedBits << ShAmt);
1962
1963 // If the shift is exact, then it does demand the low bits (and knows that
1964 // they are zero).
1965 if (Op->getFlags().hasExact())
1966 InDemandedMask.setLowBits(ShAmt);
1967
1968 // Narrow shift to lower half - similar to ShrinkDemandedOp.
1969 // (srl i64:x, K) -> (i64 zero_extend (srl (i32 (trunc i64:x)), K))
1970 if ((BitWidth % 2) == 0 && !VT.isVector()) {
1972 EVT HalfVT = EVT::getIntegerVT(*TLO.DAG.getContext(), BitWidth / 2);
1973 if (isNarrowingProfitable(VT, HalfVT) &&
1974 isTypeDesirableForOp(ISD::SRL, HalfVT) &&
1975 isTruncateFree(VT, HalfVT) && isZExtFree(HalfVT, VT) &&
1976 (!TLO.LegalOperations() || isOperationLegal(ISD::SRL, HalfVT)) &&
1977 ((InDemandedMask.countLeadingZeros() >= (BitWidth / 2)) ||
1978 TLO.DAG.MaskedValueIsZero(Op0, HiBits))) {
1979 SDValue NewOp = TLO.DAG.getNode(ISD::TRUNCATE, dl, HalfVT, Op0);
1980 SDValue NewShiftAmt =
1981 TLO.DAG.getShiftAmountConstant(ShAmt, HalfVT, dl);
1982 SDValue NewShift =
1983 TLO.DAG.getNode(ISD::SRL, dl, HalfVT, NewOp, NewShiftAmt);
1984 return TLO.CombineTo(
1985 Op, TLO.DAG.getNode(ISD::ZERO_EXTEND, dl, VT, NewShift));
1986 }
1987 }
1988
1989 // Compute the new bits that are at the top now.
1990 if (SimplifyDemandedBits(Op0, InDemandedMask, DemandedElts, Known, TLO,
1991 Depth + 1))
1992 return true;
1993 Known.Zero.lshrInPlace(ShAmt);
1994 Known.One.lshrInPlace(ShAmt);
1995 // High bits known zero.
1996 Known.Zero.setHighBits(ShAmt);
1997
1998 // Attempt to avoid multi-use ops if we don't need anything from them.
1999 if (!InDemandedMask.isAllOnes() || !DemandedElts.isAllOnes()) {
2000 SDValue DemandedOp0 = SimplifyMultipleUseDemandedBits(
2001 Op0, InDemandedMask, DemandedElts, TLO.DAG, Depth + 1);
2002 if (DemandedOp0) {
2003 SDValue NewOp = TLO.DAG.getNode(ISD::SRL, dl, VT, DemandedOp0, Op1);
2004 return TLO.CombineTo(Op, NewOp);
2005 }
2006 }
2007 } else {
2008 // Use generic knownbits computation as it has support for non-uniform
2009 // shift amounts.
2010 Known = TLO.DAG.computeKnownBits(Op, DemandedElts, Depth);
2011 }
2012
2013 // Try to match AVG patterns (after shift simplification).
2014 if (SDValue AVG = combineShiftToAVG(Op, TLO, *this, DemandedBits,
2015 DemandedElts, Depth + 1))
2016 return TLO.CombineTo(Op, AVG);
2017
2018 break;
2019 }
2020 case ISD::SRA: {
2021 SDValue Op0 = Op.getOperand(0);
2022 SDValue Op1 = Op.getOperand(1);
2023 EVT ShiftVT = Op1.getValueType();
2024
2025 // If we only want bits that already match the signbit then we don't need
2026 // to shift.
2027 unsigned NumHiDemandedBits = BitWidth - DemandedBits.countr_zero();
2028 if (TLO.DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1) >=
2029 NumHiDemandedBits)
2030 return TLO.CombineTo(Op, Op0);
2031
2032 // If this is an arithmetic shift right and only the low-bit is set, we can
2033 // always convert this into a logical shr, even if the shift amount is
2034 // variable. The low bit of the shift cannot be an input sign bit unless
2035 // the shift amount is >= the size of the datatype, which is undefined.
2036 if (DemandedBits.isOne())
2037 return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SRL, dl, VT, Op0, Op1));
2038
2039 if (std::optional<uint64_t> KnownSA =
2040 TLO.DAG.getValidShiftAmount(Op, DemandedElts, Depth + 1)) {
2041 unsigned ShAmt = *KnownSA;
2042 if (ShAmt == 0)
2043 return TLO.CombineTo(Op, Op0);
2044
2045 // fold (sra (shl x, c1), c1) -> sext_inreg for some c1 and target
2046 // supports sext_inreg.
2047 if (Op0.getOpcode() == ISD::SHL) {
2048 if (std::optional<uint64_t> InnerSA =
2049 TLO.DAG.getValidShiftAmount(Op0, DemandedElts, Depth + 2)) {
2050 unsigned LowBits = BitWidth - ShAmt;
2051 EVT ExtVT = EVT::getIntegerVT(*TLO.DAG.getContext(), LowBits);
2052 if (VT.isVector())
2053 ExtVT = EVT::getVectorVT(*TLO.DAG.getContext(), ExtVT,
2055
2056 if (*InnerSA == ShAmt) {
2057 if (!TLO.LegalOperations() ||
2059 return TLO.CombineTo(
2060 Op, TLO.DAG.getNode(ISD::SIGN_EXTEND_INREG, dl, VT,
2061 Op0.getOperand(0),
2062 TLO.DAG.getValueType(ExtVT)));
2063
2064 // Even if we can't convert to sext_inreg, we might be able to
2065 // remove this shift pair if the input is already sign extended.
2066 unsigned NumSignBits =
2067 TLO.DAG.ComputeNumSignBits(Op0.getOperand(0), DemandedElts);
2068 if (NumSignBits > ShAmt)
2069 return TLO.CombineTo(Op, Op0.getOperand(0));
2070 }
2071 }
2072 }
2073
2074 APInt InDemandedMask = (DemandedBits << ShAmt);
2075
2076 // If the shift is exact, then it does demand the low bits (and knows that
2077 // they are zero).
2078 if (Op->getFlags().hasExact())
2079 InDemandedMask.setLowBits(ShAmt);
2080
2081 // If any of the demanded bits are produced by the sign extension, we also
2082 // demand the input sign bit.
2083 if (DemandedBits.countl_zero() < ShAmt)
2084 InDemandedMask.setSignBit();
2085
2086 if (SimplifyDemandedBits(Op0, InDemandedMask, DemandedElts, Known, TLO,
2087 Depth + 1))
2088 return true;
2089 Known.Zero.lshrInPlace(ShAmt);
2090 Known.One.lshrInPlace(ShAmt);
2091
2092 // If the input sign bit is known to be zero, or if none of the top bits
2093 // are demanded, turn this into an unsigned shift right.
2094 if (Known.Zero[BitWidth - ShAmt - 1] ||
2095 DemandedBits.countl_zero() >= ShAmt) {
2096 SDNodeFlags Flags;
2097 Flags.setExact(Op->getFlags().hasExact());
2098 return TLO.CombineTo(
2099 Op, TLO.DAG.getNode(ISD::SRL, dl, VT, Op0, Op1, Flags));
2100 }
2101
2102 int Log2 = DemandedBits.exactLogBase2();
2103 if (Log2 >= 0) {
2104 // The bit must come from the sign.
2105 SDValue NewSA = TLO.DAG.getConstant(BitWidth - 1 - Log2, dl, ShiftVT);
2106 return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SRL, dl, VT, Op0, NewSA));
2107 }
2108
2109 if (Known.One[BitWidth - ShAmt - 1])
2110 // New bits are known one.
2111 Known.One.setHighBits(ShAmt);
2112
2113 // Attempt to avoid multi-use ops if we don't need anything from them.
2114 if (!InDemandedMask.isAllOnes() || !DemandedElts.isAllOnes()) {
2115 SDValue DemandedOp0 = SimplifyMultipleUseDemandedBits(
2116 Op0, InDemandedMask, DemandedElts, TLO.DAG, Depth + 1);
2117 if (DemandedOp0) {
2118 SDValue NewOp = TLO.DAG.getNode(ISD::SRA, dl, VT, DemandedOp0, Op1);
2119 return TLO.CombineTo(Op, NewOp);
2120 }
2121 }
2122 }
2123
2124 // Try to match AVG patterns (after shift simplification).
2125 if (SDValue AVG = combineShiftToAVG(Op, TLO, *this, DemandedBits,
2126 DemandedElts, Depth + 1))
2127 return TLO.CombineTo(Op, AVG);
2128
2129 break;
2130 }
2131 case ISD::FSHL:
2132 case ISD::FSHR: {
2133 SDValue Op0 = Op.getOperand(0);
2134 SDValue Op1 = Op.getOperand(1);
2135 SDValue Op2 = Op.getOperand(2);
2136 bool IsFSHL = (Op.getOpcode() == ISD::FSHL);
2137
2138 if (ConstantSDNode *SA = isConstOrConstSplat(Op2, DemandedElts)) {
2139 unsigned Amt = SA->getAPIntValue().urem(BitWidth);
2140
2141 // For fshl, 0-shift returns the 1st arg.
2142 // For fshr, 0-shift returns the 2nd arg.
2143 if (Amt == 0) {
2144 if (SimplifyDemandedBits(IsFSHL ? Op0 : Op1, DemandedBits, DemandedElts,
2145 Known, TLO, Depth + 1))
2146 return true;
2147 break;
2148 }
2149
2150 // fshl: (Op0 << Amt) | (Op1 >> (BW - Amt))
2151 // fshr: (Op0 << (BW - Amt)) | (Op1 >> Amt)
2152 APInt Demanded0 = DemandedBits.lshr(IsFSHL ? Amt : (BitWidth - Amt));
2153 APInt Demanded1 = DemandedBits << (IsFSHL ? (BitWidth - Amt) : Amt);
2154 if (SimplifyDemandedBits(Op0, Demanded0, DemandedElts, Known2, TLO,
2155 Depth + 1))
2156 return true;
2157 if (SimplifyDemandedBits(Op1, Demanded1, DemandedElts, Known, TLO,
2158 Depth + 1))
2159 return true;
2160
2161 Known2.One <<= (IsFSHL ? Amt : (BitWidth - Amt));
2162 Known2.Zero <<= (IsFSHL ? Amt : (BitWidth - Amt));
2163 Known.One.lshrInPlace(IsFSHL ? (BitWidth - Amt) : Amt);
2164 Known.Zero.lshrInPlace(IsFSHL ? (BitWidth - Amt) : Amt);
2165 Known = Known.unionWith(Known2);
2166
2167 // Attempt to avoid multi-use ops if we don't need anything from them.
2168 if (!Demanded0.isAllOnes() || !Demanded1.isAllOnes() ||
2169 !DemandedElts.isAllOnes()) {
2170 SDValue DemandedOp0 = SimplifyMultipleUseDemandedBits(
2171 Op0, Demanded0, DemandedElts, TLO.DAG, Depth + 1);
2172 SDValue DemandedOp1 = SimplifyMultipleUseDemandedBits(
2173 Op1, Demanded1, DemandedElts, TLO.DAG, Depth + 1);
2174 if (DemandedOp0 || DemandedOp1) {
2175 DemandedOp0 = DemandedOp0 ? DemandedOp0 : Op0;
2176 DemandedOp1 = DemandedOp1 ? DemandedOp1 : Op1;
2177 SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), dl, VT, DemandedOp0,
2178 DemandedOp1, Op2);
2179 return TLO.CombineTo(Op, NewOp);
2180 }
2181 }
2182 }
2183
2184 // For pow-2 bitwidths we only demand the bottom modulo amt bits.
2185 if (isPowerOf2_32(BitWidth)) {
2186 APInt DemandedAmtBits(Op2.getScalarValueSizeInBits(), BitWidth - 1);
2187 if (SimplifyDemandedBits(Op2, DemandedAmtBits, DemandedElts,
2188 Known2, TLO, Depth + 1))
2189 return true;
2190 }
2191 break;
2192 }
2193 case ISD::ROTL:
2194 case ISD::ROTR: {
2195 SDValue Op0 = Op.getOperand(0);
2196 SDValue Op1 = Op.getOperand(1);
2197 bool IsROTL = (Op.getOpcode() == ISD::ROTL);
2198
2199 // If we're rotating an 0/-1 value, then it stays an 0/-1 value.
2200 if (BitWidth == TLO.DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1))
2201 return TLO.CombineTo(Op, Op0);
2202
2203 if (ConstantSDNode *SA = isConstOrConstSplat(Op1, DemandedElts)) {
2204 unsigned Amt = SA->getAPIntValue().urem(BitWidth);
2205 unsigned RevAmt = BitWidth - Amt;
2206
2207 // rotl: (Op0 << Amt) | (Op0 >> (BW - Amt))
2208 // rotr: (Op0 << (BW - Amt)) | (Op0 >> Amt)
2209 APInt Demanded0 = DemandedBits.rotr(IsROTL ? Amt : RevAmt);
2210 if (SimplifyDemandedBits(Op0, Demanded0, DemandedElts, Known2, TLO,
2211 Depth + 1))
2212 return true;
2213
2214 // rot*(x, 0) --> x
2215 if (Amt == 0)
2216 return TLO.CombineTo(Op, Op0);
2217
2218 // See if we don't demand either half of the rotated bits.
2219 if ((!TLO.LegalOperations() || isOperationLegal(ISD::SHL, VT)) &&
2220 DemandedBits.countr_zero() >= (IsROTL ? Amt : RevAmt)) {
2221 Op1 = TLO.DAG.getConstant(IsROTL ? Amt : RevAmt, dl, Op1.getValueType());
2222 return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SHL, dl, VT, Op0, Op1));
2223 }
2224 if ((!TLO.LegalOperations() || isOperationLegal(ISD::SRL, VT)) &&
2225 DemandedBits.countl_zero() >= (IsROTL ? RevAmt : Amt)) {
2226 Op1 = TLO.DAG.getConstant(IsROTL ? RevAmt : Amt, dl, Op1.getValueType());
2227 return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SRL, dl, VT, Op0, Op1));
2228 }
2229 }
2230
2231 // For pow-2 bitwidths we only demand the bottom modulo amt bits.
2232 if (isPowerOf2_32(BitWidth)) {
2233 APInt DemandedAmtBits(Op1.getScalarValueSizeInBits(), BitWidth - 1);
2234 if (SimplifyDemandedBits(Op1, DemandedAmtBits, DemandedElts, Known2, TLO,
2235 Depth + 1))
2236 return true;
2237 }
2238 break;
2239 }
2240 case ISD::SMIN:
2241 case ISD::SMAX:
2242 case ISD::UMIN:
2243 case ISD::UMAX: {
2244 unsigned Opc = Op.getOpcode();
2245 SDValue Op0 = Op.getOperand(0);
2246 SDValue Op1 = Op.getOperand(1);
2247
2248 // If we're only demanding signbits, then we can simplify to OR/AND node.
2249 unsigned BitOp =
2250 (Opc == ISD::SMIN || Opc == ISD::UMAX) ? ISD::OR : ISD::AND;
2251 unsigned NumSignBits =
2252 std::min(TLO.DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1),
2253 TLO.DAG.ComputeNumSignBits(Op1, DemandedElts, Depth + 1));
2254 unsigned NumDemandedUpperBits = BitWidth - DemandedBits.countr_zero();
2255 if (NumSignBits >= NumDemandedUpperBits)
2256 return TLO.CombineTo(Op, TLO.DAG.getNode(BitOp, SDLoc(Op), VT, Op0, Op1));
2257
2258 // Check if one arg is always less/greater than (or equal) to the other arg.
2259 KnownBits Known0 = TLO.DAG.computeKnownBits(Op0, DemandedElts, Depth + 1);
2260 KnownBits Known1 = TLO.DAG.computeKnownBits(Op1, DemandedElts, Depth + 1);
2261 switch (Opc) {
2262 case ISD::SMIN:
2263 if (std::optional<bool> IsSLE = KnownBits::sle(Known0, Known1))
2264 return TLO.CombineTo(Op, *IsSLE ? Op0 : Op1);
2265 if (std::optional<bool> IsSLT = KnownBits::slt(Known0, Known1))
2266 return TLO.CombineTo(Op, *IsSLT ? Op0 : Op1);
2267 Known = KnownBits::smin(Known0, Known1);
2268 break;
2269 case ISD::SMAX:
2270 if (std::optional<bool> IsSGE = KnownBits::sge(Known0, Known1))
2271 return TLO.CombineTo(Op, *IsSGE ? Op0 : Op1);
2272 if (std::optional<bool> IsSGT = KnownBits::sgt(Known0, Known1))
2273 return TLO.CombineTo(Op, *IsSGT ? Op0 : Op1);
2274 Known = KnownBits::smax(Known0, Known1);
2275 break;
2276 case ISD::UMIN:
2277 if (std::optional<bool> IsULE = KnownBits::ule(Known0, Known1))
2278 return TLO.CombineTo(Op, *IsULE ? Op0 : Op1);
2279 if (std::optional<bool> IsULT = KnownBits::ult(Known0, Known1))
2280 return TLO.CombineTo(Op, *IsULT ? Op0 : Op1);
2281 Known = KnownBits::umin(Known0, Known1);
2282 break;
2283 case ISD::UMAX:
2284 if (std::optional<bool> IsUGE = KnownBits::uge(Known0, Known1))
2285 return TLO.CombineTo(Op, *IsUGE ? Op0 : Op1);
2286 if (std::optional<bool> IsUGT = KnownBits::ugt(Known0, Known1))
2287 return TLO.CombineTo(Op, *IsUGT ? Op0 : Op1);
2288 Known = KnownBits::umax(Known0, Known1);
2289 break;
2290 }
2291 break;
2292 }
2293 case ISD::BITREVERSE: {
2294 SDValue Src = Op.getOperand(0);
2295 APInt DemandedSrcBits = DemandedBits.reverseBits();
2296 if (SimplifyDemandedBits(Src, DemandedSrcBits, DemandedElts, Known2, TLO,
2297 Depth + 1))
2298 return true;
2299 Known.One = Known2.One.reverseBits();
2300 Known.Zero = Known2.Zero.reverseBits();
2301 break;
2302 }
2303 case ISD::BSWAP: {
2304 SDValue Src = Op.getOperand(0);
2305
2306 // If the only bits demanded come from one byte of the bswap result,
2307 // just shift the input byte into position to eliminate the bswap.
2308 unsigned NLZ = DemandedBits.countl_zero();
2309 unsigned NTZ = DemandedBits.countr_zero();
2310
2311 // Round NTZ down to the next byte. If we have 11 trailing zeros, then
2312 // we need all the bits down to bit 8. Likewise, round NLZ. If we
2313 // have 14 leading zeros, round to 8.
2314 NLZ = alignDown(NLZ, 8);
2315 NTZ = alignDown(NTZ, 8);
2316 // If we need exactly one byte, we can do this transformation.
2317 if (BitWidth - NLZ - NTZ == 8) {
2318 // Replace this with either a left or right shift to get the byte into
2319 // the right place.
2320 unsigned ShiftOpcode = NLZ > NTZ ? ISD::SRL : ISD::SHL;
2321 if (!TLO.LegalOperations() || isOperationLegal(ShiftOpcode, VT)) {
2322 unsigned ShiftAmount = NLZ > NTZ ? NLZ - NTZ : NTZ - NLZ;
2323 SDValue ShAmt = TLO.DAG.getShiftAmountConstant(ShiftAmount, VT, dl);
2324 SDValue NewOp = TLO.DAG.getNode(ShiftOpcode, dl, VT, Src, ShAmt);
2325 return TLO.CombineTo(Op, NewOp);
2326 }
2327 }
2328
2329 APInt DemandedSrcBits = DemandedBits.byteSwap();
2330 if (SimplifyDemandedBits(Src, DemandedSrcBits, DemandedElts, Known2, TLO,
2331 Depth + 1))
2332 return true;
2333 Known.One = Known2.One.byteSwap();
2334 Known.Zero = Known2.Zero.byteSwap();
2335 break;
2336 }
2337 case ISD::CTPOP: {
2338 // If only 1 bit is demanded, replace with PARITY as long as we're before
2339 // op legalization.
2340 // FIXME: Limit to scalars for now.
2341 if (DemandedBits.isOne() && !TLO.LegalOps && !VT.isVector())
2342 return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::PARITY, dl, VT,
2343 Op.getOperand(0)));
2344
2345 Known = TLO.DAG.computeKnownBits(Op, DemandedElts, Depth);
2346 break;
2347 }
2349 SDValue Op0 = Op.getOperand(0);
2350 EVT ExVT = cast<VTSDNode>(Op.getOperand(1))->getVT();
2351 unsigned ExVTBits = ExVT.getScalarSizeInBits();
2352
2353 // If we only care about the highest bit, don't bother shifting right.
2354 if (DemandedBits.isSignMask()) {
2355 unsigned MinSignedBits =
2356 TLO.DAG.ComputeMaxSignificantBits(Op0, DemandedElts, Depth + 1);
2357 bool AlreadySignExtended = ExVTBits >= MinSignedBits;
2358 // However if the input is already sign extended we expect the sign
2359 // extension to be dropped altogether later and do not simplify.
2360 if (!AlreadySignExtended) {
2361 // Compute the correct shift amount type, which must be getShiftAmountTy
2362 // for scalar types after legalization.
2363 SDValue ShiftAmt =
2364 TLO.DAG.getShiftAmountConstant(BitWidth - ExVTBits, VT, dl);
2365 return TLO.CombineTo(Op,
2366 TLO.DAG.getNode(ISD::SHL, dl, VT, Op0, ShiftAmt));
2367 }
2368 }
2369
2370 // If none of the extended bits are demanded, eliminate the sextinreg.
2371 if (DemandedBits.getActiveBits() <= ExVTBits)
2372 return TLO.CombineTo(Op, Op0);
2373
2374 APInt InputDemandedBits = DemandedBits.getLoBits(ExVTBits);
2375
2376 // Since the sign extended bits are demanded, we know that the sign
2377 // bit is demanded.
2378 InputDemandedBits.setBit(ExVTBits - 1);
2379
2380 if (SimplifyDemandedBits(Op0, InputDemandedBits, DemandedElts, Known, TLO,
2381 Depth + 1))
2382 return true;
2383
2384 // If the sign bit of the input is known set or clear, then we know the
2385 // top bits of the result.
2386
2387 // If the input sign bit is known zero, convert this into a zero extension.
2388 if (Known.Zero[ExVTBits - 1])
2389 return TLO.CombineTo(Op, TLO.DAG.getZeroExtendInReg(Op0, dl, ExVT));
2390
2391 APInt Mask = APInt::getLowBitsSet(BitWidth, ExVTBits);
2392 if (Known.One[ExVTBits - 1]) { // Input sign bit known set
2393 Known.One.setBitsFrom(ExVTBits);
2394 Known.Zero &= Mask;
2395 } else { // Input sign bit unknown
2396 Known.Zero &= Mask;
2397 Known.One &= Mask;
2398 }
2399 break;
2400 }
2401 case ISD::BUILD_PAIR: {
2402 EVT HalfVT = Op.getOperand(0).getValueType();
2403 unsigned HalfBitWidth = HalfVT.getScalarSizeInBits();
2404
2405 APInt MaskLo = DemandedBits.getLoBits(HalfBitWidth).trunc(HalfBitWidth);
2406 APInt MaskHi = DemandedBits.getHiBits(HalfBitWidth).trunc(HalfBitWidth);
2407
2408 KnownBits KnownLo, KnownHi;
2409
2410 if (SimplifyDemandedBits(Op.getOperand(0), MaskLo, KnownLo, TLO, Depth + 1))
2411 return true;
2412
2413 if (SimplifyDemandedBits(Op.getOperand(1), MaskHi, KnownHi, TLO, Depth + 1))
2414 return true;
2415
2416 Known = KnownHi.concat(KnownLo);
2417 break;
2418 }
2420 if (VT.isScalableVector())
2421 return false;
2422 [[fallthrough]];
2423 case ISD::ZERO_EXTEND: {
2424 SDValue Src = Op.getOperand(0);
2425 EVT SrcVT = Src.getValueType();
2426 unsigned InBits = SrcVT.getScalarSizeInBits();
2427 unsigned InElts = SrcVT.isFixedLengthVector() ? SrcVT.getVectorNumElements() : 1;
2428 bool IsVecInReg = Op.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG;
2429
2430 // If none of the top bits are demanded, convert this into an any_extend.
2431 if (DemandedBits.getActiveBits() <= InBits) {
2432 // If we only need the non-extended bits of the bottom element
2433 // then we can just bitcast to the result.
2434 if (IsLE && IsVecInReg && DemandedElts == 1 &&
2435 VT.getSizeInBits() == SrcVT.getSizeInBits())
2436 return TLO.CombineTo(Op, TLO.DAG.getBitcast(VT, Src));
2437
2438 unsigned Opc =
2440 if (!TLO.LegalOperations() || isOperationLegal(Opc, VT))
2441 return TLO.CombineTo(Op, TLO.DAG.getNode(Opc, dl, VT, Src));
2442 }
2443
2444 SDNodeFlags Flags = Op->getFlags();
2445 APInt InDemandedBits = DemandedBits.trunc(InBits);
2446 APInt InDemandedElts = DemandedElts.zext(InElts);
2447 if (SimplifyDemandedBits(Src, InDemandedBits, InDemandedElts, Known, TLO,
2448 Depth + 1)) {
2449 if (Flags.hasNonNeg()) {
2450 Flags.setNonNeg(false);
2451 Op->setFlags(Flags);
2452 }
2453 return true;
2454 }
2455 assert(Known.getBitWidth() == InBits && "Src width has changed?");
2456 Known = Known.zext(BitWidth);
2457
2458 // Attempt to avoid multi-use ops if we don't need anything from them.
2459 if (SDValue NewSrc = SimplifyMultipleUseDemandedBits(
2460 Src, InDemandedBits, InDemandedElts, TLO.DAG, Depth + 1))
2461 return TLO.CombineTo(Op, TLO.DAG.getNode(Op.getOpcode(), dl, VT, NewSrc));
2462 break;
2463 }
2465 if (VT.isScalableVector())
2466 return false;
2467 [[fallthrough]];
2468 case ISD::SIGN_EXTEND: {
2469 SDValue Src = Op.getOperand(0);
2470 EVT SrcVT = Src.getValueType();
2471 unsigned InBits = SrcVT.getScalarSizeInBits();
2472 unsigned InElts = SrcVT.isFixedLengthVector() ? SrcVT.getVectorNumElements() : 1;
2473 bool IsVecInReg = Op.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG;
2474
2475 APInt InDemandedElts = DemandedElts.zext(InElts);
2476 APInt InDemandedBits = DemandedBits.trunc(InBits);
2477
2478 // Since some of the sign extended bits are demanded, we know that the sign
2479 // bit is demanded.
2480 InDemandedBits.setBit(InBits - 1);
2481
2482 // If none of the top bits are demanded, convert this into an any_extend.
2483 if (DemandedBits.getActiveBits() <= InBits) {
2484 // If we only need the non-extended bits of the bottom element
2485 // then we can just bitcast to the result.
2486 if (IsLE && IsVecInReg && DemandedElts == 1 &&
2487 VT.getSizeInBits() == SrcVT.getSizeInBits())
2488 return TLO.CombineTo(Op, TLO.DAG.getBitcast(VT, Src));
2489
2490 // Don't lose an all signbits 0/-1 splat on targets with 0/-1 booleans.
2492 TLO.DAG.ComputeNumSignBits(Src, InDemandedElts, Depth + 1) !=
2493 InBits) {
2494 unsigned Opc =
2496 if (!TLO.LegalOperations() || isOperationLegal(Opc, VT))
2497 return TLO.CombineTo(Op, TLO.DAG.getNode(Opc, dl, VT, Src));
2498 }
2499 }
2500
2501 if (SimplifyDemandedBits(Src, InDemandedBits, InDemandedElts, Known, TLO,
2502 Depth + 1))
2503 return true;
2504 assert(Known.getBitWidth() == InBits && "Src width has changed?");
2505
2506 // If the sign bit is known one, the top bits match.
2507 Known = Known.sext(BitWidth);
2508
2509 // If the sign bit is known zero, convert this to a zero extend.
2510 if (Known.isNonNegative()) {
2511 unsigned Opc =
2513 if (!TLO.LegalOperations() || isOperationLegal(Opc, VT)) {
2514 SDNodeFlags Flags;
2515 if (!IsVecInReg)
2516 Flags.setNonNeg(true);
2517 return TLO.CombineTo(Op, TLO.DAG.getNode(Opc, dl, VT, Src, Flags));
2518 }
2519 }
2520
2521 // Attempt to avoid multi-use ops if we don't need anything from them.
2522 if (SDValue NewSrc = SimplifyMultipleUseDemandedBits(
2523 Src, InDemandedBits, InDemandedElts, TLO.DAG, Depth + 1))
2524 return TLO.CombineTo(Op, TLO.DAG.getNode(Op.getOpcode(), dl, VT, NewSrc));
2525 break;
2526 }
2528 if (VT.isScalableVector())
2529 return false;
2530 [[fallthrough]];
2531 case ISD::ANY_EXTEND: {
2532 SDValue Src = Op.getOperand(0);
2533 EVT SrcVT = Src.getValueType();
2534 unsigned InBits = SrcVT.getScalarSizeInBits();
2535 unsigned InElts = SrcVT.isFixedLengthVector() ? SrcVT.getVectorNumElements() : 1;
2536 bool IsVecInReg = Op.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG;
2537
2538 // If we only need the bottom element then we can just bitcast.
2539 // TODO: Handle ANY_EXTEND?
2540 if (IsLE && IsVecInReg && DemandedElts == 1 &&
2541 VT.getSizeInBits() == SrcVT.getSizeInBits())
2542 return TLO.CombineTo(Op, TLO.DAG.getBitcast(VT, Src));
2543
2544 APInt InDemandedBits = DemandedBits.trunc(InBits);
2545 APInt InDemandedElts = DemandedElts.zext(InElts);
2546 if (SimplifyDemandedBits(Src, InDemandedBits, InDemandedElts, Known, TLO,
2547 Depth + 1))
2548 return true;
2549 assert(Known.getBitWidth() == InBits && "Src width has changed?");
2550 Known = Known.anyext(BitWidth);
2551
2552 // Attempt to avoid multi-use ops if we don't need anything from them.
2553 if (SDValue NewSrc = SimplifyMultipleUseDemandedBits(
2554 Src, InDemandedBits, InDemandedElts, TLO.DAG, Depth + 1))
2555 return TLO.CombineTo(Op, TLO.DAG.getNode(Op.getOpcode(), dl, VT, NewSrc));
2556 break;
2557 }
2558 case ISD::TRUNCATE: {
2559 SDValue Src = Op.getOperand(0);
2560
2561 // Simplify the input, using demanded bit information, and compute the known
2562 // zero/one bits live out.
2563 unsigned OperandBitWidth = Src.getScalarValueSizeInBits();
2564 APInt TruncMask = DemandedBits.zext(OperandBitWidth);
2565 if (SimplifyDemandedBits(Src, TruncMask, DemandedElts, Known, TLO,
2566 Depth + 1))
2567 return true;
2568 Known = Known.trunc(BitWidth);
2569
2570 // Attempt to avoid multi-use ops if we don't need anything from them.
2571 if (SDValue NewSrc = SimplifyMultipleUseDemandedBits(
2572 Src, TruncMask, DemandedElts, TLO.DAG, Depth + 1))
2573 return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::TRUNCATE, dl, VT, NewSrc));
2574
2575 // If the input is only used by this truncate, see if we can shrink it based
2576 // on the known demanded bits.
2577 switch (Src.getOpcode()) {
2578 default:
2579 break;
2580 case ISD::SRL:
2581 // Shrink SRL by a constant if none of the high bits shifted in are
2582 // demanded.
2583 if (TLO.LegalTypes() && !isTypeDesirableForOp(ISD::SRL, VT))
2584 // Do not turn (vt1 truncate (vt2 srl)) into (vt1 srl) if vt1 is
2585 // undesirable.
2586 break;
2587
2588 if (Src.getNode()->hasOneUse()) {
2589 if (isTruncateFree(Src, VT) &&
2590 !isTruncateFree(Src.getValueType(), VT)) {
2591 // If truncate is only free at trunc(srl), do not turn it into
2592 // srl(trunc). The check is done by first check the truncate is free
2593 // at Src's opcode(srl), then check the truncate is not done by
2594 // referencing sub-register. In test, if both trunc(srl) and
2595 // srl(trunc)'s trunc are free, srl(trunc) performs better. If only
2596 // trunc(srl)'s trunc is free, trunc(srl) is better.
2597 break;
2598 }
2599
2600 std::optional<uint64_t> ShAmtC =
2601 TLO.DAG.getValidShiftAmount(Src, DemandedElts, Depth + 2);
2602 if (!ShAmtC || *ShAmtC >= BitWidth)
2603 break;
2604 uint64_t ShVal = *ShAmtC;
2605
2606 APInt HighBits =
2607 APInt::getHighBitsSet(OperandBitWidth, OperandBitWidth - BitWidth);
2608 HighBits.lshrInPlace(ShVal);
2609 HighBits = HighBits.trunc(BitWidth);
2610 if (!(HighBits & DemandedBits)) {
2611 // None of the shifted in bits are needed. Add a truncate of the
2612 // shift input, then shift it.
2613 SDValue NewShAmt = TLO.DAG.getShiftAmountConstant(ShVal, VT, dl);
2614 SDValue NewTrunc =
2615 TLO.DAG.getNode(ISD::TRUNCATE, dl, VT, Src.getOperand(0));
2616 return TLO.CombineTo(
2617 Op, TLO.DAG.getNode(ISD::SRL, dl, VT, NewTrunc, NewShAmt));
2618 }
2619 }
2620 break;
2621 }
2622
2623 break;
2624 }
2625 case ISD::AssertZext: {
2626 // AssertZext demands all of the high bits, plus any of the low bits
2627 // demanded by its users.
2628 EVT ZVT = cast<VTSDNode>(Op.getOperand(1))->getVT();
2630 if (SimplifyDemandedBits(Op.getOperand(0), ~InMask | DemandedBits, Known,
2631 TLO, Depth + 1))
2632 return true;
2633
2634 Known.Zero |= ~InMask;
2635 Known.One &= (~Known.Zero);
2636 break;
2637 }
2639 SDValue Src = Op.getOperand(0);
2640 SDValue Idx = Op.getOperand(1);
2641 ElementCount SrcEltCnt = Src.getValueType().getVectorElementCount();
2642 unsigned EltBitWidth = Src.getScalarValueSizeInBits();
2643
2644 if (SrcEltCnt.isScalable())
2645 return false;
2646
2647 // Demand the bits from every vector element without a constant index.
2648 unsigned NumSrcElts = SrcEltCnt.getFixedValue();
2649 APInt DemandedSrcElts = APInt::getAllOnes(NumSrcElts);
2650 if (auto *CIdx = dyn_cast<ConstantSDNode>(Idx))
2651 if (CIdx->getAPIntValue().ult(NumSrcElts))
2652 DemandedSrcElts = APInt::getOneBitSet(NumSrcElts, CIdx->getZExtValue());
2653
2654 // If BitWidth > EltBitWidth the value is anyext:ed. So we do not know
2655 // anything about the extended bits.
2656 APInt DemandedSrcBits = DemandedBits;
2657 if (BitWidth > EltBitWidth)
2658 DemandedSrcBits = DemandedSrcBits.trunc(EltBitWidth);
2659
2660 if (SimplifyDemandedBits(Src, DemandedSrcBits, DemandedSrcElts, Known2, TLO,
2661 Depth + 1))
2662 return true;
2663
2664 // Attempt to avoid multi-use ops if we don't need anything from them.
2665 if (!DemandedSrcBits.isAllOnes() || !DemandedSrcElts.isAllOnes()) {
2666 if (SDValue DemandedSrc = SimplifyMultipleUseDemandedBits(
2667 Src, DemandedSrcBits, DemandedSrcElts, TLO.DAG, Depth + 1)) {
2668 SDValue NewOp =
2669 TLO.DAG.getNode(Op.getOpcode(), dl, VT, DemandedSrc, Idx);
2670 return TLO.CombineTo(Op, NewOp);
2671 }
2672 }
2673
2674 Known = Known2;
2675 if (BitWidth > EltBitWidth)
2676 Known = Known.anyext(BitWidth);
2677 break;
2678 }
2679 case ISD::BITCAST: {
2680 if (VT.isScalableVector())
2681 return false;
2682 SDValue Src = Op.getOperand(0);
2683 EVT SrcVT = Src.getValueType();
2684 unsigned NumSrcEltBits = SrcVT.getScalarSizeInBits();
2685
2686 // If this is an FP->Int bitcast and if the sign bit is the only
2687 // thing demanded, turn this into a FGETSIGN.
2688 if (!TLO.LegalOperations() && !VT.isVector() && !SrcVT.isVector() &&
2689 DemandedBits == APInt::getSignMask(Op.getValueSizeInBits()) &&
2690 SrcVT.isFloatingPoint()) {
2691 bool OpVTLegal = isOperationLegalOrCustom(ISD::FGETSIGN, VT);
2692 bool i32Legal = isOperationLegalOrCustom(ISD::FGETSIGN, MVT::i32);
2693 if ((OpVTLegal || i32Legal) && VT.isSimple() && SrcVT != MVT::f16 &&
2694 SrcVT != MVT::f128) {
2695 // Cannot eliminate/lower SHL for f128 yet.
2696 EVT Ty = OpVTLegal ? VT : MVT::i32;
2697 // Make a FGETSIGN + SHL to move the sign bit into the appropriate
2698 // place. We expect the SHL to be eliminated by other optimizations.
2699 SDValue Sign = TLO.DAG.getNode(ISD::FGETSIGN, dl, Ty, Src);
2700 unsigned OpVTSizeInBits = Op.getValueSizeInBits();
2701 if (!OpVTLegal && OpVTSizeInBits > 32)
2702 Sign = TLO.DAG.getNode(ISD::ZERO_EXTEND, dl, VT, Sign);
2703 unsigned ShVal = Op.getValueSizeInBits() - 1;
2704 SDValue ShAmt = TLO.DAG.getConstant(ShVal, dl, VT);
2705 return TLO.CombineTo(Op,
2706 TLO.DAG.getNode(ISD::SHL, dl, VT, Sign, ShAmt));
2707 }
2708 }
2709
2710 // Bitcast from a vector using SimplifyDemanded Bits/VectorElts.
2711 // Demand the elt/bit if any of the original elts/bits are demanded.
2712 if (SrcVT.isVector() && (BitWidth % NumSrcEltBits) == 0) {
2713 unsigned Scale = BitWidth / NumSrcEltBits;
2714 unsigned NumSrcElts = SrcVT.getVectorNumElements();
2715 APInt DemandedSrcBits = APInt::getZero(NumSrcEltBits);
2716 APInt DemandedSrcElts = APInt::getZero(NumSrcElts);
2717 for (unsigned i = 0; i != Scale; ++i) {
2718 unsigned EltOffset = IsLE ? i : (Scale - 1 - i);
2719 unsigned BitOffset = EltOffset * NumSrcEltBits;
2720 APInt Sub = DemandedBits.extractBits(NumSrcEltBits, BitOffset);
2721 if (!Sub.isZero()) {
2722 DemandedSrcBits |= Sub;
2723 for (unsigned j = 0; j != NumElts; ++j)
2724 if (DemandedElts[j])
2725 DemandedSrcElts.setBit((j * Scale) + i);
2726 }
2727 }
2728
2729 APInt KnownSrcUndef, KnownSrcZero;
2730 if (SimplifyDemandedVectorElts(Src, DemandedSrcElts, KnownSrcUndef,
2731 KnownSrcZero, TLO, Depth + 1))
2732 return true;
2733
2734 KnownBits KnownSrcBits;
2735 if (SimplifyDemandedBits(Src, DemandedSrcBits, DemandedSrcElts,
2736 KnownSrcBits, TLO, Depth + 1))
2737 return true;
2738 } else if (IsLE && (NumSrcEltBits % BitWidth) == 0) {
2739 // TODO - bigendian once we have test coverage.
2740 unsigned Scale = NumSrcEltBits / BitWidth;
2741 unsigned NumSrcElts = SrcVT.isVector() ? SrcVT.getVectorNumElements() : 1;
2742 APInt DemandedSrcBits = APInt::getZero(NumSrcEltBits);
2743 APInt DemandedSrcElts = APInt::getZero(NumSrcElts);
2744 for (unsigned i = 0; i != NumElts; ++i)
2745 if (DemandedElts[i]) {
2746 unsigned Offset = (i % Scale) * BitWidth;
2747 DemandedSrcBits.insertBits(DemandedBits, Offset);
2748 DemandedSrcElts.setBit(i / Scale);
2749 }
2750
2751 if (SrcVT.isVector()) {
2752 APInt KnownSrcUndef, KnownSrcZero;
2753 if (SimplifyDemandedVectorElts(Src, DemandedSrcElts, KnownSrcUndef,
2754 KnownSrcZero, TLO, Depth + 1))
2755 return true;
2756 }
2757
2758 KnownBits KnownSrcBits;
2759 if (SimplifyDemandedBits(Src, DemandedSrcBits, DemandedSrcElts,
2760 KnownSrcBits, TLO, Depth + 1))
2761 return true;
2762
2763 // Attempt to avoid multi-use ops if we don't need anything from them.
2764 if (!DemandedSrcBits.isAllOnes() || !DemandedSrcElts.isAllOnes()) {
2765 if (SDValue DemandedSrc = SimplifyMultipleUseDemandedBits(
2766 Src, DemandedSrcBits, DemandedSrcElts, TLO.DAG, Depth + 1)) {
2767 SDValue NewOp = TLO.DAG.getBitcast(VT, DemandedSrc);
2768 return TLO.CombineTo(Op, NewOp);
2769 }
2770 }
2771 }
2772
2773 // If this is a bitcast, let computeKnownBits handle it. Only do this on a
2774 // recursive call where Known may be useful to the caller.
2775 if (Depth > 0) {
2776 Known = TLO.DAG.computeKnownBits(Op, DemandedElts, Depth);
2777 return false;
2778 }
2779 break;
2780 }
2781 case ISD::MUL:
2782 if (DemandedBits.isPowerOf2()) {
2783 // The LSB of X*Y is set only if (X & 1) == 1 and (Y & 1) == 1.
2784 // If we demand exactly one bit N and we have "X * (C' << N)" where C' is
2785 // odd (has LSB set), then the left-shifted low bit of X is the answer.
2786 unsigned CTZ = DemandedBits.countr_zero();
2787 ConstantSDNode *C = isConstOrConstSplat(Op.getOperand(1), DemandedElts);
2788 if (C && C->getAPIntValue().countr_zero() == CTZ) {
2789 SDValue AmtC = TLO.DAG.getShiftAmountConstant(CTZ, VT, dl);
2790 SDValue Shl = TLO.DAG.getNode(ISD::SHL, dl, VT, Op.getOperand(0), AmtC);
2791 return TLO.CombineTo(Op, Shl);
2792 }
2793 }
2794 // For a squared value "X * X", the bottom 2 bits are 0 and X[0] because:
2795 // X * X is odd iff X is odd.
2796 // 'Quadratic Reciprocity': X * X -> 0 for bit[1]
2797 if (Op.getOperand(0) == Op.getOperand(1) && DemandedBits.ult(4)) {
2798 SDValue One = TLO.DAG.getConstant(1, dl, VT);
2799 SDValue And1 = TLO.DAG.getNode(ISD::AND, dl, VT, Op.getOperand(0), One);
2800 return TLO.CombineTo(Op, And1);
2801 }
2802 [[fallthrough]];
2803 case ISD::ADD:
2804 case ISD::SUB: {
2805 // Add, Sub, and Mul don't demand any bits in positions beyond that
2806 // of the highest bit demanded of them.
2807 SDValue Op0 = Op.getOperand(0), Op1 = Op.getOperand(1);
2808 SDNodeFlags Flags = Op.getNode()->getFlags();
2809 unsigned DemandedBitsLZ = DemandedBits.countl_zero();
2810 APInt LoMask = APInt::getLowBitsSet(BitWidth, BitWidth - DemandedBitsLZ);
2811 KnownBits KnownOp0, KnownOp1;
2812 auto GetDemandedBitsLHSMask = [&](APInt Demanded,
2813 const KnownBits &KnownRHS) {
2814 if (Op.getOpcode() == ISD::MUL)
2815 Demanded.clearHighBits(KnownRHS.countMinTrailingZeros());
2816 return Demanded;
2817 };
2818 if (SimplifyDemandedBits(Op1, LoMask, DemandedElts, KnownOp1, TLO,
2819 Depth + 1) ||
2820 SimplifyDemandedBits(Op0, GetDemandedBitsLHSMask(LoMask, KnownOp1),
2821 DemandedElts, KnownOp0, TLO, Depth + 1) ||
2822 // See if the operation should be performed at a smaller bit width.
2823 ShrinkDemandedOp(Op, BitWidth, DemandedBits, TLO)) {
2824 if (Flags.hasNoSignedWrap() || Flags.hasNoUnsignedWrap()) {
2825 // Disable the nsw and nuw flags. We can no longer guarantee that we
2826 // won't wrap after simplification.
2827 Flags.setNoSignedWrap(false);
2828 Flags.setNoUnsignedWrap(false);
2829 Op->setFlags(Flags);
2830 }
2831 return true;
2832 }
2833
2834 // neg x with only low bit demanded is simply x.
2835 if (Op.getOpcode() == ISD::SUB && DemandedBits.isOne() &&
2836 isNullConstant(Op0))
2837 return TLO.CombineTo(Op, Op1);
2838
2839 // Attempt to avoid multi-use ops if we don't need anything from them.
2840 if (!LoMask.isAllOnes() || !DemandedElts.isAllOnes()) {
2841 SDValue DemandedOp0 = SimplifyMultipleUseDemandedBits(
2842 Op0, LoMask, DemandedElts, TLO.DAG, Depth + 1);
2843 SDValue DemandedOp1 = SimplifyMultipleUseDemandedBits(
2844 Op1, LoMask, DemandedElts, TLO.DAG, Depth + 1);
2845 if (DemandedOp0 || DemandedOp1) {
2846 Flags.setNoSignedWrap(false);
2847 Flags.setNoUnsignedWrap(false);
2848 Op0 = DemandedOp0 ? DemandedOp0 : Op0;
2849 Op1 = DemandedOp1 ? DemandedOp1 : Op1;
2850 SDValue NewOp =
2851 TLO.DAG.getNode(Op.getOpcode(), dl, VT, Op0, Op1, Flags);
2852 return TLO.CombineTo(Op, NewOp);
2853 }
2854 }
2855
2856 // If we have a constant operand, we may be able to turn it into -1 if we
2857 // do not demand the high bits. This can make the constant smaller to
2858 // encode, allow more general folding, or match specialized instruction
2859 // patterns (eg, 'blsr' on x86). Don't bother changing 1 to -1 because that
2860 // is probably not useful (and could be detrimental).
2862 APInt HighMask = APInt::getHighBitsSet(BitWidth, DemandedBitsLZ);
2863 if (C && !C->isAllOnes() && !C->isOne() &&
2864 (C->getAPIntValue() | HighMask).isAllOnes()) {
2865 SDValue Neg1 = TLO.DAG.getAllOnesConstant(dl, VT);
2866 // Disable the nsw and nuw flags. We can no longer guarantee that we
2867 // won't wrap after simplification.
2868 Flags.setNoSignedWrap(false);
2869 Flags.setNoUnsignedWrap(false);
2870 SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), dl, VT, Op0, Neg1, Flags);
2871 return TLO.CombineTo(Op, NewOp);
2872 }
2873
2874 // Match a multiply with a disguised negated-power-of-2 and convert to a
2875 // an equivalent shift-left amount.
2876 // Example: (X * MulC) + Op1 --> Op1 - (X << log2(-MulC))
2877 auto getShiftLeftAmt = [&HighMask](SDValue Mul) -> unsigned {
2878 if (Mul.getOpcode() != ISD::MUL || !Mul.hasOneUse())
2879 return 0;
2880
2881 // Don't touch opaque constants. Also, ignore zero and power-of-2
2882 // multiplies. Those will get folded later.
2883 ConstantSDNode *MulC = isConstOrConstSplat(Mul.getOperand(1));
2884 if (MulC && !MulC->isOpaque() && !MulC->isZero() &&
2885 !MulC->getAPIntValue().isPowerOf2()) {
2886 APInt UnmaskedC = MulC->getAPIntValue() | HighMask;
2887 if (UnmaskedC.isNegatedPowerOf2())
2888 return (-UnmaskedC).logBase2();
2889 }
2890 return 0;
2891 };
2892
2893 auto foldMul = [&](ISD::NodeType NT, SDValue X, SDValue Y,
2894 unsigned ShlAmt) {
2895 SDValue ShlAmtC = TLO.DAG.getShiftAmountConstant(ShlAmt, VT, dl);
2896 SDValue Shl = TLO.DAG.getNode(ISD::SHL, dl, VT, X, ShlAmtC);
2897 SDValue Res = TLO.DAG.getNode(NT, dl, VT, Y, Shl);
2898 return TLO.CombineTo(Op, Res);
2899 };
2900
2902 if (Op.getOpcode() == ISD::ADD) {
2903 // (X * MulC) + Op1 --> Op1 - (X << log2(-MulC))
2904 if (unsigned ShAmt = getShiftLeftAmt(Op0))
2905 return foldMul(ISD::SUB, Op0.getOperand(0), Op1, ShAmt);
2906 // Op0 + (X * MulC) --> Op0 - (X << log2(-MulC))
2907 if (unsigned ShAmt = getShiftLeftAmt(Op1))
2908 return foldMul(ISD::SUB, Op1.getOperand(0), Op0, ShAmt);
2909 }
2910 if (Op.getOpcode() == ISD::SUB) {
2911 // Op0 - (X * MulC) --> Op0 + (X << log2(-MulC))
2912 if (unsigned ShAmt = getShiftLeftAmt(Op1))
2913 return foldMul(ISD::ADD, Op1.getOperand(0), Op0, ShAmt);
2914 }
2915 }
2916
2917 if (Op.getOpcode() == ISD::MUL) {
2918 Known = KnownBits::mul(KnownOp0, KnownOp1);
2919 } else { // Op.getOpcode() is either ISD::ADD or ISD::SUB.
2921 Op.getOpcode() == ISD::ADD, Flags.hasNoSignedWrap(),
2922 Flags.hasNoUnsignedWrap(), KnownOp0, KnownOp1);
2923 }
2924 break;
2925 }
2926 default:
2927 // We also ask the target about intrinsics (which could be specific to it).
2928 if (Op.getOpcode() >= ISD::BUILTIN_OP_END ||
2929 Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN) {
2930 // TODO: Probably okay to remove after audit; here to reduce change size
2931 // in initial enablement patch for scalable vectors
2932 if (Op.getValueType().isScalableVector())
2933 break;
2934 if (SimplifyDemandedBitsForTargetNode(Op, DemandedBits, DemandedElts,
2935 Known, TLO, Depth))
2936 return true;
2937 break;
2938 }
2939
2940 // Just use computeKnownBits to compute output bits.
2941 Known = TLO.DAG.computeKnownBits(Op, DemandedElts, Depth);
2942 break;
2943 }
2944
2945 // If we know the value of all of the demanded bits, return this as a
2946 // constant.
2947 if (!isTargetCanonicalConstantNode(Op) &&
2948 DemandedBits.isSubsetOf(Known.Zero | Known.One)) {
2949 // Avoid folding to a constant if any OpaqueConstant is involved.
2950 const SDNode *N = Op.getNode();
2951 for (SDNode *Op :
2953 if (auto *C = dyn_cast<ConstantSDNode>(Op))
2954 if (C->isOpaque())
2955 return false;
2956 }
2957 if (VT.isInteger())
2958 return TLO.CombineTo(Op, TLO.DAG.getConstant(Known.One, dl, VT));
2959 if (VT.isFloatingPoint())
2960 return TLO.CombineTo(
2961 Op,
2962 TLO.DAG.getConstantFP(
2963 APFloat(TLO.DAG.EVTToAPFloatSemantics(VT), Known.One), dl, VT));
2964 }
2965
2966 // A multi use 'all demanded elts' simplify failed to find any knownbits.
2967 // Try again just for the original demanded elts.
2968 // Ensure we do this AFTER constant folding above.
2969 if (HasMultiUse && Known.isUnknown() && !OriginalDemandedElts.isAllOnes())
2970 Known = TLO.DAG.computeKnownBits(Op, OriginalDemandedElts, Depth);
2971
2972 return false;
2973}
2974
2976 const APInt &DemandedElts,
2977 DAGCombinerInfo &DCI) const {
2978 SelectionDAG &DAG = DCI.DAG;
2979 TargetLoweringOpt TLO(DAG, !DCI.isBeforeLegalize(),
2980 !DCI.isBeforeLegalizeOps());
2981
2982 APInt KnownUndef, KnownZero;
2983 bool Simplified =
2984 SimplifyDemandedVectorElts(Op, DemandedElts, KnownUndef, KnownZero, TLO);
2985 if (Simplified) {
2986 DCI.AddToWorklist(Op.getNode());
2987 DCI.CommitTargetLoweringOpt(TLO);
2988 }
2989
2990 return Simplified;
2991}
2992
2993/// Given a vector binary operation and known undefined elements for each input
2994/// operand, compute whether each element of the output is undefined.
2996 const APInt &UndefOp0,
2997 const APInt &UndefOp1) {
2998 EVT VT = BO.getValueType();
3000 "Vector binop only");
3001
3002 EVT EltVT = VT.getVectorElementType();
3003 unsigned NumElts = VT.isFixedLengthVector() ? VT.getVectorNumElements() : 1;
3004 assert(UndefOp0.getBitWidth() == NumElts &&
3005 UndefOp1.getBitWidth() == NumElts && "Bad type for undef analysis");
3006
3007 auto getUndefOrConstantElt = [&](SDValue V, unsigned Index,
3008 const APInt &UndefVals) {
3009 if (UndefVals[Index])
3010 return DAG.getUNDEF(EltVT);
3011
3012 if (auto *BV = dyn_cast<BuildVectorSDNode>(V)) {
3013 // Try hard to make sure that the getNode() call is not creating temporary
3014 // nodes. Ignore opaque integers because they do not constant fold.
3015 SDValue Elt = BV->getOperand(Index);
3016 auto *C = dyn_cast<ConstantSDNode>(Elt);
3017 if (isa<ConstantFPSDNode>(Elt) || Elt.isUndef() || (C && !C->isOpaque()))
3018 return Elt;
3019 }
3020
3021 return SDValue();
3022 };
3023
3024 APInt KnownUndef = APInt::getZero(NumElts);
3025 for (unsigned i = 0; i != NumElts; ++i) {
3026 // If both inputs for this element are either constant or undef and match
3027 // the element type, compute the constant/undef result for this element of
3028 // the vector.
3029 // TODO: Ideally we would use FoldConstantArithmetic() here, but that does
3030 // not handle FP constants. The code within getNode() should be refactored
3031 // to avoid the danger of creating a bogus temporary node here.
3032 SDValue C0 = getUndefOrConstantElt(BO.getOperand(0), i, UndefOp0);
3033 SDValue C1 = getUndefOrConstantElt(BO.getOperand(1), i, UndefOp1);
3034 if (C0 && C1 && C0.getValueType() == EltVT && C1.getValueType() == EltVT)
3035 if (DAG.getNode(BO.getOpcode(), SDLoc(BO), EltVT, C0, C1).isUndef())
3036 KnownUndef.setBit(i);
3037 }
3038 return KnownUndef;
3039}
3040
3042 SDValue Op, const APInt &OriginalDemandedElts, APInt &KnownUndef,
3043 APInt &KnownZero, TargetLoweringOpt &TLO, unsigned Depth,
3044 bool AssumeSingleUse) const {
3045 EVT VT = Op.getValueType();
3046 unsigned Opcode = Op.getOpcode();
3047 APInt DemandedElts = OriginalDemandedElts;
3048 unsigned NumElts = DemandedElts.getBitWidth();
3049 assert(VT.isVector() && "Expected vector op");
3050
3051 KnownUndef = KnownZero = APInt::getZero(NumElts);
3052
3053 const TargetLowering &TLI = TLO.DAG.getTargetLoweringInfo();
3054 if (!TLI.shouldSimplifyDemandedVectorElts(Op, TLO))
3055 return false;
3056
3057 // TODO: For now we assume we know nothing about scalable vectors.
3058 if (VT.isScalableVector())
3059 return false;
3060
3061 assert(VT.getVectorNumElements() == NumElts &&
3062 "Mask size mismatches value type element count!");
3063
3064 // Undef operand.
3065 if (Op.isUndef()) {
3066 KnownUndef.setAllBits();
3067 return false;
3068 }
3069
3070 // If Op has other users, assume that all elements are needed.
3071 if (!AssumeSingleUse && !Op.getNode()->hasOneUse())
3072 DemandedElts.setAllBits();
3073
3074 // Not demanding any elements from Op.
3075 if (DemandedElts == 0) {
3076 KnownUndef.setAllBits();
3077 return TLO.CombineTo(Op, TLO.DAG.getUNDEF(VT));
3078 }
3079
3080 // Limit search depth.
3082 return false;
3083
3084 SDLoc DL(Op);
3085 unsigned EltSizeInBits = VT.getScalarSizeInBits();
3086 bool IsLE = TLO.DAG.getDataLayout().isLittleEndian();
3087
3088 // Helper for demanding the specified elements and all the bits of both binary
3089 // operands.
3090 auto SimplifyDemandedVectorEltsBinOp = [&](SDValue Op0, SDValue Op1) {
3091 SDValue NewOp0 = SimplifyMultipleUseDemandedVectorElts(Op0, DemandedElts,
3092 TLO.DAG, Depth + 1);
3093 SDValue NewOp1 = SimplifyMultipleUseDemandedVectorElts(Op1, DemandedElts,
3094 TLO.DAG, Depth + 1);
3095 if (NewOp0 || NewOp1) {
3096 SDValue NewOp =
3097 TLO.DAG.getNode(Opcode, SDLoc(Op), VT, NewOp0 ? NewOp0 : Op0,
3098 NewOp1 ? NewOp1 : Op1, Op->getFlags());
3099 return TLO.CombineTo(Op, NewOp);
3100 }
3101 return false;
3102 };
3103
3104 switch (Opcode) {
3105 case ISD::SCALAR_TO_VECTOR: {
3106 if (!DemandedElts[0]) {
3107 KnownUndef.setAllBits();
3108 return TLO.CombineTo(Op, TLO.DAG.getUNDEF(VT));
3109 }
3110 SDValue ScalarSrc = Op.getOperand(0);
3111 if (ScalarSrc.getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
3112 SDValue Src = ScalarSrc.getOperand(0);
3113 SDValue Idx = ScalarSrc.getOperand(1);
3114 EVT SrcVT = Src.getValueType();
3115
3116 ElementCount SrcEltCnt = SrcVT.getVectorElementCount();
3117
3118 if (SrcEltCnt.isScalable())
3119 return false;
3120
3121 unsigned NumSrcElts = SrcEltCnt.getFixedValue();
3122 if (isNullConstant(Idx)) {
3123 APInt SrcDemandedElts = APInt::getOneBitSet(NumSrcElts, 0);
3124 APInt SrcUndef = KnownUndef.zextOrTrunc(NumSrcElts);
3125 APInt SrcZero = KnownZero.zextOrTrunc(NumSrcElts);
3126 if (SimplifyDemandedVectorElts(Src, SrcDemandedElts, SrcUndef, SrcZero,
3127 TLO, Depth + 1))
3128 return true;
3129 }
3130 }
3131 KnownUndef.setHighBits(NumElts - 1);
3132 break;
3133 }
3134 case ISD::BITCAST: {
3135 SDValue Src = Op.getOperand(0);
3136 EVT SrcVT = Src.getValueType();
3137
3138 // We only handle vectors here.
3139 // TODO - investigate calling SimplifyDemandedBits/ComputeKnownBits?
3140 if (!SrcVT.isVector())
3141 break;
3142
3143 // Fast handling of 'identity' bitcasts.
3144 unsigned NumSrcElts = SrcVT.getVectorNumElements();
3145 if (NumSrcElts == NumElts)
3146 return SimplifyDemandedVectorElts(Src, DemandedElts, KnownUndef,
3147 KnownZero, TLO, Depth + 1);
3148
3149 APInt SrcDemandedElts, SrcZero, SrcUndef;
3150
3151 // Bitcast from 'large element' src vector to 'small element' vector, we
3152 // must demand a source element if any DemandedElt maps to it.
3153 if ((NumElts % NumSrcElts) == 0) {
3154 unsigned Scale = NumElts / NumSrcElts;
3155 SrcDemandedElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
3156 if (SimplifyDemandedVectorElts(Src, SrcDemandedElts, SrcUndef, SrcZero,
3157 TLO, Depth + 1))
3158 return true;
3159
3160 // Try calling SimplifyDemandedBits, converting demanded elts to the bits
3161 // of the large element.
3162 // TODO - bigendian once we have test coverage.
3163 if (IsLE) {
3164 unsigned SrcEltSizeInBits = SrcVT.getScalarSizeInBits();
3165 APInt SrcDemandedBits = APInt::getZero(SrcEltSizeInBits);
3166 for (unsigned i = 0; i != NumElts; ++i)
3167 if (DemandedElts[i]) {
3168 unsigned Ofs = (i % Scale) * EltSizeInBits;
3169 SrcDemandedBits.setBits(Ofs, Ofs + EltSizeInBits);
3170 }
3171
3172 KnownBits Known;
3173 if (SimplifyDemandedBits(Src, SrcDemandedBits, SrcDemandedElts, Known,
3174 TLO, Depth + 1))
3175 return true;
3176
3177 // The bitcast has split each wide element into a number of
3178 // narrow subelements. We have just computed the Known bits
3179 // for wide elements. See if element splitting results in
3180 // some subelements being zero. Only for demanded elements!
3181 for (unsigned SubElt = 0; SubElt != Scale; ++SubElt) {
3182 if (!Known.Zero.extractBits(EltSizeInBits, SubElt * EltSizeInBits)
3183 .isAllOnes())
3184 continue;
3185 for (unsigned SrcElt = 0; SrcElt != NumSrcElts; ++SrcElt) {
3186 unsigned Elt = Scale * SrcElt + SubElt;
3187 if (DemandedElts[Elt])
3188 KnownZero.setBit(Elt);
3189 }
3190 }
3191 }
3192
3193 // If the src element is zero/undef then all the output elements will be -
3194 // only demanded elements are guaranteed to be correct.
3195 for (unsigned i = 0; i != NumSrcElts; ++i) {
3196 if (SrcDemandedElts[i]) {
3197 if (SrcZero[i])
3198 KnownZero.setBits(i * Scale, (i + 1) * Scale);
3199 if (SrcUndef[i])
3200 KnownUndef.setBits(i * Scale, (i + 1) * Scale);
3201 }
3202 }
3203 }
3204
3205 // Bitcast from 'small element' src vector to 'large element' vector, we
3206 // demand all smaller source elements covered by the larger demanded element
3207 // of this vector.
3208 if ((NumSrcElts % NumElts) == 0) {
3209 unsigned Scale = NumSrcElts / NumElts;
3210 SrcDemandedElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
3211 if (SimplifyDemandedVectorElts(Src, SrcDemandedElts, SrcUndef, SrcZero,
3212 TLO, Depth + 1))
3213 return true;
3214
3215 // If all the src elements covering an output element are zero/undef, then
3216 // the output element will be as well, assuming it was demanded.
3217 for (unsigned i = 0; i != NumElts; ++i) {
3218 if (DemandedElts[i]) {
3219 if (SrcZero.extractBits(Scale, i * Scale).isAllOnes())
3220 KnownZero.setBit(i);
3221 if (SrcUndef.extractBits(Scale, i * Scale).isAllOnes())
3222 KnownUndef.setBit(i);
3223 }
3224 }
3225 }
3226 break;
3227 }
3228 case ISD::FREEZE: {
3229 SDValue N0 = Op.getOperand(0);
3230 if (TLO.DAG.isGuaranteedNotToBeUndefOrPoison(N0, DemandedElts,
3231 /*PoisonOnly=*/false))
3232 return TLO.CombineTo(Op, N0);
3233
3234 // TODO: Replace this with the general fold from DAGCombiner::visitFREEZE
3235 // freeze(op(x, ...)) -> op(freeze(x), ...).
3236 if (N0.getOpcode() == ISD::SCALAR_TO_VECTOR && DemandedElts == 1)
3237 return TLO.CombineTo(
3239 TLO.DAG.getFreeze(N0.getOperand(0))));
3240 break;
3241 }
3242 case ISD::BUILD_VECTOR: {
3243 // Check all elements and simplify any unused elements with UNDEF.
3244 if (!DemandedElts.isAllOnes()) {
3245 // Don't simplify BROADCASTS.
3246 if (llvm::any_of(Op->op_values(),
3247 [&](SDValue Elt) { return Op.getOperand(0) != Elt; })) {
3248 SmallVector<SDValue, 32> Ops(Op->ops());
3249 bool Updated = false;
3250 for (unsigned i = 0; i != NumElts; ++i) {
3251 if (!DemandedElts[i] && !Ops[i].isUndef()) {
3252 Ops[i] = TLO.DAG.getUNDEF(Ops[0].getValueType());
3253 KnownUndef.setBit(i);
3254 Updated = true;
3255 }
3256 }
3257 if (Updated)
3258 return TLO.CombineTo(Op, TLO.DAG.getBuildVector(VT, DL, Ops));
3259 }
3260 }
3261 for (unsigned i = 0; i != NumElts; ++i) {
3262 SDValue SrcOp = Op.getOperand(i);
3263 if (SrcOp.isUndef()) {
3264 KnownUndef.setBit(i);
3265 } else if (EltSizeInBits == SrcOp.getScalarValueSizeInBits() &&
3267 KnownZero.setBit(i);
3268 }
3269 }
3270 break;
3271 }
3272 case ISD::CONCAT_VECTORS: {
3273 EVT SubVT = Op.getOperand(0).getValueType();
3274 unsigned NumSubVecs = Op.getNumOperands();
3275 unsigned NumSubElts = SubVT.getVectorNumElements();
3276 for (unsigned i = 0; i != NumSubVecs; ++i) {
3277 SDValue SubOp = Op.getOperand(i);
3278 APInt SubElts = DemandedElts.extractBits(NumSubElts, i * NumSubElts);
3279 APInt SubUndef, SubZero;
3280 if (SimplifyDemandedVectorElts(SubOp, SubElts, SubUndef, SubZero, TLO,
3281 Depth + 1))
3282 return true;
3283 KnownUndef.insertBits(SubUndef, i * NumSubElts);
3284 KnownZero.insertBits(SubZero, i * NumSubElts);
3285 }
3286
3287 // Attempt to avoid multi-use ops if we don't need anything from them.
3288 if (!DemandedElts.isAllOnes()) {
3289 bool FoundNewSub = false;
3290 SmallVector<SDValue, 2> DemandedSubOps;
3291 for (unsigned i = 0; i != NumSubVecs; ++i) {
3292 SDValue SubOp = Op.getOperand(i);
3293 APInt SubElts = DemandedElts.extractBits(NumSubElts, i * NumSubElts);
3294 SDValue NewSubOp = SimplifyMultipleUseDemandedVectorElts(
3295 SubOp, SubElts, TLO.DAG, Depth + 1);
3296 DemandedSubOps.push_back(NewSubOp ? NewSubOp : SubOp);
3297 FoundNewSub = NewSubOp ? true : FoundNewSub;
3298 }
3299 if (FoundNewSub) {
3300 SDValue NewOp =
3301 TLO.DAG.getNode(Op.getOpcode(), SDLoc(Op), VT, DemandedSubOps);
3302 return TLO.CombineTo(Op, NewOp);
3303 }
3304 }
3305 break;
3306 }
3307 case ISD::INSERT_SUBVECTOR: {
3308 // Demand any elements from the subvector and the remainder from the src its
3309 // inserted into.
3310 SDValue Src = Op.getOperand(0);
3311 SDValue Sub = Op.getOperand(1);
3312 uint64_t Idx = Op.getConstantOperandVal(2);
3313 unsigned NumSubElts = Sub.getValueType().getVectorNumElements();
3314 APInt DemandedSubElts = DemandedElts.extractBits(NumSubElts, Idx);
3315 APInt DemandedSrcElts = DemandedElts;
3316 DemandedSrcElts.insertBits(APInt::getZero(NumSubElts), Idx);
3317
3318 APInt SubUndef, SubZero;
3319 if (SimplifyDemandedVectorElts(Sub, DemandedSubElts, SubUndef, SubZero, TLO,
3320 Depth + 1))
3321 return true;
3322
3323 // If none of the src operand elements are demanded, replace it with undef.
3324 if (!DemandedSrcElts && !Src.isUndef())
3325 return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT,
3326 TLO.DAG.getUNDEF(VT), Sub,
3327 Op.getOperand(2)));
3328
3329 if (SimplifyDemandedVectorElts(Src, DemandedSrcElts, KnownUndef, KnownZero,
3330 TLO, Depth + 1))
3331 return true;
3332 KnownUndef.insertBits(SubUndef, Idx);
3333 KnownZero.insertBits(SubZero, Idx);
3334
3335 // Attempt to avoid multi-use ops if we don't need anything from them.
3336 if (!DemandedSrcElts.isAllOnes() || !DemandedSubElts.isAllOnes()) {
3337 SDValue NewSrc = SimplifyMultipleUseDemandedVectorElts(
3338 Src, DemandedSrcElts, TLO.DAG, Depth + 1);
3339 SDValue NewSub = SimplifyMultipleUseDemandedVectorElts(
3340 Sub, DemandedSubElts, TLO.DAG, Depth + 1);
3341 if (NewSrc || NewSub) {
3342 NewSrc = NewSrc ? NewSrc : Src;
3343 NewSub = NewSub ? NewSub : Sub;
3344 SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), SDLoc(Op), VT, NewSrc,
3345 NewSub, Op.getOperand(2));
3346 return TLO.CombineTo(Op, NewOp);
3347 }
3348 }
3349 break;
3350 }
3352 // Offset the demanded elts by the subvector index.
3353 SDValue Src = Op.getOperand(0);
3354 if (Src.getValueType().isScalableVector())
3355 break;
3356 uint64_t Idx = Op.getConstantOperandVal(1);
3357 unsigned NumSrcElts = Src.getValueType().getVectorNumElements();
3358 APInt DemandedSrcElts = DemandedElts.zext(NumSrcElts).shl(Idx);
3359
3360 APInt SrcUndef, SrcZero;
3361 if (SimplifyDemandedVectorElts(Src, DemandedSrcElts, SrcUndef, SrcZero, TLO,
3362 Depth + 1))
3363 return true;
3364 KnownUndef = SrcUndef.extractBits(NumElts, Idx);
3365 KnownZero = SrcZero.extractBits(NumElts, Idx);
3366
3367 // Attempt to avoid multi-use ops if we don't need anything from them.
3368 if (!DemandedElts.isAllOnes()) {
3369 SDValue NewSrc = SimplifyMultipleUseDemandedVectorElts(
3370 Src, DemandedSrcElts, TLO.DAG, Depth + 1);
3371 if (NewSrc) {
3372 SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), SDLoc(Op), VT, NewSrc,
3373 Op.getOperand(1));
3374 return TLO.CombineTo(Op, NewOp);
3375 }
3376 }
3377 break;
3378 }
3380 SDValue Vec = Op.getOperand(0);
3381 SDValue Scl = Op.getOperand(1);
3382 auto *CIdx = dyn_cast<ConstantSDNode>(Op.getOperand(2));
3383
3384 // For a legal, constant insertion index, if we don't need this insertion
3385 // then strip it, else remove it from the demanded elts.
3386 if (CIdx && CIdx->getAPIntValue().ult(NumElts)) {
3387 unsigned Idx = CIdx->getZExtValue();
3388 if (!DemandedElts[Idx])
3389 return TLO.CombineTo(Op, Vec);
3390
3391 APInt DemandedVecElts(DemandedElts);
3392 DemandedVecElts.clearBit(Idx);
3393 if (SimplifyDemandedVectorElts(Vec, DemandedVecElts, KnownUndef,
3394 KnownZero, TLO, Depth + 1))
3395 return true;
3396
3397 KnownUndef.setBitVal(Idx, Scl.isUndef());
3398
3399 KnownZero.setBitVal(Idx, isNullConstant(Scl) || isNullFPConstant(Scl));
3400 break;
3401 }
3402
3403 APInt VecUndef, VecZero;
3404 if (SimplifyDemandedVectorElts(Vec, DemandedElts, VecUndef, VecZero, TLO,
3405 Depth + 1))
3406 return true;
3407 // Without knowing the insertion index we can't set KnownUndef/KnownZero.
3408 break;
3409 }
3410 case ISD::VSELECT: {
3411 SDValue Sel = Op.getOperand(0);
3412 SDValue LHS = Op.getOperand(1);
3413 SDValue RHS = Op.getOperand(2);
3414
3415 // Try to transform the select condition based on the current demanded
3416 // elements.
3417 APInt UndefSel, ZeroSel;
3418 if (SimplifyDemandedVectorElts(Sel, DemandedElts, UndefSel, ZeroSel, TLO,
3419 Depth + 1))
3420 return true;
3421
3422 // See if we can simplify either vselect operand.
3423 APInt DemandedLHS(DemandedElts);
3424 APInt DemandedRHS(DemandedElts);
3425 APInt UndefLHS, ZeroLHS;
3426 APInt UndefRHS, ZeroRHS;
3427 if (SimplifyDemandedVectorElts(LHS, DemandedLHS, UndefLHS, ZeroLHS, TLO,
3428 Depth + 1))
3429 return true;
3430 if (SimplifyDemandedVectorElts(RHS, DemandedRHS, UndefRHS, ZeroRHS, TLO,
3431 Depth + 1))
3432 return true;
3433
3434 KnownUndef = UndefLHS & UndefRHS;
3435 KnownZero = ZeroLHS & ZeroRHS;
3436
3437 // If we know that the selected element is always zero, we don't need the
3438 // select value element.
3439 APInt DemandedSel = DemandedElts & ~KnownZero;
3440 if (DemandedSel != DemandedElts)
3441 if (SimplifyDemandedVectorElts(Sel, DemandedSel, UndefSel, ZeroSel, TLO,
3442 Depth + 1))
3443 return true;
3444
3445 break;
3446 }
3447 case ISD::VECTOR_SHUFFLE: {
3448 SDValue LHS = Op.getOperand(0);
3449 SDValue RHS = Op.getOperand(1);
3450 ArrayRef<int> ShuffleMask = cast<ShuffleVectorSDNode>(Op)->getMask();
3451
3452 // Collect demanded elements from shuffle operands..
3453 APInt DemandedLHS(NumElts, 0);
3454 APInt DemandedRHS(NumElts, 0);
3455 for (unsigned i = 0; i != NumElts; ++i) {
3456 int M = ShuffleMask[i];
3457 if (M < 0 || !DemandedElts[i])
3458 continue;
3459 assert(0 <= M && M < (int)(2 * NumElts) && "Shuffle index out of range");
3460 if (M < (int)NumElts)
3461 DemandedLHS.setBit(M);
3462 else
3463 DemandedRHS.setBit(M - NumElts);
3464 }
3465
3466 // See if we can simplify either shuffle operand.
3467 APInt UndefLHS, ZeroLHS;
3468 APInt UndefRHS, ZeroRHS;
3469 if (SimplifyDemandedVectorElts(LHS, DemandedLHS, UndefLHS, ZeroLHS, TLO,
3470 Depth + 1))
3471 return true;
3472 if (SimplifyDemandedVectorElts(RHS, DemandedRHS, UndefRHS, ZeroRHS, TLO,
3473 Depth + 1))
3474 return true;
3475
3476 // Simplify mask using undef elements from LHS/RHS.
3477 bool Updated = false;
3478 bool IdentityLHS = true, IdentityRHS = true;
3479 SmallVector<int, 32> NewMask(ShuffleMask);
3480 for (unsigned i = 0; i != NumElts; ++i) {
3481 int &M = NewMask[i];
3482 if (M < 0)
3483 continue;
3484 if (!DemandedElts[i] || (M < (int)NumElts && UndefLHS[M]) ||
3485 (M >= (int)NumElts && UndefRHS[M - NumElts])) {
3486 Updated = true;
3487 M = -1;
3488 }
3489 IdentityLHS &= (M < 0) || (M == (int)i);
3490 IdentityRHS &= (M < 0) || ((M - NumElts) == i);
3491 }
3492
3493 // Update legal shuffle masks based on demanded elements if it won't reduce
3494 // to Identity which can cause premature removal of the shuffle mask.
3495 if (Updated && !IdentityLHS && !IdentityRHS && !TLO.LegalOps) {
3496 SDValue LegalShuffle =
3497 buildLegalVectorShuffle(VT, DL, LHS, RHS, NewMask, TLO.DAG);
3498 if (LegalShuffle)
3499 return TLO.CombineTo(Op, LegalShuffle);
3500 }
3501
3502 // Propagate undef/zero elements from LHS/RHS.
3503 for (unsigned i = 0; i != NumElts; ++i) {
3504 int M = ShuffleMask[i];
3505 if (M < 0) {
3506 KnownUndef.setBit(i);
3507 } else if (M < (int)NumElts) {
3508 if (UndefLHS[M])
3509 KnownUndef.setBit(i);
3510 if (ZeroLHS[M])
3511 KnownZero.setBit(i);
3512 } else {
3513 if (UndefRHS[M - NumElts])
3514 KnownUndef.setBit(i);
3515 if (ZeroRHS[M - NumElts])
3516 KnownZero.setBit(i);
3517 }
3518 }
3519 break;
3520 }
3524 APInt SrcUndef, SrcZero;
3525 SDValue Src = Op.getOperand(0);
3526 unsigned NumSrcElts = Src.getValueType().getVectorNumElements();
3527 APInt DemandedSrcElts = DemandedElts.zext(NumSrcElts);
3528 if (SimplifyDemandedVectorElts(Src, DemandedSrcElts, SrcUndef, SrcZero, TLO,
3529 Depth + 1))
3530 return true;
3531 KnownZero = SrcZero.zextOrTrunc(NumElts);
3532 KnownUndef = SrcUndef.zextOrTrunc(NumElts);
3533
3534 if (IsLE && Op.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG &&
3535 Op.getValueSizeInBits() == Src.getValueSizeInBits() &&
3536 DemandedSrcElts == 1) {
3537 // aext - if we just need the bottom element then we can bitcast.
3538 return TLO.CombineTo(Op, TLO.DAG.getBitcast(VT, Src));
3539 }
3540
3541 if (Op.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG) {
3542 // zext(undef) upper bits are guaranteed to be zero.
3543 if (DemandedElts.isSubsetOf(KnownUndef))
3544 return TLO.CombineTo(Op, TLO.DAG.getConstant(0, SDLoc(Op), VT));
3545 KnownUndef.clearAllBits();
3546
3547 // zext - if we just need the bottom element then we can mask:
3548 // zext(and(x,c)) -> and(x,c') iff the zext is the only user of the and.
3549 if (IsLE && DemandedSrcElts == 1 && Src.getOpcode() == ISD::AND &&
3550 Op->isOnlyUserOf(Src.getNode()) &&
3551 Op.getValueSizeInBits() == Src.getValueSizeInBits()) {
3552 SDLoc DL(Op);
3553 EVT SrcVT = Src.getValueType();
3554 EVT SrcSVT = SrcVT.getScalarType();
3555 SmallVector<SDValue> MaskElts;
3556 MaskElts.push_back(TLO.DAG.getAllOnesConstant(DL, SrcSVT));
3557 MaskElts.append(NumSrcElts - 1, TLO.DAG.getConstant(0, DL, SrcSVT));
3558 SDValue Mask = TLO.DAG.getBuildVector(SrcVT, DL, MaskElts);
3559 if (SDValue Fold = TLO.DAG.FoldConstantArithmetic(
3560 ISD::AND, DL, SrcVT, {Src.getOperand(1), Mask})) {
3561 Fold = TLO.DAG.getNode(ISD::AND, DL, SrcVT, Src.getOperand(0), Fold);
3562 return TLO.CombineTo(Op, TLO.DAG.getBitcast(VT, Fold));
3563 }
3564 }
3565 }
3566 break;
3567 }
3568
3569 // TODO: There are more binop opcodes that could be handled here - MIN,
3570 // MAX, saturated math, etc.
3571 case ISD::ADD: {
3572 SDValue Op0 = Op.getOperand(0);
3573 SDValue Op1 = Op.getOperand(1);
3574 if (Op0 == Op1 && Op->isOnlyUserOf(Op0.getNode())) {
3575 APInt UndefLHS, ZeroLHS;
3576 if (SimplifyDemandedVectorElts(Op0, DemandedElts, UndefLHS, ZeroLHS, TLO,
3577 Depth + 1, /*AssumeSingleUse*/ true))
3578 return true;
3579 }
3580 [[fallthrough]];
3581 }
3582 case ISD::AVGCEILS:
3583 case ISD::AVGCEILU:
3584 case ISD::AVGFLOORS:
3585 case ISD::AVGFLOORU:
3586 case ISD::OR:
3587 case ISD::XOR:
3588 case ISD::SUB:
3589 case ISD::FADD:
3590 case ISD::FSUB:
3591 case ISD::FMUL:
3592 case ISD::FDIV:
3593 case ISD::FREM: {
3594 SDValue Op0 = Op.getOperand(0);
3595 SDValue Op1 = Op.getOperand(1);
3596
3597 APInt UndefRHS, ZeroRHS;
3598 if (SimplifyDemandedVectorElts(Op1, DemandedElts, UndefRHS, ZeroRHS, TLO,
3599 Depth + 1))
3600 return true;
3601 APInt UndefLHS, ZeroLHS;
3602 if (SimplifyDemandedVectorElts(Op0, DemandedElts, UndefLHS, ZeroLHS, TLO,
3603 Depth + 1))
3604 return true;
3605
3606 KnownZero = ZeroLHS & ZeroRHS;
3607 KnownUndef = getKnownUndefForVectorBinop(Op, TLO.DAG, UndefLHS, UndefRHS);
3608
3609 // Attempt to avoid multi-use ops if we don't need anything from them.
3610 // TODO - use KnownUndef to relax the demandedelts?
3611 if (!DemandedElts.isAllOnes())
3612 if (SimplifyDemandedVectorEltsBinOp(Op0, Op1))
3613 return true;
3614 break;
3615 }
3616 case ISD::SHL:
3617 case ISD::SRL:
3618 case ISD::SRA:
3619 case ISD::ROTL:
3620 case ISD::ROTR: {
3621 SDValue Op0 = Op.getOperand(0);
3622 SDValue Op1 = Op.getOperand(1);
3623
3624 APInt UndefRHS, ZeroRHS;
3625 if (SimplifyDemandedVectorElts(Op1, DemandedElts, UndefRHS, ZeroRHS, TLO,
3626 Depth + 1))
3627 return true;
3628 APInt UndefLHS, ZeroLHS;
3629 if (SimplifyDemandedVectorElts(Op0, DemandedElts, UndefLHS, ZeroLHS, TLO,
3630 Depth + 1))
3631 return true;
3632
3633 KnownZero = ZeroLHS;
3634 KnownUndef = UndefLHS & UndefRHS; // TODO: use getKnownUndefForVectorBinop?
3635
3636 // Attempt to avoid multi-use ops if we don't need anything from them.
3637 // TODO - use KnownUndef to relax the demandedelts?
3638 if (!DemandedElts.isAllOnes())
3639 if (SimplifyDemandedVectorEltsBinOp(Op0, Op1))
3640 return true;
3641 break;
3642 }
3643 case ISD::MUL:
3644 case ISD::MULHU:
3645 case ISD::MULHS:
3646 case ISD::AND: {
3647 SDValue Op0 = Op.getOperand(0);
3648 SDValue Op1 = Op.getOperand(1);
3649
3650 APInt SrcUndef, SrcZero;
3651 if (SimplifyDemandedVectorElts(Op1, DemandedElts, SrcUndef, SrcZero, TLO,
3652 Depth + 1))
3653 return true;
3654 // If we know that a demanded element was zero in Op1 we don't need to
3655 // demand it in Op0 - its guaranteed to be zero.
3656 APInt DemandedElts0 = DemandedElts & ~SrcZero;
3657 if (SimplifyDemandedVectorElts(Op0, DemandedElts0, KnownUndef, KnownZero,
3658 TLO, Depth + 1))
3659 return true;
3660
3661 KnownUndef &= DemandedElts0;
3662 KnownZero &= DemandedElts0;
3663
3664 // If every element pair has a zero/undef then just fold to zero.
3665 // fold (and x, undef) -> 0 / (and x, 0) -> 0
3666 // fold (mul x, undef) -> 0 / (mul x, 0) -> 0
3667 if (DemandedElts.isSubsetOf(SrcZero | KnownZero | SrcUndef | KnownUndef))
3668 return TLO.CombineTo(Op, TLO.DAG.getConstant(0, SDLoc(Op), VT));
3669
3670 // If either side has a zero element, then the result element is zero, even
3671 // if the other is an UNDEF.
3672 // TODO: Extend getKnownUndefForVectorBinop to also deal with known zeros
3673 // and then handle 'and' nodes with the rest of the binop opcodes.
3674 KnownZero |= SrcZero;
3675 KnownUndef &= SrcUndef;
3676 KnownUndef &= ~KnownZero;
3677
3678 // Attempt to avoid multi-use ops if we don't need anything from them.
3679 if (!DemandedElts.isAllOnes())
3680 if (SimplifyDemandedVectorEltsBinOp(Op0, Op1))
3681 return true;
3682 break;
3683 }
3684 case ISD::TRUNCATE:
3685 case ISD::SIGN_EXTEND:
3686 case ISD::ZERO_EXTEND:
3687 if (SimplifyDemandedVectorElts(Op.getOperand(0), DemandedElts, KnownUndef,
3688 KnownZero, TLO, Depth + 1))
3689 return true;
3690
3691 if (Op.getOpcode() == ISD::ZERO_EXTEND) {
3692 // zext(undef) upper bits are guaranteed to be zero.
3693 if (DemandedElts.isSubsetOf(KnownUndef))
3694 return TLO.CombineTo(Op, TLO.DAG.getConstant(0, SDLoc(Op), VT));
3695 KnownUndef.clearAllBits();
3696 }
3697 break;
3698 default: {
3699 if (Op.getOpcode() >= ISD::BUILTIN_OP_END) {
3700 if (SimplifyDemandedVectorEltsForTargetNode(Op, DemandedElts, KnownUndef,
3701 KnownZero, TLO, Depth))
3702 return true;
3703 } else {
3704 KnownBits Known;
3705 APInt DemandedBits = APInt::getAllOnes(EltSizeInBits);
3706 if (SimplifyDemandedBits(Op, DemandedBits, OriginalDemandedElts, Known,
3707 TLO, Depth, AssumeSingleUse))
3708 return true;
3709 }
3710 break;
3711 }
3712 }
3713 assert((KnownUndef & KnownZero) == 0 && "Elements flagged as undef AND zero");
3714
3715 // Constant fold all undef cases.
3716 // TODO: Handle zero cases as well.
3717 if (DemandedElts.isSubsetOf(KnownUndef))
3718 return TLO.CombineTo(Op, TLO.DAG.getUNDEF(VT));
3719
3720 return false;
3721}
3722
3723/// Determine which of the bits specified in Mask are known to be either zero or
3724/// one and return them in the Known.
3726 KnownBits &Known,
3727 const APInt &DemandedElts,
3728 const SelectionDAG &DAG,
3729 unsigned Depth) const {
3730 assert((Op.getOpcode() >= ISD::BUILTIN_OP_END ||
3731 Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
3732 Op.getOpcode() == ISD::INTRINSIC_W_CHAIN ||
3733 Op.getOpcode() == ISD::INTRINSIC_VOID) &&
3734 "Should use MaskedValueIsZero if you don't know whether Op"
3735 " is a target node!");
3736 Known.resetAll();
3737}
3738
3741 const APInt &DemandedElts, const MachineRegisterInfo &MRI,
3742 unsigned Depth) const {
3743 Known.resetAll();
3744}
3745
3747 const int FrameIdx, KnownBits &Known, const MachineFunction &MF) const {
3748 // The low bits are known zero if the pointer is aligned.
3749 Known.Zero.setLowBits(Log2(MF.getFrameInfo().getObjectAlign(FrameIdx)));
3750}
3751
3754 unsigned Depth) const {
3755 return Align(1);
3756}
3757
3758/// This method can be implemented by targets that want to expose additional
3759/// information about sign bits to the DAG Combiner.
3761 const APInt &,
3762 const SelectionDAG &,
3763 unsigned Depth) const {
3764 assert((Op.getOpcode() >= ISD::BUILTIN_OP_END ||
3765 Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
3766 Op.getOpcode() == ISD::INTRINSIC_W_CHAIN ||
3767 Op.getOpcode() == ISD::INTRINSIC_VOID) &&
3768 "Should use ComputeNumSignBits if you don't know whether Op"
3769 " is a target node!");
3770 return 1;
3771}
3772
3774 GISelKnownBits &Analysis, Register R, const APInt &DemandedElts,
3775 const MachineRegisterInfo &MRI, unsigned Depth) const {
3776 return 1;
3777}
3778
3780 SDValue Op, const APInt &DemandedElts, APInt &KnownUndef, APInt &KnownZero,
3781 TargetLoweringOpt &TLO, unsigned Depth) const {
3782 assert((Op.getOpcode() >= ISD::BUILTIN_OP_END ||
3783 Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
3784 Op.getOpcode() == ISD::INTRINSIC_W_CHAIN ||
3785 Op.getOpcode() == ISD::INTRINSIC_VOID) &&
3786 "Should use SimplifyDemandedVectorElts if you don't know whether Op"
3787 " is a target node!");
3788 return false;
3789}
3790
3792 SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
3793 KnownBits &Known, TargetLoweringOpt &TLO, unsigned Depth) const {
3794 assert((Op.getOpcode() >= ISD::BUILTIN_OP_END ||
3795 Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
3796 Op.getOpcode() == ISD::INTRINSIC_W_CHAIN ||
3797 Op.getOpcode() == ISD::INTRINSIC_VOID) &&
3798 "Should use SimplifyDemandedBits if you don't know whether Op"
3799 " is a target node!");
3800 computeKnownBitsForTargetNode(Op, Known, DemandedElts, TLO.DAG, Depth);
3801 return false;
3802}
3803
3805 SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
3806 SelectionDAG &DAG, unsigned Depth) const {
3807 assert(
3808 (Op.getOpcode() >= ISD::BUILTIN_OP_END ||
3809 Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
3810 Op.getOpcode() == ISD::INTRINSIC_W_CHAIN ||
3811 Op.getOpcode() == ISD::INTRINSIC_VOID) &&
3812 "Should use SimplifyMultipleUseDemandedBits if you don't know whether Op"
3813 " is a target node!");
3814 return SDValue();
3815}
3816
3817SDValue
3820 SelectionDAG &DAG) const {
3821 bool LegalMask = isShuffleMaskLegal(Mask, VT);
3822 if (!LegalMask) {
3823 std::swap(N0, N1);
3825 LegalMask = isShuffleMaskLegal(Mask, VT);
3826 }
3827
3828 if (!LegalMask)
3829 return SDValue();
3830
3831 return DAG.getVectorShuffle(VT, DL, N0, N1, Mask);
3832}
3833
3835 return nullptr;
3836}
3837
3839 SDValue Op, const APInt &DemandedElts, const SelectionDAG &DAG,
3840 bool PoisonOnly, unsigned Depth) const {
3841 assert(
3842 (Op.getOpcode() >= ISD::BUILTIN_OP_END ||
3843 Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
3844 Op.getOpcode() == ISD::INTRINSIC_W_CHAIN ||
3845 Op.getOpcode() == ISD::INTRINSIC_VOID) &&
3846 "Should use isGuaranteedNotToBeUndefOrPoison if you don't know whether Op"
3847 " is a target node!");
3848
3849 // If Op can't create undef/poison and none of its operands are undef/poison
3850 // then Op is never undef/poison.
3851 return !canCreateUndefOrPoisonForTargetNode(Op, DemandedElts, DAG, PoisonOnly,
3852 /*ConsiderFlags*/ true, Depth) &&
3853 all_of(Op->ops(), [&](SDValue V) {
3854 return DAG.isGuaranteedNotToBeUndefOrPoison(V, PoisonOnly,
3855 Depth + 1);
3856 });
3857}
3858
3860 SDValue Op, const APInt &DemandedElts, const SelectionDAG &DAG,
3861 bool PoisonOnly, bool ConsiderFlags, unsigned Depth) const {
3862 assert((Op.getOpcode() >= ISD::BUILTIN_OP_END ||
3863 Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
3864 Op.getOpcode() == ISD::INTRINSIC_W_CHAIN ||
3865 Op.getOpcode() == ISD::INTRINSIC_VOID) &&
3866 "Should use canCreateUndefOrPoison if you don't know whether Op"
3867 " is a target node!");
3868 // Be conservative and return true.
3869 return true;
3870}
3871
3873 const SelectionDAG &DAG,
3874 bool SNaN,
3875 unsigned Depth) const {
3876 assert((Op.getOpcode() >= ISD::BUILTIN_OP_END ||
3877 Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
3878 Op.getOpcode() == ISD::INTRINSIC_W_CHAIN ||
3879 Op.getOpcode() == ISD::INTRINSIC_VOID) &&
3880 "Should use isKnownNeverNaN if you don't know whether Op"
3881 " is a target node!");
3882 return false;
3883}
3884
3886 const APInt &DemandedElts,
3887 APInt &UndefElts,
3888 const SelectionDAG &DAG,
3889 unsigned Depth) const {
3890 assert((Op.getOpcode() >= ISD::BUILTIN_OP_END ||
3891 Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
3892 Op.getOpcode() == ISD::INTRINSIC_W_CHAIN ||
3893 Op.getOpcode() == ISD::INTRINSIC_VOID) &&
3894 "Should use isSplatValue if you don't know whether Op"
3895 " is a target node!");
3896 return false;
3897}
3898
3899// FIXME: Ideally, this would use ISD::isConstantSplatVector(), but that must
3900// work with truncating build vectors and vectors with elements of less than
3901// 8 bits.
3903 if (!N)
3904 return false;
3905
3906 unsigned EltWidth;
3907 APInt CVal;
3908 if (ConstantSDNode *CN = isConstOrConstSplat(N, /*AllowUndefs=*/false,
3909 /*AllowTruncation=*/true)) {
3910 CVal = CN->getAPIntValue();
3911 EltWidth = N.getValueType().getScalarSizeInBits();
3912 } else
3913 return false;
3914
3915 // If this is a truncating splat, truncate the splat value.
3916 // Otherwise, we may fail to match the expected values below.
3917 if (EltWidth < CVal.getBitWidth())
3918 CVal = CVal.trunc(EltWidth);
3919
3920 switch (getBooleanContents(N.getValueType())) {
3922 return CVal[0];
3924 return CVal.isOne();
3926 return CVal.isAllOnes();
3927 }
3928
3929 llvm_unreachable("Invalid boolean contents");
3930}
3931
3933 if (!N)
3934 return false;
3935
3936 const ConstantSDNode *CN = dyn_cast<ConstantSDNode>(N);
3937 if (!CN) {
3938 const BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(N);
3939 if (!BV)
3940 return false;
3941
3942 // Only interested in constant splats, we don't care about undef
3943 // elements in identifying boolean constants and getConstantSplatNode
3944 // returns NULL if all ops are undef;
3945 CN = BV->getConstantSplatNode();
3946 if (!CN)
3947 return false;
3948 }
3949
3950 if (getBooleanContents(N->getValueType(0)) == UndefinedBooleanContent)
3951 return !CN->getAPIntValue()[0];
3952
3953 return CN->isZero();
3954}
3955
3957 bool SExt) const {
3958 if (VT == MVT::i1)
3959 return N->isOne();
3960
3962 switch (Cnt) {
3964 // An extended value of 1 is always true, unless its original type is i1,
3965 // in which case it will be sign extended to -1.
3966 return (N->isOne() && !SExt) || (SExt && (N->getValueType(0) != MVT::i1));
3969 return N->isAllOnes() && SExt;
3970 }
3971 llvm_unreachable("Unexpected enumeration.");
3972}
3973
3974/// This helper function of SimplifySetCC tries to optimize the comparison when
3975/// either operand of the SetCC node is a bitwise-and instruction.
3976SDValue TargetLowering::foldSetCCWithAnd(EVT VT, SDValue N0, SDValue N1,
3977 ISD::CondCode Cond, const SDLoc &DL,
3978 DAGCombinerInfo &DCI) const {
3979 if (N1.getOpcode() == ISD::AND && N0.getOpcode() != ISD::AND)
3980 std::swap(N0, N1);
3981
3982 SelectionDAG &DAG = DCI.DAG;
3983 EVT OpVT = N0.getValueType();
3984 if (N0.getOpcode() != ISD::AND || !OpVT.isInteger() ||
3985 (Cond != ISD::SETEQ && Cond != ISD::SETNE))
3986 return SDValue();
3987
3988 // (X & Y) != 0 --> zextOrTrunc(X & Y)
3989 // iff everything but LSB is known zero:
3990 if (Cond == ISD::SETNE && isNullConstant(N1) &&
3993 unsigned NumEltBits = OpVT.getScalarSizeInBits();
3994 APInt UpperBits = APInt::getHighBitsSet(NumEltBits, NumEltBits - 1);
3995 if (DAG.MaskedValueIsZero(N0, UpperBits))
3996 return DAG.getBoolExtOrTrunc(N0, DL, VT, OpVT);
3997 }
3998
3999 // Try to eliminate a power-of-2 mask constant by converting to a signbit
4000 // test in a narrow type that we can truncate to with no cost. Examples:
4001 // (i32 X & 32768) == 0 --> (trunc X to i16) >= 0
4002 // (i32 X & 32768) != 0 --> (trunc X to i16) < 0
4003 // TODO: This conservatively checks for type legality on the source and
4004 // destination types. That may inhibit optimizations, but it also
4005 // allows setcc->shift transforms that may be more beneficial.
4006 auto *AndC = dyn_cast<ConstantSDNode>(N0.getOperand(1));
4007 if (AndC && isNullConstant(N1) && AndC->getAPIntValue().isPowerOf2() &&
4008 isTypeLegal(OpVT) && N0.hasOneUse()) {
4009 EVT NarrowVT = EVT::getIntegerVT(*DAG.getContext(),
4010 AndC->getAPIntValue().getActiveBits());
4011 if (isTruncateFree(OpVT, NarrowVT) && isTypeLegal(NarrowVT)) {
4012 SDValue Trunc = DAG.getZExtOrTrunc(N0.getOperand(0), DL, NarrowVT);
4013 SDValue Zero = DAG.getConstant(0, DL, NarrowVT);
4014 return DAG.getSetCC(DL, VT, Trunc, Zero,
4016 }
4017 }
4018
4019 // Match these patterns in any of their permutations:
4020 // (X & Y) == Y
4021 // (X & Y) != Y
4022 SDValue X, Y;
4023 if (N0.getOperand(0) == N1) {
4024 X = N0.getOperand(1);
4025 Y = N0.getOperand(0);
4026 } else if (N0.getOperand(1) == N1) {
4027 X = N0.getOperand(0);
4028 Y = N0.getOperand(1);
4029 } else {
4030 return SDValue();
4031 }
4032
4033 // TODO: We should invert (X & Y) eq/ne 0 -> (X & Y) ne/eq Y if
4034 // `isXAndYEqZeroPreferableToXAndYEqY` is false. This is a bit difficult as
4035 // its liable to create and infinite loop.
4036 SDValue Zero = DAG.getConstant(0, DL, OpVT);
4037 if (isXAndYEqZeroPreferableToXAndYEqY(Cond, OpVT) &&
4039 // Simplify X & Y == Y to X & Y != 0 if Y has exactly one bit set.
4040 // Note that where Y is variable and is known to have at most one bit set
4041 // (for example, if it is Z & 1) we cannot do this; the expressions are not
4042 // equivalent when Y == 0.
4043 assert(OpVT.isInteger());
4045 if (DCI.isBeforeLegalizeOps() ||
4047 return DAG.getSetCC(DL, VT, N0, Zero, Cond);
4048 } else if (N0.hasOneUse() && hasAndNotCompare(Y)) {
4049 // If the target supports an 'and-not' or 'and-complement' logic operation,
4050 // try to use that to make a comparison operation more efficient.
4051 // But don't do this transform if the mask is a single bit because there are
4052 // more efficient ways to deal with that case (for example, 'bt' on x86 or
4053 // 'rlwinm' on PPC).
4054
4055 // Bail out if the compare operand that we want to turn into a zero is
4056 // already a zero (otherwise, infinite loop).
4057 if (isNullConstant(Y))
4058 return SDValue();
4059
4060 // Transform this into: ~X & Y == 0.
4061 SDValue NotX = DAG.getNOT(SDLoc(X), X, OpVT);
4062 SDValue NewAnd = DAG.getNode(ISD::AND, SDLoc(N0), OpVT, NotX, Y);
4063 return DAG.getSetCC(DL, VT, NewAnd, Zero, Cond);
4064 }
4065
4066 return SDValue();
4067}
4068
4069/// There are multiple IR patterns that could be checking whether certain
4070/// truncation of a signed number would be lossy or not. The pattern which is
4071/// best at IR level, may not lower optimally. Thus, we want to unfold it.
4072/// We are looking for the following pattern: (KeptBits is a constant)
4073/// (add %x, (1 << (KeptBits-1))) srccond (1 << KeptBits)
4074/// KeptBits won't be bitwidth(x), that will be constant-folded to true/false.
4075/// KeptBits also can't be 1, that would have been folded to %x dstcond 0
4076/// We will unfold it into the natural trunc+sext pattern:
4077/// ((%x << C) a>> C) dstcond %x
4078/// Where C = bitwidth(x) - KeptBits and C u< bitwidth(x)
4079SDValue TargetLowering::optimizeSetCCOfSignedTruncationCheck(
4080 EVT SCCVT, SDValue N0, SDValue N1, ISD::CondCode Cond, DAGCombinerInfo &DCI,
4081 const SDLoc &DL) const {
4082 // We must be comparing with a constant.
4083 ConstantSDNode *C1;
4084 if (!(C1 = dyn_cast<ConstantSDNode>(N1)))
4085 return SDValue();
4086
4087 // N0 should be: add %x, (1 << (KeptBits-1))
4088 if (N0->getOpcode() != ISD::ADD)
4089 return SDValue();
4090
4091 // And we must be 'add'ing a constant.
4092 ConstantSDNode *C01;
4093 if (!(C01 = dyn_cast<ConstantSDNode>(N0->getOperand(1))))
4094 return SDValue();
4095
4096 SDValue X = N0->getOperand(0);
4097 EVT XVT = X.getValueType();
4098
4099 // Validate constants ...
4100
4101 APInt I1 = C1->getAPIntValue();
4102
4103 ISD::CondCode NewCond;
4104 if (Cond == ISD::CondCode::SETULT) {
4105 NewCond = ISD::CondCode::SETEQ;
4106 } else if (Cond == ISD::CondCode::SETULE) {
4107 NewCond = ISD::CondCode::SETEQ;
4108 // But need to 'canonicalize' the constant.
4109 I1 += 1;
4110 } else if (Cond == ISD::CondCode::SETUGT) {
4111 NewCond = ISD::CondCode::SETNE;
4112 // But need to 'canonicalize' the constant.
4113 I1 += 1;
4114 } else if (Cond == ISD::CondCode::SETUGE) {
4115 NewCond = ISD::CondCode::SETNE;
4116 } else
4117 return SDValue();
4118
4119 APInt I01 = C01->getAPIntValue();
4120
4121 auto checkConstants = [&I1, &I01]() -> bool {
4122 // Both of them must be power-of-two, and the constant from setcc is bigger.
4123 return I1.ugt(I01) && I1.isPowerOf2() && I01.isPowerOf2();
4124 };
4125
4126 if (checkConstants()) {
4127 // Great, e.g. got icmp ult i16 (add i16 %x, 128), 256
4128 } else {
4129 // What if we invert constants? (and the target predicate)
4130 I1.negate();
4131 I01.negate();
4132 assert(XVT.isInteger());
4133 NewCond = getSetCCInverse(NewCond, XVT);
4134 if (!checkConstants())
4135 return SDValue();
4136 // Great, e.g. got icmp uge i16 (add i16 %x, -128), -256
4137 }
4138
4139 // They are power-of-two, so which bit is set?
4140 const unsigned KeptBits = I1.logBase2();
4141 const unsigned KeptBitsMinusOne = I01.logBase2();
4142
4143 // Magic!
4144 if (KeptBits != (KeptBitsMinusOne + 1))
4145 return SDValue();
4146 assert(KeptBits > 0 && KeptBits < XVT.getSizeInBits() && "unreachable");
4147
4148 // We don't want to do this in every single case.
4149 SelectionDAG &DAG = DCI.DAG;
4151 XVT, KeptBits))
4152 return SDValue();
4153
4154 // Unfold into: sext_inreg(%x) cond %x
4155 // Where 'cond' will be either 'eq' or 'ne'.
4156 SDValue SExtInReg = DAG.getNode(
4158 DAG.getValueType(EVT::getIntegerVT(*DAG.getContext(), KeptBits)));
4159 return DAG.getSetCC(DL, SCCVT, SExtInReg, X, NewCond);
4160}
4161
4162// (X & (C l>>/<< Y)) ==/!= 0 --> ((X <</l>> Y) & C) ==/!= 0
4163SDValue TargetLowering::optimizeSetCCByHoistingAndByConstFromLogicalShift(
4164 EVT SCCVT, SDValue N0, SDValue N1C, ISD::CondCode Cond,
4165 DAGCombinerInfo &DCI, const SDLoc &DL) const {
4167 "Should be a comparison with 0.");
4168 assert((Cond == ISD::SETEQ || Cond == ISD::SETNE) &&
4169 "Valid only for [in]equality comparisons.");
4170
4171 unsigned NewShiftOpcode;
4172 SDValue X, C, Y;
4173
4174 SelectionDAG &DAG = DCI.DAG;
4175 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
4176
4177 // Look for '(C l>>/<< Y)'.
4178 auto Match = [&NewShiftOpcode, &X, &C, &Y, &TLI, &DAG](SDValue V) {
4179 // The shift should be one-use.
4180 if (!V.hasOneUse())
4181 return false;
4182 unsigned OldShiftOpcode = V.getOpcode();
4183 switch (OldShiftOpcode) {
4184 case ISD::SHL:
4185 NewShiftOpcode = ISD::SRL;
4186 break;
4187 case ISD::SRL:
4188 NewShiftOpcode = ISD::SHL;
4189 break;
4190 default:
4191 return false; // must be a logical shift.
4192 }
4193 // We should be shifting a constant.
4194 // FIXME: best to use isConstantOrConstantVector().
4195 C = V.getOperand(0);
4197 isConstOrConstSplat(C, /*AllowUndefs=*/true, /*AllowTruncation=*/true);
4198 if (!CC)
4199 return false;
4200 Y = V.getOperand(1);
4201
4203 isConstOrConstSplat(X, /*AllowUndefs=*/true, /*AllowTruncation=*/true);
4204 return TLI.shouldProduceAndByConstByHoistingConstFromShiftsLHSOfAnd(
4205 X, XC, CC, Y, OldShiftOpcode, NewShiftOpcode, DAG);
4206 };
4207
4208 // LHS of comparison should be an one-use 'and'.
4209 if (N0.getOpcode() != ISD::AND || !N0.hasOneUse())
4210 return SDValue();
4211
4212 X = N0.getOperand(0);
4213 SDValue Mask = N0.getOperand(1);
4214
4215 // 'and' is commutative!
4216 if (!Match(Mask)) {
4217 std::swap(X, Mask);
4218 if (!Match(Mask))
4219 return SDValue();
4220 }
4221
4222 EVT VT = X.getValueType();
4223
4224 // Produce:
4225 // ((X 'OppositeShiftOpcode' Y) & C) Cond 0
4226 SDValue T0 = DAG.getNode(NewShiftOpcode, DL, VT, X, Y);
4227 SDValue T1 = DAG.getNode(ISD::AND, DL, VT, T0, C);
4228 SDValue T2 = DAG.getSetCC(DL, SCCVT, T1, N1C, Cond);
4229 return T2;
4230}
4231
4232/// Try to fold an equality comparison with a {add/sub/xor} binary operation as
4233/// the 1st operand (N0). Callers are expected to swap the N0/N1 parameters to
4234/// handle the commuted versions of these patterns.
4235SDValue TargetLowering::foldSetCCWithBinOp(EVT VT, SDValue N0, SDValue N1,
4236 ISD::CondCode Cond, const SDLoc &DL,
4237 DAGCombinerInfo &DCI) const {
4238 unsigned BOpcode = N0.getOpcode();
4239 assert((BOpcode == ISD::ADD || BOpcode == ISD::SUB || BOpcode == ISD::XOR) &&
4240 "Unexpected binop");
4241 assert((Cond == ISD::SETEQ || Cond == ISD::SETNE) && "Unexpected condcode");
4242
4243 // (X + Y) == X --> Y == 0
4244 // (X - Y) == X --> Y == 0
4245 // (X ^ Y) == X --> Y == 0
4246 SelectionDAG &DAG = DCI.DAG;
4247 EVT OpVT = N0.getValueType();
4248 SDValue X = N0.getOperand(0);
4249 SDValue Y = N0.getOperand(1);
4250 if (X == N1)
4251 return DAG.getSetCC(DL, VT, Y, DAG.getConstant(0, DL, OpVT), Cond);
4252
4253 if (Y != N1)
4254 return SDValue();
4255
4256 // (X + Y) == Y --> X == 0
4257 // (X ^ Y) == Y --> X == 0
4258 if (BOpcode == ISD::ADD || BOpcode == ISD::XOR)
4259 return DAG.getSetCC(DL, VT, X, DAG.getConstant(0, DL, OpVT), Cond);
4260
4261 // The shift would not be valid if the operands are boolean (i1).
4262 if (!N0.hasOneUse() || OpVT.getScalarSizeInBits() == 1)
4263 return SDValue();
4264
4265 // (X - Y) == Y --> X == Y << 1
4266 SDValue One = DAG.getShiftAmountConstant(1, OpVT, DL);
4267 SDValue YShl1 = DAG.getNode(ISD::SHL, DL, N1.getValueType(), Y, One);
4268 if (!DCI.isCalledByLegalizer())
4269 DCI.AddToWorklist(YShl1.getNode());
4270 return DAG.getSetCC(DL, VT, X, YShl1, Cond);
4271}
4272
4274 SDValue N0, const APInt &C1,
4275 ISD::CondCode Cond, const SDLoc &dl,
4276 SelectionDAG &DAG) {
4277 // Look through truncs that don't change the value of a ctpop.
4278 // FIXME: Add vector support? Need to be careful with setcc result type below.
4279 SDValue CTPOP = N0;
4280 if (N0.getOpcode() == ISD::TRUNCATE && N0.hasOneUse() && !VT.isVector() &&
4282 CTPOP = N0.getOperand(0);
4283
4284 if (CTPOP.getOpcode() != ISD::CTPOP || !CTPOP.hasOneUse())
4285 return SDValue();
4286
4287 EVT CTVT = CTPOP.getValueType();
4288 SDValue CTOp = CTPOP.getOperand(0);
4289
4290 // Expand a power-of-2-or-zero comparison based on ctpop:
4291 // (ctpop x) u< 2 -> (x & x-1) == 0
4292 // (ctpop x) u> 1 -> (x & x-1) != 0
4293 if (Cond == ISD::SETULT || Cond == ISD::SETUGT) {
4294 // Keep the CTPOP if it is a cheap vector op.
4295 if (CTVT.isVector() && TLI.isCtpopFast(CTVT))
4296 return SDValue();
4297
4298 unsigned CostLimit = TLI.getCustomCtpopCost(CTVT, Cond);
4299 if (C1.ugt(CostLimit + (Cond == ISD::SETULT)))
4300 return SDValue();
4301 if (C1 == 0 && (Cond == ISD::SETULT))
4302 return SDValue(); // This is handled elsewhere.
4303
4304 unsigned Passes = C1.getLimitedValue() - (Cond == ISD::SETULT);
4305
4306 SDValue NegOne = DAG.getAllOnesConstant(dl, CTVT);
4307 SDValue Result = CTOp;
4308 for (unsigned i = 0; i < Passes; i++) {
4309 SDValue Add = DAG.getNode(ISD::ADD, dl, CTVT, Result, NegOne);
4310 Result = DAG.getNode(ISD::AND, dl, CTVT, Result, Add);
4311 }
4313 return DAG.getSetCC(dl, VT, Result, DAG.getConstant(0, dl, CTVT), CC);
4314 }
4315
4316 // Expand a power-of-2 comparison based on ctpop
4317 if ((Cond == ISD::SETEQ || Cond == ISD::SETNE) && C1 == 1) {
4318 // Keep the CTPOP if it is cheap.
4319 if (TLI.isCtpopFast(CTVT))
4320 return SDValue();
4321
4322 SDValue Zero = DAG.getConstant(0, dl, CTVT);
4323 SDValue NegOne = DAG.getAllOnesConstant(dl, CTVT);
4324 assert(CTVT.isInteger());
4325 SDValue Add = DAG.getNode(ISD::ADD, dl, CTVT, CTOp, NegOne);
4326
4327 // Its not uncommon for known-never-zero X to exist in (ctpop X) eq/ne 1, so
4328 // check before emitting a potentially unnecessary op.
4329 if (DAG.isKnownNeverZero(CTOp)) {
4330 // (ctpop x) == 1 --> (x & x-1) == 0
4331 // (ctpop x) != 1 --> (x & x-1) != 0
4332 SDValue And = DAG.getNode(ISD::AND, dl, CTVT, CTOp, Add);
4333 SDValue RHS = DAG.getSetCC(dl, VT, And, Zero, Cond);
4334 return RHS;
4335 }
4336
4337 // (ctpop x) == 1 --> (x ^ x-1) > x-1
4338 // (ctpop x) != 1 --> (x ^ x-1) <= x-1
4339 SDValue Xor = DAG.getNode(ISD::XOR, dl, CTVT, CTOp, Add);
4341 return DAG.getSetCC(dl, VT, Xor, Add, CmpCond);
4342 }
4343
4344 return SDValue();
4345}
4346
4348 ISD::CondCode Cond, const SDLoc &dl,
4349 SelectionDAG &DAG) {
4350 if (Cond != ISD::SETEQ && Cond != ISD::SETNE)
4351 return SDValue();
4352
4353 auto *C1 = isConstOrConstSplat(N1, /* AllowUndefs */ true);
4354 if (!C1 || !(C1->isZero() || C1->isAllOnes()))
4355 return SDValue();
4356
4357 auto getRotateSource = [](SDValue X) {
4358 if (X.getOpcode() == ISD::ROTL || X.getOpcode() == ISD::ROTR)
4359 return X.getOperand(0);
4360 return SDValue();
4361 };
4362