1 /* SPDX-License-Identifier: GPL-2.0-only */ 2 3 #ifndef __NET_PSP_HELPERS_H 4 #define __NET_PSP_HELPERS_H 5 6 #include <linux/skbuff.h> 7 #include <linux/rcupdate.h> 8 #include <net/sock.h> 9 #include <net/tcp.h> 10 #include <net/psp/types.h> 11 12 struct inet_timewait_sock; 13 14 /* Driver-facing API */ 15 struct psp_dev * 16 psp_dev_create(struct net_device *netdev, struct psp_dev_ops *psd_ops, 17 struct psp_dev_caps *psd_caps, void *priv_ptr); 18 void psp_dev_unregister(struct psp_dev *psd); 19 20 /* Kernel-facing API */ 21 void psp_assoc_put(struct psp_assoc *pas); 22 23 static inline void *psp_assoc_drv_data(struct psp_assoc *pas) 24 { 25 return pas->drv_data; 26 } 27 28 #if IS_ENABLED(CONFIG_INET_PSP) 29 unsigned int psp_key_size(u32 version); 30 void psp_sk_assoc_free(struct sock *sk); 31 void psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk); 32 void psp_twsk_assoc_free(struct inet_timewait_sock *tw); 33 void psp_reply_set_decrypted(struct sk_buff *skb); 34 35 static inline struct psp_assoc *psp_sk_assoc(const struct sock *sk) 36 { 37 return rcu_dereference_check(sk->psp_assoc, lockdep_sock_is_held(sk)); 38 } 39 40 static inline void 41 psp_enqueue_set_decrypted(struct sock *sk, struct sk_buff *skb) 42 { 43 struct psp_assoc *pas; 44 45 pas = psp_sk_assoc(sk); 46 if (pas && pas->tx.spi) 47 skb->decrypted = 1; 48 } 49 50 static inline unsigned long 51 __psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two, 52 unsigned long diffs) 53 { 54 struct psp_skb_ext *a, *b; 55 56 a = skb_ext_find(one, SKB_EXT_PSP); 57 b = skb_ext_find(two, SKB_EXT_PSP); 58 59 diffs |= (!!a) ^ (!!b); 60 if (!diffs && unlikely(a)) 61 diffs |= memcmp(a, b, sizeof(*a)); 62 return diffs; 63 } 64 65 static inline bool 66 psp_is_allowed_nondata(struct sk_buff *skb, struct psp_assoc *pas) 67 { 68 bool fin = !!(TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN); 69 u32 end_seq = TCP_SKB_CB(skb)->end_seq; 70 u32 seq = TCP_SKB_CB(skb)->seq; 71 bool pure_fin; 72 73 pure_fin = fin && end_seq - seq == 1; 74 75 return seq == end_seq || (pure_fin && seq == pas->upgrade_seq); 76 } 77 78 static inline bool 79 psp_pse_matches_pas(struct psp_skb_ext *pse, struct psp_assoc *pas) 80 { 81 return pse && pas->rx.spi == pse->spi && 82 pas->generation == pse->generation && 83 pas->version == pse->version && 84 pas->dev_id == pse->dev_id; 85 } 86 87 static inline enum skb_drop_reason 88 __psp_sk_rx_policy_check(struct sk_buff *skb, struct psp_assoc *pas) 89 { 90 struct psp_skb_ext *pse = skb_ext_find(skb, SKB_EXT_PSP); 91 92 if (!pas) 93 return pse ? SKB_DROP_REASON_PSP_INPUT : 0; 94 95 if (likely(psp_pse_matches_pas(pse, pas))) { 96 if (unlikely(!pas->peer_tx)) 97 pas->peer_tx = 1; 98 99 return 0; 100 } 101 102 if (!pse) { 103 if (!pas->tx.spi || 104 (!pas->peer_tx && psp_is_allowed_nondata(skb, pas))) 105 return 0; 106 } 107 108 return SKB_DROP_REASON_PSP_INPUT; 109 } 110 111 static inline enum skb_drop_reason 112 psp_sk_rx_policy_check(struct sock *sk, struct sk_buff *skb) 113 { 114 return __psp_sk_rx_policy_check(skb, psp_sk_assoc(sk)); 115 } 116 117 static inline enum skb_drop_reason 118 psp_twsk_rx_policy_check(struct inet_timewait_sock *tw, struct sk_buff *skb) 119 { 120 return __psp_sk_rx_policy_check(skb, rcu_dereference(tw->psp_assoc)); 121 } 122 123 static inline struct psp_assoc *psp_sk_get_assoc_rcu(struct sock *sk) 124 { 125 struct inet_timewait_sock *tw; 126 struct psp_assoc *pas; 127 int state; 128 129 state = 1 << READ_ONCE(sk->sk_state); 130 if (!sk_is_inet(sk) || state & TCPF_NEW_SYN_RECV) 131 return NULL; 132 133 tw = inet_twsk(sk); 134 pas = state & TCPF_TIME_WAIT ? rcu_dereference(tw->psp_assoc) : 135 rcu_dereference(sk->psp_assoc); 136 return pas; 137 } 138 139 static inline struct psp_assoc *psp_skb_get_assoc_rcu(struct sk_buff *skb) 140 { 141 if (!skb->decrypted || !skb->sk) 142 return NULL; 143 144 return psp_sk_get_assoc_rcu(skb->sk); 145 } 146 #else 147 static inline void psp_sk_assoc_free(struct sock *sk) { } 148 static inline void 149 psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk) { } 150 static inline void psp_twsk_assoc_free(struct inet_timewait_sock *tw) { } 151 static inline void 152 psp_reply_set_decrypted(struct sk_buff *skb) { } 153 154 static inline struct psp_assoc *psp_sk_assoc(const struct sock *sk) 155 { 156 return NULL; 157 } 158 159 static inline void 160 psp_enqueue_set_decrypted(struct sock *sk, struct sk_buff *skb) { } 161 162 static inline unsigned long 163 __psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two, 164 unsigned long diffs) 165 { 166 return diffs; 167 } 168 169 static inline enum skb_drop_reason 170 psp_sk_rx_policy_check(struct sock *sk, struct sk_buff *skb) 171 { 172 return 0; 173 } 174 175 static inline enum skb_drop_reason 176 psp_twsk_rx_policy_check(struct inet_timewait_sock *tw, struct sk_buff *skb) 177 { 178 return 0; 179 } 180 181 static inline struct psp_assoc *psp_skb_get_assoc_rcu(struct sk_buff *skb) 182 { 183 return NULL; 184 } 185 #endif 186 187 static inline unsigned long 188 psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two) 189 { 190 return __psp_skb_coalesce_diff(one, two, 0); 191 } 192 193 #endif /* __NET_PSP_HELPERS_H */ 194