xref: /linux/include/net/psp/functions.h (revision 28bb24dadd0ed70aed43cf9af3a54c22c3ce04b2)
100c94ca2SJakub Kicinski /* SPDX-License-Identifier: GPL-2.0-only */
200c94ca2SJakub Kicinski 
300c94ca2SJakub Kicinski #ifndef __NET_PSP_HELPERS_H
400c94ca2SJakub Kicinski #define __NET_PSP_HELPERS_H
500c94ca2SJakub Kicinski 
6659a2899SJakub Kicinski #include <linux/skbuff.h>
76b46ca26SJakub Kicinski #include <linux/rcupdate.h>
8e9726925SJakub Kicinski #include <linux/udp.h>
9659a2899SJakub Kicinski #include <net/sock.h>
106b46ca26SJakub Kicinski #include <net/tcp.h>
1100c94ca2SJakub Kicinski #include <net/psp/types.h>
1200c94ca2SJakub Kicinski 
13ed8a507bSJakub Kicinski struct inet_timewait_sock;
14ed8a507bSJakub Kicinski 
1500c94ca2SJakub Kicinski /* Driver-facing API */
1600c94ca2SJakub Kicinski struct psp_dev *
1700c94ca2SJakub Kicinski psp_dev_create(struct net_device *netdev, struct psp_dev_ops *psd_ops,
1800c94ca2SJakub Kicinski 	       struct psp_dev_caps *psd_caps, void *priv_ptr);
1900c94ca2SJakub Kicinski void psp_dev_unregister(struct psp_dev *psd);
20fc724515SRaed Salem bool psp_dev_encapsulate(struct net *net, struct sk_buff *skb, __be32 spi,
21fc724515SRaed Salem 			 u8 ver, __be16 sport);
220eddb802SRaed Salem int psp_dev_rcv(struct sk_buff *skb, u16 dev_id, u8 generation, bool strip_icv);
2300c94ca2SJakub Kicinski 
24ed8a507bSJakub Kicinski /* Kernel-facing API */
256b46ca26SJakub Kicinski void psp_assoc_put(struct psp_assoc *pas);
266b46ca26SJakub Kicinski 
276b46ca26SJakub Kicinski static inline void *psp_assoc_drv_data(struct psp_assoc *pas)
286b46ca26SJakub Kicinski {
296b46ca26SJakub Kicinski 	return pas->drv_data;
306b46ca26SJakub Kicinski }
316b46ca26SJakub Kicinski 
32659a2899SJakub Kicinski #if IS_ENABLED(CONFIG_INET_PSP)
336b46ca26SJakub Kicinski unsigned int psp_key_size(u32 version);
346b46ca26SJakub Kicinski void psp_sk_assoc_free(struct sock *sk);
356b46ca26SJakub Kicinski void psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk);
366b46ca26SJakub Kicinski void psp_twsk_assoc_free(struct inet_timewait_sock *tw);
376b46ca26SJakub Kicinski void psp_reply_set_decrypted(struct sk_buff *skb);
386b46ca26SJakub Kicinski 
396b46ca26SJakub Kicinski static inline struct psp_assoc *psp_sk_assoc(const struct sock *sk)
406b46ca26SJakub Kicinski {
416b46ca26SJakub Kicinski 	return rcu_dereference_check(sk->psp_assoc, lockdep_sock_is_held(sk));
426b46ca26SJakub Kicinski }
43659a2899SJakub Kicinski 
44659a2899SJakub Kicinski static inline void
45659a2899SJakub Kicinski psp_enqueue_set_decrypted(struct sock *sk, struct sk_buff *skb)
46659a2899SJakub Kicinski {
476b46ca26SJakub Kicinski 	struct psp_assoc *pas;
486b46ca26SJakub Kicinski 
496b46ca26SJakub Kicinski 	pas = psp_sk_assoc(sk);
506b46ca26SJakub Kicinski 	if (pas && pas->tx.spi)
516b46ca26SJakub Kicinski 		skb->decrypted = 1;
52659a2899SJakub Kicinski }
53659a2899SJakub Kicinski 
54659a2899SJakub Kicinski static inline unsigned long
55659a2899SJakub Kicinski __psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two,
56659a2899SJakub Kicinski 			unsigned long diffs)
57659a2899SJakub Kicinski {
586b46ca26SJakub Kicinski 	struct psp_skb_ext *a, *b;
596b46ca26SJakub Kicinski 
606b46ca26SJakub Kicinski 	a = skb_ext_find(one, SKB_EXT_PSP);
616b46ca26SJakub Kicinski 	b = skb_ext_find(two, SKB_EXT_PSP);
626b46ca26SJakub Kicinski 
636b46ca26SJakub Kicinski 	diffs |= (!!a) ^ (!!b);
646b46ca26SJakub Kicinski 	if (!diffs && unlikely(a))
656b46ca26SJakub Kicinski 		diffs |= memcmp(a, b, sizeof(*a));
66659a2899SJakub Kicinski 	return diffs;
67659a2899SJakub Kicinski }
68659a2899SJakub Kicinski 
696b46ca26SJakub Kicinski static inline bool
706b46ca26SJakub Kicinski psp_is_allowed_nondata(struct sk_buff *skb, struct psp_assoc *pas)
716b46ca26SJakub Kicinski {
726b46ca26SJakub Kicinski 	bool fin = !!(TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN);
736b46ca26SJakub Kicinski 	u32 end_seq = TCP_SKB_CB(skb)->end_seq;
746b46ca26SJakub Kicinski 	u32 seq = TCP_SKB_CB(skb)->seq;
756b46ca26SJakub Kicinski 	bool pure_fin;
766b46ca26SJakub Kicinski 
776b46ca26SJakub Kicinski 	pure_fin = fin && end_seq - seq == 1;
786b46ca26SJakub Kicinski 
796b46ca26SJakub Kicinski 	return seq == end_seq || (pure_fin && seq == pas->upgrade_seq);
806b46ca26SJakub Kicinski }
816b46ca26SJakub Kicinski 
826b46ca26SJakub Kicinski static inline bool
836b46ca26SJakub Kicinski psp_pse_matches_pas(struct psp_skb_ext *pse, struct psp_assoc *pas)
846b46ca26SJakub Kicinski {
856b46ca26SJakub Kicinski 	return pse && pas->rx.spi == pse->spi &&
866b46ca26SJakub Kicinski 	       pas->generation == pse->generation &&
876b46ca26SJakub Kicinski 	       pas->version == pse->version &&
886b46ca26SJakub Kicinski 	       pas->dev_id == pse->dev_id;
896b46ca26SJakub Kicinski }
906b46ca26SJakub Kicinski 
916b46ca26SJakub Kicinski static inline enum skb_drop_reason
926b46ca26SJakub Kicinski __psp_sk_rx_policy_check(struct sk_buff *skb, struct psp_assoc *pas)
936b46ca26SJakub Kicinski {
946b46ca26SJakub Kicinski 	struct psp_skb_ext *pse = skb_ext_find(skb, SKB_EXT_PSP);
956b46ca26SJakub Kicinski 
966b46ca26SJakub Kicinski 	if (!pas)
976b46ca26SJakub Kicinski 		return pse ? SKB_DROP_REASON_PSP_INPUT : 0;
986b46ca26SJakub Kicinski 
996b46ca26SJakub Kicinski 	if (likely(psp_pse_matches_pas(pse, pas))) {
1006b46ca26SJakub Kicinski 		if (unlikely(!pas->peer_tx))
1016b46ca26SJakub Kicinski 			pas->peer_tx = 1;
1026b46ca26SJakub Kicinski 
1036b46ca26SJakub Kicinski 		return 0;
1046b46ca26SJakub Kicinski 	}
1056b46ca26SJakub Kicinski 
1066b46ca26SJakub Kicinski 	if (!pse) {
1076b46ca26SJakub Kicinski 		if (!pas->tx.spi ||
1086b46ca26SJakub Kicinski 		    (!pas->peer_tx && psp_is_allowed_nondata(skb, pas)))
1096b46ca26SJakub Kicinski 			return 0;
1106b46ca26SJakub Kicinski 	}
1116b46ca26SJakub Kicinski 
1126b46ca26SJakub Kicinski 	return SKB_DROP_REASON_PSP_INPUT;
1136b46ca26SJakub Kicinski }
1146b46ca26SJakub Kicinski 
115659a2899SJakub Kicinski static inline enum skb_drop_reason
116659a2899SJakub Kicinski psp_sk_rx_policy_check(struct sock *sk, struct sk_buff *skb)
117659a2899SJakub Kicinski {
1186b46ca26SJakub Kicinski 	return __psp_sk_rx_policy_check(skb, psp_sk_assoc(sk));
119659a2899SJakub Kicinski }
120659a2899SJakub Kicinski 
121659a2899SJakub Kicinski static inline enum skb_drop_reason
122659a2899SJakub Kicinski psp_twsk_rx_policy_check(struct inet_timewait_sock *tw, struct sk_buff *skb)
123659a2899SJakub Kicinski {
1246b46ca26SJakub Kicinski 	return __psp_sk_rx_policy_check(skb, rcu_dereference(tw->psp_assoc));
1256b46ca26SJakub Kicinski }
1266b46ca26SJakub Kicinski 
127f8d2f820SDaniel Zahka static inline struct psp_assoc *psp_sk_get_assoc_rcu(const struct sock *sk)
1286b46ca26SJakub Kicinski {
1296b46ca26SJakub Kicinski 	struct psp_assoc *pas;
1306b46ca26SJakub Kicinski 	int state;
1316b46ca26SJakub Kicinski 
132*28bb24daSDaniel Zahka 	state = READ_ONCE(sk->sk_state);
133*28bb24daSDaniel Zahka 	if (!sk_is_inet(sk) || state == TCP_NEW_SYN_RECV)
1346b46ca26SJakub Kicinski 		return NULL;
1356b46ca26SJakub Kicinski 
136*28bb24daSDaniel Zahka 	pas = state == TCP_TIME_WAIT ?
137803cdb6dSDaniel Zahka 		      rcu_dereference(inet_twsk(sk)->psp_assoc) :
1386b46ca26SJakub Kicinski 		      rcu_dereference(sk->psp_assoc);
1396b46ca26SJakub Kicinski 	return pas;
140659a2899SJakub Kicinski }
141659a2899SJakub Kicinski 
142659a2899SJakub Kicinski static inline struct psp_assoc *psp_skb_get_assoc_rcu(struct sk_buff *skb)
143659a2899SJakub Kicinski {
1446b46ca26SJakub Kicinski 	if (!skb->decrypted || !skb->sk)
145659a2899SJakub Kicinski 		return NULL;
1466b46ca26SJakub Kicinski 
1476b46ca26SJakub Kicinski 	return psp_sk_get_assoc_rcu(skb->sk);
148659a2899SJakub Kicinski }
149e9726925SJakub Kicinski 
150e9726925SJakub Kicinski static inline unsigned int psp_sk_overhead(const struct sock *sk)
151e9726925SJakub Kicinski {
152e9726925SJakub Kicinski 	int psp_encap = sizeof(struct udphdr) + PSP_HDR_SIZE + PSP_TRL_SIZE;
153e9726925SJakub Kicinski 	bool has_psp = rcu_access_pointer(sk->psp_assoc);
154e9726925SJakub Kicinski 
155e9726925SJakub Kicinski 	return has_psp ? psp_encap : 0;
156e9726925SJakub Kicinski }
157659a2899SJakub Kicinski #else
158659a2899SJakub Kicinski static inline void psp_sk_assoc_free(struct sock *sk) { }
159659a2899SJakub Kicinski static inline void
160659a2899SJakub Kicinski psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk) { }
161659a2899SJakub Kicinski static inline void psp_twsk_assoc_free(struct inet_timewait_sock *tw) { }
162659a2899SJakub Kicinski static inline void
163659a2899SJakub Kicinski psp_reply_set_decrypted(struct sk_buff *skb) { }
164659a2899SJakub Kicinski 
1656b46ca26SJakub Kicinski static inline struct psp_assoc *psp_sk_assoc(const struct sock *sk)
1666b46ca26SJakub Kicinski {
1676b46ca26SJakub Kicinski 	return NULL;
1686b46ca26SJakub Kicinski }
1696b46ca26SJakub Kicinski 
170659a2899SJakub Kicinski static inline void
171659a2899SJakub Kicinski psp_enqueue_set_decrypted(struct sock *sk, struct sk_buff *skb) { }
172659a2899SJakub Kicinski 
173659a2899SJakub Kicinski static inline unsigned long
174659a2899SJakub Kicinski __psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two,
175659a2899SJakub Kicinski 			unsigned long diffs)
176659a2899SJakub Kicinski {
177659a2899SJakub Kicinski 	return diffs;
178659a2899SJakub Kicinski }
179659a2899SJakub Kicinski 
180659a2899SJakub Kicinski static inline enum skb_drop_reason
181659a2899SJakub Kicinski psp_sk_rx_policy_check(struct sock *sk, struct sk_buff *skb)
182659a2899SJakub Kicinski {
183659a2899SJakub Kicinski 	return 0;
184659a2899SJakub Kicinski }
185659a2899SJakub Kicinski 
186659a2899SJakub Kicinski static inline enum skb_drop_reason
187659a2899SJakub Kicinski psp_twsk_rx_policy_check(struct inet_timewait_sock *tw, struct sk_buff *skb)
188659a2899SJakub Kicinski {
189659a2899SJakub Kicinski 	return 0;
190659a2899SJakub Kicinski }
191659a2899SJakub Kicinski 
192659a2899SJakub Kicinski static inline struct psp_assoc *psp_skb_get_assoc_rcu(struct sk_buff *skb)
193659a2899SJakub Kicinski {
194659a2899SJakub Kicinski 	return NULL;
195659a2899SJakub Kicinski }
196e9726925SJakub Kicinski 
197e9726925SJakub Kicinski static inline unsigned int psp_sk_overhead(const struct sock *sk)
198e9726925SJakub Kicinski {
199e9726925SJakub Kicinski 	return 0;
200e9726925SJakub Kicinski }
201659a2899SJakub Kicinski #endif
202659a2899SJakub Kicinski 
203659a2899SJakub Kicinski static inline unsigned long
204659a2899SJakub Kicinski psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two)
205659a2899SJakub Kicinski {
206659a2899SJakub Kicinski 	return __psp_skb_coalesce_diff(one, two, 0);
207659a2899SJakub Kicinski }
208ed8a507bSJakub Kicinski 
20900c94ca2SJakub Kicinski #endif /* __NET_PSP_HELPERS_H */
210