xref: /freebsd/contrib/llvm-project/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===-- AArch64SMEAttributes.h - Helper for interpreting SME attributes -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef LLVM_LIB_TARGET_AARCH64_UTILS_AARCH64SMEATTRIBUTES_H
10 #define LLVM_LIB_TARGET_AARCH64_UTILS_AARCH64SMEATTRIBUTES_H
11 
12 #include "llvm/IR/Function.h"
13 
14 namespace llvm {
15 
16 class Function;
17 class CallBase;
18 class AttributeList;
19 
20 /// SMEAttrs is a utility class to parse the SME ACLE attributes on functions.
21 /// It helps determine a function's requirements for PSTATE.ZA and PSTATE.SM.
22 class SMEAttrs {
23   unsigned Bitmask = Normal;
24 
25 public:
26   enum class StateValue {
27     None = 0,
28     In = 1,        // aarch64_in_zt0
29     Out = 2,       // aarch64_out_zt0
30     InOut = 3,     // aarch64_inout_zt0
31     Preserved = 4, // aarch64_preserves_zt0
32     New = 5        // aarch64_new_zt0
33   };
34 
35   // Enum with bitmasks for each individual SME feature.
36   enum Mask {
37     Normal = 0,
38     SM_Enabled = 1 << 0,      // aarch64_pstate_sm_enabled
39     SM_Compatible = 1 << 1,   // aarch64_pstate_sm_compatible
40     SM_Body = 1 << 2,         // aarch64_pstate_sm_body
41     SME_ABI_Routine = 1 << 3, // Used for SME ABI routines to avoid lazy saves
42     ZA_State_Agnostic = 1 << 4,
43     ZT0_Undef = 1 << 5, // Use to mark ZT0 as undef to avoid spills
44     ZA_Shift = 6,
45     ZA_Mask = 0b111 << ZA_Shift,
46     ZT0_Shift = 9,
47     ZT0_Mask = 0b111 << ZT0_Shift,
48     CallSiteFlags_Mask = ZT0_Undef
49   };
50 
51   enum class InferAttrsFromName { No, Yes };
52 
53   SMEAttrs() = default;
SMEAttrs(unsigned Mask)54   SMEAttrs(unsigned Mask) { set(Mask); }
55   SMEAttrs(const Function &F, InferAttrsFromName Infer = InferAttrsFromName::No)
56       : SMEAttrs(F.getAttributes()) {
57     if (Infer == InferAttrsFromName::Yes)
58       addKnownFunctionAttrs(F.getName());
59   }
60   SMEAttrs(const AttributeList &L);
SMEAttrs(StringRef FuncName)61   SMEAttrs(StringRef FuncName) { addKnownFunctionAttrs(FuncName); };
62 
63   void set(unsigned M, bool Enable = true);
64 
65   // Interfaces to query PSTATE.SM
hasStreamingBody()66   bool hasStreamingBody() const { return Bitmask & SM_Body; }
hasStreamingInterface()67   bool hasStreamingInterface() const { return Bitmask & SM_Enabled; }
hasStreamingInterfaceOrBody()68   bool hasStreamingInterfaceOrBody() const {
69     return hasStreamingBody() || hasStreamingInterface();
70   }
hasStreamingCompatibleInterface()71   bool hasStreamingCompatibleInterface() const {
72     return Bitmask & SM_Compatible;
73   }
hasNonStreamingInterface()74   bool hasNonStreamingInterface() const {
75     return !hasStreamingInterface() && !hasStreamingCompatibleInterface();
76   }
hasNonStreamingInterfaceAndBody()77   bool hasNonStreamingInterfaceAndBody() const {
78     return hasNonStreamingInterface() && !hasStreamingBody();
79   }
80 
81   // Interfaces to query ZA
decodeZAState(unsigned Bitmask)82   static StateValue decodeZAState(unsigned Bitmask) {
83     return static_cast<StateValue>((Bitmask & ZA_Mask) >> ZA_Shift);
84   }
encodeZAState(StateValue S)85   static unsigned encodeZAState(StateValue S) {
86     return static_cast<unsigned>(S) << ZA_Shift;
87   }
88 
isNewZA()89   bool isNewZA() const { return decodeZAState(Bitmask) == StateValue::New; }
isInZA()90   bool isInZA() const { return decodeZAState(Bitmask) == StateValue::In; }
isOutZA()91   bool isOutZA() const { return decodeZAState(Bitmask) == StateValue::Out; }
isInOutZA()92   bool isInOutZA() const { return decodeZAState(Bitmask) == StateValue::InOut; }
isPreservesZA()93   bool isPreservesZA() const {
94     return decodeZAState(Bitmask) == StateValue::Preserved;
95   }
sharesZA()96   bool sharesZA() const {
97     StateValue State = decodeZAState(Bitmask);
98     return State == StateValue::In || State == StateValue::Out ||
99            State == StateValue::InOut || State == StateValue::Preserved;
100   }
hasAgnosticZAInterface()101   bool hasAgnosticZAInterface() const { return Bitmask & ZA_State_Agnostic; }
hasSharedZAInterface()102   bool hasSharedZAInterface() const { return sharesZA() || sharesZT0(); }
hasPrivateZAInterface()103   bool hasPrivateZAInterface() const {
104     return !hasSharedZAInterface() && !hasAgnosticZAInterface();
105   }
hasZAState()106   bool hasZAState() const { return isNewZA() || sharesZA(); }
isSMEABIRoutine()107   bool isSMEABIRoutine() const { return Bitmask & SME_ABI_Routine; }
108 
109   // Interfaces to query ZT0 State
decodeZT0State(unsigned Bitmask)110   static StateValue decodeZT0State(unsigned Bitmask) {
111     return static_cast<StateValue>((Bitmask & ZT0_Mask) >> ZT0_Shift);
112   }
encodeZT0State(StateValue S)113   static unsigned encodeZT0State(StateValue S) {
114     return static_cast<unsigned>(S) << ZT0_Shift;
115   }
116 
isNewZT0()117   bool isNewZT0() const { return decodeZT0State(Bitmask) == StateValue::New; }
isInZT0()118   bool isInZT0() const { return decodeZT0State(Bitmask) == StateValue::In; }
isOutZT0()119   bool isOutZT0() const { return decodeZT0State(Bitmask) == StateValue::Out; }
isInOutZT0()120   bool isInOutZT0() const {
121     return decodeZT0State(Bitmask) == StateValue::InOut;
122   }
isPreservesZT0()123   bool isPreservesZT0() const {
124     return decodeZT0State(Bitmask) == StateValue::Preserved;
125   }
hasUndefZT0()126   bool hasUndefZT0() const { return Bitmask & ZT0_Undef; }
sharesZT0()127   bool sharesZT0() const {
128     StateValue State = decodeZT0State(Bitmask);
129     return State == StateValue::In || State == StateValue::Out ||
130            State == StateValue::InOut || State == StateValue::Preserved;
131   }
hasZT0State()132   bool hasZT0State() const { return isNewZT0() || sharesZT0(); }
133 
134   SMEAttrs operator|(SMEAttrs Other) const {
135     SMEAttrs Merged(*this);
136     Merged.set(Other.Bitmask);
137     return Merged;
138   }
139 
withoutPerCallsiteFlags()140   SMEAttrs withoutPerCallsiteFlags() const {
141     return (Bitmask & ~CallSiteFlags_Mask);
142   }
143 
144   bool operator==(SMEAttrs const &Other) const {
145     return Bitmask == Other.Bitmask;
146   }
147 
148 private:
149   void addKnownFunctionAttrs(StringRef FuncName);
150 };
151 
152 /// SMECallAttrs is a utility class to hold the SMEAttrs for a callsite. It has
153 /// interfaces to query whether a streaming mode change or lazy-save mechanism
154 /// is required when going from one function to another (e.g. through a call).
155 class SMECallAttrs {
156   SMEAttrs CallerFn;
157   SMEAttrs CalledFn;
158   SMEAttrs Callsite;
159   bool IsIndirect = false;
160 
161 public:
162   SMECallAttrs(SMEAttrs Caller, SMEAttrs Callee,
163                SMEAttrs Callsite = SMEAttrs::Normal)
CallerFn(Caller)164       : CallerFn(Caller), CalledFn(Callee), Callsite(Callsite) {}
165 
166   SMECallAttrs(const CallBase &CB);
167 
caller()168   SMEAttrs &caller() { return CallerFn; }
callee()169   SMEAttrs &callee() { return IsIndirect ? Callsite : CalledFn; }
callsite()170   SMEAttrs &callsite() { return Callsite; }
caller()171   SMEAttrs const &caller() const { return CallerFn; }
callee()172   SMEAttrs const &callee() const {
173     return const_cast<SMECallAttrs *>(this)->callee();
174   }
callsite()175   SMEAttrs const &callsite() const { return Callsite; }
176 
177   /// \return true if a call from Caller -> Callee requires a change in
178   /// streaming mode.
179   bool requiresSMChange() const;
180 
requiresLazySave()181   bool requiresLazySave() const {
182     return caller().hasZAState() && callee().hasPrivateZAInterface() &&
183            !callee().isSMEABIRoutine();
184   }
185 
requiresPreservingZT0()186   bool requiresPreservingZT0() const {
187     return caller().hasZT0State() && !callsite().hasUndefZT0() &&
188            !callee().sharesZT0() && !callee().hasAgnosticZAInterface();
189   }
190 
requiresDisablingZABeforeCall()191   bool requiresDisablingZABeforeCall() const {
192     return caller().hasZT0State() && !caller().hasZAState() &&
193            callee().hasPrivateZAInterface() && !callee().isSMEABIRoutine();
194   }
195 
requiresEnablingZAAfterCall()196   bool requiresEnablingZAAfterCall() const {
197     return requiresLazySave() || requiresDisablingZABeforeCall();
198   }
199 
requiresPreservingAllZAState()200   bool requiresPreservingAllZAState() const {
201     return caller().hasAgnosticZAInterface() &&
202            !callee().hasAgnosticZAInterface() && !callee().isSMEABIRoutine();
203   }
204 };
205 
206 } // namespace llvm
207 
208 #endif // LLVM_LIB_TARGET_AARCH64_UTILS_AARCH64SMEATTRIBUTES_H
209