xref: /linux/net/psp/psp_sock.c (revision 07fdad3a93756b872da7b53647715c48d0f4a2d0)
1 // SPDX-License-Identifier: GPL-2.0-only
2 
3 #include <linux/file.h>
4 #include <linux/net.h>
5 #include <linux/rcupdate.h>
6 #include <linux/tcp.h>
7 
8 #include <net/ip.h>
9 #include <net/psp.h>
10 #include "psp.h"
11 
12 struct psp_dev *psp_dev_get_for_sock(struct sock *sk)
13 {
14 	struct psp_dev *psd = NULL;
15 	struct dst_entry *dst;
16 
17 	rcu_read_lock();
18 	dst = __sk_dst_get(sk);
19 	if (dst) {
20 		psd = rcu_dereference(dst_dev_rcu(dst)->psp_dev);
21 		if (psd && !psp_dev_tryget(psd))
22 			psd = NULL;
23 	}
24 	rcu_read_unlock();
25 
26 	return psd;
27 }
28 
29 static struct sk_buff *
30 psp_validate_xmit(struct sock *sk, struct net_device *dev, struct sk_buff *skb)
31 {
32 	struct psp_assoc *pas;
33 	bool good;
34 
35 	rcu_read_lock();
36 	pas = psp_skb_get_assoc_rcu(skb);
37 	good = !pas || rcu_access_pointer(dev->psp_dev) == pas->psd;
38 	rcu_read_unlock();
39 	if (!good) {
40 		sk_skb_reason_drop(sk, skb, SKB_DROP_REASON_PSP_OUTPUT);
41 		return NULL;
42 	}
43 
44 	return skb;
45 }
46 
47 struct psp_assoc *psp_assoc_create(struct psp_dev *psd)
48 {
49 	struct psp_assoc *pas;
50 
51 	lockdep_assert_held(&psd->lock);
52 
53 	pas = kzalloc(struct_size(pas, drv_data, psd->caps->assoc_drv_spc),
54 		      GFP_KERNEL_ACCOUNT);
55 	if (!pas)
56 		return NULL;
57 
58 	pas->psd = psd;
59 	pas->dev_id = psd->id;
60 	pas->generation = psd->generation;
61 	psp_dev_get(psd);
62 	refcount_set(&pas->refcnt, 1);
63 
64 	list_add_tail(&pas->assocs_list, &psd->active_assocs);
65 
66 	return pas;
67 }
68 
69 static struct psp_assoc *psp_assoc_dummy(struct psp_assoc *pas)
70 {
71 	struct psp_dev *psd = pas->psd;
72 	size_t sz;
73 
74 	lockdep_assert_held(&psd->lock);
75 
76 	sz = struct_size(pas, drv_data, psd->caps->assoc_drv_spc);
77 	return kmemdup(pas, sz, GFP_KERNEL);
78 }
79 
80 static int psp_dev_tx_key_add(struct psp_dev *psd, struct psp_assoc *pas,
81 			      struct netlink_ext_ack *extack)
82 {
83 	return psd->ops->tx_key_add(psd, pas, extack);
84 }
85 
86 void psp_dev_tx_key_del(struct psp_dev *psd, struct psp_assoc *pas)
87 {
88 	if (pas->tx.spi)
89 		psd->ops->tx_key_del(psd, pas);
90 	list_del(&pas->assocs_list);
91 }
92 
93 static void psp_assoc_free(struct work_struct *work)
94 {
95 	struct psp_assoc *pas = container_of(work, struct psp_assoc, work);
96 	struct psp_dev *psd = pas->psd;
97 
98 	mutex_lock(&psd->lock);
99 	if (psd->ops)
100 		psp_dev_tx_key_del(psd, pas);
101 	mutex_unlock(&psd->lock);
102 	psp_dev_put(psd);
103 	kfree(pas);
104 }
105 
106 static void psp_assoc_free_queue(struct rcu_head *head)
107 {
108 	struct psp_assoc *pas = container_of(head, struct psp_assoc, rcu);
109 
110 	INIT_WORK(&pas->work, psp_assoc_free);
111 	schedule_work(&pas->work);
112 }
113 
114 /**
115  * psp_assoc_put() - release a reference on a PSP association
116  * @pas: association to release
117  */
118 void psp_assoc_put(struct psp_assoc *pas)
119 {
120 	if (pas && refcount_dec_and_test(&pas->refcnt))
121 		call_rcu(&pas->rcu, psp_assoc_free_queue);
122 }
123 
124 void psp_sk_assoc_free(struct sock *sk)
125 {
126 	struct psp_assoc *pas = rcu_dereference_protected(sk->psp_assoc, 1);
127 
128 	rcu_assign_pointer(sk->psp_assoc, NULL);
129 	psp_assoc_put(pas);
130 }
131 
132 int psp_sock_assoc_set_rx(struct sock *sk, struct psp_assoc *pas,
133 			  struct psp_key_parsed *key,
134 			  struct netlink_ext_ack *extack)
135 {
136 	int err;
137 
138 	memcpy(&pas->rx, key, sizeof(*key));
139 
140 	lock_sock(sk);
141 
142 	if (psp_sk_assoc(sk)) {
143 		NL_SET_ERR_MSG(extack, "Socket already has PSP state");
144 		err = -EBUSY;
145 		goto exit_unlock;
146 	}
147 
148 	refcount_inc(&pas->refcnt);
149 	rcu_assign_pointer(sk->psp_assoc, pas);
150 	err = 0;
151 
152 exit_unlock:
153 	release_sock(sk);
154 
155 	return err;
156 }
157 
158 static int psp_sock_recv_queue_check(struct sock *sk, struct psp_assoc *pas)
159 {
160 	struct psp_skb_ext *pse;
161 	struct sk_buff *skb;
162 
163 	skb_rbtree_walk(skb, &tcp_sk(sk)->out_of_order_queue) {
164 		pse = skb_ext_find(skb, SKB_EXT_PSP);
165 		if (!psp_pse_matches_pas(pse, pas))
166 			return -EBUSY;
167 	}
168 
169 	skb_queue_walk(&sk->sk_receive_queue, skb) {
170 		pse = skb_ext_find(skb, SKB_EXT_PSP);
171 		if (!psp_pse_matches_pas(pse, pas))
172 			return -EBUSY;
173 	}
174 	return 0;
175 }
176 
177 int psp_sock_assoc_set_tx(struct sock *sk, struct psp_dev *psd,
178 			  u32 version, struct psp_key_parsed *key,
179 			  struct netlink_ext_ack *extack)
180 {
181 	struct inet_connection_sock *icsk;
182 	struct psp_assoc *pas, *dummy;
183 	int err;
184 
185 	lock_sock(sk);
186 
187 	pas = psp_sk_assoc(sk);
188 	if (!pas) {
189 		NL_SET_ERR_MSG(extack, "Socket has no Rx key");
190 		err = -EINVAL;
191 		goto exit_unlock;
192 	}
193 	if (pas->psd != psd) {
194 		NL_SET_ERR_MSG(extack, "Rx key from different device");
195 		err = -EINVAL;
196 		goto exit_unlock;
197 	}
198 	if (pas->version != version) {
199 		NL_SET_ERR_MSG(extack,
200 			       "PSP version mismatch with existing state");
201 		err = -EINVAL;
202 		goto exit_unlock;
203 	}
204 	if (pas->tx.spi) {
205 		NL_SET_ERR_MSG(extack, "Tx key already set");
206 		err = -EBUSY;
207 		goto exit_unlock;
208 	}
209 
210 	err = psp_sock_recv_queue_check(sk, pas);
211 	if (err) {
212 		NL_SET_ERR_MSG(extack, "Socket has incompatible segments already in the recv queue");
213 		goto exit_unlock;
214 	}
215 
216 	/* Pass a fake association to drivers to make sure they don't
217 	 * try to store pointers to it. For re-keying we'll need to
218 	 * re-allocate the assoc structures.
219 	 */
220 	dummy = psp_assoc_dummy(pas);
221 	if (!dummy) {
222 		err = -ENOMEM;
223 		goto exit_unlock;
224 	}
225 
226 	memcpy(&dummy->tx, key, sizeof(*key));
227 	err = psp_dev_tx_key_add(psd, dummy, extack);
228 	if (err)
229 		goto exit_free_dummy;
230 
231 	memcpy(pas->drv_data, dummy->drv_data, psd->caps->assoc_drv_spc);
232 	memcpy(&pas->tx, key, sizeof(*key));
233 
234 	WRITE_ONCE(sk->sk_validate_xmit_skb, psp_validate_xmit);
235 	tcp_write_collapse_fence(sk);
236 	pas->upgrade_seq = tcp_sk(sk)->rcv_nxt;
237 
238 	icsk = inet_csk(sk);
239 	icsk->icsk_ext_hdr_len += psp_sk_overhead(sk);
240 	icsk->icsk_sync_mss(sk, icsk->icsk_pmtu_cookie);
241 
242 exit_free_dummy:
243 	kfree(dummy);
244 exit_unlock:
245 	release_sock(sk);
246 	return err;
247 }
248 
249 void psp_assocs_key_rotated(struct psp_dev *psd)
250 {
251 	struct psp_assoc *pas, *next;
252 
253 	/* Mark the stale associations as invalid, they will no longer
254 	 * be able to Rx any traffic.
255 	 */
256 	list_for_each_entry_safe(pas, next, &psd->prev_assocs, assocs_list)
257 		pas->generation |= ~PSP_GEN_VALID_MASK;
258 	list_splice_init(&psd->prev_assocs, &psd->stale_assocs);
259 	list_splice_init(&psd->active_assocs, &psd->prev_assocs);
260 
261 	/* TODO: we should inform the sockets that got shut down */
262 }
263 
264 void psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk)
265 {
266 	struct psp_assoc *pas = psp_sk_assoc(sk);
267 
268 	if (pas)
269 		refcount_inc(&pas->refcnt);
270 	rcu_assign_pointer(tw->psp_assoc, pas);
271 	tw->tw_validate_xmit_skb = psp_validate_xmit;
272 }
273 
274 void psp_twsk_assoc_free(struct inet_timewait_sock *tw)
275 {
276 	struct psp_assoc *pas = rcu_dereference_protected(tw->psp_assoc, 1);
277 
278 	rcu_assign_pointer(tw->psp_assoc, NULL);
279 	psp_assoc_put(pas);
280 }
281 
282 void psp_reply_set_decrypted(struct sk_buff *skb)
283 {
284 	struct psp_assoc *pas;
285 
286 	rcu_read_lock();
287 	pas = psp_sk_get_assoc_rcu(skb->sk);
288 	if (pas && pas->tx.spi)
289 		skb->decrypted = 1;
290 	rcu_read_unlock();
291 }
292 EXPORT_IPV6_MOD_GPL(psp_reply_set_decrypted);
293