xref: /linux/include/net/psp/functions.h (revision 6b46ca260e2290e3453d1355ab5b6d283d73d780)
100c94ca2SJakub Kicinski /* SPDX-License-Identifier: GPL-2.0-only */
200c94ca2SJakub Kicinski 
300c94ca2SJakub Kicinski #ifndef __NET_PSP_HELPERS_H
400c94ca2SJakub Kicinski #define __NET_PSP_HELPERS_H
500c94ca2SJakub Kicinski 
6659a2899SJakub Kicinski #include <linux/skbuff.h>
7*6b46ca26SJakub Kicinski #include <linux/rcupdate.h>
8659a2899SJakub Kicinski #include <net/sock.h>
9*6b46ca26SJakub Kicinski #include <net/tcp.h>
1000c94ca2SJakub Kicinski #include <net/psp/types.h>
1100c94ca2SJakub Kicinski 
12ed8a507bSJakub Kicinski struct inet_timewait_sock;
13ed8a507bSJakub Kicinski 
1400c94ca2SJakub Kicinski /* Driver-facing API */
1500c94ca2SJakub Kicinski struct psp_dev *
1600c94ca2SJakub Kicinski psp_dev_create(struct net_device *netdev, struct psp_dev_ops *psd_ops,
1700c94ca2SJakub Kicinski 	       struct psp_dev_caps *psd_caps, void *priv_ptr);
1800c94ca2SJakub Kicinski void psp_dev_unregister(struct psp_dev *psd);
1900c94ca2SJakub Kicinski 
20ed8a507bSJakub Kicinski /* Kernel-facing API */
21*6b46ca26SJakub Kicinski void psp_assoc_put(struct psp_assoc *pas);
22*6b46ca26SJakub Kicinski 
23*6b46ca26SJakub Kicinski static inline void *psp_assoc_drv_data(struct psp_assoc *pas)
24*6b46ca26SJakub Kicinski {
25*6b46ca26SJakub Kicinski 	return pas->drv_data;
26*6b46ca26SJakub Kicinski }
27*6b46ca26SJakub Kicinski 
28659a2899SJakub Kicinski #if IS_ENABLED(CONFIG_INET_PSP)
29*6b46ca26SJakub Kicinski unsigned int psp_key_size(u32 version);
30*6b46ca26SJakub Kicinski void psp_sk_assoc_free(struct sock *sk);
31*6b46ca26SJakub Kicinski void psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk);
32*6b46ca26SJakub Kicinski void psp_twsk_assoc_free(struct inet_timewait_sock *tw);
33*6b46ca26SJakub Kicinski void psp_reply_set_decrypted(struct sk_buff *skb);
34*6b46ca26SJakub Kicinski 
35*6b46ca26SJakub Kicinski static inline struct psp_assoc *psp_sk_assoc(const struct sock *sk)
36*6b46ca26SJakub Kicinski {
37*6b46ca26SJakub Kicinski 	return rcu_dereference_check(sk->psp_assoc, lockdep_sock_is_held(sk));
38*6b46ca26SJakub Kicinski }
39659a2899SJakub Kicinski 
40659a2899SJakub Kicinski static inline void
41659a2899SJakub Kicinski psp_enqueue_set_decrypted(struct sock *sk, struct sk_buff *skb)
42659a2899SJakub Kicinski {
43*6b46ca26SJakub Kicinski 	struct psp_assoc *pas;
44*6b46ca26SJakub Kicinski 
45*6b46ca26SJakub Kicinski 	pas = psp_sk_assoc(sk);
46*6b46ca26SJakub Kicinski 	if (pas && pas->tx.spi)
47*6b46ca26SJakub Kicinski 		skb->decrypted = 1;
48659a2899SJakub Kicinski }
49659a2899SJakub Kicinski 
50659a2899SJakub Kicinski static inline unsigned long
51659a2899SJakub Kicinski __psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two,
52659a2899SJakub Kicinski 			unsigned long diffs)
53659a2899SJakub Kicinski {
54*6b46ca26SJakub Kicinski 	struct psp_skb_ext *a, *b;
55*6b46ca26SJakub Kicinski 
56*6b46ca26SJakub Kicinski 	a = skb_ext_find(one, SKB_EXT_PSP);
57*6b46ca26SJakub Kicinski 	b = skb_ext_find(two, SKB_EXT_PSP);
58*6b46ca26SJakub Kicinski 
59*6b46ca26SJakub Kicinski 	diffs |= (!!a) ^ (!!b);
60*6b46ca26SJakub Kicinski 	if (!diffs && unlikely(a))
61*6b46ca26SJakub Kicinski 		diffs |= memcmp(a, b, sizeof(*a));
62659a2899SJakub Kicinski 	return diffs;
63659a2899SJakub Kicinski }
64659a2899SJakub Kicinski 
65*6b46ca26SJakub Kicinski static inline bool
66*6b46ca26SJakub Kicinski psp_is_allowed_nondata(struct sk_buff *skb, struct psp_assoc *pas)
67*6b46ca26SJakub Kicinski {
68*6b46ca26SJakub Kicinski 	bool fin = !!(TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN);
69*6b46ca26SJakub Kicinski 	u32 end_seq = TCP_SKB_CB(skb)->end_seq;
70*6b46ca26SJakub Kicinski 	u32 seq = TCP_SKB_CB(skb)->seq;
71*6b46ca26SJakub Kicinski 	bool pure_fin;
72*6b46ca26SJakub Kicinski 
73*6b46ca26SJakub Kicinski 	pure_fin = fin && end_seq - seq == 1;
74*6b46ca26SJakub Kicinski 
75*6b46ca26SJakub Kicinski 	return seq == end_seq || (pure_fin && seq == pas->upgrade_seq);
76*6b46ca26SJakub Kicinski }
77*6b46ca26SJakub Kicinski 
78*6b46ca26SJakub Kicinski static inline bool
79*6b46ca26SJakub Kicinski psp_pse_matches_pas(struct psp_skb_ext *pse, struct psp_assoc *pas)
80*6b46ca26SJakub Kicinski {
81*6b46ca26SJakub Kicinski 	return pse && pas->rx.spi == pse->spi &&
82*6b46ca26SJakub Kicinski 	       pas->generation == pse->generation &&
83*6b46ca26SJakub Kicinski 	       pas->version == pse->version &&
84*6b46ca26SJakub Kicinski 	       pas->dev_id == pse->dev_id;
85*6b46ca26SJakub Kicinski }
86*6b46ca26SJakub Kicinski 
87*6b46ca26SJakub Kicinski static inline enum skb_drop_reason
88*6b46ca26SJakub Kicinski __psp_sk_rx_policy_check(struct sk_buff *skb, struct psp_assoc *pas)
89*6b46ca26SJakub Kicinski {
90*6b46ca26SJakub Kicinski 	struct psp_skb_ext *pse = skb_ext_find(skb, SKB_EXT_PSP);
91*6b46ca26SJakub Kicinski 
92*6b46ca26SJakub Kicinski 	if (!pas)
93*6b46ca26SJakub Kicinski 		return pse ? SKB_DROP_REASON_PSP_INPUT : 0;
94*6b46ca26SJakub Kicinski 
95*6b46ca26SJakub Kicinski 	if (likely(psp_pse_matches_pas(pse, pas))) {
96*6b46ca26SJakub Kicinski 		if (unlikely(!pas->peer_tx))
97*6b46ca26SJakub Kicinski 			pas->peer_tx = 1;
98*6b46ca26SJakub Kicinski 
99*6b46ca26SJakub Kicinski 		return 0;
100*6b46ca26SJakub Kicinski 	}
101*6b46ca26SJakub Kicinski 
102*6b46ca26SJakub Kicinski 	if (!pse) {
103*6b46ca26SJakub Kicinski 		if (!pas->tx.spi ||
104*6b46ca26SJakub Kicinski 		    (!pas->peer_tx && psp_is_allowed_nondata(skb, pas)))
105*6b46ca26SJakub Kicinski 			return 0;
106*6b46ca26SJakub Kicinski 	}
107*6b46ca26SJakub Kicinski 
108*6b46ca26SJakub Kicinski 	return SKB_DROP_REASON_PSP_INPUT;
109*6b46ca26SJakub Kicinski }
110*6b46ca26SJakub Kicinski 
111659a2899SJakub Kicinski static inline enum skb_drop_reason
112659a2899SJakub Kicinski psp_sk_rx_policy_check(struct sock *sk, struct sk_buff *skb)
113659a2899SJakub Kicinski {
114*6b46ca26SJakub Kicinski 	return __psp_sk_rx_policy_check(skb, psp_sk_assoc(sk));
115659a2899SJakub Kicinski }
116659a2899SJakub Kicinski 
117659a2899SJakub Kicinski static inline enum skb_drop_reason
118659a2899SJakub Kicinski psp_twsk_rx_policy_check(struct inet_timewait_sock *tw, struct sk_buff *skb)
119659a2899SJakub Kicinski {
120*6b46ca26SJakub Kicinski 	return __psp_sk_rx_policy_check(skb, rcu_dereference(tw->psp_assoc));
121*6b46ca26SJakub Kicinski }
122*6b46ca26SJakub Kicinski 
123*6b46ca26SJakub Kicinski static inline struct psp_assoc *psp_sk_get_assoc_rcu(struct sock *sk)
124*6b46ca26SJakub Kicinski {
125*6b46ca26SJakub Kicinski 	struct inet_timewait_sock *tw;
126*6b46ca26SJakub Kicinski 	struct psp_assoc *pas;
127*6b46ca26SJakub Kicinski 	int state;
128*6b46ca26SJakub Kicinski 
129*6b46ca26SJakub Kicinski 	state = 1 << READ_ONCE(sk->sk_state);
130*6b46ca26SJakub Kicinski 	if (!sk_is_inet(sk) || state & TCPF_NEW_SYN_RECV)
131*6b46ca26SJakub Kicinski 		return NULL;
132*6b46ca26SJakub Kicinski 
133*6b46ca26SJakub Kicinski 	tw = inet_twsk(sk);
134*6b46ca26SJakub Kicinski 	pas = state & TCPF_TIME_WAIT ? rcu_dereference(tw->psp_assoc) :
135*6b46ca26SJakub Kicinski 				       rcu_dereference(sk->psp_assoc);
136*6b46ca26SJakub Kicinski 	return pas;
137659a2899SJakub Kicinski }
138659a2899SJakub Kicinski 
139659a2899SJakub Kicinski static inline struct psp_assoc *psp_skb_get_assoc_rcu(struct sk_buff *skb)
140659a2899SJakub Kicinski {
141*6b46ca26SJakub Kicinski 	if (!skb->decrypted || !skb->sk)
142659a2899SJakub Kicinski 		return NULL;
143*6b46ca26SJakub Kicinski 
144*6b46ca26SJakub Kicinski 	return psp_sk_get_assoc_rcu(skb->sk);
145659a2899SJakub Kicinski }
146659a2899SJakub Kicinski #else
147659a2899SJakub Kicinski static inline void psp_sk_assoc_free(struct sock *sk) { }
148659a2899SJakub Kicinski static inline void
149659a2899SJakub Kicinski psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk) { }
150659a2899SJakub Kicinski static inline void psp_twsk_assoc_free(struct inet_timewait_sock *tw) { }
151659a2899SJakub Kicinski static inline void
152659a2899SJakub Kicinski psp_reply_set_decrypted(struct sk_buff *skb) { }
153659a2899SJakub Kicinski 
154*6b46ca26SJakub Kicinski static inline struct psp_assoc *psp_sk_assoc(const struct sock *sk)
155*6b46ca26SJakub Kicinski {
156*6b46ca26SJakub Kicinski 	return NULL;
157*6b46ca26SJakub Kicinski }
158*6b46ca26SJakub Kicinski 
159659a2899SJakub Kicinski static inline void
160659a2899SJakub Kicinski psp_enqueue_set_decrypted(struct sock *sk, struct sk_buff *skb) { }
161659a2899SJakub Kicinski 
162659a2899SJakub Kicinski static inline unsigned long
163659a2899SJakub Kicinski __psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two,
164659a2899SJakub Kicinski 			unsigned long diffs)
165659a2899SJakub Kicinski {
166659a2899SJakub Kicinski 	return diffs;
167659a2899SJakub Kicinski }
168659a2899SJakub Kicinski 
169659a2899SJakub Kicinski static inline enum skb_drop_reason
170659a2899SJakub Kicinski psp_sk_rx_policy_check(struct sock *sk, struct sk_buff *skb)
171659a2899SJakub Kicinski {
172659a2899SJakub Kicinski 	return 0;
173659a2899SJakub Kicinski }
174659a2899SJakub Kicinski 
175659a2899SJakub Kicinski static inline enum skb_drop_reason
176659a2899SJakub Kicinski psp_twsk_rx_policy_check(struct inet_timewait_sock *tw, struct sk_buff *skb)
177659a2899SJakub Kicinski {
178659a2899SJakub Kicinski 	return 0;
179659a2899SJakub Kicinski }
180659a2899SJakub Kicinski 
181659a2899SJakub Kicinski static inline struct psp_assoc *psp_skb_get_assoc_rcu(struct sk_buff *skb)
182659a2899SJakub Kicinski {
183659a2899SJakub Kicinski 	return NULL;
184659a2899SJakub Kicinski }
185659a2899SJakub Kicinski #endif
186659a2899SJakub Kicinski 
187659a2899SJakub Kicinski static inline unsigned long
188659a2899SJakub Kicinski psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two)
189659a2899SJakub Kicinski {
190659a2899SJakub Kicinski 	return __psp_skb_coalesce_diff(one, two, 0);
191659a2899SJakub Kicinski }
192ed8a507bSJakub Kicinski 
19300c94ca2SJakub Kicinski #endif /* __NET_PSP_HELPERS_H */
194