xref: /linux/net/psp/psp_sock.c (revision 6b46ca260e2290e3453d1355ab5b6d283d73d780)
1 // SPDX-License-Identifier: GPL-2.0-only
2 
3 #include <linux/file.h>
4 #include <linux/net.h>
5 #include <linux/rcupdate.h>
6 #include <linux/tcp.h>
7 
8 #include <net/ip.h>
9 #include <net/psp.h>
10 #include "psp.h"
11 
12 struct psp_dev *psp_dev_get_for_sock(struct sock *sk)
13 {
14 	struct dst_entry *dst;
15 	struct psp_dev *psd;
16 
17 	dst = sk_dst_get(sk);
18 	if (!dst)
19 		return NULL;
20 
21 	rcu_read_lock();
22 	psd = rcu_dereference(dst->dev->psp_dev);
23 	if (psd && !psp_dev_tryget(psd))
24 		psd = NULL;
25 	rcu_read_unlock();
26 
27 	dst_release(dst);
28 
29 	return psd;
30 }
31 
32 static struct sk_buff *
33 psp_validate_xmit(struct sock *sk, struct net_device *dev, struct sk_buff *skb)
34 {
35 	struct psp_assoc *pas;
36 	bool good;
37 
38 	rcu_read_lock();
39 	pas = psp_skb_get_assoc_rcu(skb);
40 	good = !pas || rcu_access_pointer(dev->psp_dev) == pas->psd;
41 	rcu_read_unlock();
42 	if (!good) {
43 		kfree_skb_reason(skb, SKB_DROP_REASON_PSP_OUTPUT);
44 		return NULL;
45 	}
46 
47 	return skb;
48 }
49 
50 struct psp_assoc *psp_assoc_create(struct psp_dev *psd)
51 {
52 	struct psp_assoc *pas;
53 
54 	lockdep_assert_held(&psd->lock);
55 
56 	pas = kzalloc(struct_size(pas, drv_data, psd->caps->assoc_drv_spc),
57 		      GFP_KERNEL_ACCOUNT);
58 	if (!pas)
59 		return NULL;
60 
61 	pas->psd = psd;
62 	pas->dev_id = psd->id;
63 	psp_dev_get(psd);
64 	refcount_set(&pas->refcnt, 1);
65 
66 	list_add_tail(&pas->assocs_list, &psd->active_assocs);
67 
68 	return pas;
69 }
70 
71 static struct psp_assoc *psp_assoc_dummy(struct psp_assoc *pas)
72 {
73 	struct psp_dev *psd = pas->psd;
74 	size_t sz;
75 
76 	lockdep_assert_held(&psd->lock);
77 
78 	sz = struct_size(pas, drv_data, psd->caps->assoc_drv_spc);
79 	return kmemdup(pas, sz, GFP_KERNEL);
80 }
81 
82 static int psp_dev_tx_key_add(struct psp_dev *psd, struct psp_assoc *pas,
83 			      struct netlink_ext_ack *extack)
84 {
85 	return psd->ops->tx_key_add(psd, pas, extack);
86 }
87 
88 void psp_dev_tx_key_del(struct psp_dev *psd, struct psp_assoc *pas)
89 {
90 	if (pas->tx.spi)
91 		psd->ops->tx_key_del(psd, pas);
92 	list_del(&pas->assocs_list);
93 }
94 
95 static void psp_assoc_free(struct work_struct *work)
96 {
97 	struct psp_assoc *pas = container_of(work, struct psp_assoc, work);
98 	struct psp_dev *psd = pas->psd;
99 
100 	mutex_lock(&psd->lock);
101 	if (psd->ops)
102 		psp_dev_tx_key_del(psd, pas);
103 	mutex_unlock(&psd->lock);
104 	psp_dev_put(psd);
105 	kfree(pas);
106 }
107 
108 static void psp_assoc_free_queue(struct rcu_head *head)
109 {
110 	struct psp_assoc *pas = container_of(head, struct psp_assoc, rcu);
111 
112 	INIT_WORK(&pas->work, psp_assoc_free);
113 	schedule_work(&pas->work);
114 }
115 
116 /**
117  * psp_assoc_put() - release a reference on a PSP association
118  * @pas: association to release
119  */
120 void psp_assoc_put(struct psp_assoc *pas)
121 {
122 	if (pas && refcount_dec_and_test(&pas->refcnt))
123 		call_rcu(&pas->rcu, psp_assoc_free_queue);
124 }
125 
126 void psp_sk_assoc_free(struct sock *sk)
127 {
128 	struct psp_assoc *pas = rcu_dereference_protected(sk->psp_assoc, 1);
129 
130 	rcu_assign_pointer(sk->psp_assoc, NULL);
131 	psp_assoc_put(pas);
132 }
133 
134 int psp_sock_assoc_set_rx(struct sock *sk, struct psp_assoc *pas,
135 			  struct psp_key_parsed *key,
136 			  struct netlink_ext_ack *extack)
137 {
138 	int err;
139 
140 	memcpy(&pas->rx, key, sizeof(*key));
141 
142 	lock_sock(sk);
143 
144 	if (psp_sk_assoc(sk)) {
145 		NL_SET_ERR_MSG(extack, "Socket already has PSP state");
146 		err = -EBUSY;
147 		goto exit_unlock;
148 	}
149 
150 	refcount_inc(&pas->refcnt);
151 	rcu_assign_pointer(sk->psp_assoc, pas);
152 	err = 0;
153 
154 exit_unlock:
155 	release_sock(sk);
156 
157 	return err;
158 }
159 
160 static int psp_sock_recv_queue_check(struct sock *sk, struct psp_assoc *pas)
161 {
162 	struct psp_skb_ext *pse;
163 	struct sk_buff *skb;
164 
165 	skb_rbtree_walk(skb, &tcp_sk(sk)->out_of_order_queue) {
166 		pse = skb_ext_find(skb, SKB_EXT_PSP);
167 		if (!psp_pse_matches_pas(pse, pas))
168 			return -EBUSY;
169 	}
170 
171 	skb_queue_walk(&sk->sk_receive_queue, skb) {
172 		pse = skb_ext_find(skb, SKB_EXT_PSP);
173 		if (!psp_pse_matches_pas(pse, pas))
174 			return -EBUSY;
175 	}
176 	return 0;
177 }
178 
179 int psp_sock_assoc_set_tx(struct sock *sk, struct psp_dev *psd,
180 			  u32 version, struct psp_key_parsed *key,
181 			  struct netlink_ext_ack *extack)
182 {
183 	struct psp_assoc *pas, *dummy;
184 	int err;
185 
186 	lock_sock(sk);
187 
188 	pas = psp_sk_assoc(sk);
189 	if (!pas) {
190 		NL_SET_ERR_MSG(extack, "Socket has no Rx key");
191 		err = -EINVAL;
192 		goto exit_unlock;
193 	}
194 	if (pas->psd != psd) {
195 		NL_SET_ERR_MSG(extack, "Rx key from different device");
196 		err = -EINVAL;
197 		goto exit_unlock;
198 	}
199 	if (pas->version != version) {
200 		NL_SET_ERR_MSG(extack,
201 			       "PSP version mismatch with existing state");
202 		err = -EINVAL;
203 		goto exit_unlock;
204 	}
205 	if (pas->tx.spi) {
206 		NL_SET_ERR_MSG(extack, "Tx key already set");
207 		err = -EBUSY;
208 		goto exit_unlock;
209 	}
210 
211 	err = psp_sock_recv_queue_check(sk, pas);
212 	if (err) {
213 		NL_SET_ERR_MSG(extack, "Socket has incompatible segments already in the recv queue");
214 		goto exit_unlock;
215 	}
216 
217 	/* Pass a fake association to drivers to make sure they don't
218 	 * try to store pointers to it. For re-keying we'll need to
219 	 * re-allocate the assoc structures.
220 	 */
221 	dummy = psp_assoc_dummy(pas);
222 	if (!dummy) {
223 		err = -ENOMEM;
224 		goto exit_unlock;
225 	}
226 
227 	memcpy(&dummy->tx, key, sizeof(*key));
228 	err = psp_dev_tx_key_add(psd, dummy, extack);
229 	if (err)
230 		goto exit_free_dummy;
231 
232 	memcpy(pas->drv_data, dummy->drv_data, psd->caps->assoc_drv_spc);
233 	memcpy(&pas->tx, key, sizeof(*key));
234 
235 	WRITE_ONCE(sk->sk_validate_xmit_skb, psp_validate_xmit);
236 	tcp_write_collapse_fence(sk);
237 	pas->upgrade_seq = tcp_sk(sk)->rcv_nxt;
238 
239 exit_free_dummy:
240 	kfree(dummy);
241 exit_unlock:
242 	release_sock(sk);
243 	return err;
244 }
245 
246 void psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk)
247 {
248 	struct psp_assoc *pas = psp_sk_assoc(sk);
249 
250 	if (pas)
251 		refcount_inc(&pas->refcnt);
252 	rcu_assign_pointer(tw->psp_assoc, pas);
253 	tw->tw_validate_xmit_skb = psp_validate_xmit;
254 }
255 
256 void psp_twsk_assoc_free(struct inet_timewait_sock *tw)
257 {
258 	struct psp_assoc *pas = rcu_dereference_protected(tw->psp_assoc, 1);
259 
260 	rcu_assign_pointer(tw->psp_assoc, NULL);
261 	psp_assoc_put(pas);
262 }
263 
264 void psp_reply_set_decrypted(struct sk_buff *skb)
265 {
266 	struct psp_assoc *pas;
267 
268 	rcu_read_lock();
269 	pas = psp_sk_get_assoc_rcu(skb->sk);
270 	if (pas && pas->tx.spi)
271 		skb->decrypted = 1;
272 	rcu_read_unlock();
273 }
274 EXPORT_IPV6_MOD_GPL(psp_reply_set_decrypted);
275