LLVM 19.0.0git
X86FixupVectorConstants.cpp
Go to the documentation of this file.
1//===-- X86FixupVectorConstants.cpp - optimize constant generation -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file examines all full size vector constant pool loads and attempts to
10// replace them with smaller constant pool entries, including:
11// * Converting AVX512 memory-fold instructions to their broadcast-fold form.
12// * Using vzload scalar loads.
13// * Broadcasting of full width loads.
14// * Sign/Zero extension of full width loads.
15//
16//===----------------------------------------------------------------------===//
17
18#include "X86.h"
19#include "X86InstrFoldTables.h"
20#include "X86InstrInfo.h"
21#include "X86Subtarget.h"
22#include "llvm/ADT/Statistic.h"
24
25using namespace llvm;
26
27#define DEBUG_TYPE "x86-fixup-vector-constants"
28
29STATISTIC(NumInstChanges, "Number of instructions changes");
30
31namespace {
32class X86FixupVectorConstantsPass : public MachineFunctionPass {
33public:
34 static char ID;
35
36 X86FixupVectorConstantsPass() : MachineFunctionPass(ID) {}
37
38 StringRef getPassName() const override {
39 return "X86 Fixup Vector Constants";
40 }
41
42 bool runOnMachineFunction(MachineFunction &MF) override;
43 bool processInstruction(MachineFunction &MF, MachineBasicBlock &MBB,
45
46 // This pass runs after regalloc and doesn't support VReg operands.
49 MachineFunctionProperties::Property::NoVRegs);
50 }
51
52private:
53 const X86InstrInfo *TII = nullptr;
54 const X86Subtarget *ST = nullptr;
55 const MCSchedModel *SM = nullptr;
56};
57} // end anonymous namespace
58
59char X86FixupVectorConstantsPass::ID = 0;
60
61INITIALIZE_PASS(X86FixupVectorConstantsPass, DEBUG_TYPE, DEBUG_TYPE, false, false)
62
64 return new X86FixupVectorConstantsPass();
65}
66
67// Attempt to extract the full width of bits data from the constant.
68static std::optional<APInt> extractConstantBits(const Constant *C) {
69 unsigned NumBits = C->getType()->getPrimitiveSizeInBits();
70
71 if (isa<UndefValue>(C))
72 return APInt::getZero(NumBits);
73
74 if (auto *CInt = dyn_cast<ConstantInt>(C))
75 return CInt->getValue();
76
77 if (auto *CFP = dyn_cast<ConstantFP>(C))
78 return CFP->getValue().bitcastToAPInt();
79
80 if (auto *CV = dyn_cast<ConstantVector>(C)) {
81 if (auto *CVSplat = CV->getSplatValue(/*AllowUndefs*/ true)) {
82 if (std::optional<APInt> Bits = extractConstantBits(CVSplat)) {
83 assert((NumBits % Bits->getBitWidth()) == 0 && "Illegal splat");
84 return APInt::getSplat(NumBits, *Bits);
85 }
86 }
87
88 APInt Bits = APInt::getZero(NumBits);
89 for (unsigned I = 0, E = CV->getNumOperands(); I != E; ++I) {
90 Constant *Elt = CV->getOperand(I);
91 std::optional<APInt> SubBits = extractConstantBits(Elt);
92 if (!SubBits)
93 return std::nullopt;
94 assert(NumBits == (E * SubBits->getBitWidth()) &&
95 "Illegal vector element size");
96 Bits.insertBits(*SubBits, I * SubBits->getBitWidth());
97 }
98 return Bits;
99 }
100
101 if (auto *CDS = dyn_cast<ConstantDataSequential>(C)) {
102 bool IsInteger = CDS->getElementType()->isIntegerTy();
103 bool IsFloat = CDS->getElementType()->isHalfTy() ||
104 CDS->getElementType()->isBFloatTy() ||
105 CDS->getElementType()->isFloatTy() ||
106 CDS->getElementType()->isDoubleTy();
107 if (IsInteger || IsFloat) {
108 APInt Bits = APInt::getZero(NumBits);
109 unsigned EltBits = CDS->getElementType()->getPrimitiveSizeInBits();
110 for (unsigned I = 0, E = CDS->getNumElements(); I != E; ++I) {
111 if (IsInteger)
112 Bits.insertBits(CDS->getElementAsAPInt(I), I * EltBits);
113 else
114 Bits.insertBits(CDS->getElementAsAPFloat(I).bitcastToAPInt(),
115 I * EltBits);
116 }
117 return Bits;
118 }
119 }
120
121 return std::nullopt;
122}
123
124static std::optional<APInt> extractConstantBits(const Constant *C,
125 unsigned NumBits) {
126 if (std::optional<APInt> Bits = extractConstantBits(C))
127 return Bits->zextOrTrunc(NumBits);
128 return std::nullopt;
129}
130
131// Attempt to compute the splat width of bits data by normalizing the splat to
132// remove undefs.
133static std::optional<APInt> getSplatableConstant(const Constant *C,
134 unsigned SplatBitWidth) {
135 const Type *Ty = C->getType();
136 assert((Ty->getPrimitiveSizeInBits() % SplatBitWidth) == 0 &&
137 "Illegal splat width");
138
139 if (std::optional<APInt> Bits = extractConstantBits(C))
140 if (Bits->isSplat(SplatBitWidth))
141 return Bits->trunc(SplatBitWidth);
142
143 // Detect general splats with undefs.
144 // TODO: Do we need to handle NumEltsBits > SplatBitWidth splitting?
145 if (auto *CV = dyn_cast<ConstantVector>(C)) {
146 unsigned NumOps = CV->getNumOperands();
147 unsigned NumEltsBits = Ty->getScalarSizeInBits();
148 unsigned NumScaleOps = SplatBitWidth / NumEltsBits;
149 if ((SplatBitWidth % NumEltsBits) == 0) {
150 // Collect the elements and ensure that within the repeated splat sequence
151 // they either match or are undef.
152 SmallVector<Constant *, 16> Sequence(NumScaleOps, nullptr);
153 for (unsigned Idx = 0; Idx != NumOps; ++Idx) {
154 if (Constant *Elt = CV->getAggregateElement(Idx)) {
155 if (isa<UndefValue>(Elt))
156 continue;
157 unsigned SplatIdx = Idx % NumScaleOps;
158 if (!Sequence[SplatIdx] || Sequence[SplatIdx] == Elt) {
159 Sequence[SplatIdx] = Elt;
160 continue;
161 }
162 }
163 return std::nullopt;
164 }
165 // Extract the constant bits forming the splat and insert into the bits
166 // data, leave undef as zero.
167 APInt SplatBits = APInt::getZero(SplatBitWidth);
168 for (unsigned I = 0; I != NumScaleOps; ++I) {
169 if (!Sequence[I])
170 continue;
171 if (std::optional<APInt> Bits = extractConstantBits(Sequence[I])) {
172 SplatBits.insertBits(*Bits, I * Bits->getBitWidth());
173 continue;
174 }
175 return std::nullopt;
176 }
177 return SplatBits;
178 }
179 }
180
181 return std::nullopt;
182}
183
184// Split raw bits into a constant vector of elements of a specific bit width.
185// NOTE: We don't always bother converting to scalars if the vector length is 1.
187 const APInt &Bits, unsigned NumSclBits) {
188 unsigned BitWidth = Bits.getBitWidth();
189
190 if (NumSclBits == 8) {
191 SmallVector<uint8_t> RawBits;
192 for (unsigned I = 0; I != BitWidth; I += 8)
193 RawBits.push_back(Bits.extractBits(8, I).getZExtValue());
194 return ConstantDataVector::get(Ctx, RawBits);
195 }
196
197 if (NumSclBits == 16) {
198 SmallVector<uint16_t> RawBits;
199 for (unsigned I = 0; I != BitWidth; I += 16)
200 RawBits.push_back(Bits.extractBits(16, I).getZExtValue());
201 if (SclTy->is16bitFPTy())
202 return ConstantDataVector::getFP(SclTy, RawBits);
203 return ConstantDataVector::get(Ctx, RawBits);
204 }
205
206 if (NumSclBits == 32) {
207 SmallVector<uint32_t> RawBits;
208 for (unsigned I = 0; I != BitWidth; I += 32)
209 RawBits.push_back(Bits.extractBits(32, I).getZExtValue());
210 if (SclTy->isFloatTy())
211 return ConstantDataVector::getFP(SclTy, RawBits);
212 return ConstantDataVector::get(Ctx, RawBits);
213 }
214
215 assert(NumSclBits == 64 && "Unhandled vector element width");
216
217 SmallVector<uint64_t> RawBits;
218 for (unsigned I = 0; I != BitWidth; I += 64)
219 RawBits.push_back(Bits.extractBits(64, I).getZExtValue());
220 if (SclTy->isDoubleTy())
221 return ConstantDataVector::getFP(SclTy, RawBits);
222 return ConstantDataVector::get(Ctx, RawBits);
223}
224
225// Attempt to rebuild a normalized splat vector constant of the requested splat
226// width, built up of potentially smaller scalar values.
227static Constant *rebuildSplatCst(const Constant *C, unsigned /*NumBits*/,
228 unsigned /*NumElts*/, unsigned SplatBitWidth) {
229 // TODO: Truncate to NumBits once ConvertToBroadcastAVX512 support this.
230 std::optional<APInt> Splat = getSplatableConstant(C, SplatBitWidth);
231 if (!Splat)
232 return nullptr;
233
234 // Determine scalar size to use for the constant splat vector, clamping as we
235 // might have found a splat smaller than the original constant data.
236 Type *SclTy = C->getType()->getScalarType();
237 unsigned NumSclBits = SclTy->getPrimitiveSizeInBits();
238 NumSclBits = std::min<unsigned>(NumSclBits, SplatBitWidth);
239
240 // Fallback to i64 / double.
241 NumSclBits = (NumSclBits == 8 || NumSclBits == 16 || NumSclBits == 32)
242 ? NumSclBits
243 : 64;
244
245 // Extract per-element bits.
246 return rebuildConstant(C->getContext(), SclTy, *Splat, NumSclBits);
247}
248
249static Constant *rebuildZeroUpperCst(const Constant *C, unsigned NumBits,
250 unsigned /*NumElts*/,
251 unsigned ScalarBitWidth) {
252 Type *SclTy = C->getType()->getScalarType();
253 unsigned NumSclBits = SclTy->getPrimitiveSizeInBits();
254 LLVMContext &Ctx = C->getContext();
255
256 if (NumBits > ScalarBitWidth) {
257 // Determine if the upper bits are all zero.
258 if (std::optional<APInt> Bits = extractConstantBits(C, NumBits)) {
259 if (Bits->countLeadingZeros() >= (NumBits - ScalarBitWidth)) {
260 // If the original constant was made of smaller elements, try to retain
261 // those types.
262 if (ScalarBitWidth > NumSclBits && (ScalarBitWidth % NumSclBits) == 0)
263 return rebuildConstant(Ctx, SclTy, *Bits, NumSclBits);
264
265 // Fallback to raw integer bits.
266 APInt RawBits = Bits->zextOrTrunc(ScalarBitWidth);
267 return ConstantInt::get(Ctx, RawBits);
268 }
269 }
270 }
271
272 return nullptr;
273}
274
275static Constant *rebuildExtCst(const Constant *C, bool IsSExt,
276 unsigned NumBits, unsigned NumElts,
277 unsigned SrcEltBitWidth) {
278 unsigned DstEltBitWidth = NumBits / NumElts;
279 assert((NumBits % NumElts) == 0 && (NumBits % SrcEltBitWidth) == 0 &&
280 (DstEltBitWidth % SrcEltBitWidth) == 0 &&
281 (DstEltBitWidth > SrcEltBitWidth) && "Illegal extension width");
282
283 if (std::optional<APInt> Bits = extractConstantBits(C, NumBits)) {
284 assert((Bits->getBitWidth() / DstEltBitWidth) == NumElts &&
285 (Bits->getBitWidth() % DstEltBitWidth) == 0 &&
286 "Unexpected constant extension");
287
288 // Ensure every vector element can be represented by the src bitwidth.
289 APInt TruncBits = APInt::getZero(NumElts * SrcEltBitWidth);
290 for (unsigned I = 0; I != NumElts; ++I) {
291 APInt Elt = Bits->extractBits(DstEltBitWidth, I * DstEltBitWidth);
292 if ((IsSExt && Elt.getSignificantBits() > SrcEltBitWidth) ||
293 (!IsSExt && Elt.getActiveBits() > SrcEltBitWidth))
294 return nullptr;
295 TruncBits.insertBits(Elt.trunc(SrcEltBitWidth), I * SrcEltBitWidth);
296 }
297
298 Type *Ty = C->getType();
299 return rebuildConstant(Ty->getContext(), Ty->getScalarType(), TruncBits,
300 SrcEltBitWidth);
301 }
302
303 return nullptr;
304}
305static Constant *rebuildSExtCst(const Constant *C, unsigned NumBits,
306 unsigned NumElts, unsigned SrcEltBitWidth) {
307 return rebuildExtCst(C, true, NumBits, NumElts, SrcEltBitWidth);
308}
309static Constant *rebuildZExtCst(const Constant *C, unsigned NumBits,
310 unsigned NumElts, unsigned SrcEltBitWidth) {
311 return rebuildExtCst(C, false, NumBits, NumElts, SrcEltBitWidth);
312}
313
314bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
316 MachineInstr &MI) {
317 unsigned Opc = MI.getOpcode();
318 MachineConstantPool *CP = MI.getParent()->getParent()->getConstantPool();
319 bool HasSSE41 = ST->hasSSE41();
320 bool HasAVX2 = ST->hasAVX2();
321 bool HasDQI = ST->hasDQI();
322 bool HasBWI = ST->hasBWI();
323 bool HasVLX = ST->hasVLX();
324
325 struct FixupEntry {
326 int Op;
327 int NumCstElts;
328 int MemBitWidth;
329 std::function<Constant *(const Constant *, unsigned, unsigned, unsigned)>
330 RebuildConstant;
331 };
332 auto FixupConstant = [&](ArrayRef<FixupEntry> Fixups, unsigned RegBitWidth,
333 unsigned OperandNo) {
334#ifdef EXPENSIVE_CHECKS
335 assert(llvm::is_sorted(Fixups,
336 [](const FixupEntry &A, const FixupEntry &B) {
337 return (A.NumCstElts * A.MemBitWidth) <
338 (B.NumCstElts * B.MemBitWidth);
339 }) &&
340 "Constant fixup table not sorted in ascending constant size");
341#endif
342 assert(MI.getNumOperands() >= (OperandNo + X86::AddrNumOperands) &&
343 "Unexpected number of operands!");
344 if (auto *C = X86::getConstantFromPool(MI, OperandNo)) {
345 RegBitWidth =
346 RegBitWidth ? RegBitWidth : C->getType()->getPrimitiveSizeInBits();
347 for (const FixupEntry &Fixup : Fixups) {
348 if (Fixup.Op) {
349 // Construct a suitable constant and adjust the MI to use the new
350 // constant pool entry.
351 if (Constant *NewCst = Fixup.RebuildConstant(
352 C, RegBitWidth, Fixup.NumCstElts, Fixup.MemBitWidth)) {
353 unsigned NewCPI =
354 CP->getConstantPoolIndex(NewCst, Align(Fixup.MemBitWidth / 8));
355 MI.setDesc(TII->get(Fixup.Op));
356 MI.getOperand(OperandNo + X86::AddrDisp).setIndex(NewCPI);
357 return true;
358 }
359 }
360 }
361 }
362 return false;
363 };
364
365 // Attempt to detect a suitable vzload/broadcast/vextload from increasing
366 // constant bitwidths. Prefer vzload/broadcast/vextload for same bitwidth:
367 // - vzload shouldn't ever need a shuffle port to zero the upper elements and
368 // the fp/int domain versions are equally available so we don't introduce a
369 // domain crossing penalty.
370 // - broadcast sometimes need a shuffle port (especially for 8/16-bit
371 // variants), AVX1 only has fp domain broadcasts but AVX2+ have good fp/int
372 // domain equivalents.
373 // - vextload always needs a shuffle port and is only ever int domain.
374 switch (Opc) {
375 /* FP Loads */
376 case X86::MOVAPDrm:
377 case X86::MOVAPSrm:
378 case X86::MOVUPDrm:
379 case X86::MOVUPSrm:
380 // TODO: SSE3 MOVDDUP Handling
381 return FixupConstant({{X86::MOVSSrm, 1, 32, rebuildZeroUpperCst},
382 {X86::MOVSDrm, 1, 64, rebuildZeroUpperCst}},
383 128, 1);
384 case X86::VMOVAPDrm:
385 case X86::VMOVAPSrm:
386 case X86::VMOVUPDrm:
387 case X86::VMOVUPSrm:
388 return FixupConstant({{X86::VMOVSSrm, 1, 32, rebuildZeroUpperCst},
389 {X86::VBROADCASTSSrm, 1, 32, rebuildSplatCst},
390 {X86::VMOVSDrm, 1, 64, rebuildZeroUpperCst},
391 {X86::VMOVDDUPrm, 1, 64, rebuildSplatCst}},
392 128, 1);
393 case X86::VMOVAPDYrm:
394 case X86::VMOVAPSYrm:
395 case X86::VMOVUPDYrm:
396 case X86::VMOVUPSYrm:
397 return FixupConstant({{X86::VBROADCASTSSYrm, 1, 32, rebuildSplatCst},
398 {X86::VBROADCASTSDYrm, 1, 64, rebuildSplatCst},
399 {X86::VBROADCASTF128rm, 1, 128, rebuildSplatCst}},
400 256, 1);
401 case X86::VMOVAPDZ128rm:
402 case X86::VMOVAPSZ128rm:
403 case X86::VMOVUPDZ128rm:
404 case X86::VMOVUPSZ128rm:
405 return FixupConstant({{X86::VMOVSSZrm, 1, 32, rebuildZeroUpperCst},
406 {X86::VBROADCASTSSZ128rm, 1, 32, rebuildSplatCst},
407 {X86::VMOVSDZrm, 1, 64, rebuildZeroUpperCst},
408 {X86::VMOVDDUPZ128rm, 1, 64, rebuildSplatCst}},
409 128, 1);
410 case X86::VMOVAPDZ256rm:
411 case X86::VMOVAPSZ256rm:
412 case X86::VMOVUPDZ256rm:
413 case X86::VMOVUPSZ256rm:
414 return FixupConstant(
415 {{X86::VBROADCASTSSZ256rm, 1, 32, rebuildSplatCst},
416 {X86::VBROADCASTSDZ256rm, 1, 64, rebuildSplatCst},
417 {X86::VBROADCASTF32X4Z256rm, 1, 128, rebuildSplatCst}},
418 256, 1);
419 case X86::VMOVAPDZrm:
420 case X86::VMOVAPSZrm:
421 case X86::VMOVUPDZrm:
422 case X86::VMOVUPSZrm:
423 return FixupConstant({{X86::VBROADCASTSSZrm, 1, 32, rebuildSplatCst},
424 {X86::VBROADCASTSDZrm, 1, 64, rebuildSplatCst},
425 {X86::VBROADCASTF32X4rm, 1, 128, rebuildSplatCst},
426 {X86::VBROADCASTF64X4rm, 1, 256, rebuildSplatCst}},
427 512, 1);
428 /* Integer Loads */
429 case X86::MOVDQArm:
430 case X86::MOVDQUrm: {
431 FixupEntry Fixups[] = {
432 {HasSSE41 ? X86::PMOVSXBQrm : 0, 2, 8, rebuildSExtCst},
433 {HasSSE41 ? X86::PMOVZXBQrm : 0, 2, 8, rebuildZExtCst},
434 {X86::MOVDI2PDIrm, 1, 32, rebuildZeroUpperCst},
435 {HasSSE41 ? X86::PMOVSXBDrm : 0, 4, 8, rebuildSExtCst},
436 {HasSSE41 ? X86::PMOVZXBDrm : 0, 4, 8, rebuildZExtCst},
437 {HasSSE41 ? X86::PMOVSXWQrm : 0, 2, 16, rebuildSExtCst},
438 {HasSSE41 ? X86::PMOVZXWQrm : 0, 2, 16, rebuildZExtCst},
439 {X86::MOVQI2PQIrm, 1, 64, rebuildZeroUpperCst},
440 {HasSSE41 ? X86::PMOVSXBWrm : 0, 8, 8, rebuildSExtCst},
441 {HasSSE41 ? X86::PMOVZXBWrm : 0, 8, 8, rebuildZExtCst},
442 {HasSSE41 ? X86::PMOVSXWDrm : 0, 4, 16, rebuildSExtCst},
443 {HasSSE41 ? X86::PMOVZXWDrm : 0, 4, 16, rebuildZExtCst},
444 {HasSSE41 ? X86::PMOVSXDQrm : 0, 2, 32, rebuildSExtCst},
445 {HasSSE41 ? X86::PMOVZXDQrm : 0, 2, 32, rebuildZExtCst}};
446 return FixupConstant(Fixups, 128, 1);
447 }
448 case X86::VMOVDQArm:
449 case X86::VMOVDQUrm: {
450 FixupEntry Fixups[] = {
451 {HasAVX2 ? X86::VPBROADCASTBrm : 0, 1, 8, rebuildSplatCst},
452 {HasAVX2 ? X86::VPBROADCASTWrm : 0, 1, 16, rebuildSplatCst},
453 {X86::VPMOVSXBQrm, 2, 8, rebuildSExtCst},
454 {X86::VPMOVZXBQrm, 2, 8, rebuildZExtCst},
455 {X86::VMOVDI2PDIrm, 1, 32, rebuildZeroUpperCst},
456 {HasAVX2 ? X86::VPBROADCASTDrm : X86::VBROADCASTSSrm, 1, 32,
458 {X86::VPMOVSXBDrm, 4, 8, rebuildSExtCst},
459 {X86::VPMOVZXBDrm, 4, 8, rebuildZExtCst},
460 {X86::VPMOVSXWQrm, 2, 16, rebuildSExtCst},
461 {X86::VPMOVZXWQrm, 2, 16, rebuildZExtCst},
462 {X86::VMOVQI2PQIrm, 1, 64, rebuildZeroUpperCst},
463 {HasAVX2 ? X86::VPBROADCASTQrm : X86::VMOVDDUPrm, 1, 64,
465 {X86::VPMOVSXBWrm, 8, 8, rebuildSExtCst},
466 {X86::VPMOVZXBWrm, 8, 8, rebuildZExtCst},
467 {X86::VPMOVSXWDrm, 4, 16, rebuildSExtCst},
468 {X86::VPMOVZXWDrm, 4, 16, rebuildZExtCst},
469 {X86::VPMOVSXDQrm, 2, 32, rebuildSExtCst},
470 {X86::VPMOVZXDQrm, 2, 32, rebuildZExtCst}};
471 return FixupConstant(Fixups, 128, 1);
472 }
473 case X86::VMOVDQAYrm:
474 case X86::VMOVDQUYrm: {
475 FixupEntry Fixups[] = {
476 {HasAVX2 ? X86::VPBROADCASTBYrm : 0, 1, 8, rebuildSplatCst},
477 {HasAVX2 ? X86::VPBROADCASTWYrm : 0, 1, 16, rebuildSplatCst},
478 {HasAVX2 ? X86::VPBROADCASTDYrm : X86::VBROADCASTSSYrm, 1, 32,
480 {HasAVX2 ? X86::VPMOVSXBQYrm : 0, 4, 8, rebuildSExtCst},
481 {HasAVX2 ? X86::VPMOVZXBQYrm : 0, 4, 8, rebuildZExtCst},
482 {HasAVX2 ? X86::VPBROADCASTQYrm : X86::VBROADCASTSDYrm, 1, 64,
484 {HasAVX2 ? X86::VPMOVSXBDYrm : 0, 8, 8, rebuildSExtCst},
485 {HasAVX2 ? X86::VPMOVZXBDYrm : 0, 8, 8, rebuildZExtCst},
486 {HasAVX2 ? X86::VPMOVSXWQYrm : 0, 4, 16, rebuildSExtCst},
487 {HasAVX2 ? X86::VPMOVZXWQYrm : 0, 4, 16, rebuildZExtCst},
488 {HasAVX2 ? X86::VBROADCASTI128rm : X86::VBROADCASTF128rm, 1, 128,
490 {HasAVX2 ? X86::VPMOVSXBWYrm : 0, 16, 8, rebuildSExtCst},
491 {HasAVX2 ? X86::VPMOVZXBWYrm : 0, 16, 8, rebuildZExtCst},
492 {HasAVX2 ? X86::VPMOVSXWDYrm : 0, 8, 16, rebuildSExtCst},
493 {HasAVX2 ? X86::VPMOVZXWDYrm : 0, 8, 16, rebuildZExtCst},
494 {HasAVX2 ? X86::VPMOVSXDQYrm : 0, 4, 32, rebuildSExtCst},
495 {HasAVX2 ? X86::VPMOVZXDQYrm : 0, 4, 32, rebuildZExtCst}};
496 return FixupConstant(Fixups, 256, 1);
497 }
498 case X86::VMOVDQA32Z128rm:
499 case X86::VMOVDQA64Z128rm:
500 case X86::VMOVDQU32Z128rm:
501 case X86::VMOVDQU64Z128rm: {
502 FixupEntry Fixups[] = {
503 {HasBWI ? X86::VPBROADCASTBZ128rm : 0, 1, 8, rebuildSplatCst},
504 {HasBWI ? X86::VPBROADCASTWZ128rm : 0, 1, 16, rebuildSplatCst},
505 {X86::VPMOVSXBQZ128rm, 2, 8, rebuildSExtCst},
506 {X86::VPMOVZXBQZ128rm, 2, 8, rebuildZExtCst},
507 {X86::VMOVDI2PDIZrm, 1, 32, rebuildZeroUpperCst},
508 {X86::VPBROADCASTDZ128rm, 1, 32, rebuildSplatCst},
509 {X86::VPMOVSXBDZ128rm, 4, 8, rebuildSExtCst},
510 {X86::VPMOVZXBDZ128rm, 4, 8, rebuildZExtCst},
511 {X86::VPMOVSXWQZ128rm, 2, 16, rebuildSExtCst},
512 {X86::VPMOVZXWQZ128rm, 2, 16, rebuildZExtCst},
513 {X86::VMOVQI2PQIZrm, 1, 64, rebuildZeroUpperCst},
514 {X86::VPBROADCASTQZ128rm, 1, 64, rebuildSplatCst},
515 {HasBWI ? X86::VPMOVSXBWZ128rm : 0, 8, 8, rebuildSExtCst},
516 {HasBWI ? X86::VPMOVZXBWZ128rm : 0, 8, 8, rebuildZExtCst},
517 {X86::VPMOVSXWDZ128rm, 4, 16, rebuildSExtCst},
518 {X86::VPMOVZXWDZ128rm, 4, 16, rebuildZExtCst},
519 {X86::VPMOVSXDQZ128rm, 2, 32, rebuildSExtCst},
520 {X86::VPMOVZXDQZ128rm, 2, 32, rebuildZExtCst}};
521 return FixupConstant(Fixups, 128, 1);
522 }
523 case X86::VMOVDQA32Z256rm:
524 case X86::VMOVDQA64Z256rm:
525 case X86::VMOVDQU32Z256rm:
526 case X86::VMOVDQU64Z256rm: {
527 FixupEntry Fixups[] = {
528 {HasBWI ? X86::VPBROADCASTBZ256rm : 0, 1, 8, rebuildSplatCst},
529 {HasBWI ? X86::VPBROADCASTWZ256rm : 0, 1, 16, rebuildSplatCst},
530 {X86::VPBROADCASTDZ256rm, 1, 32, rebuildSplatCst},
531 {X86::VPMOVSXBQZ256rm, 4, 8, rebuildSExtCst},
532 {X86::VPMOVZXBQZ256rm, 4, 8, rebuildZExtCst},
533 {X86::VPBROADCASTQZ256rm, 1, 64, rebuildSplatCst},
534 {X86::VPMOVSXBDZ256rm, 8, 8, rebuildSExtCst},
535 {X86::VPMOVZXBDZ256rm, 8, 8, rebuildZExtCst},
536 {X86::VPMOVSXWQZ256rm, 4, 16, rebuildSExtCst},
537 {X86::VPMOVZXWQZ256rm, 4, 16, rebuildZExtCst},
538 {X86::VBROADCASTI32X4Z256rm, 1, 128, rebuildSplatCst},
539 {HasBWI ? X86::VPMOVSXBWZ256rm : 0, 16, 8, rebuildSExtCst},
540 {HasBWI ? X86::VPMOVZXBWZ256rm : 0, 16, 8, rebuildZExtCst},
541 {X86::VPMOVSXWDZ256rm, 8, 16, rebuildSExtCst},
542 {X86::VPMOVZXWDZ256rm, 8, 16, rebuildZExtCst},
543 {X86::VPMOVSXDQZ256rm, 4, 32, rebuildSExtCst},
544 {X86::VPMOVZXDQZ256rm, 4, 32, rebuildZExtCst}};
545 return FixupConstant(Fixups, 256, 1);
546 }
547 case X86::VMOVDQA32Zrm:
548 case X86::VMOVDQA64Zrm:
549 case X86::VMOVDQU32Zrm:
550 case X86::VMOVDQU64Zrm: {
551 FixupEntry Fixups[] = {
552 {HasBWI ? X86::VPBROADCASTBZrm : 0, 1, 8, rebuildSplatCst},
553 {HasBWI ? X86::VPBROADCASTWZrm : 0, 1, 16, rebuildSplatCst},
554 {X86::VPBROADCASTDZrm, 1, 32, rebuildSplatCst},
555 {X86::VPBROADCASTQZrm, 1, 64, rebuildSplatCst},
556 {X86::VPMOVSXBQZrm, 8, 8, rebuildSExtCst},
557 {X86::VPMOVZXBQZrm, 8, 8, rebuildZExtCst},
558 {X86::VBROADCASTI32X4rm, 1, 128, rebuildSplatCst},
559 {X86::VPMOVSXBDZrm, 16, 8, rebuildSExtCst},
560 {X86::VPMOVZXBDZrm, 16, 8, rebuildZExtCst},
561 {X86::VPMOVSXWQZrm, 8, 16, rebuildSExtCst},
562 {X86::VPMOVZXWQZrm, 8, 16, rebuildZExtCst},
563 {X86::VBROADCASTI64X4rm, 1, 256, rebuildSplatCst},
564 {HasBWI ? X86::VPMOVSXBWZrm : 0, 32, 8, rebuildSExtCst},
565 {HasBWI ? X86::VPMOVZXBWZrm : 0, 32, 8, rebuildZExtCst},
566 {X86::VPMOVSXWDZrm, 16, 16, rebuildSExtCst},
567 {X86::VPMOVZXWDZrm, 16, 16, rebuildZExtCst},
568 {X86::VPMOVSXDQZrm, 8, 32, rebuildSExtCst},
569 {X86::VPMOVZXDQZrm, 8, 32, rebuildZExtCst}};
570 return FixupConstant(Fixups, 512, 1);
571 }
572 }
573
574 auto ConvertToBroadcastAVX512 = [&](unsigned OpSrc32, unsigned OpSrc64) {
575 unsigned OpBcst32 = 0, OpBcst64 = 0;
576 unsigned OpNoBcst32 = 0, OpNoBcst64 = 0;
577 if (OpSrc32) {
578 if (const X86FoldTableEntry *Mem2Bcst =
580 OpBcst32 = Mem2Bcst->DstOp;
581 OpNoBcst32 = Mem2Bcst->Flags & TB_INDEX_MASK;
582 }
583 }
584 if (OpSrc64) {
585 if (const X86FoldTableEntry *Mem2Bcst =
587 OpBcst64 = Mem2Bcst->DstOp;
588 OpNoBcst64 = Mem2Bcst->Flags & TB_INDEX_MASK;
589 }
590 }
591 assert(((OpBcst32 == 0) || (OpBcst64 == 0) || (OpNoBcst32 == OpNoBcst64)) &&
592 "OperandNo mismatch");
593
594 if (OpBcst32 || OpBcst64) {
595 unsigned OpNo = OpBcst32 == 0 ? OpNoBcst64 : OpNoBcst32;
596 FixupEntry Fixups[] = {{(int)OpBcst32, 32, 32, rebuildSplatCst},
597 {(int)OpBcst64, 64, 64, rebuildSplatCst}};
598 // TODO: Add support for RegBitWidth, but currently rebuildSplatCst
599 // doesn't require it (defaults to Constant::getPrimitiveSizeInBits).
600 return FixupConstant(Fixups, 0, OpNo);
601 }
602 return false;
603 };
604
605 // Attempt to find a AVX512 mapping from a full width memory-fold instruction
606 // to a broadcast-fold instruction variant.
607 if ((MI.getDesc().TSFlags & X86II::EncodingMask) == X86II::EVEX)
608 return ConvertToBroadcastAVX512(Opc, Opc);
609
610 // Reverse the X86InstrInfo::setExecutionDomainCustom EVEX->VEX logic
611 // conversion to see if we can convert to a broadcasted (integer) logic op.
612 if (HasVLX && !HasDQI) {
613 unsigned OpSrc32 = 0, OpSrc64 = 0;
614 switch (Opc) {
615 case X86::VANDPDrm:
616 case X86::VANDPSrm:
617 case X86::VPANDrm:
618 OpSrc32 = X86 ::VPANDDZ128rm;
619 OpSrc64 = X86 ::VPANDQZ128rm;
620 break;
621 case X86::VANDPDYrm:
622 case X86::VANDPSYrm:
623 case X86::VPANDYrm:
624 OpSrc32 = X86 ::VPANDDZ256rm;
625 OpSrc64 = X86 ::VPANDQZ256rm;
626 break;
627 case X86::VANDNPDrm:
628 case X86::VANDNPSrm:
629 case X86::VPANDNrm:
630 OpSrc32 = X86 ::VPANDNDZ128rm;
631 OpSrc64 = X86 ::VPANDNQZ128rm;
632 break;
633 case X86::VANDNPDYrm:
634 case X86::VANDNPSYrm:
635 case X86::VPANDNYrm:
636 OpSrc32 = X86 ::VPANDNDZ256rm;
637 OpSrc64 = X86 ::VPANDNQZ256rm;
638 break;
639 case X86::VORPDrm:
640 case X86::VORPSrm:
641 case X86::VPORrm:
642 OpSrc32 = X86 ::VPORDZ128rm;
643 OpSrc64 = X86 ::VPORQZ128rm;
644 break;
645 case X86::VORPDYrm:
646 case X86::VORPSYrm:
647 case X86::VPORYrm:
648 OpSrc32 = X86 ::VPORDZ256rm;
649 OpSrc64 = X86 ::VPORQZ256rm;
650 break;
651 case X86::VXORPDrm:
652 case X86::VXORPSrm:
653 case X86::VPXORrm:
654 OpSrc32 = X86 ::VPXORDZ128rm;
655 OpSrc64 = X86 ::VPXORQZ128rm;
656 break;
657 case X86::VXORPDYrm:
658 case X86::VXORPSYrm:
659 case X86::VPXORYrm:
660 OpSrc32 = X86 ::VPXORDZ256rm;
661 OpSrc64 = X86 ::VPXORQZ256rm;
662 break;
663 }
664 if (OpSrc32 || OpSrc64)
665 return ConvertToBroadcastAVX512(OpSrc32, OpSrc64);
666 }
667
668 return false;
669}
670
671bool X86FixupVectorConstantsPass::runOnMachineFunction(MachineFunction &MF) {
672 LLVM_DEBUG(dbgs() << "Start X86FixupVectorConstants\n";);
673 bool Changed = false;
675 TII = ST->getInstrInfo();
676 SM = &ST->getSchedModel();
677
678 for (MachineBasicBlock &MBB : MF) {
679 for (MachineInstr &MI : MBB) {
680 if (processInstruction(MF, MBB, MI)) {
681 ++NumInstChanges;
682 Changed = true;
683 }
684 }
685 }
686 LLVM_DEBUG(dbgs() << "End X86FixupVectorConstants\n";);
687 return Changed;
688}
MachineBasicBlock & MBB
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
#define LLVM_DEBUG(X)
Definition: Debug.h:101
const HexagonInstrInfo * TII
IRTranslator LLVM IR MI
#define I(x, y, z)
Definition: MD5.cpp:58
This file declares the MachineConstantPool class which is an abstract constant pool to keep track of ...
PowerPC TLS Dynamic Call Fixup
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:38
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition: Statistic.h:167
static Constant * rebuildSplatCst(const Constant *C, unsigned, unsigned, unsigned SplatBitWidth)
static std::optional< APInt > getSplatableConstant(const Constant *C, unsigned SplatBitWidth)
static Constant * rebuildZExtCst(const Constant *C, unsigned NumBits, unsigned NumElts, unsigned SrcEltBitWidth)
static std::optional< APInt > extractConstantBits(const Constant *C)
static Constant * rebuildExtCst(const Constant *C, bool IsSExt, unsigned NumBits, unsigned NumElts, unsigned SrcEltBitWidth)
static Constant * rebuildZeroUpperCst(const Constant *C, unsigned NumBits, unsigned, unsigned ScalarBitWidth)
static Constant * rebuildSExtCst(const Constant *C, unsigned NumBits, unsigned NumElts, unsigned SrcEltBitWidth)
#define DEBUG_TYPE
static Constant * rebuildConstant(LLVMContext &Ctx, Type *SclTy, const APInt &Bits, unsigned NumSclBits)
Class for arbitrary precision integers.
Definition: APInt.h:76
unsigned getActiveBits() const
Compute the number of active bits in the value.
Definition: APInt.h:1457
APInt trunc(unsigned width) const
Truncate to new width.
Definition: APInt.cpp:906
static APInt getSplat(unsigned NewLen, const APInt &V)
Return a value containing V broadcasted over NewLen bits.
Definition: APInt.cpp:620
unsigned getSignificantBits() const
Get the minimum bit size for this signed APInt.
Definition: APInt.h:1476
void insertBits(const APInt &SubBits, unsigned bitPosition)
Insert the bits from a smaller APInt starting at bitPosition.
Definition: APInt.cpp:368
static APInt getZero(unsigned numBits)
Get the '0' value for the specified bit-width.
Definition: APInt.h:178
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: ArrayRef.h:41
static Constant * get(LLVMContext &Context, ArrayRef< uint8_t > Elts)
get() constructors - Return a constant with vector type with an element count and element type matchi...
Definition: Constants.cpp:2809
static Constant * getFP(Type *ElementType, ArrayRef< uint16_t > Elts)
getFP() constructors - Return a constant of vector type with a float element type taken from argument...
Definition: Constants.cpp:2846
static Constant * get(Type *Ty, uint64_t V, bool IsSigned=false)
If Ty is a vector type, return a Constant with a splat of the given value.
Definition: Constants.cpp:888
This is an important base class in LLVM.
Definition: Constant.h:41
Constant * getAggregateElement(unsigned Elt) const
For aggregates (struct/array/vector) return the constant that corresponds to the specified element if...
Definition: Constants.cpp:418
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:311
This is an important class for using LLVM in a threaded context.
Definition: LLVMContext.h:67
The MachineConstantPool class keeps track of constants referenced by a function which must be spilled...
MachineFunctionPass - This class adapts the FunctionPass interface to allow convenient creation of pa...
virtual bool runOnMachineFunction(MachineFunction &MF)=0
runOnMachineFunction - This method must be overloaded to perform the desired machine code transformat...
virtual MachineFunctionProperties getRequiredProperties() const
Properties which a MachineFunction may have at a given point in time.
MachineFunctionProperties & set(Property P)
const TargetSubtargetInfo & getSubtarget() const
getSubtarget - Return the subtarget for which this machine code is being compiled.
Representation of each machine instruction.
Definition: MachineInstr.h:68
virtual StringRef getPassName() const
getPassName - Return a nice clean name for a pass.
Definition: Pass.cpp:81
void push_back(const T &Elt)
Definition: SmallVector.h:426
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1209
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:50
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
bool isFloatTy() const
Return true if this is 'float', a 32-bit IEEE fp type.
Definition: Type.h:154
bool is16bitFPTy() const
Return true if this is a 16-bit float type.
Definition: Type.h:149
unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
LLVMContext & getContext() const
Return the LLVMContext in which this type was uniqued.
Definition: Type.h:129
bool isDoubleTy() const
Return true if this is 'double', a 64-bit IEEE fp type.
Definition: Type.h:157
TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
Type * getScalarType() const
If this is a vector type, return the element type, otherwise return 'this'.
Definition: Type.h:348
Value * getOperand(unsigned i) const
Definition: User.h:169
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition: CallingConv.h:34
@ EVEX
EVEX - Specifies that this instruction use EVEX form which provides syntax support up to 32 512-bit r...
Definition: X86BaseInfo.h:831
@ AddrNumOperands
Definition: X86BaseInfo.h:36
const Constant * getConstantFromPool(const MachineInstr &MI, unsigned OpNo)
Find any constant pool entry associated with a specific instruction operand.
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
const X86FoldTableEntry * lookupBroadcastFoldTableBySize(unsigned MemOp, unsigned BroadcastBits)
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
bool is_sorted(R &&Range, Compare C)
Wrapper function around std::is_sorted to check if elements in a range R are sorted with respect to a...
Definition: STLExtras.h:1911
DWARFExpression::Operation Op
constexpr unsigned BitWidth
Definition: BitmaskEnum.h:191
FunctionPass * createX86FixupVectorConstants()
Return a pass that reduces the size of vector constant pool loads.
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition: Alignment.h:39
Machine model for scheduling, bundling, and heuristics.
Definition: MCSchedule.h:253