LLVM 19.0.0git
AArch64PreLegalizerCombiner.cpp
Go to the documentation of this file.
1//=== lib/CodeGen/GlobalISel/AArch64PreLegalizerCombiner.cpp --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass does combining of machine instructions at the generic MI level,
10// before the legalizer.
11//
12//===----------------------------------------------------------------------===//
13
31#include "llvm/Support/Debug.h"
32
33#define GET_GICOMBINER_DEPS
34#include "AArch64GenPreLegalizeGICombiner.inc"
35#undef GET_GICOMBINER_DEPS
36
37#define DEBUG_TYPE "aarch64-prelegalizer-combiner"
38
39using namespace llvm;
40using namespace MIPatternMatch;
41
42namespace {
43
44#define GET_GICOMBINER_TYPES
45#include "AArch64GenPreLegalizeGICombiner.inc"
46#undef GET_GICOMBINER_TYPES
47
48/// Return true if a G_FCONSTANT instruction is known to be better-represented
49/// as a G_CONSTANT.
50bool matchFConstantToConstant(MachineInstr &MI, MachineRegisterInfo &MRI) {
51 assert(MI.getOpcode() == TargetOpcode::G_FCONSTANT);
52 Register DstReg = MI.getOperand(0).getReg();
53 const unsigned DstSize = MRI.getType(DstReg).getSizeInBits();
54 if (DstSize != 32 && DstSize != 64)
55 return false;
56
57 // When we're storing a value, it doesn't matter what register bank it's on.
58 // Since not all floating point constants can be materialized using a fmov,
59 // it makes more sense to just use a GPR.
60 return all_of(MRI.use_nodbg_instructions(DstReg),
61 [](const MachineInstr &Use) { return Use.mayStore(); });
62}
63
64/// Change a G_FCONSTANT into a G_CONSTANT.
65void applyFConstantToConstant(MachineInstr &MI) {
66 assert(MI.getOpcode() == TargetOpcode::G_FCONSTANT);
68 const APFloat &ImmValAPF = MI.getOperand(1).getFPImm()->getValueAPF();
69 MIB.buildConstant(MI.getOperand(0).getReg(), ImmValAPF.bitcastToAPInt());
70 MI.eraseFromParent();
71}
72
73/// Try to match a G_ICMP of a G_TRUNC with zero, in which the truncated bits
74/// are sign bits. In this case, we can transform the G_ICMP to directly compare
75/// the wide value with a zero.
76bool matchICmpRedundantTrunc(MachineInstr &MI, MachineRegisterInfo &MRI,
77 GISelKnownBits *KB, Register &MatchInfo) {
78 assert(MI.getOpcode() == TargetOpcode::G_ICMP && KB);
79
80 auto Pred = (CmpInst::Predicate)MI.getOperand(1).getPredicate();
81 if (!ICmpInst::isEquality(Pred))
82 return false;
83
84 Register LHS = MI.getOperand(2).getReg();
85 LLT LHSTy = MRI.getType(LHS);
86 if (!LHSTy.isScalar())
87 return false;
88
89 Register RHS = MI.getOperand(3).getReg();
90 Register WideReg;
91
92 if (!mi_match(LHS, MRI, m_GTrunc(m_Reg(WideReg))) ||
93 !mi_match(RHS, MRI, m_SpecificICst(0)))
94 return false;
95
96 LLT WideTy = MRI.getType(WideReg);
97 if (KB->computeNumSignBits(WideReg) <=
98 WideTy.getSizeInBits() - LHSTy.getSizeInBits())
99 return false;
100
101 MatchInfo = WideReg;
102 return true;
103}
104
105void applyICmpRedundantTrunc(MachineInstr &MI, MachineRegisterInfo &MRI,
106 MachineIRBuilder &Builder,
107 GISelChangeObserver &Observer, Register &WideReg) {
108 assert(MI.getOpcode() == TargetOpcode::G_ICMP);
109
110 LLT WideTy = MRI.getType(WideReg);
111 // We're going to directly use the wide register as the LHS, and then use an
112 // equivalent size zero for RHS.
113 Builder.setInstrAndDebugLoc(MI);
114 auto WideZero = Builder.buildConstant(WideTy, 0);
115 Observer.changingInstr(MI);
116 MI.getOperand(2).setReg(WideReg);
117 MI.getOperand(3).setReg(WideZero.getReg(0));
118 Observer.changedInstr(MI);
119}
120
121/// \returns true if it is possible to fold a constant into a G_GLOBAL_VALUE.
122///
123/// e.g.
124///
125/// %g = G_GLOBAL_VALUE @x -> %g = G_GLOBAL_VALUE @x + cst
126bool matchFoldGlobalOffset(MachineInstr &MI, MachineRegisterInfo &MRI,
127 std::pair<uint64_t, uint64_t> &MatchInfo) {
128 assert(MI.getOpcode() == TargetOpcode::G_GLOBAL_VALUE);
129 MachineFunction &MF = *MI.getMF();
130 auto &GlobalOp = MI.getOperand(1);
131 auto *GV = GlobalOp.getGlobal();
132 if (GV->isThreadLocal())
133 return false;
134
135 // Don't allow anything that could represent offsets etc.
137 GV, MF.getTarget()) != AArch64II::MO_NO_FLAG)
138 return false;
139
140 // Look for a G_GLOBAL_VALUE only used by G_PTR_ADDs against constants:
141 //
142 // %g = G_GLOBAL_VALUE @x
143 // %ptr1 = G_PTR_ADD %g, cst1
144 // %ptr2 = G_PTR_ADD %g, cst2
145 // ...
146 // %ptrN = G_PTR_ADD %g, cstN
147 //
148 // Identify the *smallest* constant. We want to be able to form this:
149 //
150 // %offset_g = G_GLOBAL_VALUE @x + min_cst
151 // %g = G_PTR_ADD %offset_g, -min_cst
152 // %ptr1 = G_PTR_ADD %g, cst1
153 // ...
154 Register Dst = MI.getOperand(0).getReg();
155 uint64_t MinOffset = -1ull;
156 for (auto &UseInstr : MRI.use_nodbg_instructions(Dst)) {
157 if (UseInstr.getOpcode() != TargetOpcode::G_PTR_ADD)
158 return false;
160 UseInstr.getOperand(2).getReg(), MRI);
161 if (!Cst)
162 return false;
163 MinOffset = std::min(MinOffset, Cst->Value.getZExtValue());
164 }
165
166 // Require that the new offset is larger than the existing one to avoid
167 // infinite loops.
168 uint64_t CurrOffset = GlobalOp.getOffset();
169 uint64_t NewOffset = MinOffset + CurrOffset;
170 if (NewOffset <= CurrOffset)
171 return false;
172
173 // Check whether folding this offset is legal. It must not go out of bounds of
174 // the referenced object to avoid violating the code model, and must be
175 // smaller than 2^20 because this is the largest offset expressible in all
176 // object formats. (The IMAGE_REL_ARM64_PAGEBASE_REL21 relocation in COFF
177 // stores an immediate signed 21 bit offset.)
178 //
179 // This check also prevents us from folding negative offsets, which will end
180 // up being treated in the same way as large positive ones. They could also
181 // cause code model violations, and aren't really common enough to matter.
182 if (NewOffset >= (1 << 20))
183 return false;
184
185 Type *T = GV->getValueType();
186 if (!T->isSized() ||
187 NewOffset > GV->getParent()->getDataLayout().getTypeAllocSize(T))
188 return false;
189 MatchInfo = std::make_pair(NewOffset, MinOffset);
190 return true;
191}
192
193void applyFoldGlobalOffset(MachineInstr &MI, MachineRegisterInfo &MRI,
195 std::pair<uint64_t, uint64_t> &MatchInfo) {
196 // Change:
197 //
198 // %g = G_GLOBAL_VALUE @x
199 // %ptr1 = G_PTR_ADD %g, cst1
200 // %ptr2 = G_PTR_ADD %g, cst2
201 // ...
202 // %ptrN = G_PTR_ADD %g, cstN
203 //
204 // To:
205 //
206 // %offset_g = G_GLOBAL_VALUE @x + min_cst
207 // %g = G_PTR_ADD %offset_g, -min_cst
208 // %ptr1 = G_PTR_ADD %g, cst1
209 // ...
210 // %ptrN = G_PTR_ADD %g, cstN
211 //
212 // Then, the original G_PTR_ADDs should be folded later on so that they look
213 // like this:
214 //
215 // %ptrN = G_PTR_ADD %offset_g, cstN - min_cst
216 uint64_t Offset, MinOffset;
217 std::tie(Offset, MinOffset) = MatchInfo;
218 B.setInstrAndDebugLoc(*std::next(MI.getIterator()));
219 Observer.changingInstr(MI);
220 auto &GlobalOp = MI.getOperand(1);
221 auto *GV = GlobalOp.getGlobal();
222 GlobalOp.ChangeToGA(GV, Offset, GlobalOp.getTargetFlags());
223 Register Dst = MI.getOperand(0).getReg();
224 Register NewGVDst = MRI.cloneVirtualRegister(Dst);
225 MI.getOperand(0).setReg(NewGVDst);
226 Observer.changedInstr(MI);
227 B.buildPtrAdd(
228 Dst, NewGVDst,
229 B.buildConstant(LLT::scalar(64), -static_cast<int64_t>(MinOffset)));
230}
231
232// Combines vecreduce_add(mul(ext(x), ext(y))) -> vecreduce_add(udot(x, y))
233// Or vecreduce_add(ext(x)) -> vecreduce_add(udot(x, 1))
234// Similar to performVecReduceAddCombine in SelectionDAG
235bool matchExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI,
236 const AArch64Subtarget &STI,
237 std::tuple<Register, Register, bool> &MatchInfo) {
238 assert(MI.getOpcode() == TargetOpcode::G_VECREDUCE_ADD &&
239 "Expected a G_VECREDUCE_ADD instruction");
240 assert(STI.hasDotProd() && "Target should have Dot Product feature");
241
242 MachineInstr *I1 = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI);
243 Register DstReg = MI.getOperand(0).getReg();
244 Register MidReg = I1->getOperand(0).getReg();
245 LLT DstTy = MRI.getType(DstReg);
246 LLT MidTy = MRI.getType(MidReg);
247 if (DstTy.getScalarSizeInBits() != 32 || MidTy.getScalarSizeInBits() != 32)
248 return false;
249
250 LLT SrcTy;
251 auto I1Opc = I1->getOpcode();
252 if (I1Opc == TargetOpcode::G_MUL) {
253 // If result of this has more than 1 use, then there is no point in creating
254 // udot instruction
255 if (!MRI.hasOneNonDBGUse(MidReg))
256 return false;
257
258 MachineInstr *ExtMI1 =
259 getDefIgnoringCopies(I1->getOperand(1).getReg(), MRI);
260 MachineInstr *ExtMI2 =
261 getDefIgnoringCopies(I1->getOperand(2).getReg(), MRI);
262 LLT Ext1DstTy = MRI.getType(ExtMI1->getOperand(0).getReg());
263 LLT Ext2DstTy = MRI.getType(ExtMI2->getOperand(0).getReg());
264
265 if (ExtMI1->getOpcode() != ExtMI2->getOpcode() || Ext1DstTy != Ext2DstTy)
266 return false;
267 I1Opc = ExtMI1->getOpcode();
268 SrcTy = MRI.getType(ExtMI1->getOperand(1).getReg());
269 std::get<0>(MatchInfo) = ExtMI1->getOperand(1).getReg();
270 std::get<1>(MatchInfo) = ExtMI2->getOperand(1).getReg();
271 } else {
272 SrcTy = MRI.getType(I1->getOperand(1).getReg());
273 std::get<0>(MatchInfo) = I1->getOperand(1).getReg();
274 std::get<1>(MatchInfo) = 0;
275 }
276
277 if (I1Opc == TargetOpcode::G_ZEXT)
278 std::get<2>(MatchInfo) = 0;
279 else if (I1Opc == TargetOpcode::G_SEXT)
280 std::get<2>(MatchInfo) = 1;
281 else
282 return false;
283
284 if (SrcTy.getScalarSizeInBits() != 8 || SrcTy.getNumElements() % 8 != 0)
285 return false;
286
287 return true;
288}
289
290void applyExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI,
291 MachineIRBuilder &Builder,
292 GISelChangeObserver &Observer,
293 const AArch64Subtarget &STI,
294 std::tuple<Register, Register, bool> &MatchInfo) {
295 assert(MI.getOpcode() == TargetOpcode::G_VECREDUCE_ADD &&
296 "Expected a G_VECREDUCE_ADD instruction");
297 assert(STI.hasDotProd() && "Target should have Dot Product feature");
298
299 // Initialise the variables
300 unsigned DotOpcode =
301 std::get<2>(MatchInfo) ? AArch64::G_SDOT : AArch64::G_UDOT;
302 Register Ext1SrcReg = std::get<0>(MatchInfo);
303
304 // If there is one source register, create a vector of 0s as the second
305 // source register
306 Register Ext2SrcReg;
307 if (std::get<1>(MatchInfo) == 0)
308 Ext2SrcReg = Builder.buildConstant(MRI.getType(Ext1SrcReg), 1)
309 ->getOperand(0)
310 .getReg();
311 else
312 Ext2SrcReg = std::get<1>(MatchInfo);
313
314 // Find out how many DOT instructions are needed
315 LLT SrcTy = MRI.getType(Ext1SrcReg);
316 LLT MidTy;
317 unsigned NumOfDotMI;
318 if (SrcTy.getNumElements() % 16 == 0) {
319 NumOfDotMI = SrcTy.getNumElements() / 16;
320 MidTy = LLT::fixed_vector(4, 32);
321 } else if (SrcTy.getNumElements() % 8 == 0) {
322 NumOfDotMI = SrcTy.getNumElements() / 8;
323 MidTy = LLT::fixed_vector(2, 32);
324 } else {
325 llvm_unreachable("Source type number of elements is not multiple of 8");
326 }
327
328 // Handle case where one DOT instruction is needed
329 if (NumOfDotMI == 1) {
330 auto Zeroes = Builder.buildConstant(MidTy, 0)->getOperand(0).getReg();
331 auto Dot = Builder.buildInstr(DotOpcode, {MidTy},
332 {Zeroes, Ext1SrcReg, Ext2SrcReg});
333 Builder.buildVecReduceAdd(MI.getOperand(0), Dot->getOperand(0));
334 } else {
335 // If not pad the last v8 element with 0s to a v16
336 SmallVector<Register, 4> Ext1UnmergeReg;
337 SmallVector<Register, 4> Ext2UnmergeReg;
338 if (SrcTy.getNumElements() % 16 != 0) {
339 SmallVector<Register> Leftover1;
340 SmallVector<Register> Leftover2;
341
342 // Split the elements into v16i8 and v8i8
343 LLT MainTy = LLT::fixed_vector(16, 8);
344 LLT LeftoverTy1, LeftoverTy2;
345 if ((!extractParts(Ext1SrcReg, MRI.getType(Ext1SrcReg), MainTy,
346 LeftoverTy1, Ext1UnmergeReg, Leftover1, Builder,
347 MRI)) ||
348 (!extractParts(Ext2SrcReg, MRI.getType(Ext2SrcReg), MainTy,
349 LeftoverTy2, Ext2UnmergeReg, Leftover2, Builder,
350 MRI))) {
351 llvm_unreachable("Unable to split this vector properly");
352 }
353
354 // Pad the leftover v8i8 vector with register of 0s of type v8i8
355 Register v8Zeroes = Builder.buildConstant(LLT::fixed_vector(8, 8), 0)
356 ->getOperand(0)
357 .getReg();
358
359 Ext1UnmergeReg.push_back(
360 Builder
361 .buildMergeLikeInstr(LLT::fixed_vector(16, 8),
362 {Leftover1[0], v8Zeroes})
363 .getReg(0));
364 Ext2UnmergeReg.push_back(
365 Builder
366 .buildMergeLikeInstr(LLT::fixed_vector(16, 8),
367 {Leftover2[0], v8Zeroes})
368 .getReg(0));
369
370 } else {
371 // Unmerge the source vectors to v16i8
372 unsigned SrcNumElts = SrcTy.getNumElements();
373 extractParts(Ext1SrcReg, LLT::fixed_vector(16, 8), SrcNumElts / 16,
374 Ext1UnmergeReg, Builder, MRI);
375 extractParts(Ext2SrcReg, LLT::fixed_vector(16, 8), SrcNumElts / 16,
376 Ext2UnmergeReg, Builder, MRI);
377 }
378
379 // Build the UDOT instructions
381 unsigned NumElements = 0;
382 for (unsigned i = 0; i < Ext1UnmergeReg.size(); i++) {
383 LLT ZeroesLLT;
384 // Check if it is 16 or 8 elements. Set Zeroes to the according size
385 if (MRI.getType(Ext1UnmergeReg[i]).getNumElements() == 16) {
386 ZeroesLLT = LLT::fixed_vector(4, 32);
387 NumElements += 4;
388 } else {
389 ZeroesLLT = LLT::fixed_vector(2, 32);
390 NumElements += 2;
391 }
392 auto Zeroes = Builder.buildConstant(ZeroesLLT, 0)->getOperand(0).getReg();
393 DotReg.push_back(
394 Builder
395 .buildInstr(DotOpcode, {MRI.getType(Zeroes)},
396 {Zeroes, Ext1UnmergeReg[i], Ext2UnmergeReg[i]})
397 .getReg(0));
398 }
399
400 // Merge the output
401 auto ConcatMI =
402 Builder.buildConcatVectors(LLT::fixed_vector(NumElements, 32), DotReg);
403
404 // Put it through a vector reduction
405 Builder.buildVecReduceAdd(MI.getOperand(0).getReg(),
406 ConcatMI->getOperand(0).getReg());
407 }
408
409 // Erase the dead instructions
410 MI.eraseFromParent();
411}
412
413// Matches {U/S}ADDV(ext(x)) => {U/S}ADDLV(x)
414// Ensure that the type coming from the extend instruction is the right size
415bool matchExtUaddvToUaddlv(MachineInstr &MI, MachineRegisterInfo &MRI,
416 std::pair<Register, bool> &MatchInfo) {
417 assert(MI.getOpcode() == TargetOpcode::G_VECREDUCE_ADD &&
418 "Expected G_VECREDUCE_ADD Opcode");
419
420 // Check if the last instruction is an extend
421 MachineInstr *ExtMI = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI);
422 auto ExtOpc = ExtMI->getOpcode();
423
424 if (ExtOpc == TargetOpcode::G_ZEXT)
425 std::get<1>(MatchInfo) = 0;
426 else if (ExtOpc == TargetOpcode::G_SEXT)
427 std::get<1>(MatchInfo) = 1;
428 else
429 return false;
430
431 // Check if the source register is a valid type
432 Register ExtSrcReg = ExtMI->getOperand(1).getReg();
433 LLT ExtSrcTy = MRI.getType(ExtSrcReg);
434 LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
435 if ((DstTy.getScalarSizeInBits() == 16 &&
436 ExtSrcTy.getNumElements() % 8 == 0 && ExtSrcTy.getNumElements() < 256) ||
437 (DstTy.getScalarSizeInBits() == 32 &&
438 ExtSrcTy.getNumElements() % 4 == 0) ||
439 (DstTy.getScalarSizeInBits() == 64 &&
440 ExtSrcTy.getNumElements() % 4 == 0)) {
441 std::get<0>(MatchInfo) = ExtSrcReg;
442 return true;
443 }
444 return false;
445}
446
447void applyExtUaddvToUaddlv(MachineInstr &MI, MachineRegisterInfo &MRI,
449 std::pair<Register, bool> &MatchInfo) {
450 assert(MI.getOpcode() == TargetOpcode::G_VECREDUCE_ADD &&
451 "Expected G_VECREDUCE_ADD Opcode");
452
453 unsigned Opc = std::get<1>(MatchInfo) ? AArch64::G_SADDLV : AArch64::G_UADDLV;
454 Register SrcReg = std::get<0>(MatchInfo);
455 Register DstReg = MI.getOperand(0).getReg();
456 LLT SrcTy = MRI.getType(SrcReg);
457 LLT DstTy = MRI.getType(DstReg);
458
459 // If SrcTy has more elements than expected, split them into multiple
460 // insructions and sum the results
461 LLT MainTy;
462 SmallVector<Register, 1> WorkingRegisters;
463 unsigned SrcScalSize = SrcTy.getScalarSizeInBits();
464 unsigned SrcNumElem = SrcTy.getNumElements();
465 if ((SrcScalSize == 8 && SrcNumElem > 16) ||
466 (SrcScalSize == 16 && SrcNumElem > 8) ||
467 (SrcScalSize == 32 && SrcNumElem > 4)) {
468
469 LLT LeftoverTy;
470 SmallVector<Register, 4> LeftoverRegs;
471 if (SrcScalSize == 8)
472 MainTy = LLT::fixed_vector(16, 8);
473 else if (SrcScalSize == 16)
474 MainTy = LLT::fixed_vector(8, 16);
475 else if (SrcScalSize == 32)
476 MainTy = LLT::fixed_vector(4, 32);
477 else
478 llvm_unreachable("Source's Scalar Size not supported");
479
480 // Extract the parts and put each extracted sources through U/SADDLV and put
481 // the values inside a small vec
482 extractParts(SrcReg, SrcTy, MainTy, LeftoverTy, WorkingRegisters,
483 LeftoverRegs, B, MRI);
484 for (unsigned I = 0; I < LeftoverRegs.size(); I++) {
485 WorkingRegisters.push_back(LeftoverRegs[I]);
486 }
487 } else {
488 WorkingRegisters.push_back(SrcReg);
489 MainTy = SrcTy;
490 }
491
492 unsigned MidScalarSize = MainTy.getScalarSizeInBits() * 2;
493 LLT MidScalarLLT = LLT::scalar(MidScalarSize);
494 Register zeroReg = B.buildConstant(LLT::scalar(64), 0).getReg(0);
495 for (unsigned I = 0; I < WorkingRegisters.size(); I++) {
496 // If the number of elements is too small to build an instruction, extend
497 // its size before applying addlv
498 LLT WorkingRegTy = MRI.getType(WorkingRegisters[I]);
499 if ((WorkingRegTy.getScalarSizeInBits() == 8) &&
500 (WorkingRegTy.getNumElements() == 4)) {
501 WorkingRegisters[I] =
502 B.buildInstr(std::get<1>(MatchInfo) ? TargetOpcode::G_SEXT
503 : TargetOpcode::G_ZEXT,
504 {LLT::fixed_vector(4, 16)}, {WorkingRegisters[I]})
505 .getReg(0);
506 }
507
508 // Generate the {U/S}ADDLV instruction, whose output is always double of the
509 // Src's Scalar size
510 LLT addlvTy = MidScalarSize <= 32 ? LLT::fixed_vector(4, 32)
511 : LLT::fixed_vector(2, 64);
512 Register addlvReg =
513 B.buildInstr(Opc, {addlvTy}, {WorkingRegisters[I]}).getReg(0);
514
515 // The output from {U/S}ADDLV gets placed in the lowest lane of a v4i32 or
516 // v2i64 register.
517 // i16, i32 results uses v4i32 registers
518 // i64 results uses v2i64 registers
519 // Therefore we have to extract/truncate the the value to the right type
520 if (MidScalarSize == 32 || MidScalarSize == 64) {
521 WorkingRegisters[I] = B.buildInstr(AArch64::G_EXTRACT_VECTOR_ELT,
522 {MidScalarLLT}, {addlvReg, zeroReg})
523 .getReg(0);
524 } else {
525 Register extractReg = B.buildInstr(AArch64::G_EXTRACT_VECTOR_ELT,
526 {LLT::scalar(32)}, {addlvReg, zeroReg})
527 .getReg(0);
528 WorkingRegisters[I] =
529 B.buildTrunc({MidScalarLLT}, {extractReg}).getReg(0);
530 }
531 }
532
533 Register outReg;
534 if (WorkingRegisters.size() > 1) {
535 outReg = B.buildAdd(MidScalarLLT, WorkingRegisters[0], WorkingRegisters[1])
536 .getReg(0);
537 for (unsigned I = 2; I < WorkingRegisters.size(); I++) {
538 outReg = B.buildAdd(MidScalarLLT, outReg, WorkingRegisters[I]).getReg(0);
539 }
540 } else {
541 outReg = WorkingRegisters[0];
542 }
543
544 if (DstTy.getScalarSizeInBits() > MidScalarSize) {
545 // Handle the scalar value if the DstTy's Scalar Size is more than double
546 // Src's ScalarType
547 B.buildInstr(std::get<1>(MatchInfo) ? TargetOpcode::G_SEXT
548 : TargetOpcode::G_ZEXT,
549 {DstReg}, {outReg});
550 } else {
551 B.buildCopy(DstReg, outReg);
552 }
553
554 MI.eraseFromParent();
555}
556
557bool tryToSimplifyUADDO(MachineInstr &MI, MachineIRBuilder &B,
558 CombinerHelper &Helper, GISelChangeObserver &Observer) {
559 // Try simplify G_UADDO with 8 or 16 bit operands to wide G_ADD and TBNZ if
560 // result is only used in the no-overflow case. It is restricted to cases
561 // where we know that the high-bits of the operands are 0. If there's an
562 // overflow, then the 9th or 17th bit must be set, which can be checked
563 // using TBNZ.
564 //
565 // Change (for UADDOs on 8 and 16 bits):
566 //
567 // %z0 = G_ASSERT_ZEXT _
568 // %op0 = G_TRUNC %z0
569 // %z1 = G_ASSERT_ZEXT _
570 // %op1 = G_TRUNC %z1
571 // %val, %cond = G_UADDO %op0, %op1
572 // G_BRCOND %cond, %error.bb
573 //
574 // error.bb:
575 // (no successors and no uses of %val)
576 //
577 // To:
578 //
579 // %z0 = G_ASSERT_ZEXT _
580 // %z1 = G_ASSERT_ZEXT _
581 // %add = G_ADD %z0, %z1
582 // %val = G_TRUNC %add
583 // %bit = G_AND %add, 1 << scalar-size-in-bits(%op1)
584 // %cond = G_ICMP NE, %bit, 0
585 // G_BRCOND %cond, %error.bb
586
587 auto &MRI = *B.getMRI();
588
589 MachineOperand *DefOp0 = MRI.getOneDef(MI.getOperand(2).getReg());
590 MachineOperand *DefOp1 = MRI.getOneDef(MI.getOperand(3).getReg());
591 Register Op0Wide;
592 Register Op1Wide;
593 if (!mi_match(DefOp0->getParent(), MRI, m_GTrunc(m_Reg(Op0Wide))) ||
594 !mi_match(DefOp1->getParent(), MRI, m_GTrunc(m_Reg(Op1Wide))))
595 return false;
596 LLT WideTy0 = MRI.getType(Op0Wide);
597 LLT WideTy1 = MRI.getType(Op1Wide);
598 Register ResVal = MI.getOperand(0).getReg();
599 LLT OpTy = MRI.getType(ResVal);
600 MachineInstr *Op0WideDef = MRI.getVRegDef(Op0Wide);
601 MachineInstr *Op1WideDef = MRI.getVRegDef(Op1Wide);
602
603 unsigned OpTySize = OpTy.getScalarSizeInBits();
604 // First check that the G_TRUNC feeding the G_UADDO are no-ops, because the
605 // inputs have been zero-extended.
606 if (Op0WideDef->getOpcode() != TargetOpcode::G_ASSERT_ZEXT ||
607 Op1WideDef->getOpcode() != TargetOpcode::G_ASSERT_ZEXT ||
608 OpTySize != Op0WideDef->getOperand(2).getImm() ||
609 OpTySize != Op1WideDef->getOperand(2).getImm())
610 return false;
611
612 // Only scalar UADDO with either 8 or 16 bit operands are handled.
613 if (!WideTy0.isScalar() || !WideTy1.isScalar() || WideTy0 != WideTy1 ||
614 OpTySize >= WideTy0.getScalarSizeInBits() ||
615 (OpTySize != 8 && OpTySize != 16))
616 return false;
617
618 // The overflow-status result must be used by a branch only.
619 Register ResStatus = MI.getOperand(1).getReg();
620 if (!MRI.hasOneNonDBGUse(ResStatus))
621 return false;
622 MachineInstr *CondUser = &*MRI.use_instr_nodbg_begin(ResStatus);
623 if (CondUser->getOpcode() != TargetOpcode::G_BRCOND)
624 return false;
625
626 // Make sure the computed result is only used in the no-overflow blocks.
627 MachineBasicBlock *CurrentMBB = MI.getParent();
628 MachineBasicBlock *FailMBB = CondUser->getOperand(1).getMBB();
629 if (!FailMBB->succ_empty() || CondUser->getParent() != CurrentMBB)
630 return false;
631 if (any_of(MRI.use_nodbg_instructions(ResVal),
632 [&MI, FailMBB, CurrentMBB](MachineInstr &I) {
633 return &MI != &I &&
634 (I.getParent() == FailMBB || I.getParent() == CurrentMBB);
635 }))
636 return false;
637
638 // Remove G_ADDO.
639 B.setInstrAndDebugLoc(*MI.getNextNode());
640 MI.eraseFromParent();
641
642 // Emit wide add.
643 Register AddDst = MRI.cloneVirtualRegister(Op0Wide);
644 B.buildInstr(TargetOpcode::G_ADD, {AddDst}, {Op0Wide, Op1Wide});
645
646 // Emit check of the 9th or 17th bit and update users (the branch). This will
647 // later be folded to TBNZ.
648 Register CondBit = MRI.cloneVirtualRegister(Op0Wide);
649 B.buildAnd(
650 CondBit, AddDst,
651 B.buildConstant(LLT::scalar(32), OpTySize == 8 ? 1 << 8 : 1 << 16));
652 B.buildICmp(CmpInst::ICMP_NE, ResStatus, CondBit,
653 B.buildConstant(LLT::scalar(32), 0));
654
655 // Update ZEXts users of the result value. Because all uses are in the
656 // no-overflow case, we know that the top bits are 0 and we can ignore ZExts.
657 B.buildZExtOrTrunc(ResVal, AddDst);
658 for (MachineOperand &U : make_early_inc_range(MRI.use_operands(ResVal))) {
659 Register WideReg;
660 if (mi_match(U.getParent(), MRI, m_GZExt(m_Reg(WideReg)))) {
661 auto OldR = U.getParent()->getOperand(0).getReg();
662 Observer.erasingInstr(*U.getParent());
663 U.getParent()->eraseFromParent();
664 Helper.replaceRegWith(MRI, OldR, AddDst);
665 }
666 }
667
668 return true;
669}
670
671class AArch64PreLegalizerCombinerImpl : public Combiner {
672protected:
673 // TODO: Make CombinerHelper methods const.
674 mutable CombinerHelper Helper;
675 const AArch64PreLegalizerCombinerImplRuleConfig &RuleConfig;
676 const AArch64Subtarget &STI;
677
678public:
679 AArch64PreLegalizerCombinerImpl(
680 MachineFunction &MF, CombinerInfo &CInfo, const TargetPassConfig *TPC,
681 GISelKnownBits &KB, GISelCSEInfo *CSEInfo,
682 const AArch64PreLegalizerCombinerImplRuleConfig &RuleConfig,
683 const AArch64Subtarget &STI, MachineDominatorTree *MDT,
684 const LegalizerInfo *LI);
685
686 static const char *getName() { return "AArch6400PreLegalizerCombiner"; }
687
688 bool tryCombineAll(MachineInstr &I) const override;
689
690 bool tryCombineAllImpl(MachineInstr &I) const;
691
692private:
693#define GET_GICOMBINER_CLASS_MEMBERS
694#include "AArch64GenPreLegalizeGICombiner.inc"
695#undef GET_GICOMBINER_CLASS_MEMBERS
696};
697
698#define GET_GICOMBINER_IMPL
699#include "AArch64GenPreLegalizeGICombiner.inc"
700#undef GET_GICOMBINER_IMPL
701
702AArch64PreLegalizerCombinerImpl::AArch64PreLegalizerCombinerImpl(
703 MachineFunction &MF, CombinerInfo &CInfo, const TargetPassConfig *TPC,
704 GISelKnownBits &KB, GISelCSEInfo *CSEInfo,
705 const AArch64PreLegalizerCombinerImplRuleConfig &RuleConfig,
706 const AArch64Subtarget &STI, MachineDominatorTree *MDT,
707 const LegalizerInfo *LI)
708 : Combiner(MF, CInfo, TPC, &KB, CSEInfo),
709 Helper(Observer, B, /*IsPreLegalize*/ true, &KB, MDT, LI),
710 RuleConfig(RuleConfig), STI(STI),
712#include "AArch64GenPreLegalizeGICombiner.inc"
714{
715}
716
717bool AArch64PreLegalizerCombinerImpl::tryCombineAll(MachineInstr &MI) const {
718 if (tryCombineAllImpl(MI))
719 return true;
720
721 unsigned Opc = MI.getOpcode();
722 switch (Opc) {
723 case TargetOpcode::G_SHUFFLE_VECTOR:
724 return Helper.tryCombineShuffleVector(MI);
725 case TargetOpcode::G_UADDO:
726 return tryToSimplifyUADDO(MI, B, Helper, Observer);
727 case TargetOpcode::G_MEMCPY_INLINE:
728 return Helper.tryEmitMemcpyInline(MI);
729 case TargetOpcode::G_MEMCPY:
730 case TargetOpcode::G_MEMMOVE:
731 case TargetOpcode::G_MEMSET: {
732 // If we're at -O0 set a maxlen of 32 to inline, otherwise let the other
733 // heuristics decide.
734 unsigned MaxLen = CInfo.EnableOpt ? 0 : 32;
735 // Try to inline memcpy type calls if optimizations are enabled.
736 if (Helper.tryCombineMemCpyFamily(MI, MaxLen))
737 return true;
738 if (Opc == TargetOpcode::G_MEMSET)
739 return llvm::AArch64GISelUtils::tryEmitBZero(MI, B, CInfo.EnableMinSize);
740 return false;
741 }
742 }
743
744 return false;
745}
746
747// Pass boilerplate
748// ================
749
750class AArch64PreLegalizerCombiner : public MachineFunctionPass {
751public:
752 static char ID;
753
754 AArch64PreLegalizerCombiner();
755
756 StringRef getPassName() const override {
757 return "AArch64PreLegalizerCombiner";
758 }
759
760 bool runOnMachineFunction(MachineFunction &MF) override;
761
762 void getAnalysisUsage(AnalysisUsage &AU) const override;
763
764private:
765 AArch64PreLegalizerCombinerImplRuleConfig RuleConfig;
766};
767} // end anonymous namespace
768
769void AArch64PreLegalizerCombiner::getAnalysisUsage(AnalysisUsage &AU) const {
771 AU.setPreservesCFG();
780}
781
782AArch64PreLegalizerCombiner::AArch64PreLegalizerCombiner()
785
786 if (!RuleConfig.parseCommandLineOption())
787 report_fatal_error("Invalid rule identifier");
788}
789
790bool AArch64PreLegalizerCombiner::runOnMachineFunction(MachineFunction &MF) {
791 if (MF.getProperties().hasProperty(
792 MachineFunctionProperties::Property::FailedISel))
793 return false;
794 auto &TPC = getAnalysis<TargetPassConfig>();
795
796 // Enable CSE.
798 getAnalysis<GISelCSEAnalysisWrapperPass>().getCSEWrapper();
799 auto *CSEInfo = &Wrapper.get(TPC.getCSEConfig());
800
802 const auto *LI = ST.getLegalizerInfo();
803
804 const Function &F = MF.getFunction();
805 bool EnableOpt =
806 MF.getTarget().getOptLevel() != CodeGenOptLevel::None && !skipFunction(F);
807 GISelKnownBits *KB = &getAnalysis<GISelKnownBitsAnalysis>().get(MF);
808 MachineDominatorTree *MDT = &getAnalysis<MachineDominatorTree>();
809 CombinerInfo CInfo(/*AllowIllegalOps*/ true, /*ShouldLegalizeIllegal*/ false,
810 /*LegalizerInfo*/ nullptr, EnableOpt, F.hasOptSize(),
811 F.hasMinSize());
812 AArch64PreLegalizerCombinerImpl Impl(MF, CInfo, &TPC, *KB, CSEInfo,
813 RuleConfig, ST, MDT, LI);
814 return Impl.combineMachineInstrs();
815}
816
817char AArch64PreLegalizerCombiner::ID = 0;
818INITIALIZE_PASS_BEGIN(AArch64PreLegalizerCombiner, DEBUG_TYPE,
819 "Combine AArch64 machine instrs before legalization",
820 false, false)
824INITIALIZE_PASS_END(AArch64PreLegalizerCombiner, DEBUG_TYPE,
825 "Combine AArch64 machine instrs before legalization", false,
826 false)
827
828namespace llvm {
830 return new AArch64PreLegalizerCombiner();
831}
832} // end namespace llvm
unsigned const MachineRegisterInfo * MRI
#define GET_GICOMBINER_CONSTRUCTOR_INITS
#define DEBUG_TYPE
Combine AArch64 machine instrs before legalization
amdgpu aa AMDGPU Address space based Alias Analysis Wrapper
basic Basic Alias true
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
Provides analysis for continuously CSEing during GISel passes.
This contains common combine transformations that may be used in a combine pass,or by the target else...
Option class for Targets to specify which operations are combined how and when.
This contains the base class for all Combiners generated by TableGen.
Provides analysis for querying information about KnownBits during GISel passes.
Hexagon Vector Combine
IRTranslator LLVM IR MI
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
Contains matchers for matching SSA Machine Instructions.
This file declares the MachineIRBuilder class.
static unsigned getReg(const MCDisassembler *D, unsigned RC, unsigned RegNo)
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition: PassSupport.h:55
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:59
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:52
static StringRef getName(Value *V)
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
Target-Independent Code Generator Pass Configuration Options pass.
Value * RHS
Value * LHS
unsigned ClassifyGlobalReference(const GlobalValue *GV, const TargetMachine &TM) const
ClassifyGlobalReference - Find the target operand flags that describe how a global value should be re...
APInt bitcastToAPInt() const
Definition: APFloat.h:1210
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition: Pass.cpp:269
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition: InstrTypes.h:960
@ ICMP_NE
not equal
Definition: InstrTypes.h:982
void replaceRegWith(MachineRegisterInfo &MRI, Register FromReg, Register ToReg) const
MachineRegisterInfo::replaceRegWith() and inform the observer of the changes.
bool tryCombineMemCpyFamily(MachineInstr &MI, unsigned MaxLen=0)
Optimize memcpy intrinsics et al, e.g.
bool tryEmitMemcpyInline(MachineInstr &MI)
Emit loads and stores that perform the given memcpy.
bool tryCombineShuffleVector(MachineInstr &MI)
Try to combine G_SHUFFLE_VECTOR into G_CONCAT_VECTORS.
Combiner implementation.
Definition: Combiner.h:34
virtual bool tryCombineAll(MachineInstr &I) const =0
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:311
The actual analysis pass wrapper.
Definition: CSEInfo.h:222
Simple wrapper that does the following.
Definition: CSEInfo.h:204
The CSE Analysis object.
Definition: CSEInfo.h:69
Abstract class that contains various methods for clients to notify about changes.
virtual void changingInstr(MachineInstr &MI)=0
This instruction is about to be mutated in some way.
virtual void changedInstr(MachineInstr &MI)=0
This instruction was mutated in some way.
virtual void erasingInstr(MachineInstr &MI)=0
An instruction is about to be erased.
To use KnownBitsInfo analysis in a pass, KnownBitsInfo &Info = getAnalysis<GISelKnownBitsInfoAnalysis...
unsigned computeNumSignBits(Register R, const APInt &DemandedElts, unsigned Depth=0)
bool isEquality() const
Return true if this predicate is either EQ or NE.
constexpr unsigned getScalarSizeInBits() const
Definition: LowLevelType.h:267
constexpr bool isScalar() const
Definition: LowLevelType.h:146
static constexpr LLT scalar(unsigned SizeInBits)
Get a low-level scalar or aggregate "bag of bits".
Definition: LowLevelType.h:42
constexpr uint16_t getNumElements() const
Returns the number of elements in a vector LLT.
Definition: LowLevelType.h:159
constexpr TypeSize getSizeInBits() const
Returns the total size of the type. Must only be called on sized types.
Definition: LowLevelType.h:193
static constexpr LLT fixed_vector(unsigned NumElements, unsigned ScalarSizeInBits)
Get a low-level fixed-width vector of some number of elements and element width.
Definition: LowLevelType.h:100
DominatorTree Class - Concrete subclass of DominatorTreeBase that is used to compute a normal dominat...
MachineFunctionPass - This class adapts the FunctionPass interface to allow convenient creation of pa...
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - Subclasses that override getAnalysisUsage must call this.
virtual bool runOnMachineFunction(MachineFunction &MF)=0
runOnMachineFunction - This method must be overloaded to perform the desired machine code transformat...
bool hasProperty(Property P) const
const TargetSubtargetInfo & getSubtarget() const
getSubtarget - Return the subtarget for which this machine code is being compiled.
Function & getFunction()
Return the LLVM function that this machine code represents.
const LLVMTargetMachine & getTarget() const
getTarget - Return the target machine this machine code is compiled with
const MachineFunctionProperties & getProperties() const
Get the function properties.
Helper class to build MachineInstr.
MachineInstrBuilder buildConcatVectors(const DstOp &Res, ArrayRef< Register > Ops)
Build and insert Res = G_CONCAT_VECTORS Op0, ...
MachineInstrBuilder buildInstr(unsigned Opcode)
Build and insert <empty> = Opcode <empty>.
void setInstrAndDebugLoc(MachineInstr &MI)
Set the insertion point to before MI, and set the debug loc to MI's loc.
MachineInstrBuilder buildVecReduceAdd(const DstOp &Dst, const SrcOp &Src)
Build and insert Res = G_VECREDUCE_ADD Src.
virtual MachineInstrBuilder buildConstant(const DstOp &Res, const ConstantInt &Val)
Build and insert Res = G_CONSTANT Val.
Representation of each machine instruction.
Definition: MachineInstr.h:69
unsigned getOpcode() const
Returns the opcode of this MachineInstr.
Definition: MachineInstr.h:546
const MachineBasicBlock * getParent() const
Definition: MachineInstr.h:329
const MachineOperand & getOperand(unsigned i) const
Definition: MachineInstr.h:556
MachineOperand class - Representation of each machine instruction operand.
int64_t getImm() const
MachineBasicBlock * getMBB() const
MachineInstr * getParent()
getParent - Return the instruction that this operand belongs to.
Register getReg() const
getReg - Returns the register number.
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
virtual StringRef getPassName() const
getPassName - Return a nice clean name for a pass.
Definition: Pass.cpp:81
Wrapper class representing virtual and physical registers.
Definition: Register.h:19
size_t size() const
Definition: SmallVector.h:91
void push_back(const T &Elt)
Definition: SmallVector.h:426
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1209
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:50
CodeGenOptLevel getOptLevel() const
Returns the optimization level: None, Less, Default, or Aggressive.
Target-Independent Code Generator Pass Configuration Options.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
A Use represents the edge between a Value definition and its users.
Definition: Use.h:43
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
bool tryEmitBZero(MachineInstr &MI, MachineIRBuilder &MIRBuilder, bool MinSize)
Replace a G_MEMSET with a value of 0 with a G_BZERO instruction if it is supported and beneficial to ...
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
operand_type_match m_Reg()
SpecificConstantMatch m_SpecificICst(int64_t RequestedValue)
Matches a constant equal to RequestedValue.
UnaryOp_match< SrcTy, TargetOpcode::G_ZEXT > m_GZExt(const SrcTy &Src)
bool mi_match(Reg R, const MachineRegisterInfo &MRI, Pattern &&P)
UnaryOp_match< SrcTy, TargetOpcode::G_TRUNC > m_GTrunc(const SrcTy &Src)
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
FunctionPass * createAArch64PreLegalizerCombiner()
@ Offset
Definition: DWP.cpp:456
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1731
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition: STLExtras.h:665
MachineInstr * getDefIgnoringCopies(Register Reg, const MachineRegisterInfo &MRI)
Find the def instruction for Reg, folding away any trivial copies.
Definition: Utils.cpp:465
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1738
void report_fatal_error(Error Err, bool gen_crash_diag=true)
Report a serious error, calling any installed error handler.
Definition: Error.cpp:156
void initializeAArch64PreLegalizerCombinerPass(PassRegistry &)
void extractParts(Register Reg, LLT Ty, int NumParts, SmallVectorImpl< Register > &VRegs, MachineIRBuilder &MIRBuilder, MachineRegisterInfo &MRI)
Helper function to split a wide generic register into bitwise blocks with the given Type (which impli...
Definition: Utils.cpp:479
void getSelectionDAGFallbackAnalysisUsage(AnalysisUsage &AU)
Modify analysis usage so it preserves passes required for the SelectionDAG fallback.
Definition: Utils.cpp:1072
std::optional< ValueAndVReg > getIConstantVRegValWithLookThrough(Register VReg, const MachineRegisterInfo &MRI, bool LookThroughInstrs=true)
If VReg is defined by a statically evaluable chain of instructions rooted on a G_CONSTANT returns its...
Definition: Utils.cpp:413
auto instrs(const MachineBasicBlock &BB)