1 /* SPDX-License-Identifier: GPL-2.0-only */ 2 3 #ifndef __NET_PSP_HELPERS_H 4 #define __NET_PSP_HELPERS_H 5 6 #include <linux/skbuff.h> 7 #include <linux/rcupdate.h> 8 #include <linux/udp.h> 9 #include <net/sock.h> 10 #include <net/tcp.h> 11 #include <net/psp/types.h> 12 13 struct inet_timewait_sock; 14 15 /* Driver-facing API */ 16 struct psp_dev * 17 psp_dev_create(struct net_device *netdev, struct psp_dev_ops *psd_ops, 18 struct psp_dev_caps *psd_caps, void *priv_ptr); 19 void psp_dev_unregister(struct psp_dev *psd); 20 bool psp_dev_encapsulate(struct net *net, struct sk_buff *skb, __be32 spi, 21 u8 ver, __be16 sport); 22 23 /* Kernel-facing API */ 24 void psp_assoc_put(struct psp_assoc *pas); 25 26 static inline void *psp_assoc_drv_data(struct psp_assoc *pas) 27 { 28 return pas->drv_data; 29 } 30 31 #if IS_ENABLED(CONFIG_INET_PSP) 32 unsigned int psp_key_size(u32 version); 33 void psp_sk_assoc_free(struct sock *sk); 34 void psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk); 35 void psp_twsk_assoc_free(struct inet_timewait_sock *tw); 36 void psp_reply_set_decrypted(struct sk_buff *skb); 37 38 static inline struct psp_assoc *psp_sk_assoc(const struct sock *sk) 39 { 40 return rcu_dereference_check(sk->psp_assoc, lockdep_sock_is_held(sk)); 41 } 42 43 static inline void 44 psp_enqueue_set_decrypted(struct sock *sk, struct sk_buff *skb) 45 { 46 struct psp_assoc *pas; 47 48 pas = psp_sk_assoc(sk); 49 if (pas && pas->tx.spi) 50 skb->decrypted = 1; 51 } 52 53 static inline unsigned long 54 __psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two, 55 unsigned long diffs) 56 { 57 struct psp_skb_ext *a, *b; 58 59 a = skb_ext_find(one, SKB_EXT_PSP); 60 b = skb_ext_find(two, SKB_EXT_PSP); 61 62 diffs |= (!!a) ^ (!!b); 63 if (!diffs && unlikely(a)) 64 diffs |= memcmp(a, b, sizeof(*a)); 65 return diffs; 66 } 67 68 static inline bool 69 psp_is_allowed_nondata(struct sk_buff *skb, struct psp_assoc *pas) 70 { 71 bool fin = !!(TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN); 72 u32 end_seq = TCP_SKB_CB(skb)->end_seq; 73 u32 seq = TCP_SKB_CB(skb)->seq; 74 bool pure_fin; 75 76 pure_fin = fin && end_seq - seq == 1; 77 78 return seq == end_seq || (pure_fin && seq == pas->upgrade_seq); 79 } 80 81 static inline bool 82 psp_pse_matches_pas(struct psp_skb_ext *pse, struct psp_assoc *pas) 83 { 84 return pse && pas->rx.spi == pse->spi && 85 pas->generation == pse->generation && 86 pas->version == pse->version && 87 pas->dev_id == pse->dev_id; 88 } 89 90 static inline enum skb_drop_reason 91 __psp_sk_rx_policy_check(struct sk_buff *skb, struct psp_assoc *pas) 92 { 93 struct psp_skb_ext *pse = skb_ext_find(skb, SKB_EXT_PSP); 94 95 if (!pas) 96 return pse ? SKB_DROP_REASON_PSP_INPUT : 0; 97 98 if (likely(psp_pse_matches_pas(pse, pas))) { 99 if (unlikely(!pas->peer_tx)) 100 pas->peer_tx = 1; 101 102 return 0; 103 } 104 105 if (!pse) { 106 if (!pas->tx.spi || 107 (!pas->peer_tx && psp_is_allowed_nondata(skb, pas))) 108 return 0; 109 } 110 111 return SKB_DROP_REASON_PSP_INPUT; 112 } 113 114 static inline enum skb_drop_reason 115 psp_sk_rx_policy_check(struct sock *sk, struct sk_buff *skb) 116 { 117 return __psp_sk_rx_policy_check(skb, psp_sk_assoc(sk)); 118 } 119 120 static inline enum skb_drop_reason 121 psp_twsk_rx_policy_check(struct inet_timewait_sock *tw, struct sk_buff *skb) 122 { 123 return __psp_sk_rx_policy_check(skb, rcu_dereference(tw->psp_assoc)); 124 } 125 126 static inline struct psp_assoc *psp_sk_get_assoc_rcu(struct sock *sk) 127 { 128 struct inet_timewait_sock *tw; 129 struct psp_assoc *pas; 130 int state; 131 132 state = 1 << READ_ONCE(sk->sk_state); 133 if (!sk_is_inet(sk) || state & TCPF_NEW_SYN_RECV) 134 return NULL; 135 136 tw = inet_twsk(sk); 137 pas = state & TCPF_TIME_WAIT ? rcu_dereference(tw->psp_assoc) : 138 rcu_dereference(sk->psp_assoc); 139 return pas; 140 } 141 142 static inline struct psp_assoc *psp_skb_get_assoc_rcu(struct sk_buff *skb) 143 { 144 if (!skb->decrypted || !skb->sk) 145 return NULL; 146 147 return psp_sk_get_assoc_rcu(skb->sk); 148 } 149 150 static inline unsigned int psp_sk_overhead(const struct sock *sk) 151 { 152 int psp_encap = sizeof(struct udphdr) + PSP_HDR_SIZE + PSP_TRL_SIZE; 153 bool has_psp = rcu_access_pointer(sk->psp_assoc); 154 155 return has_psp ? psp_encap : 0; 156 } 157 #else 158 static inline void psp_sk_assoc_free(struct sock *sk) { } 159 static inline void 160 psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk) { } 161 static inline void psp_twsk_assoc_free(struct inet_timewait_sock *tw) { } 162 static inline void 163 psp_reply_set_decrypted(struct sk_buff *skb) { } 164 165 static inline struct psp_assoc *psp_sk_assoc(const struct sock *sk) 166 { 167 return NULL; 168 } 169 170 static inline void 171 psp_enqueue_set_decrypted(struct sock *sk, struct sk_buff *skb) { } 172 173 static inline unsigned long 174 __psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two, 175 unsigned long diffs) 176 { 177 return diffs; 178 } 179 180 static inline enum skb_drop_reason 181 psp_sk_rx_policy_check(struct sock *sk, struct sk_buff *skb) 182 { 183 return 0; 184 } 185 186 static inline enum skb_drop_reason 187 psp_twsk_rx_policy_check(struct inet_timewait_sock *tw, struct sk_buff *skb) 188 { 189 return 0; 190 } 191 192 static inline struct psp_assoc *psp_skb_get_assoc_rcu(struct sk_buff *skb) 193 { 194 return NULL; 195 } 196 197 static inline unsigned int psp_sk_overhead(const struct sock *sk) 198 { 199 return 0; 200 } 201 #endif 202 203 static inline unsigned long 204 psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two) 205 { 206 return __psp_skb_coalesce_diff(one, two, 0); 207 } 208 209 #endif /* __NET_PSP_HELPERS_H */ 210