1 /* SPDX-License-Identifier: GPL-2.0-only */ 2 3 #ifndef __NET_PSP_HELPERS_H 4 #define __NET_PSP_HELPERS_H 5 6 #include <linux/skbuff.h> 7 #include <linux/rcupdate.h> 8 #include <linux/udp.h> 9 #include <net/sock.h> 10 #include <net/tcp.h> 11 #include <net/psp/types.h> 12 13 struct inet_timewait_sock; 14 15 /* Driver-facing API */ 16 struct psp_dev * 17 psp_dev_create(struct net_device *netdev, struct psp_dev_ops *psd_ops, 18 struct psp_dev_caps *psd_caps, void *priv_ptr); 19 void psp_dev_unregister(struct psp_dev *psd); 20 bool psp_dev_encapsulate(struct net *net, struct sk_buff *skb, __be32 spi, 21 u8 ver, __be16 sport); 22 int psp_dev_rcv(struct sk_buff *skb, u16 dev_id, u8 generation, bool strip_icv); 23 24 /* Kernel-facing API */ 25 void psp_assoc_put(struct psp_assoc *pas); 26 27 static inline void *psp_assoc_drv_data(struct psp_assoc *pas) 28 { 29 return pas->drv_data; 30 } 31 32 #if IS_ENABLED(CONFIG_INET_PSP) 33 unsigned int psp_key_size(u32 version); 34 void psp_sk_assoc_free(struct sock *sk); 35 void psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk); 36 void psp_twsk_assoc_free(struct inet_timewait_sock *tw); 37 void psp_reply_set_decrypted(struct sk_buff *skb); 38 39 static inline struct psp_assoc *psp_sk_assoc(const struct sock *sk) 40 { 41 return rcu_dereference_check(sk->psp_assoc, lockdep_sock_is_held(sk)); 42 } 43 44 static inline void 45 psp_enqueue_set_decrypted(struct sock *sk, struct sk_buff *skb) 46 { 47 struct psp_assoc *pas; 48 49 pas = psp_sk_assoc(sk); 50 if (pas && pas->tx.spi) 51 skb->decrypted = 1; 52 } 53 54 static inline unsigned long 55 __psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two, 56 unsigned long diffs) 57 { 58 struct psp_skb_ext *a, *b; 59 60 a = skb_ext_find(one, SKB_EXT_PSP); 61 b = skb_ext_find(two, SKB_EXT_PSP); 62 63 diffs |= (!!a) ^ (!!b); 64 if (!diffs && unlikely(a)) 65 diffs |= memcmp(a, b, sizeof(*a)); 66 return diffs; 67 } 68 69 static inline bool 70 psp_is_allowed_nondata(struct sk_buff *skb, struct psp_assoc *pas) 71 { 72 bool fin = !!(TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN); 73 u32 end_seq = TCP_SKB_CB(skb)->end_seq; 74 u32 seq = TCP_SKB_CB(skb)->seq; 75 bool pure_fin; 76 77 pure_fin = fin && end_seq - seq == 1; 78 79 return seq == end_seq || (pure_fin && seq == pas->upgrade_seq); 80 } 81 82 static inline bool 83 psp_pse_matches_pas(struct psp_skb_ext *pse, struct psp_assoc *pas) 84 { 85 return pse && pas->rx.spi == pse->spi && 86 pas->generation == pse->generation && 87 pas->version == pse->version && 88 pas->dev_id == pse->dev_id; 89 } 90 91 static inline enum skb_drop_reason 92 __psp_sk_rx_policy_check(struct sk_buff *skb, struct psp_assoc *pas) 93 { 94 struct psp_skb_ext *pse = skb_ext_find(skb, SKB_EXT_PSP); 95 96 if (!pas) 97 return pse ? SKB_DROP_REASON_PSP_INPUT : 0; 98 99 if (likely(psp_pse_matches_pas(pse, pas))) { 100 if (unlikely(!pas->peer_tx)) 101 pas->peer_tx = 1; 102 103 return 0; 104 } 105 106 if (!pse) { 107 if (!pas->tx.spi || 108 (!pas->peer_tx && psp_is_allowed_nondata(skb, pas))) 109 return 0; 110 } 111 112 return SKB_DROP_REASON_PSP_INPUT; 113 } 114 115 static inline enum skb_drop_reason 116 psp_sk_rx_policy_check(struct sock *sk, struct sk_buff *skb) 117 { 118 return __psp_sk_rx_policy_check(skb, psp_sk_assoc(sk)); 119 } 120 121 static inline enum skb_drop_reason 122 psp_twsk_rx_policy_check(struct inet_timewait_sock *tw, struct sk_buff *skb) 123 { 124 return __psp_sk_rx_policy_check(skb, rcu_dereference(tw->psp_assoc)); 125 } 126 127 static inline struct psp_assoc *psp_sk_get_assoc_rcu(const struct sock *sk) 128 { 129 struct psp_assoc *pas; 130 int state; 131 132 state = READ_ONCE(sk->sk_state); 133 if (!sk_is_inet(sk) || state == TCP_NEW_SYN_RECV) 134 return NULL; 135 136 pas = state == TCP_TIME_WAIT ? 137 rcu_dereference(inet_twsk(sk)->psp_assoc) : 138 rcu_dereference(sk->psp_assoc); 139 return pas; 140 } 141 142 static inline struct psp_assoc *psp_skb_get_assoc_rcu(struct sk_buff *skb) 143 { 144 if (!skb->decrypted || !skb->sk) 145 return NULL; 146 147 return psp_sk_get_assoc_rcu(skb->sk); 148 } 149 150 static inline unsigned int psp_sk_overhead(const struct sock *sk) 151 { 152 int psp_encap = sizeof(struct udphdr) + PSP_HDR_SIZE + PSP_TRL_SIZE; 153 bool has_psp = rcu_access_pointer(sk->psp_assoc); 154 155 return has_psp ? psp_encap : 0; 156 } 157 #else 158 static inline void psp_sk_assoc_free(struct sock *sk) { } 159 static inline void 160 psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk) { } 161 static inline void psp_twsk_assoc_free(struct inet_timewait_sock *tw) { } 162 static inline void 163 psp_reply_set_decrypted(struct sk_buff *skb) { } 164 165 static inline struct psp_assoc *psp_sk_assoc(const struct sock *sk) 166 { 167 return NULL; 168 } 169 170 static inline void 171 psp_enqueue_set_decrypted(struct sock *sk, struct sk_buff *skb) { } 172 173 static inline unsigned long 174 __psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two, 175 unsigned long diffs) 176 { 177 return diffs; 178 } 179 180 static inline enum skb_drop_reason 181 psp_sk_rx_policy_check(struct sock *sk, struct sk_buff *skb) 182 { 183 return 0; 184 } 185 186 static inline enum skb_drop_reason 187 psp_twsk_rx_policy_check(struct inet_timewait_sock *tw, struct sk_buff *skb) 188 { 189 return 0; 190 } 191 192 static inline struct psp_assoc *psp_skb_get_assoc_rcu(struct sk_buff *skb) 193 { 194 return NULL; 195 } 196 197 static inline unsigned int psp_sk_overhead(const struct sock *sk) 198 { 199 return 0; 200 } 201 #endif 202 203 static inline unsigned long 204 psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two) 205 { 206 return __psp_skb_coalesce_diff(one, two, 0); 207 } 208 209 #endif /* __NET_PSP_HELPERS_H */ 210