1 /* SPDX-License-Identifier: GPL-2.0-only */ 2 3 #ifndef __NET_PSP_HELPERS_H 4 #define __NET_PSP_HELPERS_H 5 6 #include <linux/skbuff.h> 7 #include <linux/rcupdate.h> 8 #include <linux/udp.h> 9 #include <net/sock.h> 10 #include <net/tcp.h> 11 #include <net/psp/types.h> 12 13 struct inet_timewait_sock; 14 15 /* Driver-facing API */ 16 struct psp_dev * 17 psp_dev_create(struct net_device *netdev, struct psp_dev_ops *psd_ops, 18 struct psp_dev_caps *psd_caps, void *priv_ptr); 19 void psp_dev_unregister(struct psp_dev *psd); 20 bool psp_dev_encapsulate(struct net *net, struct sk_buff *skb, __be32 spi, 21 u8 ver, __be16 sport); 22 int psp_dev_rcv(struct sk_buff *skb, u16 dev_id, u8 generation, bool strip_icv); 23 24 /* Kernel-facing API */ 25 void psp_assoc_put(struct psp_assoc *pas); 26 27 static inline void *psp_assoc_drv_data(struct psp_assoc *pas) 28 { 29 return pas->drv_data; 30 } 31 32 #if IS_ENABLED(CONFIG_INET_PSP) 33 unsigned int psp_key_size(u32 version); 34 void psp_sk_assoc_free(struct sock *sk); 35 void psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk); 36 void psp_twsk_assoc_free(struct inet_timewait_sock *tw); 37 void psp_reply_set_decrypted(struct sk_buff *skb); 38 39 static inline struct psp_assoc *psp_sk_assoc(const struct sock *sk) 40 { 41 return rcu_dereference_check(sk->psp_assoc, lockdep_sock_is_held(sk)); 42 } 43 44 static inline void 45 psp_enqueue_set_decrypted(struct sock *sk, struct sk_buff *skb) 46 { 47 struct psp_assoc *pas; 48 49 pas = psp_sk_assoc(sk); 50 if (pas && pas->tx.spi) 51 skb->decrypted = 1; 52 } 53 54 static inline unsigned long 55 __psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two, 56 unsigned long diffs) 57 { 58 struct psp_skb_ext *a, *b; 59 60 a = skb_ext_find(one, SKB_EXT_PSP); 61 b = skb_ext_find(two, SKB_EXT_PSP); 62 63 diffs |= (!!a) ^ (!!b); 64 if (!diffs && unlikely(a)) 65 diffs |= memcmp(a, b, sizeof(*a)); 66 return diffs; 67 } 68 69 static inline bool 70 psp_is_allowed_nondata(struct sk_buff *skb, struct psp_assoc *pas) 71 { 72 bool fin = !!(TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN); 73 u32 end_seq = TCP_SKB_CB(skb)->end_seq; 74 u32 seq = TCP_SKB_CB(skb)->seq; 75 bool pure_fin; 76 77 pure_fin = fin && end_seq - seq == 1; 78 79 return seq == end_seq || (pure_fin && seq == pas->upgrade_seq); 80 } 81 82 static inline bool 83 psp_pse_matches_pas(struct psp_skb_ext *pse, struct psp_assoc *pas) 84 { 85 return pse && pas->rx.spi == pse->spi && 86 pas->generation == pse->generation && 87 pas->version == pse->version && 88 pas->dev_id == pse->dev_id; 89 } 90 91 static inline enum skb_drop_reason 92 __psp_sk_rx_policy_check(struct sk_buff *skb, struct psp_assoc *pas) 93 { 94 struct psp_skb_ext *pse = skb_ext_find(skb, SKB_EXT_PSP); 95 96 if (!pas) 97 return pse ? SKB_DROP_REASON_PSP_INPUT : 0; 98 99 if (likely(psp_pse_matches_pas(pse, pas))) { 100 if (unlikely(!pas->peer_tx)) 101 pas->peer_tx = 1; 102 103 return 0; 104 } 105 106 if (!pse) { 107 if (!pas->tx.spi || 108 (!pas->peer_tx && psp_is_allowed_nondata(skb, pas))) 109 return 0; 110 } 111 112 return SKB_DROP_REASON_PSP_INPUT; 113 } 114 115 static inline enum skb_drop_reason 116 psp_sk_rx_policy_check(struct sock *sk, struct sk_buff *skb) 117 { 118 return __psp_sk_rx_policy_check(skb, psp_sk_assoc(sk)); 119 } 120 121 static inline enum skb_drop_reason 122 psp_twsk_rx_policy_check(struct inet_timewait_sock *tw, struct sk_buff *skb) 123 { 124 return __psp_sk_rx_policy_check(skb, rcu_dereference(tw->psp_assoc)); 125 } 126 127 static inline struct psp_assoc *psp_sk_get_assoc_rcu(const struct sock *sk) 128 { 129 struct inet_timewait_sock *tw; 130 struct psp_assoc *pas; 131 int state; 132 133 state = 1 << READ_ONCE(sk->sk_state); 134 if (!sk_is_inet(sk) || state & TCPF_NEW_SYN_RECV) 135 return NULL; 136 137 tw = inet_twsk(sk); 138 pas = state & TCPF_TIME_WAIT ? rcu_dereference(tw->psp_assoc) : 139 rcu_dereference(sk->psp_assoc); 140 return pas; 141 } 142 143 static inline struct psp_assoc *psp_skb_get_assoc_rcu(struct sk_buff *skb) 144 { 145 if (!skb->decrypted || !skb->sk) 146 return NULL; 147 148 return psp_sk_get_assoc_rcu(skb->sk); 149 } 150 151 static inline unsigned int psp_sk_overhead(const struct sock *sk) 152 { 153 int psp_encap = sizeof(struct udphdr) + PSP_HDR_SIZE + PSP_TRL_SIZE; 154 bool has_psp = rcu_access_pointer(sk->psp_assoc); 155 156 return has_psp ? psp_encap : 0; 157 } 158 #else 159 static inline void psp_sk_assoc_free(struct sock *sk) { } 160 static inline void 161 psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk) { } 162 static inline void psp_twsk_assoc_free(struct inet_timewait_sock *tw) { } 163 static inline void 164 psp_reply_set_decrypted(struct sk_buff *skb) { } 165 166 static inline struct psp_assoc *psp_sk_assoc(const struct sock *sk) 167 { 168 return NULL; 169 } 170 171 static inline void 172 psp_enqueue_set_decrypted(struct sock *sk, struct sk_buff *skb) { } 173 174 static inline unsigned long 175 __psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two, 176 unsigned long diffs) 177 { 178 return diffs; 179 } 180 181 static inline enum skb_drop_reason 182 psp_sk_rx_policy_check(struct sock *sk, struct sk_buff *skb) 183 { 184 return 0; 185 } 186 187 static inline enum skb_drop_reason 188 psp_twsk_rx_policy_check(struct inet_timewait_sock *tw, struct sk_buff *skb) 189 { 190 return 0; 191 } 192 193 static inline struct psp_assoc *psp_skb_get_assoc_rcu(struct sk_buff *skb) 194 { 195 return NULL; 196 } 197 198 static inline unsigned int psp_sk_overhead(const struct sock *sk) 199 { 200 return 0; 201 } 202 #endif 203 204 static inline unsigned long 205 psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two) 206 { 207 return __psp_skb_coalesce_diff(one, two, 0); 208 } 209 210 #endif /* __NET_PSP_HELPERS_H */ 211