xref: /linux/include/net/psp/functions.h (revision e97269257fe437910cddc7c642a636ca3cf9fb1d)
1 /* SPDX-License-Identifier: GPL-2.0-only */
2 
3 #ifndef __NET_PSP_HELPERS_H
4 #define __NET_PSP_HELPERS_H
5 
6 #include <linux/skbuff.h>
7 #include <linux/rcupdate.h>
8 #include <linux/udp.h>
9 #include <net/sock.h>
10 #include <net/tcp.h>
11 #include <net/psp/types.h>
12 
13 struct inet_timewait_sock;
14 
15 /* Driver-facing API */
16 struct psp_dev *
17 psp_dev_create(struct net_device *netdev, struct psp_dev_ops *psd_ops,
18 	       struct psp_dev_caps *psd_caps, void *priv_ptr);
19 void psp_dev_unregister(struct psp_dev *psd);
20 
21 /* Kernel-facing API */
22 void psp_assoc_put(struct psp_assoc *pas);
23 
24 static inline void *psp_assoc_drv_data(struct psp_assoc *pas)
25 {
26 	return pas->drv_data;
27 }
28 
29 #if IS_ENABLED(CONFIG_INET_PSP)
30 unsigned int psp_key_size(u32 version);
31 void psp_sk_assoc_free(struct sock *sk);
32 void psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk);
33 void psp_twsk_assoc_free(struct inet_timewait_sock *tw);
34 void psp_reply_set_decrypted(struct sk_buff *skb);
35 
36 static inline struct psp_assoc *psp_sk_assoc(const struct sock *sk)
37 {
38 	return rcu_dereference_check(sk->psp_assoc, lockdep_sock_is_held(sk));
39 }
40 
41 static inline void
42 psp_enqueue_set_decrypted(struct sock *sk, struct sk_buff *skb)
43 {
44 	struct psp_assoc *pas;
45 
46 	pas = psp_sk_assoc(sk);
47 	if (pas && pas->tx.spi)
48 		skb->decrypted = 1;
49 }
50 
51 static inline unsigned long
52 __psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two,
53 			unsigned long diffs)
54 {
55 	struct psp_skb_ext *a, *b;
56 
57 	a = skb_ext_find(one, SKB_EXT_PSP);
58 	b = skb_ext_find(two, SKB_EXT_PSP);
59 
60 	diffs |= (!!a) ^ (!!b);
61 	if (!diffs && unlikely(a))
62 		diffs |= memcmp(a, b, sizeof(*a));
63 	return diffs;
64 }
65 
66 static inline bool
67 psp_is_allowed_nondata(struct sk_buff *skb, struct psp_assoc *pas)
68 {
69 	bool fin = !!(TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN);
70 	u32 end_seq = TCP_SKB_CB(skb)->end_seq;
71 	u32 seq = TCP_SKB_CB(skb)->seq;
72 	bool pure_fin;
73 
74 	pure_fin = fin && end_seq - seq == 1;
75 
76 	return seq == end_seq || (pure_fin && seq == pas->upgrade_seq);
77 }
78 
79 static inline bool
80 psp_pse_matches_pas(struct psp_skb_ext *pse, struct psp_assoc *pas)
81 {
82 	return pse && pas->rx.spi == pse->spi &&
83 	       pas->generation == pse->generation &&
84 	       pas->version == pse->version &&
85 	       pas->dev_id == pse->dev_id;
86 }
87 
88 static inline enum skb_drop_reason
89 __psp_sk_rx_policy_check(struct sk_buff *skb, struct psp_assoc *pas)
90 {
91 	struct psp_skb_ext *pse = skb_ext_find(skb, SKB_EXT_PSP);
92 
93 	if (!pas)
94 		return pse ? SKB_DROP_REASON_PSP_INPUT : 0;
95 
96 	if (likely(psp_pse_matches_pas(pse, pas))) {
97 		if (unlikely(!pas->peer_tx))
98 			pas->peer_tx = 1;
99 
100 		return 0;
101 	}
102 
103 	if (!pse) {
104 		if (!pas->tx.spi ||
105 		    (!pas->peer_tx && psp_is_allowed_nondata(skb, pas)))
106 			return 0;
107 	}
108 
109 	return SKB_DROP_REASON_PSP_INPUT;
110 }
111 
112 static inline enum skb_drop_reason
113 psp_sk_rx_policy_check(struct sock *sk, struct sk_buff *skb)
114 {
115 	return __psp_sk_rx_policy_check(skb, psp_sk_assoc(sk));
116 }
117 
118 static inline enum skb_drop_reason
119 psp_twsk_rx_policy_check(struct inet_timewait_sock *tw, struct sk_buff *skb)
120 {
121 	return __psp_sk_rx_policy_check(skb, rcu_dereference(tw->psp_assoc));
122 }
123 
124 static inline struct psp_assoc *psp_sk_get_assoc_rcu(struct sock *sk)
125 {
126 	struct inet_timewait_sock *tw;
127 	struct psp_assoc *pas;
128 	int state;
129 
130 	state = 1 << READ_ONCE(sk->sk_state);
131 	if (!sk_is_inet(sk) || state & TCPF_NEW_SYN_RECV)
132 		return NULL;
133 
134 	tw = inet_twsk(sk);
135 	pas = state & TCPF_TIME_WAIT ? rcu_dereference(tw->psp_assoc) :
136 				       rcu_dereference(sk->psp_assoc);
137 	return pas;
138 }
139 
140 static inline struct psp_assoc *psp_skb_get_assoc_rcu(struct sk_buff *skb)
141 {
142 	if (!skb->decrypted || !skb->sk)
143 		return NULL;
144 
145 	return psp_sk_get_assoc_rcu(skb->sk);
146 }
147 
148 static inline unsigned int psp_sk_overhead(const struct sock *sk)
149 {
150 	int psp_encap = sizeof(struct udphdr) + PSP_HDR_SIZE + PSP_TRL_SIZE;
151 	bool has_psp = rcu_access_pointer(sk->psp_assoc);
152 
153 	return has_psp ? psp_encap : 0;
154 }
155 #else
156 static inline void psp_sk_assoc_free(struct sock *sk) { }
157 static inline void
158 psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk) { }
159 static inline void psp_twsk_assoc_free(struct inet_timewait_sock *tw) { }
160 static inline void
161 psp_reply_set_decrypted(struct sk_buff *skb) { }
162 
163 static inline struct psp_assoc *psp_sk_assoc(const struct sock *sk)
164 {
165 	return NULL;
166 }
167 
168 static inline void
169 psp_enqueue_set_decrypted(struct sock *sk, struct sk_buff *skb) { }
170 
171 static inline unsigned long
172 __psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two,
173 			unsigned long diffs)
174 {
175 	return diffs;
176 }
177 
178 static inline enum skb_drop_reason
179 psp_sk_rx_policy_check(struct sock *sk, struct sk_buff *skb)
180 {
181 	return 0;
182 }
183 
184 static inline enum skb_drop_reason
185 psp_twsk_rx_policy_check(struct inet_timewait_sock *tw, struct sk_buff *skb)
186 {
187 	return 0;
188 }
189 
190 static inline struct psp_assoc *psp_skb_get_assoc_rcu(struct sk_buff *skb)
191 {
192 	return NULL;
193 }
194 
195 static inline unsigned int psp_sk_overhead(const struct sock *sk)
196 {
197 	return 0;
198 }
199 #endif
200 
201 static inline unsigned long
202 psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two)
203 {
204 	return __psp_skb_coalesce_diff(one, two, 0);
205 }
206 
207 #endif /* __NET_PSP_HELPERS_H */
208