LLVM 23.0.0git
MachineSMEABIPass.cpp
Go to the documentation of this file.
1//===- MachineSMEABIPass.cpp ----------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass implements the SME ABI requirements for ZA state. This includes
10// implementing the lazy (and agnostic) ZA state save schemes around calls.
11//
12//===----------------------------------------------------------------------===//
13//
14// This pass works by collecting instructions that require ZA to be in a
15// specific state (e.g., "ACTIVE" or "SAVED") and inserting the necessary state
16// transitions to ensure ZA is in the required state before instructions. State
17// transitions represent actions such as setting up or restoring a lazy save.
18// Certain points within a function may also have predefined states independent
19// of any instructions, for example, a "shared_za" function is always entered
20// and exited in the "ACTIVE" state.
21//
22// To handle ZA state across control flow, we make use of edge bundling. This
23// assigns each block an "incoming" and "outgoing" edge bundle (representing
24// incoming and outgoing edges). Initially, these are unique to each block;
25// then, in the process of forming bundles, the outgoing bundle of a block is
26// joined with the incoming bundle of all successors. The result is that each
27// bundle can be assigned a single ZA state, which ensures the state required by
28// all a blocks' successors is the same, and that each basic block will always
29// be entered with the same ZA state. This eliminates the need for splitting
30// edges to insert state transitions or "phi" nodes for ZA states.
31//
32// See below for a simple example of edge bundling.
33//
34// The following shows a conditionally executed basic block (BB1):
35//
36// if (cond)
37// BB1
38// BB2
39//
40// Initial Bundles Joined Bundles
41//
42// ┌──0──┐ ┌──0──┐
43// │ BB0 │ │ BB0 │
44// └──1──┘ └──1──┘
45// ├───────┐ ├───────┐
46// ▼ │ ▼ │
47// ┌──2──┐ │ ─────► ┌──1──┐ │
48// │ BB1 │ ▼ │ BB1 │ ▼
49// └──3──┘ ┌──4──┐ └──1──┘ ┌──1──┐
50// └───►4 BB2 │ └───►1 BB2 │
51// └──5──┘ └──2──┘
52//
53// On the left are the initial per-block bundles, and on the right are the
54// joined bundles (which are the result of the EdgeBundles analysis).
55
56#include "AArch64InstrInfo.h"
58#include "AArch64Subtarget.h"
69
70using namespace llvm;
71
72#define DEBUG_TYPE "aarch64-machine-sme-abi"
73
74namespace {
75
76// Note: For agnostic ZA, we assume the function is always entered/exited in the
77// "ACTIVE" state -- this _may_ not be the case (since OFF is also a
78// possibility, but for the purpose of placing ZA saves/restores, that does not
79// matter).
80enum ZAState : uint8_t {
81 // Any/unknown state (not valid)
82 ANY = 0,
83
84 // ZA is in use and active (i.e. within the accumulator)
85 ACTIVE,
86
87 // ZA is active, but ZT0 has been saved.
88 // This handles the edge case of sharedZA && !sharesZT0.
89 ACTIVE_ZT0_SAVED,
90
91 // A ZA save has been set up or committed (i.e. ZA is dormant or off)
92 // If the function uses ZT0 it must also be saved.
93 LOCAL_SAVED,
94
95 // ZA has been committed to the lazy save buffer of the current function.
96 // If the function uses ZT0 it must also be saved.
97 // ZA is off.
98 LOCAL_COMMITTED,
99
100 // The ZA/ZT0 state on entry to the function.
101 ENTRY,
102
103 // ZA is off.
104 OFF,
105
106 // The number of ZA states (not a valid state)
107 NUM_ZA_STATE
108};
109
110/// A bitmask enum to record live physical registers that the "emit*" routines
111/// may need to preserve. Note: This only tracks registers we may clobber.
112enum LiveRegs : uint8_t {
113 None = 0,
114 NZCV = 1 << 0,
115 W0 = 1 << 1,
116 W0_HI = 1 << 2,
117 X0 = W0 | W0_HI,
118 LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ W0_HI)
119};
120
121/// Holds the virtual registers live physical registers have been saved to.
122struct PhysRegSave {
123 LiveRegs PhysLiveRegs;
124 Register StatusFlags = AArch64::NoRegister;
125 Register X0Save = AArch64::NoRegister;
126};
127
128/// Contains the needed ZA state (and live registers) at an instruction. That is
129/// the state ZA must be in _before_ "InsertPt".
130struct InstInfo {
131 ZAState NeededState{ZAState::ANY};
133 LiveRegs PhysLiveRegs = LiveRegs::None;
134};
135
136/// Contains the needed ZA state for each instruction in a block. Instructions
137/// that do not require a ZA state are not recorded.
138struct BlockInfo {
140 ZAState FixedEntryState{ZAState::ANY};
141 ZAState DesiredIncomingState{ZAState::ANY};
142 ZAState DesiredOutgoingState{ZAState::ANY};
143 LiveRegs PhysLiveRegsAtEntry = LiveRegs::None;
144 LiveRegs PhysLiveRegsAtExit = LiveRegs::None;
145};
146
147/// Contains the needed ZA state information for all blocks within a function.
148struct FunctionInfo {
150 std::optional<MachineBasicBlock::iterator> AfterSMEProloguePt;
151 LiveRegs PhysLiveRegsAfterSMEPrologue = LiveRegs::None;
152};
153
154/// State/helpers that is only needed when emitting code to handle
155/// saving/restoring ZA.
156class EmitContext {
157public:
158 EmitContext() = default;
159
160 /// Get or create a TPIDR2 block in \p MF.
161 int getTPIDR2Block(MachineFunction &MF) {
162 if (TPIDR2BlockFI)
163 return *TPIDR2BlockFI;
164 MachineFrameInfo &MFI = MF.getFrameInfo();
165 TPIDR2BlockFI = MFI.CreateStackObject(16, Align(16), false);
166 return *TPIDR2BlockFI;
167 }
168
169 /// Get or create agnostic ZA buffer pointer in \p MF.
170 Register getAgnosticZABufferPtr(MachineFunction &MF) {
171 if (AgnosticZABufferPtr != AArch64::NoRegister)
172 return AgnosticZABufferPtr;
173 Register BufferPtr =
174 MF.getInfo<AArch64FunctionInfo>()->getEarlyAllocSMESaveBuffer();
175 AgnosticZABufferPtr =
176 BufferPtr != AArch64::NoRegister
177 ? BufferPtr
178 : MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
179 return AgnosticZABufferPtr;
180 }
181
182 int getZT0SaveSlot(MachineFunction &MF) {
183 if (ZT0SaveFI)
184 return *ZT0SaveFI;
185 MachineFrameInfo &MFI = MF.getFrameInfo();
186 ZT0SaveFI = MFI.CreateSpillStackObject(64, Align(16));
187 return *ZT0SaveFI;
188 }
189
190 /// Returns true if the function must allocate a ZA save buffer on entry. This
191 /// will be the case if, at any point in the function, a ZA save was emitted.
192 bool needsSaveBuffer() const {
193 assert(!(TPIDR2BlockFI && AgnosticZABufferPtr) &&
194 "Cannot have both a TPIDR2 block and agnostic ZA buffer");
195 return TPIDR2BlockFI || AgnosticZABufferPtr != AArch64::NoRegister;
196 }
197
198private:
199 std::optional<int> ZT0SaveFI;
200 std::optional<int> TPIDR2BlockFI;
201 Register AgnosticZABufferPtr = AArch64::NoRegister;
202};
203
204StringRef getZAStateString(ZAState State) {
205#define MAKE_CASE(V) \
206 case V: \
207 return #V;
208 switch (State) {
209 MAKE_CASE(ZAState::ANY)
210 MAKE_CASE(ZAState::ACTIVE)
211 MAKE_CASE(ZAState::ACTIVE_ZT0_SAVED)
212 MAKE_CASE(ZAState::LOCAL_SAVED)
213 MAKE_CASE(ZAState::LOCAL_COMMITTED)
214 MAKE_CASE(ZAState::ENTRY)
215 MAKE_CASE(ZAState::OFF)
216 default:
217 llvm_unreachable("Unexpected ZAState");
218 }
219#undef MAKE_CASE
220}
221
222static bool isZAorZTRegOp(const TargetRegisterInfo &TRI,
223 const MachineOperand &MO) {
224 if (!MO.isReg() || !MO.getReg().isPhysical())
225 return false;
226 return any_of(TRI.subregs_inclusive(MO.getReg()), [](const MCPhysReg &SR) {
227 return AArch64::MPR128RegClass.contains(SR) ||
228 AArch64::ZTRRegClass.contains(SR);
229 });
230}
231
232/// Returns the required ZA state needed before \p MI and an iterator pointing
233/// to where any code required to change the ZA state should be inserted.
234static std::pair<ZAState, MachineBasicBlock::iterator>
235getInstNeededZAState(const TargetRegisterInfo &TRI, MachineInstr &MI,
236 SMEAttrs SMEFnAttrs) {
238
239 // Note: InOutZAUsePseudo, RequiresZASavePseudo, and RequiresZT0SavePseudo are
240 // intended to mark the position immediately before a call. Due to
241 // SelectionDAG constraints, these markers occur after the ADJCALLSTACKDOWN,
242 // so we use std::prev(InsertPt) to get the position before the call.
243
244 if (MI.getOpcode() == AArch64::InOutZAUsePseudo)
245 return {ZAState::ACTIVE, std::prev(InsertPt)};
246
247 // Note: If we need to save both ZA and ZT0 we use RequiresZASavePseudo.
248 if (MI.getOpcode() == AArch64::RequiresZASavePseudo)
249 return {ZAState::LOCAL_SAVED, std::prev(InsertPt)};
250
251 // If we only need to save ZT0 there's two cases to consider:
252 // 1. The function has ZA state (that we don't need to save).
253 // - In this case we switch to the "ACTIVE_ZT0_SAVED" state.
254 // This only saves ZT0.
255 // 2. The function does not have ZA state
256 // - In this case we switch to "LOCAL_COMMITTED" state.
257 // This saves ZT0 and turns ZA off.
258 if (MI.getOpcode() == AArch64::RequiresZT0SavePseudo) {
259 return {SMEFnAttrs.hasZAState() ? ZAState::ACTIVE_ZT0_SAVED
260 : ZAState::LOCAL_COMMITTED,
261 std::prev(InsertPt)};
262 }
263
264 if (MI.isReturn()) {
265 bool ZAOffAtReturn = SMEFnAttrs.hasPrivateZAInterface();
266 return {ZAOffAtReturn ? ZAState::OFF : ZAState::ACTIVE, InsertPt};
267 }
268
269 for (auto &MO : MI.operands()) {
270 if (isZAorZTRegOp(TRI, MO))
271 return {ZAState::ACTIVE, InsertPt};
272 }
273
274 return {ZAState::ANY, InsertPt};
275}
276
277struct MachineSMEABI : public MachineFunctionPass {
278 inline static char ID = 0;
279
280 MachineSMEABI(CodeGenOptLevel OptLevel = CodeGenOptLevel::Default)
281 : MachineFunctionPass(ID), OptLevel(OptLevel) {}
282
283 bool runOnMachineFunction(MachineFunction &MF) override;
284
285 StringRef getPassName() const override { return "Machine SME ABI pass"; }
286
287 void getAnalysisUsage(AnalysisUsage &AU) const override {
288 AU.setPreservesCFG();
295 }
296
297 /// Collects the needed ZA state (and live registers) before each instruction
298 /// within the machine function.
299 FunctionInfo collectNeededZAStates(SMEAttrs SMEFnAttrs);
300
301 /// Assigns each edge bundle a ZA state based on the desired states of
302 /// incoming and outgoing blocks in the bundle.
303 SmallVector<ZAState> assignBundleZAStates(const EdgeBundles &Bundles,
304 const FunctionInfo &FnInfo);
305
306 /// Inserts code to handle changes between ZA states within the function.
307 /// E.g., ACTIVE -> LOCAL_SAVED will insert code required to save ZA.
308 void insertStateChanges(EmitContext &, const FunctionInfo &FnInfo,
309 const EdgeBundles &Bundles,
310 ArrayRef<ZAState> BundleStates);
311
312 void addSMELibCall(MachineInstrBuilder &MIB, RTLIB::Libcall LC,
313 CallingConv::ID ExpectedCC);
314
315 void emitZT0SaveRestore(EmitContext &, MachineBasicBlock &MBB,
316 MachineBasicBlock::iterator MBBI, bool IsSave);
317
318 // Emission routines for private and shared ZA functions (using lazy saves).
319 void emitSMEPrologue(MachineBasicBlock &MBB,
321 void emitRestoreLazySave(EmitContext &, MachineBasicBlock &MBB,
323 LiveRegs PhysLiveRegs);
324 void emitSetupLazySave(EmitContext &, MachineBasicBlock &MBB,
326 void emitAllocateLazySaveBuffer(EmitContext &, MachineBasicBlock &MBB,
329 bool ClearTPIDR2, bool On);
330
331 // Emission routines for agnostic ZA functions.
332 void emitSetupFullZASave(MachineBasicBlock &MBB,
334 LiveRegs PhysLiveRegs);
335 // Emit a "full" ZA save or restore. It is "full" in the sense that this
336 // function will emit a call to __arm_sme_save or __arm_sme_restore, which
337 // handles saving and restoring both ZA and ZT0.
338 void emitFullZASaveRestore(EmitContext &, MachineBasicBlock &MBB,
340 LiveRegs PhysLiveRegs, bool IsSave);
341 void emitAllocateFullZASaveBuffer(EmitContext &, MachineBasicBlock &MBB,
343 LiveRegs PhysLiveRegs);
344
345 /// Attempts to find an insertion point before \p Inst where the status flags
346 /// are not live. If \p Inst is `Block.Insts.end()` a point before the end of
347 /// the block is found.
348 std::pair<MachineBasicBlock::iterator, LiveRegs>
349 findStateChangeInsertionPoint(MachineBasicBlock &MBB, const BlockInfo &Block,
351 void emitStateChange(EmitContext &, MachineBasicBlock &MBB,
352 MachineBasicBlock::iterator MBBI, ZAState From,
353 ZAState To, LiveRegs PhysLiveRegs);
354
355 // Helpers for switching between lazy/full ZA save/restore routines.
356 void emitZASave(EmitContext &Context, MachineBasicBlock &MBB,
358 if (AFI->getSMEFnAttrs().hasAgnosticZAInterface())
359 return emitFullZASaveRestore(Context, MBB, MBBI, PhysLiveRegs,
360 /*IsSave=*/true);
361 return emitSetupLazySave(Context, MBB, MBBI);
362 }
363 void emitZARestore(EmitContext &Context, MachineBasicBlock &MBB,
365 if (AFI->getSMEFnAttrs().hasAgnosticZAInterface())
366 return emitFullZASaveRestore(Context, MBB, MBBI, PhysLiveRegs,
367 /*IsSave=*/false);
368 return emitRestoreLazySave(Context, MBB, MBBI, PhysLiveRegs);
369 }
370 void emitAllocateZASaveBuffer(EmitContext &Context, MachineBasicBlock &MBB,
372 LiveRegs PhysLiveRegs) {
373 if (AFI->getSMEFnAttrs().hasAgnosticZAInterface())
374 return emitAllocateFullZASaveBuffer(Context, MBB, MBBI, PhysLiveRegs);
375 return emitAllocateLazySaveBuffer(Context, MBB, MBBI);
376 }
377
378 /// Collects the reachable calls from \p MBBI marked with \p Marker. This is
379 /// intended to be used to emit lazy save remarks. Note: This stops at the
380 /// first marked call along any path.
381 void collectReachableMarkedCalls(const MachineBasicBlock &MBB,
384 unsigned Marker) const;
385
386 void emitCallSaveRemarks(const MachineBasicBlock &MBB,
388 unsigned Marker, StringRef RemarkName,
389 StringRef SaveName) const;
390
391 void emitError(const Twine &Message) {
392 LLVMContext &Context = MF->getFunction().getContext();
393 Context.emitError(MF->getName() + ": " + Message);
394 }
395
396 /// Save live physical registers to virtual registers.
397 PhysRegSave createPhysRegSave(LiveRegs PhysLiveRegs, MachineBasicBlock &MBB,
399 /// Restore physical registers from a save of their previous values.
400 void restorePhyRegSave(const PhysRegSave &RegSave, MachineBasicBlock &MBB,
402
403private:
405
406 MachineFunction *MF = nullptr;
407 const AArch64Subtarget *Subtarget = nullptr;
408 const AArch64RegisterInfo *TRI = nullptr;
409 const AArch64FunctionInfo *AFI = nullptr;
410 const AArch64InstrInfo *TII = nullptr;
411 const LibcallLoweringInfo *LLI = nullptr;
412
414 MachineRegisterInfo *MRI = nullptr;
415 MachineLoopInfo *MLI = nullptr;
416};
417
418static LiveRegs getPhysLiveRegs(LiveRegUnits const &LiveUnits) {
419 LiveRegs PhysLiveRegs = LiveRegs::None;
420 if (!LiveUnits.available(AArch64::NZCV))
421 PhysLiveRegs |= LiveRegs::NZCV;
422 // We have to track W0 and X0 separately as otherwise things can get
423 // confused if we attempt to preserve X0 but only W0 was defined.
424 if (!LiveUnits.available(AArch64::W0))
425 PhysLiveRegs |= LiveRegs::W0;
426 if (!LiveUnits.available(AArch64::W0_HI))
427 PhysLiveRegs |= LiveRegs::W0_HI;
428 return PhysLiveRegs;
429}
430
431static void setPhysLiveRegs(LiveRegUnits &LiveUnits, LiveRegs PhysLiveRegs) {
432 if (PhysLiveRegs & LiveRegs::NZCV)
433 LiveUnits.addReg(AArch64::NZCV);
434 if (PhysLiveRegs & LiveRegs::W0)
435 LiveUnits.addReg(AArch64::W0);
436 if (PhysLiveRegs & LiveRegs::W0_HI)
437 LiveUnits.addReg(AArch64::W0_HI);
438}
439
440[[maybe_unused]] bool isCallStartOpcode(unsigned Opc) {
441 switch (Opc) {
442 case AArch64::TLSDESC_CALLSEQ:
443 case AArch64::TLSDESC_AUTH_CALLSEQ:
444 case AArch64::ADJCALLSTACKDOWN:
445 return true;
446 default:
447 return false;
448 }
449}
450
451FunctionInfo MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
452 assert((SMEFnAttrs.hasAgnosticZAInterface() || SMEFnAttrs.hasZT0State() ||
453 SMEFnAttrs.hasZAState()) &&
454 "Expected function to have ZA/ZT0 state!");
455
457 LiveRegs PhysLiveRegsAfterSMEPrologue = LiveRegs::None;
458 std::optional<MachineBasicBlock::iterator> AfterSMEProloguePt;
459
460 for (MachineBasicBlock &MBB : *MF) {
461 BlockInfo &Block = Blocks[MBB.getNumber()];
462
463 if (MBB.isEntryBlock()) {
464 // Entry block:
465 Block.FixedEntryState = ZAState::ENTRY;
466 } else if (MBB.isEHPad()) {
467 // EH entry block:
468 Block.FixedEntryState = ZAState::LOCAL_COMMITTED;
469 }
470
471 LiveRegUnits LiveUnits(*TRI);
472 LiveUnits.addLiveOuts(MBB);
473
474 Block.PhysLiveRegsAtExit = getPhysLiveRegs(LiveUnits);
475 auto FirstTerminatorInsertPt = MBB.getFirstTerminator();
476 auto FirstNonPhiInsertPt = MBB.getFirstNonPHI();
477 for (MachineInstr &MI : reverse(MBB)) {
478 if (MI.isDebugInstr())
479 continue;
480
482 LiveUnits.stepBackward(MI);
483 LiveRegs PhysLiveRegs = getPhysLiveRegs(LiveUnits);
484 // The SMEStateAllocPseudo marker is added to a function if the save
485 // buffer was allocated in SelectionDAG. It marks the end of the
486 // allocation -- which is a safe point for this pass to insert any TPIDR2
487 // block setup.
488 if (MI.getOpcode() == AArch64::SMEStateAllocPseudo) {
489 AfterSMEProloguePt = MBBI;
490 PhysLiveRegsAfterSMEPrologue = PhysLiveRegs;
491 }
492 // Note: We treat Agnostic ZA as inout_za with an alternate save/restore.
493 auto [NeededState, InsertPt] = getInstNeededZAState(*TRI, MI, SMEFnAttrs);
494 assert((InsertPt == MBBI || isCallStartOpcode(InsertPt->getOpcode())) &&
495 "Unexpected state change insertion point!");
496 if (MBBI == FirstTerminatorInsertPt)
497 Block.PhysLiveRegsAtExit = PhysLiveRegs;
498 if (MBBI == FirstNonPhiInsertPt)
499 Block.PhysLiveRegsAtEntry = PhysLiveRegs;
500 if (NeededState != ZAState::ANY)
501 Block.Insts.push_back({NeededState, InsertPt, PhysLiveRegs});
502 }
503
504 // Reverse vector (as we had to iterate backwards for liveness).
505 std::reverse(Block.Insts.begin(), Block.Insts.end());
506
507 // Record the desired states on entry/exit of this block. These are the
508 // states that would not incur a state transition.
509 if (!Block.Insts.empty()) {
510 Block.DesiredIncomingState = Block.Insts.front().NeededState;
511 Block.DesiredOutgoingState = Block.Insts.back().NeededState;
512 }
513 }
514
515 return FunctionInfo{std::move(Blocks), AfterSMEProloguePt,
516 PhysLiveRegsAfterSMEPrologue};
517}
518
519/// Assigns each edge bundle a ZA state based on the desired states of incoming
520/// and outgoing blocks in the bundle.
522MachineSMEABI::assignBundleZAStates(const EdgeBundles &Bundles,
523 const FunctionInfo &FnInfo) {
524 SmallVector<ZAState> BundleStates(Bundles.getNumBundles());
525 for (unsigned I = 0, E = Bundles.getNumBundles(); I != E; ++I) {
526 std::optional<ZAState> BundleState;
527 for (unsigned BlockID : Bundles.getBlocks(I)) {
528 const BlockInfo &Block = FnInfo.Blocks[BlockID];
529 // Check if the block is an incoming block in the bundle. Note: We skip
530 // Block.FixedEntryState != ANY to ignore EH pads (which are only
531 // reachable via exceptions).
532 if (Block.FixedEntryState != ZAState::ANY ||
533 Bundles.getBundle(BlockID, /*Out=*/false) != I)
534 continue;
535
536 // Pick a state that matches all incoming blocks. Fall back to "ACTIVE" if
537 // any incoming state doesn't match. This will hoist the state from
538 // incoming blocks to outgoing blocks.
539 if (!BundleState)
540 BundleState = Block.DesiredIncomingState;
541 else if (BundleState != Block.DesiredIncomingState)
542 BundleState = ZAState::ACTIVE;
543 }
544
545 if (!BundleState || BundleState == ZAState::ANY)
546 BundleState = ZAState::ACTIVE;
547
548 BundleStates[I] = *BundleState;
549 }
550
551 return BundleStates;
552}
553
554std::pair<MachineBasicBlock::iterator, LiveRegs>
555MachineSMEABI::findStateChangeInsertionPoint(
556 MachineBasicBlock &MBB, const BlockInfo &Block,
558 LiveRegs PhysLiveRegs;
560 if (Inst != Block.Insts.end()) {
561 InsertPt = Inst->InsertPt;
562 PhysLiveRegs = Inst->PhysLiveRegs;
563 } else {
564 InsertPt = MBB.getFirstTerminator();
565 PhysLiveRegs = Block.PhysLiveRegsAtExit;
566 }
567
568 if (PhysLiveRegs == LiveRegs::None)
569 return {InsertPt, PhysLiveRegs}; // Nothing to do (no live regs).
570
571 // Find the previous state change. We can not move before this point.
572 MachineBasicBlock::iterator PrevStateChangeI;
573 if (Inst == Block.Insts.begin()) {
574 PrevStateChangeI = MBB.begin();
575 } else {
576 // Note: `std::prev(Inst)` is the previous InstInfo. We only create an
577 // InstInfo object for instructions that require a specific ZA state, so the
578 // InstInfo is the site of the previous state change in the block (which can
579 // be several MIs earlier).
580 PrevStateChangeI = std::prev(Inst)->InsertPt;
581 }
582
583 // Note: LiveUnits will only accurately track X0 and NZCV.
584 LiveRegUnits LiveUnits(*TRI);
585 setPhysLiveRegs(LiveUnits, PhysLiveRegs);
586 auto BestCandidate = std::make_pair(InsertPt, PhysLiveRegs);
587 for (MachineBasicBlock::iterator I = InsertPt; I != PrevStateChangeI; --I) {
588 if (I->isDebugInstr())
589 continue;
590
591 // Don't move before/into a call (which may have a state change before it).
592 if (I->getOpcode() == TII->getCallFrameDestroyOpcode() || I->isCall())
593 break;
594 LiveUnits.stepBackward(*I);
595 LiveRegs CurrentPhysLiveRegs = getPhysLiveRegs(LiveUnits);
596 // Find places where NZCV is available, but keep looking for locations where
597 // both NZCV and X0 are available, which can avoid some copies.
598 if (!(CurrentPhysLiveRegs & LiveRegs::NZCV))
599 BestCandidate = {I, CurrentPhysLiveRegs};
600 if (CurrentPhysLiveRegs == LiveRegs::None)
601 break;
602 }
603 return BestCandidate;
604}
605
606void MachineSMEABI::insertStateChanges(EmitContext &Context,
607 const FunctionInfo &FnInfo,
608 const EdgeBundles &Bundles,
609 ArrayRef<ZAState> BundleStates) {
610 for (MachineBasicBlock &MBB : *MF) {
611 const BlockInfo &Block = FnInfo.Blocks[MBB.getNumber()];
612 ZAState InState = BundleStates[Bundles.getBundle(MBB.getNumber(),
613 /*Out=*/false)];
614
615 ZAState CurrentState = Block.FixedEntryState;
616 if (CurrentState == ZAState::ANY)
617 CurrentState = InState;
618
619 for (auto &Inst : Block.Insts) {
620 if (CurrentState != Inst.NeededState) {
621 auto [InsertPt, PhysLiveRegs] =
622 findStateChangeInsertionPoint(MBB, Block, &Inst);
623 emitStateChange(Context, MBB, InsertPt, CurrentState, Inst.NeededState,
624 PhysLiveRegs);
625 CurrentState = Inst.NeededState;
626 }
627 }
628
629 if (MBB.succ_empty())
630 continue;
631
632 ZAState OutState =
633 BundleStates[Bundles.getBundle(MBB.getNumber(), /*Out=*/true)];
634 if (CurrentState != OutState) {
635 auto [InsertPt, PhysLiveRegs] =
636 findStateChangeInsertionPoint(MBB, Block, Block.Insts.end());
637 emitStateChange(Context, MBB, InsertPt, CurrentState, OutState,
638 PhysLiveRegs);
639 }
640 }
641}
642
645 if (MBB.empty())
646 return DebugLoc();
647 return MBBI != MBB.end() ? MBBI->getDebugLoc() : MBB.back().getDebugLoc();
648}
649
650/// Finds the first call (as determined by MachineInstr::isCall()) starting from
651/// \p MBBI in \p MBB marked with \p Marker (which is a marker opcode such as
652/// RequiresZASavePseudo). If a marked call is found, it is pushed to \p Calls
653/// and the function returns true.
654static bool findMarkedCall(const MachineBasicBlock &MBB,
657 unsigned Marker, unsigned CallDestroyOpcode) {
658 auto IsMarker = [&](auto &MI) { return MI.getOpcode() == Marker; };
659 auto MarkerInst = std::find_if(MBBI, MBB.end(), IsMarker);
660 if (MarkerInst == MBB.end())
661 return false;
663 while (++I != MBB.end()) {
664 if (I->isCall() || I->getOpcode() == CallDestroyOpcode)
665 break;
666 }
667 if (I != MBB.end() && I->isCall())
668 Calls.push_back(&*I);
669 // Note: This function always returns true if a "Marker" was found.
670 return true;
671}
672
673void MachineSMEABI::collectReachableMarkedCalls(
674 const MachineBasicBlock &StartMBB,
676 SmallVectorImpl<const MachineInstr *> &Calls, unsigned Marker) const {
677 assert(Marker == AArch64::InOutZAUsePseudo ||
678 Marker == AArch64::RequiresZASavePseudo ||
679 Marker == AArch64::RequiresZT0SavePseudo);
680 unsigned CallDestroyOpcode = TII->getCallFrameDestroyOpcode();
681 if (findMarkedCall(StartMBB, StartInst, Calls, Marker, CallDestroyOpcode))
682 return;
683
686 StartMBB.succ_rend());
687 while (!Worklist.empty()) {
688 const MachineBasicBlock *MBB = Worklist.pop_back_val();
689 auto [_, Inserted] = Visited.insert(MBB);
690 if (!Inserted)
691 continue;
692
693 if (!findMarkedCall(*MBB, MBB->begin(), Calls, Marker, CallDestroyOpcode))
694 Worklist.append(MBB->succ_rbegin(), MBB->succ_rend());
695 }
696}
697
698static StringRef getCalleeName(const MachineInstr &CallInst) {
699 assert(CallInst.isCall() && "expected a call");
700 for (const MachineOperand &MO : CallInst.operands()) {
701 if (MO.isSymbol())
702 return MO.getSymbolName();
703 if (MO.isGlobal())
704 return MO.getGlobal()->getName();
705 }
706 return {};
707}
708
709void MachineSMEABI::emitCallSaveRemarks(const MachineBasicBlock &MBB,
711 DebugLoc DL, unsigned Marker,
712 StringRef RemarkName,
713 StringRef SaveName) const {
714 auto SaveRemark = [&](DebugLoc DL, const MachineBasicBlock &MBB) {
715 return MachineOptimizationRemarkAnalysis("sme", RemarkName, DL, &MBB);
716 };
717 StringRef StateName = Marker == AArch64::RequiresZT0SavePseudo ? "ZT0" : "ZA";
718 ORE->emit([&] {
719 return SaveRemark(DL, MBB) << SaveName << " of " << StateName
720 << " emitted in '" << MF->getName() << "'";
721 });
722 if (!ORE->allowExtraAnalysis("sme"))
723 return;
724 SmallVector<const MachineInstr *> CallsRequiringSaves;
725 collectReachableMarkedCalls(MBB, MBBI, CallsRequiringSaves, Marker);
726 for (const MachineInstr *CallInst : CallsRequiringSaves) {
727 auto R = SaveRemark(CallInst->getDebugLoc(), *CallInst->getParent());
728 R << "call";
729 if (StringRef CalleeName = getCalleeName(*CallInst); !CalleeName.empty())
730 R << " to '" << CalleeName << "'";
731 R << " requires " << StateName << " save";
732 ORE->emit(R);
733 }
734}
735
736void MachineSMEABI::emitSetupLazySave(EmitContext &Context,
740
741 emitCallSaveRemarks(MBB, MBBI, DL, AArch64::RequiresZASavePseudo,
742 "SMELazySaveZA", "lazy save");
743
744 // Get pointer to TPIDR2 block.
745 Register TPIDR2 = MRI->createVirtualRegister(&AArch64::GPR64spRegClass);
746 Register TPIDR2Ptr = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
747 BuildMI(MBB, MBBI, DL, TII->get(AArch64::ADDXri), TPIDR2)
748 .addFrameIndex(Context.getTPIDR2Block(*MF))
749 .addImm(0)
750 .addImm(0);
751 BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), TPIDR2Ptr)
752 .addReg(TPIDR2);
753 // Set TPIDR2_EL0 to point to TPIDR2 block.
754 BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSR))
755 .addImm(AArch64SysReg::TPIDR2_EL0)
756 .addReg(TPIDR2Ptr);
757}
758
759PhysRegSave MachineSMEABI::createPhysRegSave(LiveRegs PhysLiveRegs,
762 DebugLoc DL) {
763 PhysRegSave RegSave{PhysLiveRegs};
764 if (PhysLiveRegs & LiveRegs::NZCV) {
765 RegSave.StatusFlags = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
766 BuildMI(MBB, MBBI, DL, TII->get(AArch64::MRS), RegSave.StatusFlags)
767 .addImm(AArch64SysReg::NZCV)
768 .addReg(AArch64::NZCV, RegState::Implicit);
769 }
770 // Note: Preserving X0 is "free" as this is before register allocation, so
771 // the register allocator is still able to optimize these copies.
772 if (PhysLiveRegs & LiveRegs::W0) {
773 RegSave.X0Save = MRI->createVirtualRegister(PhysLiveRegs & LiveRegs::W0_HI
774 ? &AArch64::GPR64RegClass
775 : &AArch64::GPR32RegClass);
776 BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), RegSave.X0Save)
777 .addReg(PhysLiveRegs & LiveRegs::W0_HI ? AArch64::X0 : AArch64::W0);
778 }
779 return RegSave;
780}
781
782void MachineSMEABI::restorePhyRegSave(const PhysRegSave &RegSave,
785 DebugLoc DL) {
786 if (RegSave.StatusFlags != AArch64::NoRegister)
787 BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSR))
788 .addImm(AArch64SysReg::NZCV)
789 .addReg(RegSave.StatusFlags)
790 .addReg(AArch64::NZCV, RegState::ImplicitDefine);
791
792 if (RegSave.X0Save != AArch64::NoRegister)
793 BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY),
794 RegSave.PhysLiveRegs & LiveRegs::W0_HI ? AArch64::X0 : AArch64::W0)
795 .addReg(RegSave.X0Save);
796}
797
798void MachineSMEABI::addSMELibCall(MachineInstrBuilder &MIB, RTLIB::Libcall LC,
799 CallingConv::ID ExpectedCC) {
800 RTLIB::LibcallImpl LCImpl = LLI->getLibcallImpl(LC);
801 if (LCImpl == RTLIB::Unsupported)
802 emitError("cannot lower SME ABI (SME routines unsupported)");
805 if (CC != ExpectedCC)
806 emitError("invalid calling convention for SME routine: '" + ImplName + "'");
807 // FIXME: This assumes the ImplName StringRef is null-terminated.
808 MIB.addExternalSymbol(ImplName.data());
809 MIB.addRegMask(TRI->getCallPreservedMask(*MF, CC));
810}
811
812void MachineSMEABI::emitRestoreLazySave(EmitContext &Context,
815 LiveRegs PhysLiveRegs) {
817 Register TPIDR2EL0 = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
818 Register TPIDR2 = AArch64::X0;
819
820 // TODO: Emit these within the restore MBB to prevent unnecessary saves.
821 PhysRegSave RegSave = createPhysRegSave(PhysLiveRegs, MBB, MBBI, DL);
822
823 // Enable ZA.
824 BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSRpstatesvcrImm1))
825 .addImm(AArch64SVCR::SVCRZA)
826 .addImm(1);
827 // Get current TPIDR2_EL0.
828 BuildMI(MBB, MBBI, DL, TII->get(AArch64::MRS), TPIDR2EL0)
829 .addImm(AArch64SysReg::TPIDR2_EL0);
830 // Get pointer to TPIDR2 block.
831 BuildMI(MBB, MBBI, DL, TII->get(AArch64::ADDXri), TPIDR2)
832 .addFrameIndex(Context.getTPIDR2Block(*MF))
833 .addImm(0)
834 .addImm(0);
835 // (Conditionally) restore ZA state.
836 auto RestoreZA = BuildMI(MBB, MBBI, DL, TII->get(AArch64::RestoreZAPseudo))
837 .addReg(TPIDR2EL0)
838 .addReg(TPIDR2);
839 addSMELibCall(
840 RestoreZA, RTLIB::SMEABI_TPIDR2_RESTORE,
842 // Zero TPIDR2_EL0.
843 BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSR))
844 .addImm(AArch64SysReg::TPIDR2_EL0)
845 .addReg(AArch64::XZR);
846
847 restorePhyRegSave(RegSave, MBB, MBBI, DL);
848}
849
850void MachineSMEABI::emitZAMode(MachineBasicBlock &MBB,
852 bool ClearTPIDR2, bool On) {
854
855 if (ClearTPIDR2)
856 BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSR))
857 .addImm(AArch64SysReg::TPIDR2_EL0)
858 .addReg(AArch64::XZR);
859
860 // Disable ZA.
861 BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSRpstatesvcrImm1))
862 .addImm(AArch64SVCR::SVCRZA)
863 .addImm(On ? 1 : 0);
864}
865
866void MachineSMEABI::emitAllocateLazySaveBuffer(
867 EmitContext &Context, MachineBasicBlock &MBB,
869 MachineFrameInfo &MFI = MF->getFrameInfo();
871 Register SP = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
872 Register SVL = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
873 Register Buffer = AFI->getEarlyAllocSMESaveBuffer();
874
875 // Calculate SVL.
876 BuildMI(MBB, MBBI, DL, TII->get(AArch64::RDSVLI_XI), SVL).addImm(1);
877
878 // 1. Allocate the lazy save buffer.
879 if (Buffer == AArch64::NoRegister) {
880 // TODO: On Windows, we allocate the lazy save buffer in SelectionDAG (so
881 // Buffer != AArch64::NoRegister). This is done to reuse the existing
882 // expansions (which can insert stack checks). This works, but it means we
883 // will always allocate the lazy save buffer (even if the function contains
884 // no lazy saves). If we want to handle Windows here, we'll need to
885 // implement something similar to LowerWindowsDYNAMIC_STACKALLOC.
886 assert(!Subtarget->isTargetWindows() &&
887 "Lazy ZA save is not yet supported on Windows");
888 Buffer = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
889 // Get original stack pointer.
890 BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), SP)
891 .addReg(AArch64::SP);
892 // Allocate a lazy-save buffer object of the size given, normally SVL * SVL
893 BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSUBXrrr), Buffer)
894 .addReg(SVL)
895 .addReg(SVL)
896 .addReg(SP);
897 BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), AArch64::SP)
898 .addReg(Buffer);
899 // We have just allocated a variable sized object, tell this to PEI.
900 MFI.CreateVariableSizedObject(Align(16), nullptr);
901 }
902
903 // 2. Setup the TPIDR2 block.
904 {
905 // Note: This case just needs to do `SVL << 48`. It is not implemented as we
906 // generally don't support big-endian SVE/SME.
907 if (!Subtarget->isLittleEndian())
909 "TPIDR2 block initialization is not supported on big-endian targets");
910
911 // Store buffer pointer and num_za_save_slices.
912 // Bytes 10-15 are implicitly zeroed.
913 BuildMI(MBB, MBBI, DL, TII->get(AArch64::STPXi))
914 .addReg(Buffer)
915 .addReg(SVL)
916 .addFrameIndex(Context.getTPIDR2Block(*MF))
917 .addImm(0);
918 }
919}
920
921static constexpr unsigned ZERO_ALL_ZA_MASK = 0b11111111;
922
923void MachineSMEABI::emitSMEPrologue(MachineBasicBlock &MBB,
926
927 bool ZeroZA = AFI->getSMEFnAttrs().isNewZA();
928 bool ZeroZT0 = AFI->getSMEFnAttrs().isNewZT0();
930 // Get current TPIDR2_EL0.
931 Register TPIDR2EL0 = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
932 BuildMI(MBB, MBBI, DL, TII->get(AArch64::MRS))
933 .addReg(TPIDR2EL0, RegState::Define)
934 .addImm(AArch64SysReg::TPIDR2_EL0);
935 // If TPIDR2_EL0 is non-zero, commit the lazy save.
936 // NOTE: Functions that only use ZT0 don't need to zero ZA.
937 auto CommitZASave =
938 BuildMI(MBB, MBBI, DL, TII->get(AArch64::CommitZASavePseudo))
939 .addReg(TPIDR2EL0)
940 .addImm(ZeroZA)
941 .addImm(ZeroZT0);
942 addSMELibCall(
943 CommitZASave, RTLIB::SMEABI_TPIDR2_SAVE,
945 if (ZeroZA)
946 CommitZASave.addDef(AArch64::ZAB0, RegState::ImplicitDefine);
947 if (ZeroZT0)
948 CommitZASave.addDef(AArch64::ZT0, RegState::ImplicitDefine);
949 // Enable ZA (as ZA could have previously been in the OFF state).
950 BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSRpstatesvcrImm1))
951 .addImm(AArch64SVCR::SVCRZA)
952 .addImm(1);
953 } else if (AFI->getSMEFnAttrs().hasSharedZAInterface()) {
954 if (ZeroZA)
955 BuildMI(MBB, MBBI, DL, TII->get(AArch64::ZERO_M))
957 .addDef(AArch64::ZAB0, RegState::ImplicitDefine);
958 if (ZeroZT0)
959 BuildMI(MBB, MBBI, DL, TII->get(AArch64::ZERO_T)).addDef(AArch64::ZT0);
960 }
961}
962
963void MachineSMEABI::emitFullZASaveRestore(EmitContext &Context,
966 LiveRegs PhysLiveRegs, bool IsSave) {
968
969 if (IsSave)
970 emitCallSaveRemarks(MBB, MBBI, DL, AArch64::RequiresZASavePseudo,
971 "SMEFullZASave", "full save");
972
973 PhysRegSave RegSave = createPhysRegSave(PhysLiveRegs, MBB, MBBI, DL);
974
975 // Copy the buffer pointer into X0.
976 Register BufferPtr = AArch64::X0;
977 BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), BufferPtr)
978 .addReg(Context.getAgnosticZABufferPtr(*MF));
979
980 // Call __arm_sme_save/__arm_sme_restore.
981 auto SaveRestoreZA = BuildMI(MBB, MBBI, DL, TII->get(AArch64::BL))
982 .addReg(BufferPtr, RegState::Implicit);
983 addSMELibCall(
984 SaveRestoreZA,
985 IsSave ? RTLIB::SMEABI_SME_SAVE : RTLIB::SMEABI_SME_RESTORE,
987
988 restorePhyRegSave(RegSave, MBB, MBBI, DL);
989}
990
991void MachineSMEABI::emitZT0SaveRestore(EmitContext &Context,
994 bool IsSave) {
996
997 // Note: This will report calls that _only_ need ZT0 saved. Call that save
998 // both ZA and ZT0 will be under the SMELazySaveZA remark. This prevents
999 // reporting the same calls twice.
1000 if (IsSave)
1001 emitCallSaveRemarks(MBB, MBBI, DL, AArch64::RequiresZT0SavePseudo,
1002 "SMEZT0Save", "spill");
1003
1004 Register ZT0Save = MRI->createVirtualRegister(&AArch64::GPR64spRegClass);
1005
1006 BuildMI(MBB, MBBI, DL, TII->get(AArch64::ADDXri), ZT0Save)
1007 .addFrameIndex(Context.getZT0SaveSlot(*MF))
1008 .addImm(0)
1009 .addImm(0);
1010
1011 if (IsSave) {
1012 BuildMI(MBB, MBBI, DL, TII->get(AArch64::STR_TX))
1013 .addReg(AArch64::ZT0)
1014 .addReg(ZT0Save);
1015 } else {
1016 BuildMI(MBB, MBBI, DL, TII->get(AArch64::LDR_TX), AArch64::ZT0)
1017 .addReg(ZT0Save);
1018 }
1019}
1020
1021void MachineSMEABI::emitAllocateFullZASaveBuffer(
1022 EmitContext &Context, MachineBasicBlock &MBB,
1024 // Buffer already allocated in SelectionDAG.
1025 if (AFI->getEarlyAllocSMESaveBuffer())
1026 return;
1027
1029 Register BufferPtr = Context.getAgnosticZABufferPtr(*MF);
1030 Register BufferSize = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
1031
1032 PhysRegSave RegSave = createPhysRegSave(PhysLiveRegs, MBB, MBBI, DL);
1033
1034 // Calculate the SME state size.
1035 {
1036 auto SMEStateSize = BuildMI(MBB, MBBI, DL, TII->get(AArch64::BL))
1037 .addReg(AArch64::X0, RegState::ImplicitDefine);
1038 addSMELibCall(
1039 SMEStateSize, RTLIB::SMEABI_SME_STATE_SIZE,
1041 BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), BufferSize)
1042 .addReg(AArch64::X0);
1043 }
1044
1045 // Allocate a buffer object of the size given __arm_sme_state_size.
1046 {
1047 MachineFrameInfo &MFI = MF->getFrameInfo();
1048 BuildMI(MBB, MBBI, DL, TII->get(AArch64::SUBXrx64), AArch64::SP)
1049 .addReg(AArch64::SP)
1050 .addReg(BufferSize)
1052 BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), BufferPtr)
1053 .addReg(AArch64::SP);
1054
1055 // We have just allocated a variable sized object, tell this to PEI.
1056 MFI.CreateVariableSizedObject(Align(16), nullptr);
1057 }
1058
1059 restorePhyRegSave(RegSave, MBB, MBBI, DL);
1060}
1061
1062struct FromState {
1063 ZAState From;
1064
1065 constexpr uint8_t to(ZAState To) const {
1066 static_assert(NUM_ZA_STATE < 16, "expected ZAState to fit in 4-bits");
1067 return uint8_t(From) << 4 | uint8_t(To);
1068 }
1069};
1070
1071constexpr FromState transitionFrom(ZAState From) { return FromState{From}; }
1072
1073void MachineSMEABI::emitStateChange(EmitContext &Context,
1076 ZAState From, ZAState To,
1077 LiveRegs PhysLiveRegs) {
1078 // ZA not used.
1079 if (From == ZAState::ANY || To == ZAState::ANY)
1080 return;
1081
1082 // If we're exiting from the ENTRY state that means that the function has not
1083 // used ZA, so in the case of private ZA/ZT0 functions we can omit any set up.
1084 if (From == ZAState::ENTRY && To == ZAState::OFF)
1085 return;
1086
1087 // TODO: Avoid setting up the save buffer if there's no transition to
1088 // LOCAL_SAVED.
1089 if (From == ZAState::ENTRY) {
1090 assert(&MBB == &MBB.getParent()->front() &&
1091 "ENTRY state only valid in entry block");
1092 emitSMEPrologue(MBB, MBB.getFirstNonPHI());
1093 if (To == ZAState::ACTIVE)
1094 return; // Nothing more to do (ZA is active after the prologue).
1095
1096 // Note: "emitNewZAPrologue" zeros ZA, so we may need to setup a lazy save
1097 // if "To" is "ZAState::LOCAL_SAVED". It may be possible to improve this
1098 // case by changing the placement of the zero instruction.
1099 From = ZAState::ACTIVE;
1100 }
1101
1102 SMEAttrs SMEFnAttrs = AFI->getSMEFnAttrs();
1103 bool IsAgnosticZA = SMEFnAttrs.hasAgnosticZAInterface();
1104 bool HasZT0State = SMEFnAttrs.hasZT0State();
1105 bool HasZAState = IsAgnosticZA || SMEFnAttrs.hasZAState();
1106
1107 switch (transitionFrom(From).to(To)) {
1108 // This section handles: ACTIVE <-> ACTIVE_ZT0_SAVED
1109 case transitionFrom(ZAState::ACTIVE).to(ZAState::ACTIVE_ZT0_SAVED):
1110 emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/true);
1111 break;
1112 case transitionFrom(ZAState::ACTIVE_ZT0_SAVED).to(ZAState::ACTIVE):
1113 emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/false);
1114 break;
1115
1116 // This section handles: ACTIVE[_ZT0_SAVED] -> LOCAL_SAVED
1117 case transitionFrom(ZAState::ACTIVE).to(ZAState::LOCAL_SAVED):
1118 case transitionFrom(ZAState::ACTIVE_ZT0_SAVED).to(ZAState::LOCAL_SAVED):
1119 if (HasZT0State && From == ZAState::ACTIVE)
1120 emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/true);
1121 if (HasZAState)
1122 emitZASave(Context, MBB, InsertPt, PhysLiveRegs);
1123 break;
1124
1125 // This section handles: ACTIVE -> LOCAL_COMMITTED
1126 case transitionFrom(ZAState::ACTIVE).to(ZAState::LOCAL_COMMITTED):
1127 // TODO: We could support ZA state here, but this transition is currently
1128 // only possible when we _don't_ have ZA state.
1129 assert(HasZT0State && !HasZAState && "Expect to only have ZT0 state.");
1130 emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/true);
1131 emitZAMode(MBB, InsertPt, /*ClearTPIDR2=*/false, /*On=*/false);
1132 break;
1133
1134 // This section handles: LOCAL_COMMITTED -> (OFF|LOCAL_SAVED)
1135 case transitionFrom(ZAState::LOCAL_COMMITTED).to(ZAState::OFF):
1136 case transitionFrom(ZAState::LOCAL_COMMITTED).to(ZAState::LOCAL_SAVED):
1137 // These transitions are a no-op.
1138 break;
1139
1140 // This section handles: LOCAL_(SAVED|COMMITTED) -> ACTIVE[_ZT0_SAVED]
1141 case transitionFrom(ZAState::LOCAL_COMMITTED).to(ZAState::ACTIVE):
1142 case transitionFrom(ZAState::LOCAL_COMMITTED).to(ZAState::ACTIVE_ZT0_SAVED):
1143 case transitionFrom(ZAState::LOCAL_SAVED).to(ZAState::ACTIVE):
1144 case transitionFrom(ZAState::LOCAL_SAVED).to(ZAState::ACTIVE_ZT0_SAVED):
1145 if (HasZAState)
1146 emitZARestore(Context, MBB, InsertPt, PhysLiveRegs);
1147 else
1148 emitZAMode(MBB, InsertPt, /*ClearTPIDR2=*/false, /*On=*/true);
1149 if (HasZT0State && To == ZAState::ACTIVE)
1150 emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/false);
1151 break;
1152
1153 // This section handles transitions to OFF (not previously covered)
1154 case transitionFrom(ZAState::ACTIVE).to(ZAState::OFF):
1155 case transitionFrom(ZAState::ACTIVE_ZT0_SAVED).to(ZAState::OFF):
1156 case transitionFrom(ZAState::LOCAL_SAVED).to(ZAState::OFF):
1157 assert(SMEFnAttrs.hasPrivateZAInterface() &&
1158 "Did not expect to turn ZA off in shared/agnostic ZA function");
1159 emitZAMode(MBB, InsertPt, /*ClearTPIDR2=*/From == ZAState::LOCAL_SAVED,
1160 /*On=*/false);
1161 break;
1162
1163 default:
1164 dbgs() << "Error: Transition from " << getZAStateString(From) << " to "
1165 << getZAStateString(To) << '\n';
1166 llvm_unreachable("Unimplemented state transition");
1167 }
1168}
1169
1170/// Returns true if private ZA setup can be elided. This occurs when there is
1171/// no instruction within the function that requires ZA to be active.
1172static bool canElidePrivateZASetup(const FunctionInfo &FnInfo) {
1173 for (const BlockInfo &BlockInfo : FnInfo.Blocks) {
1174 for (const InstInfo &InstInfo : BlockInfo.Insts) {
1175 if (InstInfo.NeededState == ZAState::ACTIVE ||
1176 InstInfo.NeededState == ZAState::ACTIVE_ZT0_SAVED)
1177 return false;
1178 }
1179 }
1180 return true;
1181}
1182
1183} // end anonymous namespace
1184
1185INITIALIZE_PASS(MachineSMEABI, "aarch64-machine-sme-abi", "Machine SME ABI",
1186 false, false)
1187
1188bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
1189 AFI = MF.getInfo<AArch64FunctionInfo>();
1190 SMEAttrs SMEFnAttrs = AFI->getSMEFnAttrs();
1191 if (!SMEFnAttrs.hasZAState() && !SMEFnAttrs.hasZT0State() &&
1192 !SMEFnAttrs.hasAgnosticZAInterface())
1193 return false;
1194
1195 Subtarget = &MF.getSubtarget<AArch64Subtarget>();
1196 if (!Subtarget->hasSME() && !SMEFnAttrs.hasAgnosticZAInterface())
1197 return false;
1198
1199 assert(MF.getRegInfo().isSSA() && "Expected to be run on SSA form!");
1200
1201 this->MF = &MF;
1202 ORE = &getAnalysis<MachineOptimizationRemarkEmitterPass>().getORE();
1203 LLI = &getAnalysis<LibcallLoweringInfoWrapper>().getLibcallLowering(
1204 *MF.getFunction().getParent(), *Subtarget);
1205 TII = Subtarget->getInstrInfo();
1206 TRI = Subtarget->getRegisterInfo();
1207 MRI = &MF.getRegInfo();
1208
1209 const EdgeBundles &Bundles =
1210 getAnalysis<EdgeBundlesWrapperLegacy>().getEdgeBundles();
1211
1212 FunctionInfo FnInfo = collectNeededZAStates(SMEFnAttrs);
1213
1214 if (SMEFnAttrs.hasPrivateZAInterface() && canElidePrivateZASetup(FnInfo))
1215 return false;
1216
1217 SmallVector<ZAState> BundleStates = assignBundleZAStates(Bundles, FnInfo);
1218
1219 EmitContext Context;
1220 insertStateChanges(Context, FnInfo, Bundles, BundleStates);
1221
1222 if (Context.needsSaveBuffer()) {
1223 if (FnInfo.AfterSMEProloguePt) {
1224 // Note: With inline stack probes the AfterSMEProloguePt may not be in the
1225 // entry block (due to the probing loop).
1226 MachineBasicBlock::iterator MBBI = *FnInfo.AfterSMEProloguePt;
1227 emitAllocateZASaveBuffer(Context, *MBBI->getParent(), MBBI,
1228 FnInfo.PhysLiveRegsAfterSMEPrologue);
1229 } else {
1230 MachineBasicBlock &EntryBlock = MF.front();
1231 emitAllocateZASaveBuffer(
1232 Context, EntryBlock, EntryBlock.getFirstNonPHI(),
1233 FnInfo.Blocks[EntryBlock.getNumber()].PhysLiveRegsAtEntry);
1234 }
1235 }
1236
1237 return true;
1238}
1239
1241 return new MachineSMEABI(OptLevel);
1242}
static constexpr unsigned ZERO_ALL_ZA_MASK
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
MachineBasicBlock & MBB
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
MachineBasicBlock MachineBasicBlock::iterator MBBI
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
const HexagonInstrInfo * TII
#define _
IRTranslator LLVM IR MI
This file implements the LivePhysRegs utility for tracking liveness of physical registers.
#define ENTRY(ASMNAME, ENUM)
#define I(x, y, z)
Definition MD5.cpp:57
static DebugLoc getDebugLoc(MachineBasicBlock::instr_iterator FirstMI, MachineBasicBlock::instr_iterator LastMI)
Return the first DebugLoc that has line number information, given a range of instructions.
===- MachineOptimizationRemarkEmitter.h - Opt Diagnostics -*- C++ -*-—===//
#define MAKE_CASE(V)
Register const TargetRegisterInfo * TRI
if(PassOpts->AAPipeline)
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
Definition PassSupport.h:56
This file defines the SmallVector class.
AArch64FunctionInfo - This class is derived from MachineFunctionInfo and contains private AArch64-spe...
Represent the analysis usage information of a pass.
AnalysisUsage & addPreservedID(const void *ID)
AnalysisUsage & addRequired()
LLVM_ABI void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition Pass.cpp:270
Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition ArrayRef.h:40
This class represents a function call, abstracting a target machine's calling convention.
A debug info location.
Definition DebugLoc.h:123
ArrayRef< unsigned > getBlocks(unsigned Bundle) const
getBlocks - Return an array of blocks that are connected to Bundle.
Definition EdgeBundles.h:53
unsigned getBundle(unsigned N, bool Out) const
getBundle - Return the ingoing (Out = false) or outgoing (Out = true) bundle number for basic block N
Definition EdgeBundles.h:47
unsigned getNumBundles() const
getNumBundles - Return the total number of bundles in the CFG.
Definition EdgeBundles.h:50
FunctionPass class - This class is used to implement most global optimizations.
Definition Pass.h:314
LLVMContext & getContext() const
getContext - Return a reference to the LLVMContext associated with this function.
Definition Function.cpp:358
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
This is an important class for using LLVM in a threaded context.
Definition LLVMContext.h:68
LLVM_ABI void emitError(const Instruction *I, const Twine &ErrorStr)
emitError - Emit an error message to the currently installed error handler with optional location inf...
Tracks which library functions to use for a particular subtarget.
LLVM_ABI CallingConv::ID getLibcallImplCallingConv(RTLIB::LibcallImpl Call) const
Get the CallingConv that should be used for the specified libcall.
LLVM_ABI RTLIB::LibcallImpl getLibcallImpl(RTLIB::Libcall Call) const
Return the lowering's selection of implementation call for Call.
A set of register units used to track register liveness.
bool available(MCRegister Reg) const
Returns true if no part of physical register Reg is live.
void addReg(MCRegister Reg)
Adds register units covered by physical register Reg.
LLVM_ABI void stepBackward(const MachineInstr &MI)
Updates liveness when stepping backwards over the instruction MI.
LLVM_ABI void addLiveOuts(const MachineBasicBlock &MBB)
Adds registers living out of block MBB.
MachineInstrBundleIterator< const MachineInstr > const_iterator
int getNumber() const
MachineBasicBlocks are uniquely numbered at the function level, unless they're not in a MachineFuncti...
LLVM_ABI iterator getFirstNonPHI()
Returns a pointer to the first instruction in this block that is not a PHINode instruction.
succ_reverse_iterator succ_rbegin()
MachineInstrBundleIterator< MachineInstr > iterator
succ_reverse_iterator succ_rend()
The MachineFrameInfo class represents an abstract stack frame until prolog/epilog code is inserted.
LLVM_ABI int CreateStackObject(uint64_t Size, Align Alignment, bool isSpillSlot, const AllocaInst *Alloca=nullptr, uint8_t ID=0)
Create a new statically sized stack object, returning a nonnegative identifier to represent it.
LLVM_ABI int CreateSpillStackObject(uint64_t Size, Align Alignment)
Create a new statically sized stack object that represents a spill slot, returning a nonnegative iden...
LLVM_ABI int CreateVariableSizedObject(Align Alignment, const AllocaInst *Alloca)
Notify the MachineFrameInfo object that a variable sized object has been created.
MachineFunctionPass - This class adapts the FunctionPass interface to allow convenient creation of pa...
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - Subclasses that override getAnalysisUsage must call this.
StringRef getName() const
getName - Return the name of the corresponding LLVM function.
MachineFrameInfo & getFrameInfo()
getFrameInfo - Return the frame info object for the current function.
MachineRegisterInfo & getRegInfo()
getRegInfo - Return information about the registers currently in use.
Function & getFunction()
Return the LLVM function that this machine code represents.
unsigned getNumBlockIDs() const
getNumBlockIDs - Return the number of MBB ID's allocated.
Ty * getInfo()
getInfo - Keep track of various per-function pieces of information for backends that would like to do...
const MachineInstrBuilder & addExternalSymbol(const char *FnName, unsigned TargetFlags=0) const
const MachineInstrBuilder & addReg(Register RegNo, RegState Flags={}, unsigned SubReg=0) const
Add a new virtual register operand.
const MachineInstrBuilder & addImm(int64_t Val) const
Add a new immediate operand.
const MachineInstrBuilder & addFrameIndex(int Idx) const
const MachineInstrBuilder & addRegMask(const uint32_t *Mask) const
const MachineInstrBuilder & addDef(Register RegNo, RegState Flags={}, unsigned SubReg=0) const
Add a virtual register definition operand.
Representation of each machine instruction.
MachineOperand class - Representation of each machine instruction operand.
const GlobalValue * getGlobal() const
bool isReg() const
isReg - Tests if this is a MO_Register operand.
bool isSymbol() const
isSymbol - Tests if this is a MO_ExternalSymbol operand.
bool isGlobal() const
isGlobal - Tests if this is a MO_GlobalAddress operand.
const char * getSymbolName() const
Register getReg() const
getReg - Returns the register number.
Diagnostic information for optimization analysis remarks.
LLVM_ABI void emit(DiagnosticInfoOptimizationBase &OptDiag)
Emit an optimization remark.
bool allowExtraAnalysis(StringRef PassName) const
Whether we allow for extra compile-time budget to perform more analysis to be more informative.
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
LLVM_ABI Register createVirtualRegister(const TargetRegisterClass *RegClass, StringRef Name="")
createVirtualRegister - Create and return a new virtual register in the function with the specified r...
Wrapper class representing virtual and physical registers.
Definition Register.h:20
constexpr bool isPhysical() const
Return true if the specified register number is in the physical register namespace.
Definition Register.h:83
SMEAttrs is a utility class to parse the SME ACLE attributes on functions.
bool hasAgnosticZAInterface() const
bool hasPrivateZAInterface() const
bool hasSharedZAInterface() const
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
typename SuperClass::const_iterator const_iterator
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Represent a constant reference to a string, i.e.
Definition StringRef.h:56
constexpr bool empty() const
Check if the string is empty.
Definition StringRef.h:141
constexpr const char * data() const
Get a pointer to the start of the string (which may not be null terminated).
Definition StringRef.h:138
TargetRegisterInfo base class - We assume that the target defines a static array of TargetRegisterDes...
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
Definition Twine.h:82
op_range operands()
Definition User.h:267
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
Definition Value.cpp:318
const ParentTy * getParent() const
Definition ilist_node.h:34
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
static unsigned getArithExtendImm(AArch64_AM::ShiftExtendType ET, unsigned Imm)
getArithExtendImm - Encode the extend type and shift amount for an arithmetic instruction: imm: 3-bit...
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
@ AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0
Preserve X0-X13, X19-X29, SP, Z0-Z31, P0-P15.
@ AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1
Preserve X1-X15, X19-X29, SP, Z0-Z31, P0-P15.
This is an optimization pass for GlobalISel generic memory operations.
MachineInstrBuilder BuildMI(MachineFunction &MF, const MIMetadata &MIMD, const MCInstrDesc &MCID)
Builder interface. Specify how to create the initial instruction itself.
@ Implicit
Not emitted register (e.g. carry, or temporary result).
@ Define
Register definition.
FunctionPass * createMachineSMEABIPass(CodeGenOptLevel)
LLVM_ABI char & MachineDominatorsID
MachineDominators - This pass is a machine dominators analysis pass.
LLVM_ABI void reportFatalInternalError(Error Err)
Report a fatal error that indicates a bug in LLVM.
Definition Error.cpp:173
LLVM_ABI char & MachineLoopInfoID
MachineLoopInfo - This pass is a loop analysis pass.
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1745
auto reverse(ContainerTy &&C)
Definition STLExtras.h:407
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:209
CodeGenOptLevel
Code generation optimization level.
Definition CodeGen.h:82
@ Default
-O2, -Os, -Oz
Definition CodeGen.h:85
uint16_t MCPhysReg
An unsigned integer type large enough to represent all physical registers, but not necessarily virtua...
Definition MCRegister.h:21
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition Alignment.h:39
static StringRef getLibcallImplName(RTLIB::LibcallImpl CallImpl)
Get the libcall routine name for the specified libcall implementation.