LLVM  13.0.0git
AArch64PBQPRegAlloc.cpp
Go to the documentation of this file.
1 //===-- AArch64PBQPRegAlloc.cpp - AArch64 specific PBQP constraints -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // This file contains the AArch64 / Cortex-A57 specific register allocation
9 // constraints for use by the PBQP register allocator.
10 //
11 // It is essentially a transcription of what is contained in
12 // AArch64A57FPLoadBalancing, which tries to use a balanced
13 // mix of odd and even D-registers when performing a critical sequence of
14 // independent, non-quadword FP/ASIMD floating-point multiply-accumulates.
15 //===----------------------------------------------------------------------===//
16 
17 #define DEBUG_TYPE "aarch64-pbqp"
18 
19 #include "AArch64PBQPRegAlloc.h"
20 #include "AArch64.h"
21 #include "AArch64RegisterInfo.h"
27 #include "llvm/Support/Debug.h"
30 
31 using namespace llvm;
32 
33 namespace {
34 
35 #ifndef NDEBUG
36 bool isFPReg(unsigned reg) {
37  return AArch64::FPR32RegClass.contains(reg) ||
38  AArch64::FPR64RegClass.contains(reg) ||
39  AArch64::FPR128RegClass.contains(reg);
40 }
41 #endif
42 
43 bool isOdd(unsigned reg) {
44  switch (reg) {
45  default:
46  llvm_unreachable("Register is not from the expected class !");
47  case AArch64::S1:
48  case AArch64::S3:
49  case AArch64::S5:
50  case AArch64::S7:
51  case AArch64::S9:
52  case AArch64::S11:
53  case AArch64::S13:
54  case AArch64::S15:
55  case AArch64::S17:
56  case AArch64::S19:
57  case AArch64::S21:
58  case AArch64::S23:
59  case AArch64::S25:
60  case AArch64::S27:
61  case AArch64::S29:
62  case AArch64::S31:
63  case AArch64::D1:
64  case AArch64::D3:
65  case AArch64::D5:
66  case AArch64::D7:
67  case AArch64::D9:
68  case AArch64::D11:
69  case AArch64::D13:
70  case AArch64::D15:
71  case AArch64::D17:
72  case AArch64::D19:
73  case AArch64::D21:
74  case AArch64::D23:
75  case AArch64::D25:
76  case AArch64::D27:
77  case AArch64::D29:
78  case AArch64::D31:
79  case AArch64::Q1:
80  case AArch64::Q3:
81  case AArch64::Q5:
82  case AArch64::Q7:
83  case AArch64::Q9:
84  case AArch64::Q11:
85  case AArch64::Q13:
86  case AArch64::Q15:
87  case AArch64::Q17:
88  case AArch64::Q19:
89  case AArch64::Q21:
90  case AArch64::Q23:
91  case AArch64::Q25:
92  case AArch64::Q27:
93  case AArch64::Q29:
94  case AArch64::Q31:
95  return true;
96  case AArch64::S0:
97  case AArch64::S2:
98  case AArch64::S4:
99  case AArch64::S6:
100  case AArch64::S8:
101  case AArch64::S10:
102  case AArch64::S12:
103  case AArch64::S14:
104  case AArch64::S16:
105  case AArch64::S18:
106  case AArch64::S20:
107  case AArch64::S22:
108  case AArch64::S24:
109  case AArch64::S26:
110  case AArch64::S28:
111  case AArch64::S30:
112  case AArch64::D0:
113  case AArch64::D2:
114  case AArch64::D4:
115  case AArch64::D6:
116  case AArch64::D8:
117  case AArch64::D10:
118  case AArch64::D12:
119  case AArch64::D14:
120  case AArch64::D16:
121  case AArch64::D18:
122  case AArch64::D20:
123  case AArch64::D22:
124  case AArch64::D24:
125  case AArch64::D26:
126  case AArch64::D28:
127  case AArch64::D30:
128  case AArch64::Q0:
129  case AArch64::Q2:
130  case AArch64::Q4:
131  case AArch64::Q6:
132  case AArch64::Q8:
133  case AArch64::Q10:
134  case AArch64::Q12:
135  case AArch64::Q14:
136  case AArch64::Q16:
137  case AArch64::Q18:
138  case AArch64::Q20:
139  case AArch64::Q22:
140  case AArch64::Q24:
141  case AArch64::Q26:
142  case AArch64::Q28:
143  case AArch64::Q30:
144  return false;
145 
146  }
147 }
148 
149 bool haveSameParity(unsigned reg1, unsigned reg2) {
150  assert(isFPReg(reg1) && "Expecting an FP register for reg1");
151  assert(isFPReg(reg2) && "Expecting an FP register for reg2");
152 
153  return isOdd(reg1) == isOdd(reg2);
154 }
155 
156 }
157 
158 bool A57ChainingConstraint::addIntraChainConstraint(PBQPRAGraph &G, unsigned Rd,
159  unsigned Ra) {
160  if (Rd == Ra)
161  return false;
162 
163  LiveIntervals &LIs = G.getMetadata().LIS;
164 
166  LLVM_DEBUG(dbgs() << "Rd is a physical reg:"
167  << Register::isPhysicalRegister(Rd) << '\n');
168  LLVM_DEBUG(dbgs() << "Ra is a physical reg:"
169  << Register::isPhysicalRegister(Ra) << '\n');
170  return false;
171  }
172 
173  PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd);
174  PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(Ra);
175 
176  const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed =
177  &G.getNodeMetadata(node1).getAllowedRegs();
178  const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRaAllowed =
179  &G.getNodeMetadata(node2).getAllowedRegs();
180 
181  PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2);
182 
183  // The edge does not exist. Create one with the appropriate interference
184  // costs.
185  if (edge == G.invalidEdgeId()) {
186  const LiveInterval &ld = LIs.getInterval(Rd);
187  const LiveInterval &la = LIs.getInterval(Ra);
188  bool livesOverlap = ld.overlaps(la);
189 
190  PBQPRAGraph::RawMatrix costs(vRdAllowed->size() + 1,
191  vRaAllowed->size() + 1, 0);
192  for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
193  unsigned pRd = (*vRdAllowed)[i];
194  for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
195  unsigned pRa = (*vRaAllowed)[j];
196  if (livesOverlap && TRI->regsOverlap(pRd, pRa))
197  costs[i + 1][j + 1] = std::numeric_limits<PBQP::PBQPNum>::infinity();
198  else
199  costs[i + 1][j + 1] = haveSameParity(pRd, pRa) ? 0.0 : 1.0;
200  }
201  }
202  G.addEdge(node1, node2, std::move(costs));
203  return true;
204  }
205 
206  if (G.getEdgeNode1Id(edge) == node2) {
207  std::swap(node1, node2);
208  std::swap(vRdAllowed, vRaAllowed);
209  }
210 
211  // Enforce minCost(sameParity(RaClass)) > maxCost(otherParity(RdClass))
212  PBQPRAGraph::RawMatrix costs(G.getEdgeCosts(edge));
213  for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
214  unsigned pRd = (*vRdAllowed)[i];
215 
216  // Get the maximum cost (excluding unallocatable reg) for same parity
217  // registers
219  for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
220  unsigned pRa = (*vRaAllowed)[j];
221  if (haveSameParity(pRd, pRa))
222  if (costs[i + 1][j + 1] !=
223  std::numeric_limits<PBQP::PBQPNum>::infinity() &&
224  costs[i + 1][j + 1] > sameParityMax)
225  sameParityMax = costs[i + 1][j + 1];
226  }
227 
228  // Ensure all registers with a different parity have a higher cost
229  // than sameParityMax
230  for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
231  unsigned pRa = (*vRaAllowed)[j];
232  if (!haveSameParity(pRd, pRa))
233  if (sameParityMax > costs[i + 1][j + 1])
234  costs[i + 1][j + 1] = sameParityMax + 1.0;
235  }
236  }
237  G.updateEdgeCosts(edge, std::move(costs));
238 
239  return true;
240 }
241 
242 void A57ChainingConstraint::addInterChainConstraint(PBQPRAGraph &G, unsigned Rd,
243  unsigned Ra) {
244  LiveIntervals &LIs = G.getMetadata().LIS;
245 
246  // Do some Chain management
247  if (Chains.count(Ra)) {
248  if (Rd != Ra) {
249  LLVM_DEBUG(dbgs() << "Moving acc chain from " << printReg(Ra, TRI)
250  << " to " << printReg(Rd, TRI) << '\n';);
251  Chains.remove(Ra);
252  Chains.insert(Rd);
253  }
254  } else {
255  LLVM_DEBUG(dbgs() << "Creating new acc chain for " << printReg(Rd, TRI)
256  << '\n';);
257  Chains.insert(Rd);
258  }
259 
260  PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd);
261 
262  const LiveInterval &ld = LIs.getInterval(Rd);
263  for (auto r : Chains) {
264  // Skip self
265  if (r == Rd)
266  continue;
267 
268  const LiveInterval &lr = LIs.getInterval(r);
269  if (ld.overlaps(lr)) {
270  const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed =
271  &G.getNodeMetadata(node1).getAllowedRegs();
272 
273  PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(r);
274  const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRrAllowed =
275  &G.getNodeMetadata(node2).getAllowedRegs();
276 
277  PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2);
278  assert(edge != G.invalidEdgeId() &&
279  "PBQP error ! The edge should exist !");
280 
281  LLVM_DEBUG(dbgs() << "Refining constraint !\n";);
282 
283  if (G.getEdgeNode1Id(edge) == node2) {
284  std::swap(node1, node2);
285  std::swap(vRdAllowed, vRrAllowed);
286  }
287 
288  // Enforce that cost is higher with all other Chains of the same parity
289  PBQP::Matrix costs(G.getEdgeCosts(edge));
290  for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
291  unsigned pRd = (*vRdAllowed)[i];
292 
293  // Get the maximum cost (excluding unallocatable reg) for all other
294  // parity registers
296  for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) {
297  unsigned pRa = (*vRrAllowed)[j];
298  if (!haveSameParity(pRd, pRa))
299  if (costs[i + 1][j + 1] !=
300  std::numeric_limits<PBQP::PBQPNum>::infinity() &&
301  costs[i + 1][j + 1] > sameParityMax)
302  sameParityMax = costs[i + 1][j + 1];
303  }
304 
305  // Ensure all registers with same parity have a higher cost
306  // than sameParityMax
307  for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) {
308  unsigned pRa = (*vRrAllowed)[j];
309  if (haveSameParity(pRd, pRa))
310  if (sameParityMax > costs[i + 1][j + 1])
311  costs[i + 1][j + 1] = sameParityMax + 1.0;
312  }
313  }
314  G.updateEdgeCosts(edge, std::move(costs));
315  }
316  }
317 }
318 
319 static bool regJustKilledBefore(const LiveIntervals &LIs, unsigned reg,
320  const MachineInstr &MI) {
321  const LiveInterval &LI = LIs.getInterval(reg);
323  return LI.expiredAt(SI);
324 }
325 
327  const MachineFunction &MF = G.getMetadata().MF;
328  LiveIntervals &LIs = G.getMetadata().LIS;
329 
330  TRI = MF.getSubtarget().getRegisterInfo();
331  LLVM_DEBUG(MF.dump());
332 
333  for (const auto &MBB: MF) {
334  Chains.clear(); // FIXME: really needed ? Could not work at MF level ?
335 
336  for (const auto &MI: MBB) {
337 
338  // Forget Chains which have expired
339  for (auto r : Chains) {
341  if(regJustKilledBefore(LIs, r, MI)) {
342  LLVM_DEBUG(dbgs() << "Killing chain " << printReg(r, TRI) << " at ";
343  MI.print(dbgs()););
344  toDel.push_back(r);
345  }
346 
347  while (!toDel.empty()) {
348  Chains.remove(toDel.back());
349  toDel.pop_back();
350  }
351  }
352 
353  switch (MI.getOpcode()) {
354  case AArch64::FMSUBSrrr:
355  case AArch64::FMADDSrrr:
356  case AArch64::FNMSUBSrrr:
357  case AArch64::FNMADDSrrr:
358  case AArch64::FMSUBDrrr:
359  case AArch64::FMADDDrrr:
360  case AArch64::FNMSUBDrrr:
361  case AArch64::FNMADDDrrr: {
362  Register Rd = MI.getOperand(0).getReg();
363  Register Ra = MI.getOperand(3).getReg();
364 
365  if (addIntraChainConstraint(G, Rd, Ra))
366  addInterChainConstraint(G, Rd, Ra);
367  break;
368  }
369 
370  case AArch64::FMLAv2f32:
371  case AArch64::FMLSv2f32: {
372  Register Rd = MI.getOperand(0).getReg();
373  addInterChainConstraint(G, Rd, Rd);
374  break;
375  }
376 
377  default:
378  break;
379  }
380  }
381  }
382 }
i
i
Definition: README.txt:29
AArch64RegisterInfo.h
MI
IRTranslator LLVM IR MI
Definition: IRTranslator.cpp:100
llvm
Definition: AllocatorList.h:23
AArch64.h
llvm::SmallVector< unsigned, 8 >
ErrorHandling.h
llvm::PBQP::GraphBase::EdgeId
unsigned EdgeId
Definition: Graph.h:29
MachineBasicBlock.h
llvm::TargetSubtargetInfo::getRegisterInfo
virtual const TargetRegisterInfo * getRegisterInfo() const
getRegisterInfo - If register information is available, return it.
Definition: TargetSubtargetInfo.h:124
RegAllocPBQP.h
llvm::PBQP::PBQPNum
float PBQPNum
Definition: Math.h:22
LLVM_DEBUG
#define LLVM_DEBUG(X)
Definition: Debug.h:122
llvm::LiveIntervals::getInstructionIndex
SlotIndex getInstructionIndex(const MachineInstr &Instr) const
Returns the base index of the given instruction.
Definition: LiveIntervals.h:226
MachineRegisterInfo.h
AArch64PBQPRegAlloc.h
llvm::dbgs
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:132
llvm::SetVector::remove
bool remove(const value_type &X)
Remove an item from the set vector.
Definition: SetVector.h:157
llvm::Register::isPhysicalRegister
static bool isPhysicalRegister(unsigned Reg)
Return true if the specified register number is in the physical register namespace.
Definition: Register.h:65
SI
@ SI
Definition: SIInstrInfo.cpp:7463
llvm::TargetRegisterInfo::regsOverlap
bool regsOverlap(Register regA, Register regB) const
Returns true if the two registers are equal or alias each other.
Definition: TargetRegisterInfo.h:416
lr
Common register allocation spilling lr str lr
Definition: README.txt:6
llvm::PBQP::RegAlloc::PBQPRAGraph
Definition: RegAllocPBQP.h:502
llvm::LiveInterval
LiveInterval - This class represents the liveness of a register, or stack slot.
Definition: LiveInterval.h:680
llvm::SlotIndex
SlotIndex - An opaque wrapper around machine indexes.
Definition: SlotIndexes.h:83
llvm::Pass::print
virtual void print(raw_ostream &OS, const Module *M) const
print - Print out the internal state of the pass.
Definition: Pass.cpp:125
G
const DataFlowGraph & G
Definition: RDFGraph.cpp:202
llvm::MachineFunction::getSubtarget
const TargetSubtargetInfo & getSubtarget() const
getSubtarget - Return the subtarget for which this machine code is being compiled.
Definition: MachineFunction.h:558
llvm::MachineInstr
Representation of each machine instruction.
Definition: MachineInstr.h:64
LiveIntervals.h
move
compiles ldr LCPI1_0 ldr ldr mov lsr tst moveq r1 ldr LCPI1_1 and r0 bx lr It would be better to do something like to fold the shift into the conditional move
Definition: README.txt:546
llvm::MachineFunction::dump
void dump() const
dump - Print the current MachineFunction to cerr, useful for debugger use.
Definition: MachineFunction.cpp:517
llvm::LiveRange::expiredAt
bool expiredAt(SlotIndex index) const
Definition: LiveInterval.h:389
llvm::LiveRange::overlaps
bool overlaps(const LiveRange &other) const
overlaps - Return true if the intersection of the two live ranges is not empty.
Definition: LiveInterval.h:440
llvm::PBQP::Graph< RegAllocSolverImpl >::RawMatrix
typename RegAllocSolverImpl ::RawMatrix RawMatrix
Definition: Graph.h:52
llvm::PBQP::GraphBase::NodeId
unsigned NodeId
Definition: Graph.h:28
assert
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
std::swap
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition: BitVector.h:840
llvm::A57ChainingConstraint::apply
void apply(PBQPRAGraph &G) override
Definition: AArch64PBQPRegAlloc.cpp:326
llvm::LiveIntervals::getInterval
LiveInterval & getInterval(Register Reg)
Definition: LiveIntervals.h:114
llvm::MachineFunction
Definition: MachineFunction.h:230
llvm::SetVector::insert
bool insert(const value_type &X)
Insert a new element into the SetVector.
Definition: SetVector.h:141
llvm::min
Expected< ExpressionValue > min(const ExpressionValue &Lhs, const ExpressionValue &Rhs)
Definition: FileCheck.cpp:357
llvm_unreachable
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
Definition: ErrorHandling.h:136
llvm::Register
Wrapper class representing virtual and physical registers.
Definition: Register.h:19
MBB
MachineBasicBlock & MBB
Definition: AArch64SLSHardening.cpp:74
j
return j(j<< 16)
llvm::SetVector::count
size_type count(const key_type &key) const
Count the number of elements of a given key in the SetVector.
Definition: SetVector.h:215
llvm::LiveIntervals
Definition: LiveIntervals.h:54
raw_ostream.h
MachineFunction.h
llvm::printReg
Printable printReg(Register Reg, const TargetRegisterInfo *TRI=nullptr, unsigned SubIdx=0, const MachineRegisterInfo *MRI=nullptr)
Prints virtual and physical registers with or without a TRI instance.
Definition: TargetRegisterInfo.cpp:110
Debug.h
regJustKilledBefore
static bool regJustKilledBefore(const LiveIntervals &LIs, unsigned reg, const MachineInstr &MI)
Definition: AArch64PBQPRegAlloc.cpp:319
llvm::PBQP::Matrix
PBQP Matrix class.
Definition: Math.h:121