LLVM 23.0.0git
MachineSMEABIPass.cpp
Go to the documentation of this file.
1//===- MachineSMEABIPass.cpp ----------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass implements the SME ABI requirements for ZA state. This includes
10// implementing the lazy (and agnostic) ZA state save schemes around calls.
11//
12//===----------------------------------------------------------------------===//
13//
14// This pass works by collecting instructions that require ZA to be in a
15// specific state (e.g., "ACTIVE" or "SAVED") and inserting the necessary state
16// transitions to ensure ZA is in the required state before instructions. State
17// transitions represent actions such as setting up or restoring a lazy save.
18// Certain points within a function may also have predefined states independent
19// of any instructions, for example, a "shared_za" function is always entered
20// and exited in the "ACTIVE" state.
21//
22// To handle ZA state across control flow, we make use of edge bundling. This
23// assigns each block an "incoming" and "outgoing" edge bundle (representing
24// incoming and outgoing edges). Initially, these are unique to each block;
25// then, in the process of forming bundles, the outgoing bundle of a block is
26// joined with the incoming bundle of all successors. The result is that each
27// bundle can be assigned a single ZA state, which ensures the state required by
28// all a blocks' successors is the same, and that each basic block will always
29// be entered with the same ZA state. This eliminates the need for splitting
30// edges to insert state transitions or "phi" nodes for ZA states.
31//
32// See below for a simple example of edge bundling.
33//
34// The following shows a conditionally executed basic block (BB1):
35//
36// if (cond)
37// BB1
38// BB2
39//
40// Initial Bundles Joined Bundles
41//
42// ┌──0──┐ ┌──0──┐
43// │ BB0 │ │ BB0 │
44// └──1──┘ └──1──┘
45// ├───────┐ ├───────┐
46// ▼ │ ▼ │
47// ┌──2──┐ │ ─────► ┌──1──┐ │
48// │ BB1 │ ▼ │ BB1 │ ▼
49// └──3──┘ ┌──4──┐ └──1──┘ ┌──1──┐
50// └───►4 BB2 │ └───►1 BB2 │
51// └──5──┘ └──2──┘
52//
53// On the left are the initial per-block bundles, and on the right are the
54// joined bundles (which are the result of the EdgeBundles analysis).
55
56#include "AArch64InstrInfo.h"
58#include "AArch64Subtarget.h"
69
70using namespace llvm;
71
72#define DEBUG_TYPE "aarch64-machine-sme-abi"
73
74namespace {
75
76// Note: For agnostic ZA, we assume the function is always entered/exited in the
77// "ACTIVE" state -- this _may_ not be the case (since OFF is also a
78// possibility, but for the purpose of placing ZA saves/restores, that does not
79// matter).
80enum ZAState : uint8_t {
81 // Any/unknown state (not valid)
82 ANY = 0,
83
84 // ZA is in use and active (i.e. within the accumulator)
85 ACTIVE,
86
87 // ZA is active, but ZT0 has been saved.
88 // This handles the edge case of sharedZA && !sharesZT0.
89 ACTIVE_ZT0_SAVED,
90
91 // A ZA save has been set up or committed (i.e. ZA is dormant or off)
92 // If the function uses ZT0 it must also be saved.
93 LOCAL_SAVED,
94
95 // ZA has been committed to the lazy save buffer of the current function.
96 // If the function uses ZT0 it must also be saved.
97 // ZA is off.
98 LOCAL_COMMITTED,
99
100 // The ZA/ZT0 state on entry to the function.
101 ENTRY,
102
103 // ZA is off.
104 OFF,
105
106 // The number of ZA states (not a valid state)
107 NUM_ZA_STATE
108};
109
110/// A bitmask enum to record live physical registers that the "emit*" routines
111/// may need to preserve. Note: This only tracks registers we may clobber.
112enum LiveRegs : uint8_t {
113 None = 0,
114 NZCV = 1 << 0,
115 W0 = 1 << 1,
116 W0_HI = 1 << 2,
117 X0 = W0 | W0_HI,
118 LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ W0_HI)
119};
120
121/// Holds the virtual registers live physical registers have been saved to.
122struct PhysRegSave {
123 LiveRegs PhysLiveRegs;
124 Register StatusFlags = AArch64::NoRegister;
125 Register X0Save = AArch64::NoRegister;
126};
127
128/// Contains the needed ZA state (and live registers) at an instruction. That is
129/// the state ZA must be in _before_ "InsertPt".
130struct InstInfo {
131 ZAState NeededState{ZAState::ANY};
133 LiveRegs PhysLiveRegs = LiveRegs::None;
134};
135
136/// Contains the needed ZA state for each instruction in a block. Instructions
137/// that do not require a ZA state are not recorded.
138struct BlockInfo {
140 ZAState FixedEntryState{ZAState::ANY};
141 ZAState DesiredIncomingState{ZAState::ANY};
142 ZAState DesiredOutgoingState{ZAState::ANY};
143 LiveRegs PhysLiveRegsAtEntry = LiveRegs::None;
144 LiveRegs PhysLiveRegsAtExit = LiveRegs::None;
145};
146
147/// Contains the needed ZA state information for all blocks within a function.
148struct FunctionInfo {
150 std::optional<MachineBasicBlock::iterator> AfterSMEProloguePt;
151 LiveRegs PhysLiveRegsAfterSMEPrologue = LiveRegs::None;
152};
153
154/// State/helpers that is only needed when emitting code to handle
155/// saving/restoring ZA.
156class EmitContext {
157public:
158 EmitContext() = default;
159
160 /// Get or create a TPIDR2 block in \p MF.
161 int getTPIDR2Block(MachineFunction &MF) {
162 if (TPIDR2BlockFI)
163 return *TPIDR2BlockFI;
164 MachineFrameInfo &MFI = MF.getFrameInfo();
165 TPIDR2BlockFI = MFI.CreateStackObject(16, Align(16), false);
166 return *TPIDR2BlockFI;
167 }
168
169 /// Get or create agnostic ZA buffer pointer in \p MF.
170 Register getAgnosticZABufferPtr(MachineFunction &MF) {
171 if (AgnosticZABufferPtr != AArch64::NoRegister)
172 return AgnosticZABufferPtr;
173 Register BufferPtr =
174 MF.getInfo<AArch64FunctionInfo>()->getEarlyAllocSMESaveBuffer();
175 AgnosticZABufferPtr =
176 BufferPtr != AArch64::NoRegister
177 ? BufferPtr
178 : MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
179 return AgnosticZABufferPtr;
180 }
181
182 int getZT0SaveSlot(MachineFunction &MF) {
183 if (ZT0SaveFI)
184 return *ZT0SaveFI;
185 MachineFrameInfo &MFI = MF.getFrameInfo();
186 ZT0SaveFI = MFI.CreateSpillStackObject(64, Align(16));
187 return *ZT0SaveFI;
188 }
189
190 /// Returns true if the function must allocate a ZA save buffer on entry. This
191 /// will be the case if, at any point in the function, a ZA save was emitted.
192 bool needsSaveBuffer() const {
193 assert(!(TPIDR2BlockFI && AgnosticZABufferPtr) &&
194 "Cannot have both a TPIDR2 block and agnostic ZA buffer");
195 return TPIDR2BlockFI || AgnosticZABufferPtr != AArch64::NoRegister;
196 }
197
198private:
199 std::optional<int> ZT0SaveFI;
200 std::optional<int> TPIDR2BlockFI;
201 Register AgnosticZABufferPtr = AArch64::NoRegister;
202};
203
204StringRef getZAStateString(ZAState State) {
205#define MAKE_CASE(V) \
206 case V: \
207 return #V;
208 switch (State) {
209 MAKE_CASE(ZAState::ANY)
210 MAKE_CASE(ZAState::ACTIVE)
211 MAKE_CASE(ZAState::ACTIVE_ZT0_SAVED)
212 MAKE_CASE(ZAState::LOCAL_SAVED)
213 MAKE_CASE(ZAState::LOCAL_COMMITTED)
214 MAKE_CASE(ZAState::ENTRY)
215 MAKE_CASE(ZAState::OFF)
216 default:
217 llvm_unreachable("Unexpected ZAState");
218 }
219#undef MAKE_CASE
220}
221
222static bool isZAorZTRegOp(const TargetRegisterInfo &TRI,
223 const MachineOperand &MO) {
224 if (!MO.isReg() || !MO.getReg().isPhysical())
225 return false;
226 return any_of(TRI.subregs_inclusive(MO.getReg()), [](const MCPhysReg &SR) {
227 return AArch64::MPR128RegClass.contains(SR) ||
228 AArch64::ZTRRegClass.contains(SR);
229 });
230}
231
232/// Returns the required ZA state needed before \p MI and an iterator pointing
233/// to where any code required to change the ZA state should be inserted.
234static std::pair<ZAState, MachineBasicBlock::iterator>
235getInstNeededZAState(const TargetRegisterInfo &TRI, MachineInstr &MI,
236 SMEAttrs SMEFnAttrs) {
238
239 // Note: InOutZAUsePseudo, RequiresZASavePseudo, and RequiresZT0SavePseudo are
240 // intended to mark the position immediately before a call. Due to
241 // SelectionDAG constraints, these markers occur after the ADJCALLSTACKDOWN,
242 // so we use std::prev(InsertPt) to get the position before the call.
243
244 if (MI.getOpcode() == AArch64::InOutZAUsePseudo)
245 return {ZAState::ACTIVE, std::prev(InsertPt)};
246
247 // Note: If we need to save both ZA and ZT0 we use RequiresZASavePseudo.
248 if (MI.getOpcode() == AArch64::RequiresZASavePseudo)
249 return {ZAState::LOCAL_SAVED, std::prev(InsertPt)};
250
251 // If we only need to save ZT0 there's two cases to consider:
252 // 1. The function has ZA state (that we don't need to save).
253 // - In this case we switch to the "ACTIVE_ZT0_SAVED" state.
254 // This only saves ZT0.
255 // 2. The function does not have ZA state
256 // - In this case we switch to "LOCAL_COMMITTED" state.
257 // This saves ZT0 and turns ZA off.
258 if (MI.getOpcode() == AArch64::RequiresZT0SavePseudo) {
259 return {SMEFnAttrs.hasZAState() ? ZAState::ACTIVE_ZT0_SAVED
260 : ZAState::LOCAL_COMMITTED,
261 std::prev(InsertPt)};
262 }
263
264 if (MI.isReturn()) {
265 bool ZAOffAtReturn = SMEFnAttrs.hasPrivateZAInterface();
266 return {ZAOffAtReturn ? ZAState::OFF : ZAState::ACTIVE, InsertPt};
267 }
268
269 for (auto &MO : MI.operands()) {
270 if (isZAorZTRegOp(TRI, MO))
271 return {ZAState::ACTIVE, InsertPt};
272 }
273
274 return {ZAState::ANY, InsertPt};
275}
276
277struct MachineSMEABI : public MachineFunctionPass {
278 inline static char ID = 0;
279
280 MachineSMEABI(CodeGenOptLevel OptLevel = CodeGenOptLevel::Default)
281 : MachineFunctionPass(ID), OptLevel(OptLevel) {}
282
283 bool runOnMachineFunction(MachineFunction &MF) override;
284
285 StringRef getPassName() const override { return "Machine SME ABI pass"; }
286
287 void getAnalysisUsage(AnalysisUsage &AU) const override {
288 AU.setPreservesCFG();
295 }
296
297 /// Collects the needed ZA state (and live registers) before each instruction
298 /// within the machine function.
299 FunctionInfo collectNeededZAStates(SMEAttrs SMEFnAttrs);
300
301 /// Assigns each edge bundle a ZA state based on the desired states of
302 /// incoming and outgoing blocks in the bundle.
303 SmallVector<ZAState> assignBundleZAStates(const EdgeBundles &Bundles,
304 const FunctionInfo &FnInfo);
305
306 /// Inserts code to handle changes between ZA states within the function.
307 /// E.g., ACTIVE -> LOCAL_SAVED will insert code required to save ZA.
308 void insertStateChanges(EmitContext &, const FunctionInfo &FnInfo,
309 const EdgeBundles &Bundles,
310 ArrayRef<ZAState> BundleStates);
311
312 void addSMELibCall(MachineInstrBuilder &MIB, RTLIB::Libcall LC,
313 CallingConv::ID ExpectedCC);
314
315 void emitZT0SaveRestore(EmitContext &, MachineBasicBlock &MBB,
316 MachineBasicBlock::iterator MBBI, bool IsSave);
317
318 // Emission routines for private and shared ZA functions (using lazy saves).
319 void emitSMEPrologue(MachineBasicBlock &MBB,
321 void emitRestoreLazySave(EmitContext &, MachineBasicBlock &MBB,
323 LiveRegs PhysLiveRegs);
324 void emitSetupLazySave(EmitContext &, MachineBasicBlock &MBB,
326 void emitAllocateLazySaveBuffer(EmitContext &, MachineBasicBlock &MBB,
329 bool ClearTPIDR2, bool On);
330
331 // Emission routines for agnostic ZA functions.
332 void emitSetupFullZASave(MachineBasicBlock &MBB,
334 LiveRegs PhysLiveRegs);
335 // Emit a "full" ZA save or restore. It is "full" in the sense that this
336 // function will emit a call to __arm_sme_save or __arm_sme_restore, which
337 // handles saving and restoring both ZA and ZT0.
338 void emitFullZASaveRestore(EmitContext &, MachineBasicBlock &MBB,
340 LiveRegs PhysLiveRegs, bool IsSave);
341 void emitAllocateFullZASaveBuffer(EmitContext &, MachineBasicBlock &MBB,
343 LiveRegs PhysLiveRegs);
344
345 /// Attempts to find an insertion point before \p Inst where the status flags
346 /// are not live. If \p Inst is `Block.Insts.end()` a point before the end of
347 /// the block is found.
348 std::pair<MachineBasicBlock::iterator, LiveRegs>
349 findStateChangeInsertionPoint(MachineBasicBlock &MBB, const BlockInfo &Block,
351 void emitStateChange(EmitContext &, MachineBasicBlock &MBB,
352 MachineBasicBlock::iterator MBBI, ZAState From,
353 ZAState To, LiveRegs PhysLiveRegs);
354
355 // Helpers for switching between lazy/full ZA save/restore routines.
356 void emitZASave(EmitContext &Context, MachineBasicBlock &MBB,
358 if (AFI->getSMEFnAttrs().hasAgnosticZAInterface())
359 return emitFullZASaveRestore(Context, MBB, MBBI, PhysLiveRegs,
360 /*IsSave=*/true);
361 return emitSetupLazySave(Context, MBB, MBBI);
362 }
363 void emitZARestore(EmitContext &Context, MachineBasicBlock &MBB,
365 if (AFI->getSMEFnAttrs().hasAgnosticZAInterface())
366 return emitFullZASaveRestore(Context, MBB, MBBI, PhysLiveRegs,
367 /*IsSave=*/false);
368 return emitRestoreLazySave(Context, MBB, MBBI, PhysLiveRegs);
369 }
370 void emitAllocateZASaveBuffer(EmitContext &Context, MachineBasicBlock &MBB,
372 LiveRegs PhysLiveRegs) {
373 if (AFI->getSMEFnAttrs().hasAgnosticZAInterface())
374 return emitAllocateFullZASaveBuffer(Context, MBB, MBBI, PhysLiveRegs);
375 return emitAllocateLazySaveBuffer(Context, MBB, MBBI);
376 }
377
378 /// Collects the reachable calls from \p MBBI marked with \p Marker. This is
379 /// intended to be used to emit lazy save remarks. Note: This stops at the
380 /// first marked call along any path.
381 void collectReachableMarkedCalls(const MachineBasicBlock &MBB,
384 unsigned Marker) const;
385
386 void emitCallSaveRemarks(const MachineBasicBlock &MBB,
388 unsigned Marker, StringRef RemarkName,
389 StringRef SaveName) const;
390
391 void emitError(const Twine &Message) {
392 LLVMContext &Context = MF->getFunction().getContext();
393 Context.emitError(MF->getName() + ": " + Message);
394 }
395
396 /// Save live physical registers to virtual registers.
397 PhysRegSave createPhysRegSave(LiveRegs PhysLiveRegs, MachineBasicBlock &MBB,
399 /// Restore physical registers from a save of their previous values.
400 void restorePhyRegSave(const PhysRegSave &RegSave, MachineBasicBlock &MBB,
402
403private:
405
406 MachineFunction *MF = nullptr;
407 const AArch64Subtarget *Subtarget = nullptr;
408 const AArch64RegisterInfo *TRI = nullptr;
409 const AArch64FunctionInfo *AFI = nullptr;
410 const AArch64InstrInfo *TII = nullptr;
411 const LibcallLoweringInfo *LLI = nullptr;
412
414 MachineRegisterInfo *MRI = nullptr;
415 MachineLoopInfo *MLI = nullptr;
416};
417
418static LiveRegs getPhysLiveRegs(LiveRegUnits const &LiveUnits) {
419 LiveRegs PhysLiveRegs = LiveRegs::None;
420 if (!LiveUnits.available(AArch64::NZCV))
421 PhysLiveRegs |= LiveRegs::NZCV;
422 // We have to track W0 and X0 separately as otherwise things can get
423 // confused if we attempt to preserve X0 but only W0 was defined.
424 if (!LiveUnits.available(AArch64::W0))
425 PhysLiveRegs |= LiveRegs::W0;
426 if (!LiveUnits.available(AArch64::W0_HI))
427 PhysLiveRegs |= LiveRegs::W0_HI;
428 return PhysLiveRegs;
429}
430
431static void setPhysLiveRegs(LiveRegUnits &LiveUnits, LiveRegs PhysLiveRegs) {
432 if (PhysLiveRegs & LiveRegs::NZCV)
433 LiveUnits.addReg(AArch64::NZCV);
434 if (PhysLiveRegs & LiveRegs::W0)
435 LiveUnits.addReg(AArch64::W0);
436 if (PhysLiveRegs & LiveRegs::W0_HI)
437 LiveUnits.addReg(AArch64::W0_HI);
438}
439
440[[maybe_unused]] bool isCallStartOpcode(unsigned Opc) {
441 switch (Opc) {
442 case AArch64::TLSDESC_CALLSEQ:
443 case AArch64::TLSDESC_AUTH_CALLSEQ:
444 case AArch64::ADJCALLSTACKDOWN:
445 return true;
446 default:
447 return false;
448 }
449}
450
451FunctionInfo MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
452 assert((SMEFnAttrs.hasAgnosticZAInterface() || SMEFnAttrs.hasZT0State() ||
453 SMEFnAttrs.hasZAState()) &&
454 "Expected function to have ZA/ZT0 state!");
455
457 LiveRegs PhysLiveRegsAfterSMEPrologue = LiveRegs::None;
458 std::optional<MachineBasicBlock::iterator> AfterSMEProloguePt;
459
460 for (MachineBasicBlock &MBB : *MF) {
461 BlockInfo &Block = Blocks[MBB.getNumber()];
462
463 if (MBB.isEntryBlock()) {
464 // Entry block:
465 Block.FixedEntryState = ZAState::ENTRY;
466 } else if (MBB.isEHPad()) {
467 // EH entry block:
468 Block.FixedEntryState = ZAState::LOCAL_COMMITTED;
469 }
470
471 LiveRegUnits LiveUnits(*TRI);
472 LiveUnits.addLiveOuts(MBB);
473
474 Block.PhysLiveRegsAtExit = getPhysLiveRegs(LiveUnits);
475 auto FirstTerminatorInsertPt = MBB.getFirstTerminator();
476 auto FirstNonPhiInsertPt = MBB.getFirstNonPHI();
477 for (MachineInstr &MI : reverse(MBB)) {
479 LiveUnits.stepBackward(MI);
480 LiveRegs PhysLiveRegs = getPhysLiveRegs(LiveUnits);
481 // The SMEStateAllocPseudo marker is added to a function if the save
482 // buffer was allocated in SelectionDAG. It marks the end of the
483 // allocation -- which is a safe point for this pass to insert any TPIDR2
484 // block setup.
485 if (MI.getOpcode() == AArch64::SMEStateAllocPseudo) {
486 AfterSMEProloguePt = MBBI;
487 PhysLiveRegsAfterSMEPrologue = PhysLiveRegs;
488 }
489 // Note: We treat Agnostic ZA as inout_za with an alternate save/restore.
490 auto [NeededState, InsertPt] = getInstNeededZAState(*TRI, MI, SMEFnAttrs);
491 assert((InsertPt == MBBI || isCallStartOpcode(InsertPt->getOpcode())) &&
492 "Unexpected state change insertion point!");
493 if (MBBI == FirstTerminatorInsertPt)
494 Block.PhysLiveRegsAtExit = PhysLiveRegs;
495 if (MBBI == FirstNonPhiInsertPt)
496 Block.PhysLiveRegsAtEntry = PhysLiveRegs;
497 if (NeededState != ZAState::ANY)
498 Block.Insts.push_back({NeededState, InsertPt, PhysLiveRegs});
499 }
500
501 // Reverse vector (as we had to iterate backwards for liveness).
502 std::reverse(Block.Insts.begin(), Block.Insts.end());
503
504 // Record the desired states on entry/exit of this block. These are the
505 // states that would not incur a state transition.
506 if (!Block.Insts.empty()) {
507 Block.DesiredIncomingState = Block.Insts.front().NeededState;
508 Block.DesiredOutgoingState = Block.Insts.back().NeededState;
509 }
510 }
511
512 return FunctionInfo{std::move(Blocks), AfterSMEProloguePt,
513 PhysLiveRegsAfterSMEPrologue};
514}
515
516/// Assigns each edge bundle a ZA state based on the desired states of incoming
517/// and outgoing blocks in the bundle.
519MachineSMEABI::assignBundleZAStates(const EdgeBundles &Bundles,
520 const FunctionInfo &FnInfo) {
521 SmallVector<ZAState> BundleStates(Bundles.getNumBundles());
522 for (unsigned I = 0, E = Bundles.getNumBundles(); I != E; ++I) {
523 std::optional<ZAState> BundleState;
524 for (unsigned BlockID : Bundles.getBlocks(I)) {
525 const BlockInfo &Block = FnInfo.Blocks[BlockID];
526 // Check if the block is an incoming block in the bundle. Note: We skip
527 // Block.FixedEntryState != ANY to ignore EH pads (which are only
528 // reachable via exceptions).
529 if (Block.FixedEntryState != ZAState::ANY ||
530 Bundles.getBundle(BlockID, /*Out=*/false) != I)
531 continue;
532
533 // Pick a state that matches all incoming blocks. Fall back to "ACTIVE" if
534 // any incoming state doesn't match. This will hoist the state from
535 // incoming blocks to outgoing blocks.
536 if (!BundleState)
537 BundleState = Block.DesiredIncomingState;
538 else if (BundleState != Block.DesiredIncomingState)
539 BundleState = ZAState::ACTIVE;
540 }
541
542 if (!BundleState || BundleState == ZAState::ANY)
543 BundleState = ZAState::ACTIVE;
544
545 BundleStates[I] = *BundleState;
546 }
547
548 return BundleStates;
549}
550
551std::pair<MachineBasicBlock::iterator, LiveRegs>
552MachineSMEABI::findStateChangeInsertionPoint(
553 MachineBasicBlock &MBB, const BlockInfo &Block,
555 LiveRegs PhysLiveRegs;
557 if (Inst != Block.Insts.end()) {
558 InsertPt = Inst->InsertPt;
559 PhysLiveRegs = Inst->PhysLiveRegs;
560 } else {
561 InsertPt = MBB.getFirstTerminator();
562 PhysLiveRegs = Block.PhysLiveRegsAtExit;
563 }
564
565 if (PhysLiveRegs == LiveRegs::None)
566 return {InsertPt, PhysLiveRegs}; // Nothing to do (no live regs).
567
568 // Find the previous state change. We can not move before this point.
569 MachineBasicBlock::iterator PrevStateChangeI;
570 if (Inst == Block.Insts.begin()) {
571 PrevStateChangeI = MBB.begin();
572 } else {
573 // Note: `std::prev(Inst)` is the previous InstInfo. We only create an
574 // InstInfo object for instructions that require a specific ZA state, so the
575 // InstInfo is the site of the previous state change in the block (which can
576 // be several MIs earlier).
577 PrevStateChangeI = std::prev(Inst)->InsertPt;
578 }
579
580 // Note: LiveUnits will only accurately track X0 and NZCV.
581 LiveRegUnits LiveUnits(*TRI);
582 setPhysLiveRegs(LiveUnits, PhysLiveRegs);
583 auto BestCandidate = std::make_pair(InsertPt, PhysLiveRegs);
584 for (MachineBasicBlock::iterator I = InsertPt; I != PrevStateChangeI; --I) {
585 // Don't move before/into a call (which may have a state change before it).
586 if (I->getOpcode() == TII->getCallFrameDestroyOpcode() || I->isCall())
587 break;
588 LiveUnits.stepBackward(*I);
589 LiveRegs CurrentPhysLiveRegs = getPhysLiveRegs(LiveUnits);
590 // Find places where NZCV is available, but keep looking for locations where
591 // both NZCV and X0 are available, which can avoid some copies.
592 if (!(CurrentPhysLiveRegs & LiveRegs::NZCV))
593 BestCandidate = {I, CurrentPhysLiveRegs};
594 if (CurrentPhysLiveRegs == LiveRegs::None)
595 break;
596 }
597 return BestCandidate;
598}
599
600void MachineSMEABI::insertStateChanges(EmitContext &Context,
601 const FunctionInfo &FnInfo,
602 const EdgeBundles &Bundles,
603 ArrayRef<ZAState> BundleStates) {
604 for (MachineBasicBlock &MBB : *MF) {
605 const BlockInfo &Block = FnInfo.Blocks[MBB.getNumber()];
606 ZAState InState = BundleStates[Bundles.getBundle(MBB.getNumber(),
607 /*Out=*/false)];
608
609 ZAState CurrentState = Block.FixedEntryState;
610 if (CurrentState == ZAState::ANY)
611 CurrentState = InState;
612
613 for (auto &Inst : Block.Insts) {
614 if (CurrentState != Inst.NeededState) {
615 auto [InsertPt, PhysLiveRegs] =
616 findStateChangeInsertionPoint(MBB, Block, &Inst);
617 emitStateChange(Context, MBB, InsertPt, CurrentState, Inst.NeededState,
618 PhysLiveRegs);
619 CurrentState = Inst.NeededState;
620 }
621 }
622
623 if (MBB.succ_empty())
624 continue;
625
626 ZAState OutState =
627 BundleStates[Bundles.getBundle(MBB.getNumber(), /*Out=*/true)];
628 if (CurrentState != OutState) {
629 auto [InsertPt, PhysLiveRegs] =
630 findStateChangeInsertionPoint(MBB, Block, Block.Insts.end());
631 emitStateChange(Context, MBB, InsertPt, CurrentState, OutState,
632 PhysLiveRegs);
633 }
634 }
635}
636
639 if (MBB.empty())
640 return DebugLoc();
641 return MBBI != MBB.end() ? MBBI->getDebugLoc() : MBB.back().getDebugLoc();
642}
643
644/// Finds the first call (as determined by MachineInstr::isCall()) starting from
645/// \p MBBI in \p MBB marked with \p Marker (which is a marker opcode such as
646/// RequiresZASavePseudo). If a marked call is found, it is pushed to \p Calls
647/// and the function returns true.
648static bool findMarkedCall(const MachineBasicBlock &MBB,
651 unsigned Marker, unsigned CallDestroyOpcode) {
652 auto IsMarker = [&](auto &MI) { return MI.getOpcode() == Marker; };
653 auto MarkerInst = std::find_if(MBBI, MBB.end(), IsMarker);
654 if (MarkerInst == MBB.end())
655 return false;
657 while (++I != MBB.end()) {
658 if (I->isCall() || I->getOpcode() == CallDestroyOpcode)
659 break;
660 }
661 if (I != MBB.end() && I->isCall())
662 Calls.push_back(&*I);
663 // Note: This function always returns true if a "Marker" was found.
664 return true;
665}
666
667void MachineSMEABI::collectReachableMarkedCalls(
668 const MachineBasicBlock &StartMBB,
670 SmallVectorImpl<const MachineInstr *> &Calls, unsigned Marker) const {
671 assert(Marker == AArch64::InOutZAUsePseudo ||
672 Marker == AArch64::RequiresZASavePseudo ||
673 Marker == AArch64::RequiresZT0SavePseudo);
674 unsigned CallDestroyOpcode = TII->getCallFrameDestroyOpcode();
675 if (findMarkedCall(StartMBB, StartInst, Calls, Marker, CallDestroyOpcode))
676 return;
677
680 StartMBB.succ_rend());
681 while (!Worklist.empty()) {
682 const MachineBasicBlock *MBB = Worklist.pop_back_val();
683 auto [_, Inserted] = Visited.insert(MBB);
684 if (!Inserted)
685 continue;
686
687 if (!findMarkedCall(*MBB, MBB->begin(), Calls, Marker, CallDestroyOpcode))
688 Worklist.append(MBB->succ_rbegin(), MBB->succ_rend());
689 }
690}
691
692static StringRef getCalleeName(const MachineInstr &CallInst) {
693 assert(CallInst.isCall() && "expected a call");
694 for (const MachineOperand &MO : CallInst.operands()) {
695 if (MO.isSymbol())
696 return MO.getSymbolName();
697 if (MO.isGlobal())
698 return MO.getGlobal()->getName();
699 }
700 return {};
701}
702
703void MachineSMEABI::emitCallSaveRemarks(const MachineBasicBlock &MBB,
705 DebugLoc DL, unsigned Marker,
706 StringRef RemarkName,
707 StringRef SaveName) const {
708 auto SaveRemark = [&](DebugLoc DL, const MachineBasicBlock &MBB) {
709 return MachineOptimizationRemarkAnalysis("sme", RemarkName, DL, &MBB);
710 };
711 StringRef StateName = Marker == AArch64::RequiresZT0SavePseudo ? "ZT0" : "ZA";
712 ORE->emit([&] {
713 return SaveRemark(DL, MBB) << SaveName << " of " << StateName
714 << " emitted in '" << MF->getName() << "'";
715 });
716 if (!ORE->allowExtraAnalysis("sme"))
717 return;
718 SmallVector<const MachineInstr *> CallsRequiringSaves;
719 collectReachableMarkedCalls(MBB, MBBI, CallsRequiringSaves, Marker);
720 for (const MachineInstr *CallInst : CallsRequiringSaves) {
721 auto R = SaveRemark(CallInst->getDebugLoc(), *CallInst->getParent());
722 R << "call";
723 if (StringRef CalleeName = getCalleeName(*CallInst); !CalleeName.empty())
724 R << " to '" << CalleeName << "'";
725 R << " requires " << StateName << " save";
726 ORE->emit(R);
727 }
728}
729
730void MachineSMEABI::emitSetupLazySave(EmitContext &Context,
734
735 emitCallSaveRemarks(MBB, MBBI, DL, AArch64::RequiresZASavePseudo,
736 "SMELazySaveZA", "lazy save");
737
738 // Get pointer to TPIDR2 block.
739 Register TPIDR2 = MRI->createVirtualRegister(&AArch64::GPR64spRegClass);
740 Register TPIDR2Ptr = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
741 BuildMI(MBB, MBBI, DL, TII->get(AArch64::ADDXri), TPIDR2)
742 .addFrameIndex(Context.getTPIDR2Block(*MF))
743 .addImm(0)
744 .addImm(0);
745 BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), TPIDR2Ptr)
746 .addReg(TPIDR2);
747 // Set TPIDR2_EL0 to point to TPIDR2 block.
748 BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSR))
749 .addImm(AArch64SysReg::TPIDR2_EL0)
750 .addReg(TPIDR2Ptr);
751}
752
753PhysRegSave MachineSMEABI::createPhysRegSave(LiveRegs PhysLiveRegs,
756 DebugLoc DL) {
757 PhysRegSave RegSave{PhysLiveRegs};
758 if (PhysLiveRegs & LiveRegs::NZCV) {
759 RegSave.StatusFlags = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
760 BuildMI(MBB, MBBI, DL, TII->get(AArch64::MRS), RegSave.StatusFlags)
761 .addImm(AArch64SysReg::NZCV)
762 .addReg(AArch64::NZCV, RegState::Implicit);
763 }
764 // Note: Preserving X0 is "free" as this is before register allocation, so
765 // the register allocator is still able to optimize these copies.
766 if (PhysLiveRegs & LiveRegs::W0) {
767 RegSave.X0Save = MRI->createVirtualRegister(PhysLiveRegs & LiveRegs::W0_HI
768 ? &AArch64::GPR64RegClass
769 : &AArch64::GPR32RegClass);
770 BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), RegSave.X0Save)
771 .addReg(PhysLiveRegs & LiveRegs::W0_HI ? AArch64::X0 : AArch64::W0);
772 }
773 return RegSave;
774}
775
776void MachineSMEABI::restorePhyRegSave(const PhysRegSave &RegSave,
779 DebugLoc DL) {
780 if (RegSave.StatusFlags != AArch64::NoRegister)
781 BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSR))
782 .addImm(AArch64SysReg::NZCV)
783 .addReg(RegSave.StatusFlags)
784 .addReg(AArch64::NZCV, RegState::ImplicitDefine);
785
786 if (RegSave.X0Save != AArch64::NoRegister)
787 BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY),
788 RegSave.PhysLiveRegs & LiveRegs::W0_HI ? AArch64::X0 : AArch64::W0)
789 .addReg(RegSave.X0Save);
790}
791
792void MachineSMEABI::addSMELibCall(MachineInstrBuilder &MIB, RTLIB::Libcall LC,
793 CallingConv::ID ExpectedCC) {
794 RTLIB::LibcallImpl LCImpl = LLI->getLibcallImpl(LC);
795 if (LCImpl == RTLIB::Unsupported)
796 emitError("cannot lower SME ABI (SME routines unsupported)");
799 if (CC != ExpectedCC)
800 emitError("invalid calling convention for SME routine: '" + ImplName + "'");
801 // FIXME: This assumes the ImplName StringRef is null-terminated.
802 MIB.addExternalSymbol(ImplName.data());
803 MIB.addRegMask(TRI->getCallPreservedMask(*MF, CC));
804}
805
806void MachineSMEABI::emitRestoreLazySave(EmitContext &Context,
809 LiveRegs PhysLiveRegs) {
811 Register TPIDR2EL0 = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
812 Register TPIDR2 = AArch64::X0;
813
814 // TODO: Emit these within the restore MBB to prevent unnecessary saves.
815 PhysRegSave RegSave = createPhysRegSave(PhysLiveRegs, MBB, MBBI, DL);
816
817 // Enable ZA.
818 BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSRpstatesvcrImm1))
819 .addImm(AArch64SVCR::SVCRZA)
820 .addImm(1);
821 // Get current TPIDR2_EL0.
822 BuildMI(MBB, MBBI, DL, TII->get(AArch64::MRS), TPIDR2EL0)
823 .addImm(AArch64SysReg::TPIDR2_EL0);
824 // Get pointer to TPIDR2 block.
825 BuildMI(MBB, MBBI, DL, TII->get(AArch64::ADDXri), TPIDR2)
826 .addFrameIndex(Context.getTPIDR2Block(*MF))
827 .addImm(0)
828 .addImm(0);
829 // (Conditionally) restore ZA state.
830 auto RestoreZA = BuildMI(MBB, MBBI, DL, TII->get(AArch64::RestoreZAPseudo))
831 .addReg(TPIDR2EL0)
832 .addReg(TPIDR2);
833 addSMELibCall(
834 RestoreZA, RTLIB::SMEABI_TPIDR2_RESTORE,
836 // Zero TPIDR2_EL0.
837 BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSR))
838 .addImm(AArch64SysReg::TPIDR2_EL0)
839 .addReg(AArch64::XZR);
840
841 restorePhyRegSave(RegSave, MBB, MBBI, DL);
842}
843
844void MachineSMEABI::emitZAMode(MachineBasicBlock &MBB,
846 bool ClearTPIDR2, bool On) {
848
849 if (ClearTPIDR2)
850 BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSR))
851 .addImm(AArch64SysReg::TPIDR2_EL0)
852 .addReg(AArch64::XZR);
853
854 // Disable ZA.
855 BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSRpstatesvcrImm1))
856 .addImm(AArch64SVCR::SVCRZA)
857 .addImm(On ? 1 : 0);
858}
859
860void MachineSMEABI::emitAllocateLazySaveBuffer(
861 EmitContext &Context, MachineBasicBlock &MBB,
863 MachineFrameInfo &MFI = MF->getFrameInfo();
865 Register SP = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
866 Register SVL = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
867 Register Buffer = AFI->getEarlyAllocSMESaveBuffer();
868
869 // Calculate SVL.
870 BuildMI(MBB, MBBI, DL, TII->get(AArch64::RDSVLI_XI), SVL).addImm(1);
871
872 // 1. Allocate the lazy save buffer.
873 if (Buffer == AArch64::NoRegister) {
874 // TODO: On Windows, we allocate the lazy save buffer in SelectionDAG (so
875 // Buffer != AArch64::NoRegister). This is done to reuse the existing
876 // expansions (which can insert stack checks). This works, but it means we
877 // will always allocate the lazy save buffer (even if the function contains
878 // no lazy saves). If we want to handle Windows here, we'll need to
879 // implement something similar to LowerWindowsDYNAMIC_STACKALLOC.
880 assert(!Subtarget->isTargetWindows() &&
881 "Lazy ZA save is not yet supported on Windows");
882 Buffer = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
883 // Get original stack pointer.
884 BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), SP)
885 .addReg(AArch64::SP);
886 // Allocate a lazy-save buffer object of the size given, normally SVL * SVL
887 BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSUBXrrr), Buffer)
888 .addReg(SVL)
889 .addReg(SVL)
890 .addReg(SP);
891 BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), AArch64::SP)
892 .addReg(Buffer);
893 // We have just allocated a variable sized object, tell this to PEI.
894 MFI.CreateVariableSizedObject(Align(16), nullptr);
895 }
896
897 // 2. Setup the TPIDR2 block.
898 {
899 // Note: This case just needs to do `SVL << 48`. It is not implemented as we
900 // generally don't support big-endian SVE/SME.
901 if (!Subtarget->isLittleEndian())
903 "TPIDR2 block initialization is not supported on big-endian targets");
904
905 // Store buffer pointer and num_za_save_slices.
906 // Bytes 10-15 are implicitly zeroed.
907 BuildMI(MBB, MBBI, DL, TII->get(AArch64::STPXi))
908 .addReg(Buffer)
909 .addReg(SVL)
910 .addFrameIndex(Context.getTPIDR2Block(*MF))
911 .addImm(0);
912 }
913}
914
915static constexpr unsigned ZERO_ALL_ZA_MASK = 0b11111111;
916
917void MachineSMEABI::emitSMEPrologue(MachineBasicBlock &MBB,
920
921 bool ZeroZA = AFI->getSMEFnAttrs().isNewZA();
922 bool ZeroZT0 = AFI->getSMEFnAttrs().isNewZT0();
924 // Get current TPIDR2_EL0.
925 Register TPIDR2EL0 = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
926 BuildMI(MBB, MBBI, DL, TII->get(AArch64::MRS))
927 .addReg(TPIDR2EL0, RegState::Define)
928 .addImm(AArch64SysReg::TPIDR2_EL0);
929 // If TPIDR2_EL0 is non-zero, commit the lazy save.
930 // NOTE: Functions that only use ZT0 don't need to zero ZA.
931 auto CommitZASave =
932 BuildMI(MBB, MBBI, DL, TII->get(AArch64::CommitZASavePseudo))
933 .addReg(TPIDR2EL0)
934 .addImm(ZeroZA)
935 .addImm(ZeroZT0);
936 addSMELibCall(
937 CommitZASave, RTLIB::SMEABI_TPIDR2_SAVE,
939 if (ZeroZA)
940 CommitZASave.addDef(AArch64::ZAB0, RegState::ImplicitDefine);
941 if (ZeroZT0)
942 CommitZASave.addDef(AArch64::ZT0, RegState::ImplicitDefine);
943 // Enable ZA (as ZA could have previously been in the OFF state).
944 BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSRpstatesvcrImm1))
945 .addImm(AArch64SVCR::SVCRZA)
946 .addImm(1);
947 } else if (AFI->getSMEFnAttrs().hasSharedZAInterface()) {
948 if (ZeroZA)
949 BuildMI(MBB, MBBI, DL, TII->get(AArch64::ZERO_M))
951 .addDef(AArch64::ZAB0, RegState::ImplicitDefine);
952 if (ZeroZT0)
953 BuildMI(MBB, MBBI, DL, TII->get(AArch64::ZERO_T)).addDef(AArch64::ZT0);
954 }
955}
956
957void MachineSMEABI::emitFullZASaveRestore(EmitContext &Context,
960 LiveRegs PhysLiveRegs, bool IsSave) {
962
963 if (IsSave)
964 emitCallSaveRemarks(MBB, MBBI, DL, AArch64::RequiresZASavePseudo,
965 "SMEFullZASave", "full save");
966
967 PhysRegSave RegSave = createPhysRegSave(PhysLiveRegs, MBB, MBBI, DL);
968
969 // Copy the buffer pointer into X0.
970 Register BufferPtr = AArch64::X0;
971 BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), BufferPtr)
972 .addReg(Context.getAgnosticZABufferPtr(*MF));
973
974 // Call __arm_sme_save/__arm_sme_restore.
975 auto SaveRestoreZA = BuildMI(MBB, MBBI, DL, TII->get(AArch64::BL))
976 .addReg(BufferPtr, RegState::Implicit);
977 addSMELibCall(
978 SaveRestoreZA,
979 IsSave ? RTLIB::SMEABI_SME_SAVE : RTLIB::SMEABI_SME_RESTORE,
981
982 restorePhyRegSave(RegSave, MBB, MBBI, DL);
983}
984
985void MachineSMEABI::emitZT0SaveRestore(EmitContext &Context,
988 bool IsSave) {
990
991 // Note: This will report calls that _only_ need ZT0 saved. Call that save
992 // both ZA and ZT0 will be under the SMELazySaveZA remark. This prevents
993 // reporting the same calls twice.
994 if (IsSave)
995 emitCallSaveRemarks(MBB, MBBI, DL, AArch64::RequiresZT0SavePseudo,
996 "SMEZT0Save", "spill");
997
998 Register ZT0Save = MRI->createVirtualRegister(&AArch64::GPR64spRegClass);
999
1000 BuildMI(MBB, MBBI, DL, TII->get(AArch64::ADDXri), ZT0Save)
1001 .addFrameIndex(Context.getZT0SaveSlot(*MF))
1002 .addImm(0)
1003 .addImm(0);
1004
1005 if (IsSave) {
1006 BuildMI(MBB, MBBI, DL, TII->get(AArch64::STR_TX))
1007 .addReg(AArch64::ZT0)
1008 .addReg(ZT0Save);
1009 } else {
1010 BuildMI(MBB, MBBI, DL, TII->get(AArch64::LDR_TX), AArch64::ZT0)
1011 .addReg(ZT0Save);
1012 }
1013}
1014
1015void MachineSMEABI::emitAllocateFullZASaveBuffer(
1016 EmitContext &Context, MachineBasicBlock &MBB,
1018 // Buffer already allocated in SelectionDAG.
1019 if (AFI->getEarlyAllocSMESaveBuffer())
1020 return;
1021
1023 Register BufferPtr = Context.getAgnosticZABufferPtr(*MF);
1024 Register BufferSize = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
1025
1026 PhysRegSave RegSave = createPhysRegSave(PhysLiveRegs, MBB, MBBI, DL);
1027
1028 // Calculate the SME state size.
1029 {
1030 auto SMEStateSize = BuildMI(MBB, MBBI, DL, TII->get(AArch64::BL))
1031 .addReg(AArch64::X0, RegState::ImplicitDefine);
1032 addSMELibCall(
1033 SMEStateSize, RTLIB::SMEABI_SME_STATE_SIZE,
1035 BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), BufferSize)
1036 .addReg(AArch64::X0);
1037 }
1038
1039 // Allocate a buffer object of the size given __arm_sme_state_size.
1040 {
1041 MachineFrameInfo &MFI = MF->getFrameInfo();
1042 BuildMI(MBB, MBBI, DL, TII->get(AArch64::SUBXrx64), AArch64::SP)
1043 .addReg(AArch64::SP)
1044 .addReg(BufferSize)
1046 BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), BufferPtr)
1047 .addReg(AArch64::SP);
1048
1049 // We have just allocated a variable sized object, tell this to PEI.
1050 MFI.CreateVariableSizedObject(Align(16), nullptr);
1051 }
1052
1053 restorePhyRegSave(RegSave, MBB, MBBI, DL);
1054}
1055
1056struct FromState {
1057 ZAState From;
1058
1059 constexpr uint8_t to(ZAState To) const {
1060 static_assert(NUM_ZA_STATE < 16, "expected ZAState to fit in 4-bits");
1061 return uint8_t(From) << 4 | uint8_t(To);
1062 }
1063};
1064
1065constexpr FromState transitionFrom(ZAState From) { return FromState{From}; }
1066
1067void MachineSMEABI::emitStateChange(EmitContext &Context,
1070 ZAState From, ZAState To,
1071 LiveRegs PhysLiveRegs) {
1072 // ZA not used.
1073 if (From == ZAState::ANY || To == ZAState::ANY)
1074 return;
1075
1076 // If we're exiting from the ENTRY state that means that the function has not
1077 // used ZA, so in the case of private ZA/ZT0 functions we can omit any set up.
1078 if (From == ZAState::ENTRY && To == ZAState::OFF)
1079 return;
1080
1081 // TODO: Avoid setting up the save buffer if there's no transition to
1082 // LOCAL_SAVED.
1083 if (From == ZAState::ENTRY) {
1084 assert(&MBB == &MBB.getParent()->front() &&
1085 "ENTRY state only valid in entry block");
1086 emitSMEPrologue(MBB, MBB.getFirstNonPHI());
1087 if (To == ZAState::ACTIVE)
1088 return; // Nothing more to do (ZA is active after the prologue).
1089
1090 // Note: "emitNewZAPrologue" zeros ZA, so we may need to setup a lazy save
1091 // if "To" is "ZAState::LOCAL_SAVED". It may be possible to improve this
1092 // case by changing the placement of the zero instruction.
1093 From = ZAState::ACTIVE;
1094 }
1095
1096 SMEAttrs SMEFnAttrs = AFI->getSMEFnAttrs();
1097 bool IsAgnosticZA = SMEFnAttrs.hasAgnosticZAInterface();
1098 bool HasZT0State = SMEFnAttrs.hasZT0State();
1099 bool HasZAState = IsAgnosticZA || SMEFnAttrs.hasZAState();
1100
1101 switch (transitionFrom(From).to(To)) {
1102 // This section handles: ACTIVE <-> ACTIVE_ZT0_SAVED
1103 case transitionFrom(ZAState::ACTIVE).to(ZAState::ACTIVE_ZT0_SAVED):
1104 emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/true);
1105 break;
1106 case transitionFrom(ZAState::ACTIVE_ZT0_SAVED).to(ZAState::ACTIVE):
1107 emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/false);
1108 break;
1109
1110 // This section handles: ACTIVE[_ZT0_SAVED] -> LOCAL_SAVED
1111 case transitionFrom(ZAState::ACTIVE).to(ZAState::LOCAL_SAVED):
1112 case transitionFrom(ZAState::ACTIVE_ZT0_SAVED).to(ZAState::LOCAL_SAVED):
1113 if (HasZT0State && From == ZAState::ACTIVE)
1114 emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/true);
1115 if (HasZAState)
1116 emitZASave(Context, MBB, InsertPt, PhysLiveRegs);
1117 break;
1118
1119 // This section handles: ACTIVE -> LOCAL_COMMITTED
1120 case transitionFrom(ZAState::ACTIVE).to(ZAState::LOCAL_COMMITTED):
1121 // TODO: We could support ZA state here, but this transition is currently
1122 // only possible when we _don't_ have ZA state.
1123 assert(HasZT0State && !HasZAState && "Expect to only have ZT0 state.");
1124 emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/true);
1125 emitZAMode(MBB, InsertPt, /*ClearTPIDR2=*/false, /*On=*/false);
1126 break;
1127
1128 // This section handles: LOCAL_COMMITTED -> (OFF|LOCAL_SAVED)
1129 case transitionFrom(ZAState::LOCAL_COMMITTED).to(ZAState::OFF):
1130 case transitionFrom(ZAState::LOCAL_COMMITTED).to(ZAState::LOCAL_SAVED):
1131 // These transitions are a no-op.
1132 break;
1133
1134 // This section handles: LOCAL_(SAVED|COMMITTED) -> ACTIVE[_ZT0_SAVED]
1135 case transitionFrom(ZAState::LOCAL_COMMITTED).to(ZAState::ACTIVE):
1136 case transitionFrom(ZAState::LOCAL_COMMITTED).to(ZAState::ACTIVE_ZT0_SAVED):
1137 case transitionFrom(ZAState::LOCAL_SAVED).to(ZAState::ACTIVE):
1138 case transitionFrom(ZAState::LOCAL_SAVED).to(ZAState::ACTIVE_ZT0_SAVED):
1139 if (HasZAState)
1140 emitZARestore(Context, MBB, InsertPt, PhysLiveRegs);
1141 else
1142 emitZAMode(MBB, InsertPt, /*ClearTPIDR2=*/false, /*On=*/true);
1143 if (HasZT0State && To == ZAState::ACTIVE)
1144 emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/false);
1145 break;
1146
1147 // This section handles transitions to OFF (not previously covered)
1148 case transitionFrom(ZAState::ACTIVE).to(ZAState::OFF):
1149 case transitionFrom(ZAState::ACTIVE_ZT0_SAVED).to(ZAState::OFF):
1150 case transitionFrom(ZAState::LOCAL_SAVED).to(ZAState::OFF):
1151 assert(SMEFnAttrs.hasPrivateZAInterface() &&
1152 "Did not expect to turn ZA off in shared/agnostic ZA function");
1153 emitZAMode(MBB, InsertPt, /*ClearTPIDR2=*/From == ZAState::LOCAL_SAVED,
1154 /*On=*/false);
1155 break;
1156
1157 default:
1158 dbgs() << "Error: Transition from " << getZAStateString(From) << " to "
1159 << getZAStateString(To) << '\n';
1160 llvm_unreachable("Unimplemented state transition");
1161 }
1162}
1163
1164/// Returns true if private ZA setup can be elided. This occurs when there is
1165/// no instruction within the function that requires ZA to be active.
1166static bool canElidePrivateZASetup(const FunctionInfo &FnInfo) {
1167 for (const BlockInfo &BlockInfo : FnInfo.Blocks) {
1168 for (const InstInfo &InstInfo : BlockInfo.Insts) {
1169 if (InstInfo.NeededState == ZAState::ACTIVE ||
1170 InstInfo.NeededState == ZAState::ACTIVE_ZT0_SAVED)
1171 return false;
1172 }
1173 }
1174 return true;
1175}
1176
1177} // end anonymous namespace
1178
1179INITIALIZE_PASS(MachineSMEABI, "aarch64-machine-sme-abi", "Machine SME ABI",
1180 false, false)
1181
1182bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
1183 AFI = MF.getInfo<AArch64FunctionInfo>();
1184 SMEAttrs SMEFnAttrs = AFI->getSMEFnAttrs();
1185 if (!SMEFnAttrs.hasZAState() && !SMEFnAttrs.hasZT0State() &&
1186 !SMEFnAttrs.hasAgnosticZAInterface())
1187 return false;
1188
1189 Subtarget = &MF.getSubtarget<AArch64Subtarget>();
1190 if (!Subtarget->hasSME() && !SMEFnAttrs.hasAgnosticZAInterface())
1191 return false;
1192
1193 assert(MF.getRegInfo().isSSA() && "Expected to be run on SSA form!");
1194
1195 this->MF = &MF;
1196 ORE = &getAnalysis<MachineOptimizationRemarkEmitterPass>().getORE();
1197 LLI = &getAnalysis<LibcallLoweringInfoWrapper>().getLibcallLowering(
1198 *MF.getFunction().getParent(), *Subtarget);
1199 TII = Subtarget->getInstrInfo();
1200 TRI = Subtarget->getRegisterInfo();
1201 MRI = &MF.getRegInfo();
1202
1203 const EdgeBundles &Bundles =
1204 getAnalysis<EdgeBundlesWrapperLegacy>().getEdgeBundles();
1205
1206 FunctionInfo FnInfo = collectNeededZAStates(SMEFnAttrs);
1207
1208 if (SMEFnAttrs.hasPrivateZAInterface() && canElidePrivateZASetup(FnInfo))
1209 return false;
1210
1211 SmallVector<ZAState> BundleStates = assignBundleZAStates(Bundles, FnInfo);
1212
1213 EmitContext Context;
1214 insertStateChanges(Context, FnInfo, Bundles, BundleStates);
1215
1216 if (Context.needsSaveBuffer()) {
1217 if (FnInfo.AfterSMEProloguePt) {
1218 // Note: With inline stack probes the AfterSMEProloguePt may not be in the
1219 // entry block (due to the probing loop).
1220 MachineBasicBlock::iterator MBBI = *FnInfo.AfterSMEProloguePt;
1221 emitAllocateZASaveBuffer(Context, *MBBI->getParent(), MBBI,
1222 FnInfo.PhysLiveRegsAfterSMEPrologue);
1223 } else {
1224 MachineBasicBlock &EntryBlock = MF.front();
1225 emitAllocateZASaveBuffer(
1226 Context, EntryBlock, EntryBlock.getFirstNonPHI(),
1227 FnInfo.Blocks[EntryBlock.getNumber()].PhysLiveRegsAtEntry);
1228 }
1229 }
1230
1231 return true;
1232}
1233
1235 return new MachineSMEABI(OptLevel);
1236}
static constexpr unsigned ZERO_ALL_ZA_MASK
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
MachineBasicBlock & MBB
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
MachineBasicBlock MachineBasicBlock::iterator MBBI
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
const HexagonInstrInfo * TII
#define _
IRTranslator LLVM IR MI
This file implements the LivePhysRegs utility for tracking liveness of physical registers.
#define ENTRY(ASMNAME, ENUM)
#define I(x, y, z)
Definition MD5.cpp:57
static DebugLoc getDebugLoc(MachineBasicBlock::instr_iterator FirstMI, MachineBasicBlock::instr_iterator LastMI)
Return the first DebugLoc that has line number information, given a range of instructions.
===- MachineOptimizationRemarkEmitter.h - Opt Diagnostics -*- C++ -*-—===//
#define MAKE_CASE(V)
Register const TargetRegisterInfo * TRI
if(PassOpts->AAPipeline)
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
Definition PassSupport.h:56
This file defines the SmallVector class.
AArch64FunctionInfo - This class is derived from MachineFunctionInfo and contains private AArch64-spe...
Represent the analysis usage information of a pass.
AnalysisUsage & addPreservedID(const void *ID)
AnalysisUsage & addRequired()
LLVM_ABI void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition Pass.cpp:270
Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition ArrayRef.h:40
This class represents a function call, abstracting a target machine's calling convention.
A debug info location.
Definition DebugLoc.h:123
ArrayRef< unsigned > getBlocks(unsigned Bundle) const
getBlocks - Return an array of blocks that are connected to Bundle.
Definition EdgeBundles.h:53
unsigned getBundle(unsigned N, bool Out) const
getBundle - Return the ingoing (Out = false) or outgoing (Out = true) bundle number for basic block N
Definition EdgeBundles.h:47
unsigned getNumBundles() const
getNumBundles - Return the total number of bundles in the CFG.
Definition EdgeBundles.h:50
FunctionPass class - This class is used to implement most global optimizations.
Definition Pass.h:314
LLVMContext & getContext() const
getContext - Return a reference to the LLVMContext associated with this function.
Definition Function.cpp:358
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
This is an important class for using LLVM in a threaded context.
Definition LLVMContext.h:68
LLVM_ABI void emitError(const Instruction *I, const Twine &ErrorStr)
emitError - Emit an error message to the currently installed error handler with optional location inf...
Tracks which library functions to use for a particular subtarget.
LLVM_ABI CallingConv::ID getLibcallImplCallingConv(RTLIB::LibcallImpl Call) const
Get the CallingConv that should be used for the specified libcall.
LLVM_ABI RTLIB::LibcallImpl getLibcallImpl(RTLIB::Libcall Call) const
Return the lowering's selection of implementation call for Call.
A set of register units used to track register liveness.
bool available(MCRegister Reg) const
Returns true if no part of physical register Reg is live.
void addReg(MCRegister Reg)
Adds register units covered by physical register Reg.
LLVM_ABI void stepBackward(const MachineInstr &MI)
Updates liveness when stepping backwards over the instruction MI.
LLVM_ABI void addLiveOuts(const MachineBasicBlock &MBB)
Adds registers living out of block MBB.
MachineInstrBundleIterator< const MachineInstr > const_iterator
int getNumber() const
MachineBasicBlocks are uniquely numbered at the function level, unless they're not in a MachineFuncti...
LLVM_ABI iterator getFirstNonPHI()
Returns a pointer to the first instruction in this block that is not a PHINode instruction.
succ_reverse_iterator succ_rbegin()
MachineInstrBundleIterator< MachineInstr > iterator
succ_reverse_iterator succ_rend()
The MachineFrameInfo class represents an abstract stack frame until prolog/epilog code is inserted.
LLVM_ABI int CreateStackObject(uint64_t Size, Align Alignment, bool isSpillSlot, const AllocaInst *Alloca=nullptr, uint8_t ID=0)
Create a new statically sized stack object, returning a nonnegative identifier to represent it.
LLVM_ABI int CreateSpillStackObject(uint64_t Size, Align Alignment)
Create a new statically sized stack object that represents a spill slot, returning a nonnegative iden...
LLVM_ABI int CreateVariableSizedObject(Align Alignment, const AllocaInst *Alloca)
Notify the MachineFrameInfo object that a variable sized object has been created.
MachineFunctionPass - This class adapts the FunctionPass interface to allow convenient creation of pa...
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - Subclasses that override getAnalysisUsage must call this.
StringRef getName() const
getName - Return the name of the corresponding LLVM function.
MachineFrameInfo & getFrameInfo()
getFrameInfo - Return the frame info object for the current function.
MachineRegisterInfo & getRegInfo()
getRegInfo - Return information about the registers currently in use.
Function & getFunction()
Return the LLVM function that this machine code represents.
unsigned getNumBlockIDs() const
getNumBlockIDs - Return the number of MBB ID's allocated.
Ty * getInfo()
getInfo - Keep track of various per-function pieces of information for backends that would like to do...
const MachineInstrBuilder & addExternalSymbol(const char *FnName, unsigned TargetFlags=0) const
const MachineInstrBuilder & addReg(Register RegNo, RegState Flags={}, unsigned SubReg=0) const
Add a new virtual register operand.
const MachineInstrBuilder & addImm(int64_t Val) const
Add a new immediate operand.
const MachineInstrBuilder & addFrameIndex(int Idx) const
const MachineInstrBuilder & addRegMask(const uint32_t *Mask) const
const MachineInstrBuilder & addDef(Register RegNo, RegState Flags={}, unsigned SubReg=0) const
Add a virtual register definition operand.
Representation of each machine instruction.
MachineOperand class - Representation of each machine instruction operand.
const GlobalValue * getGlobal() const
bool isReg() const
isReg - Tests if this is a MO_Register operand.
bool isSymbol() const
isSymbol - Tests if this is a MO_ExternalSymbol operand.
bool isGlobal() const
isGlobal - Tests if this is a MO_GlobalAddress operand.
const char * getSymbolName() const
Register getReg() const
getReg - Returns the register number.
Diagnostic information for optimization analysis remarks.
LLVM_ABI void emit(DiagnosticInfoOptimizationBase &OptDiag)
Emit an optimization remark.
bool allowExtraAnalysis(StringRef PassName) const
Whether we allow for extra compile-time budget to perform more analysis to be more informative.
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
LLVM_ABI Register createVirtualRegister(const TargetRegisterClass *RegClass, StringRef Name="")
createVirtualRegister - Create and return a new virtual register in the function with the specified r...
Wrapper class representing virtual and physical registers.
Definition Register.h:20
constexpr bool isPhysical() const
Return true if the specified register number is in the physical register namespace.
Definition Register.h:83
SMEAttrs is a utility class to parse the SME ACLE attributes on functions.
bool hasAgnosticZAInterface() const
bool hasPrivateZAInterface() const
bool hasSharedZAInterface() const
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
typename SuperClass::const_iterator const_iterator
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Represent a constant reference to a string, i.e.
Definition StringRef.h:56
constexpr bool empty() const
Check if the string is empty.
Definition StringRef.h:141
constexpr const char * data() const
Get a pointer to the start of the string (which may not be null terminated).
Definition StringRef.h:138
TargetRegisterInfo base class - We assume that the target defines a static array of TargetRegisterDes...
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
Definition Twine.h:82
op_range operands()
Definition User.h:267
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
Definition Value.cpp:318
const ParentTy * getParent() const
Definition ilist_node.h:34
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
static unsigned getArithExtendImm(AArch64_AM::ShiftExtendType ET, unsigned Imm)
getArithExtendImm - Encode the extend type and shift amount for an arithmetic instruction: imm: 3-bit...
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
@ AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0
Preserve X0-X13, X19-X29, SP, Z0-Z31, P0-P15.
@ AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1
Preserve X1-X15, X19-X29, SP, Z0-Z31, P0-P15.
This is an optimization pass for GlobalISel generic memory operations.
MachineInstrBuilder BuildMI(MachineFunction &MF, const MIMetadata &MIMD, const MCInstrDesc &MCID)
Builder interface. Specify how to create the initial instruction itself.
@ Implicit
Not emitted register (e.g. carry, or temporary result).
@ Define
Register definition.
FunctionPass * createMachineSMEABIPass(CodeGenOptLevel)
LLVM_ABI char & MachineDominatorsID
MachineDominators - This pass is a machine dominators analysis pass.
LLVM_ABI void reportFatalInternalError(Error Err)
Report a fatal error that indicates a bug in LLVM.
Definition Error.cpp:173
LLVM_ABI char & MachineLoopInfoID
MachineLoopInfo - This pass is a loop analysis pass.
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1745
auto reverse(ContainerTy &&C)
Definition STLExtras.h:407
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:209
CodeGenOptLevel
Code generation optimization level.
Definition CodeGen.h:82
@ Default
-O2, -Os, -Oz
Definition CodeGen.h:85
uint16_t MCPhysReg
An unsigned integer type large enough to represent all physical registers, but not necessarily virtua...
Definition MCRegister.h:21
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition Alignment.h:39
static StringRef getLibcallImplName(RTLIB::LibcallImpl CallImpl)
Get the libcall routine name for the specified libcall implementation.