1 /* SPDX-License-Identifier: GPL-2.0-only */ 2 3 #ifndef __NET_PSP_HELPERS_H 4 #define __NET_PSP_HELPERS_H 5 6 #include <linux/skbuff.h> 7 #include <linux/rcupdate.h> 8 #include <linux/udp.h> 9 #include <net/sock.h> 10 #include <net/tcp.h> 11 #include <net/psp/types.h> 12 13 struct inet_timewait_sock; 14 15 /* Driver-facing API */ 16 struct psp_dev * 17 psp_dev_create(struct net_device *netdev, struct psp_dev_ops *psd_ops, 18 struct psp_dev_caps *psd_caps, void *priv_ptr); 19 void psp_dev_unregister(struct psp_dev *psd); 20 21 /* Kernel-facing API */ 22 void psp_assoc_put(struct psp_assoc *pas); 23 24 static inline void *psp_assoc_drv_data(struct psp_assoc *pas) 25 { 26 return pas->drv_data; 27 } 28 29 #if IS_ENABLED(CONFIG_INET_PSP) 30 unsigned int psp_key_size(u32 version); 31 void psp_sk_assoc_free(struct sock *sk); 32 void psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk); 33 void psp_twsk_assoc_free(struct inet_timewait_sock *tw); 34 void psp_reply_set_decrypted(struct sk_buff *skb); 35 36 static inline struct psp_assoc *psp_sk_assoc(const struct sock *sk) 37 { 38 return rcu_dereference_check(sk->psp_assoc, lockdep_sock_is_held(sk)); 39 } 40 41 static inline void 42 psp_enqueue_set_decrypted(struct sock *sk, struct sk_buff *skb) 43 { 44 struct psp_assoc *pas; 45 46 pas = psp_sk_assoc(sk); 47 if (pas && pas->tx.spi) 48 skb->decrypted = 1; 49 } 50 51 static inline unsigned long 52 __psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two, 53 unsigned long diffs) 54 { 55 struct psp_skb_ext *a, *b; 56 57 a = skb_ext_find(one, SKB_EXT_PSP); 58 b = skb_ext_find(two, SKB_EXT_PSP); 59 60 diffs |= (!!a) ^ (!!b); 61 if (!diffs && unlikely(a)) 62 diffs |= memcmp(a, b, sizeof(*a)); 63 return diffs; 64 } 65 66 static inline bool 67 psp_is_allowed_nondata(struct sk_buff *skb, struct psp_assoc *pas) 68 { 69 bool fin = !!(TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN); 70 u32 end_seq = TCP_SKB_CB(skb)->end_seq; 71 u32 seq = TCP_SKB_CB(skb)->seq; 72 bool pure_fin; 73 74 pure_fin = fin && end_seq - seq == 1; 75 76 return seq == end_seq || (pure_fin && seq == pas->upgrade_seq); 77 } 78 79 static inline bool 80 psp_pse_matches_pas(struct psp_skb_ext *pse, struct psp_assoc *pas) 81 { 82 return pse && pas->rx.spi == pse->spi && 83 pas->generation == pse->generation && 84 pas->version == pse->version && 85 pas->dev_id == pse->dev_id; 86 } 87 88 static inline enum skb_drop_reason 89 __psp_sk_rx_policy_check(struct sk_buff *skb, struct psp_assoc *pas) 90 { 91 struct psp_skb_ext *pse = skb_ext_find(skb, SKB_EXT_PSP); 92 93 if (!pas) 94 return pse ? SKB_DROP_REASON_PSP_INPUT : 0; 95 96 if (likely(psp_pse_matches_pas(pse, pas))) { 97 if (unlikely(!pas->peer_tx)) 98 pas->peer_tx = 1; 99 100 return 0; 101 } 102 103 if (!pse) { 104 if (!pas->tx.spi || 105 (!pas->peer_tx && psp_is_allowed_nondata(skb, pas))) 106 return 0; 107 } 108 109 return SKB_DROP_REASON_PSP_INPUT; 110 } 111 112 static inline enum skb_drop_reason 113 psp_sk_rx_policy_check(struct sock *sk, struct sk_buff *skb) 114 { 115 return __psp_sk_rx_policy_check(skb, psp_sk_assoc(sk)); 116 } 117 118 static inline enum skb_drop_reason 119 psp_twsk_rx_policy_check(struct inet_timewait_sock *tw, struct sk_buff *skb) 120 { 121 return __psp_sk_rx_policy_check(skb, rcu_dereference(tw->psp_assoc)); 122 } 123 124 static inline struct psp_assoc *psp_sk_get_assoc_rcu(struct sock *sk) 125 { 126 struct inet_timewait_sock *tw; 127 struct psp_assoc *pas; 128 int state; 129 130 state = 1 << READ_ONCE(sk->sk_state); 131 if (!sk_is_inet(sk) || state & TCPF_NEW_SYN_RECV) 132 return NULL; 133 134 tw = inet_twsk(sk); 135 pas = state & TCPF_TIME_WAIT ? rcu_dereference(tw->psp_assoc) : 136 rcu_dereference(sk->psp_assoc); 137 return pas; 138 } 139 140 static inline struct psp_assoc *psp_skb_get_assoc_rcu(struct sk_buff *skb) 141 { 142 if (!skb->decrypted || !skb->sk) 143 return NULL; 144 145 return psp_sk_get_assoc_rcu(skb->sk); 146 } 147 148 static inline unsigned int psp_sk_overhead(const struct sock *sk) 149 { 150 int psp_encap = sizeof(struct udphdr) + PSP_HDR_SIZE + PSP_TRL_SIZE; 151 bool has_psp = rcu_access_pointer(sk->psp_assoc); 152 153 return has_psp ? psp_encap : 0; 154 } 155 #else 156 static inline void psp_sk_assoc_free(struct sock *sk) { } 157 static inline void 158 psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk) { } 159 static inline void psp_twsk_assoc_free(struct inet_timewait_sock *tw) { } 160 static inline void 161 psp_reply_set_decrypted(struct sk_buff *skb) { } 162 163 static inline struct psp_assoc *psp_sk_assoc(const struct sock *sk) 164 { 165 return NULL; 166 } 167 168 static inline void 169 psp_enqueue_set_decrypted(struct sock *sk, struct sk_buff *skb) { } 170 171 static inline unsigned long 172 __psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two, 173 unsigned long diffs) 174 { 175 return diffs; 176 } 177 178 static inline enum skb_drop_reason 179 psp_sk_rx_policy_check(struct sock *sk, struct sk_buff *skb) 180 { 181 return 0; 182 } 183 184 static inline enum skb_drop_reason 185 psp_twsk_rx_policy_check(struct inet_timewait_sock *tw, struct sk_buff *skb) 186 { 187 return 0; 188 } 189 190 static inline struct psp_assoc *psp_skb_get_assoc_rcu(struct sk_buff *skb) 191 { 192 return NULL; 193 } 194 195 static inline unsigned int psp_sk_overhead(const struct sock *sk) 196 { 197 return 0; 198 } 199 #endif 200 201 static inline unsigned long 202 psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two) 203 { 204 return __psp_skb_coalesce_diff(one, two, 0); 205 } 206 207 #endif /* __NET_PSP_HELPERS_H */ 208