xref: /linux/net/psp/psp_nl.c (revision 07fdad3a93756b872da7b53647715c48d0f4a2d0)
1 // SPDX-License-Identifier: GPL-2.0-only
2 
3 #include <linux/skbuff.h>
4 #include <linux/xarray.h>
5 #include <net/genetlink.h>
6 #include <net/psp.h>
7 #include <net/sock.h>
8 
9 #include "psp-nl-gen.h"
10 #include "psp.h"
11 
12 /* Netlink helpers */
13 
14 static struct sk_buff *psp_nl_reply_new(struct genl_info *info)
15 {
16 	struct sk_buff *rsp;
17 	void *hdr;
18 
19 	rsp = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL);
20 	if (!rsp)
21 		return NULL;
22 
23 	hdr = genlmsg_iput(rsp, info);
24 	if (!hdr) {
25 		nlmsg_free(rsp);
26 		return NULL;
27 	}
28 
29 	return rsp;
30 }
31 
32 static int psp_nl_reply_send(struct sk_buff *rsp, struct genl_info *info)
33 {
34 	/* Note that this *only* works with a single message per skb! */
35 	nlmsg_end(rsp, (struct nlmsghdr *)rsp->data);
36 
37 	return genlmsg_reply(rsp, info);
38 }
39 
40 /* Device stuff */
41 
42 static struct psp_dev *
43 psp_device_get_and_lock(struct net *net, struct nlattr *dev_id)
44 {
45 	struct psp_dev *psd;
46 	int err;
47 
48 	mutex_lock(&psp_devs_lock);
49 	psd = xa_load(&psp_devs, nla_get_u32(dev_id));
50 	if (!psd) {
51 		mutex_unlock(&psp_devs_lock);
52 		return ERR_PTR(-ENODEV);
53 	}
54 
55 	mutex_lock(&psd->lock);
56 	mutex_unlock(&psp_devs_lock);
57 
58 	err = psp_dev_check_access(psd, net);
59 	if (err) {
60 		mutex_unlock(&psd->lock);
61 		return ERR_PTR(err);
62 	}
63 
64 	return psd;
65 }
66 
67 int psp_device_get_locked(const struct genl_split_ops *ops,
68 			  struct sk_buff *skb, struct genl_info *info)
69 {
70 	if (GENL_REQ_ATTR_CHECK(info, PSP_A_DEV_ID))
71 		return -EINVAL;
72 
73 	info->user_ptr[0] = psp_device_get_and_lock(genl_info_net(info),
74 						    info->attrs[PSP_A_DEV_ID]);
75 	return PTR_ERR_OR_ZERO(info->user_ptr[0]);
76 }
77 
78 void
79 psp_device_unlock(const struct genl_split_ops *ops, struct sk_buff *skb,
80 		  struct genl_info *info)
81 {
82 	struct socket *socket = info->user_ptr[1];
83 	struct psp_dev *psd = info->user_ptr[0];
84 
85 	mutex_unlock(&psd->lock);
86 	if (socket)
87 		sockfd_put(socket);
88 }
89 
90 static int
91 psp_nl_dev_fill(struct psp_dev *psd, struct sk_buff *rsp,
92 		const struct genl_info *info)
93 {
94 	void *hdr;
95 
96 	hdr = genlmsg_iput(rsp, info);
97 	if (!hdr)
98 		return -EMSGSIZE;
99 
100 	if (nla_put_u32(rsp, PSP_A_DEV_ID, psd->id) ||
101 	    nla_put_u32(rsp, PSP_A_DEV_IFINDEX, psd->main_netdev->ifindex) ||
102 	    nla_put_u32(rsp, PSP_A_DEV_PSP_VERSIONS_CAP, psd->caps->versions) ||
103 	    nla_put_u32(rsp, PSP_A_DEV_PSP_VERSIONS_ENA, psd->config.versions))
104 		goto err_cancel_msg;
105 
106 	genlmsg_end(rsp, hdr);
107 	return 0;
108 
109 err_cancel_msg:
110 	genlmsg_cancel(rsp, hdr);
111 	return -EMSGSIZE;
112 }
113 
114 void psp_nl_notify_dev(struct psp_dev *psd, u32 cmd)
115 {
116 	struct genl_info info;
117 	struct sk_buff *ntf;
118 
119 	if (!genl_has_listeners(&psp_nl_family, dev_net(psd->main_netdev),
120 				PSP_NLGRP_MGMT))
121 		return;
122 
123 	ntf = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL);
124 	if (!ntf)
125 		return;
126 
127 	genl_info_init_ntf(&info, &psp_nl_family, cmd);
128 	if (psp_nl_dev_fill(psd, ntf, &info)) {
129 		nlmsg_free(ntf);
130 		return;
131 	}
132 
133 	genlmsg_multicast_netns(&psp_nl_family, dev_net(psd->main_netdev), ntf,
134 				0, PSP_NLGRP_MGMT, GFP_KERNEL);
135 }
136 
137 int psp_nl_dev_get_doit(struct sk_buff *req, struct genl_info *info)
138 {
139 	struct psp_dev *psd = info->user_ptr[0];
140 	struct sk_buff *rsp;
141 	int err;
142 
143 	rsp = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL);
144 	if (!rsp)
145 		return -ENOMEM;
146 
147 	err = psp_nl_dev_fill(psd, rsp, info);
148 	if (err)
149 		goto err_free_msg;
150 
151 	return genlmsg_reply(rsp, info);
152 
153 err_free_msg:
154 	nlmsg_free(rsp);
155 	return err;
156 }
157 
158 static int
159 psp_nl_dev_get_dumpit_one(struct sk_buff *rsp, struct netlink_callback *cb,
160 			  struct psp_dev *psd)
161 {
162 	if (psp_dev_check_access(psd, sock_net(rsp->sk)))
163 		return 0;
164 
165 	return psp_nl_dev_fill(psd, rsp, genl_info_dump(cb));
166 }
167 
168 int psp_nl_dev_get_dumpit(struct sk_buff *rsp, struct netlink_callback *cb)
169 {
170 	struct psp_dev *psd;
171 	int err = 0;
172 
173 	mutex_lock(&psp_devs_lock);
174 	xa_for_each_start(&psp_devs, cb->args[0], psd, cb->args[0]) {
175 		mutex_lock(&psd->lock);
176 		err = psp_nl_dev_get_dumpit_one(rsp, cb, psd);
177 		mutex_unlock(&psd->lock);
178 		if (err)
179 			break;
180 	}
181 	mutex_unlock(&psp_devs_lock);
182 
183 	return err;
184 }
185 
186 int psp_nl_dev_set_doit(struct sk_buff *skb, struct genl_info *info)
187 {
188 	struct psp_dev *psd = info->user_ptr[0];
189 	struct psp_dev_config new_config;
190 	struct sk_buff *rsp;
191 	int err;
192 
193 	memcpy(&new_config, &psd->config, sizeof(new_config));
194 
195 	if (info->attrs[PSP_A_DEV_PSP_VERSIONS_ENA]) {
196 		new_config.versions =
197 			nla_get_u32(info->attrs[PSP_A_DEV_PSP_VERSIONS_ENA]);
198 		if (new_config.versions & ~psd->caps->versions) {
199 			NL_SET_ERR_MSG(info->extack, "Requested PSP versions not supported by the device");
200 			return -EINVAL;
201 		}
202 	} else {
203 		NL_SET_ERR_MSG(info->extack, "No settings present");
204 		return -EINVAL;
205 	}
206 
207 	rsp = psp_nl_reply_new(info);
208 	if (!rsp)
209 		return -ENOMEM;
210 
211 	if (memcmp(&new_config, &psd->config, sizeof(new_config))) {
212 		err = psd->ops->set_config(psd, &new_config, info->extack);
213 		if (err)
214 			goto err_free_rsp;
215 
216 		memcpy(&psd->config, &new_config, sizeof(new_config));
217 	}
218 
219 	psp_nl_notify_dev(psd, PSP_CMD_DEV_CHANGE_NTF);
220 
221 	return psp_nl_reply_send(rsp, info);
222 
223 err_free_rsp:
224 	nlmsg_free(rsp);
225 	return err;
226 }
227 
228 int psp_nl_key_rotate_doit(struct sk_buff *skb, struct genl_info *info)
229 {
230 	struct psp_dev *psd = info->user_ptr[0];
231 	struct genl_info ntf_info;
232 	struct sk_buff *ntf, *rsp;
233 	u8 prev_gen;
234 	int err;
235 
236 	rsp = psp_nl_reply_new(info);
237 	if (!rsp)
238 		return -ENOMEM;
239 
240 	genl_info_init_ntf(&ntf_info, &psp_nl_family, PSP_CMD_KEY_ROTATE_NTF);
241 	ntf = psp_nl_reply_new(&ntf_info);
242 	if (!ntf) {
243 		err = -ENOMEM;
244 		goto err_free_rsp;
245 	}
246 
247 	if (nla_put_u32(rsp, PSP_A_DEV_ID, psd->id) ||
248 	    nla_put_u32(ntf, PSP_A_DEV_ID, psd->id)) {
249 		err = -EMSGSIZE;
250 		goto err_free_ntf;
251 	}
252 
253 	/* suggest the next gen number, driver can override */
254 	prev_gen = psd->generation;
255 	psd->generation = (prev_gen + 1) & PSP_GEN_VALID_MASK;
256 
257 	err = psd->ops->key_rotate(psd, info->extack);
258 	if (err)
259 		goto err_free_ntf;
260 
261 	WARN_ON_ONCE((psd->generation && psd->generation == prev_gen) ||
262 		     psd->generation & ~PSP_GEN_VALID_MASK);
263 
264 	psp_assocs_key_rotated(psd);
265 
266 	nlmsg_end(ntf, (struct nlmsghdr *)ntf->data);
267 	genlmsg_multicast_netns(&psp_nl_family, dev_net(psd->main_netdev), ntf,
268 				0, PSP_NLGRP_USE, GFP_KERNEL);
269 	return psp_nl_reply_send(rsp, info);
270 
271 err_free_ntf:
272 	nlmsg_free(ntf);
273 err_free_rsp:
274 	nlmsg_free(rsp);
275 	return err;
276 }
277 
278 /* Key etc. */
279 
280 int psp_assoc_device_get_locked(const struct genl_split_ops *ops,
281 				struct sk_buff *skb, struct genl_info *info)
282 {
283 	struct socket *socket;
284 	struct psp_dev *psd;
285 	struct nlattr *id;
286 	int fd, err;
287 
288 	if (GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_SOCK_FD))
289 		return -EINVAL;
290 
291 	fd = nla_get_u32(info->attrs[PSP_A_ASSOC_SOCK_FD]);
292 	socket = sockfd_lookup(fd, &err);
293 	if (!socket)
294 		return err;
295 
296 	if (!sk_is_tcp(socket->sk)) {
297 		NL_SET_ERR_MSG_ATTR(info->extack,
298 				    info->attrs[PSP_A_ASSOC_SOCK_FD],
299 				    "Unsupported socket family and type");
300 		err = -EOPNOTSUPP;
301 		goto err_sock_put;
302 	}
303 
304 	psd = psp_dev_get_for_sock(socket->sk);
305 	if (psd) {
306 		err = psp_dev_check_access(psd, genl_info_net(info));
307 		if (err) {
308 			psp_dev_put(psd);
309 			psd = NULL;
310 		}
311 	}
312 
313 	if (!psd && GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_DEV_ID)) {
314 		err = -EINVAL;
315 		goto err_sock_put;
316 	}
317 
318 	id = info->attrs[PSP_A_ASSOC_DEV_ID];
319 	if (psd) {
320 		mutex_lock(&psd->lock);
321 		if (id && psd->id != nla_get_u32(id)) {
322 			mutex_unlock(&psd->lock);
323 			NL_SET_ERR_MSG_ATTR(info->extack, id,
324 					    "Device id vs socket mismatch");
325 			err = -EINVAL;
326 			goto err_psd_put;
327 		}
328 
329 		psp_dev_put(psd);
330 	} else {
331 		psd = psp_device_get_and_lock(genl_info_net(info), id);
332 		if (IS_ERR(psd)) {
333 			err = PTR_ERR(psd);
334 			goto err_sock_put;
335 		}
336 	}
337 
338 	info->user_ptr[0] = psd;
339 	info->user_ptr[1] = socket;
340 
341 	return 0;
342 
343 err_psd_put:
344 	psp_dev_put(psd);
345 err_sock_put:
346 	sockfd_put(socket);
347 	return err;
348 }
349 
350 static int
351 psp_nl_parse_key(struct genl_info *info, u32 attr, struct psp_key_parsed *key,
352 		 unsigned int key_sz)
353 {
354 	struct nlattr *nest = info->attrs[attr];
355 	struct nlattr *tb[PSP_A_KEYS_SPI + 1];
356 	u32 spi;
357 	int err;
358 
359 	err = nla_parse_nested(tb, ARRAY_SIZE(tb) - 1, nest,
360 			       psp_keys_nl_policy, info->extack);
361 	if (err)
362 		return err;
363 
364 	if (NL_REQ_ATTR_CHECK(info->extack, nest, tb, PSP_A_KEYS_KEY) ||
365 	    NL_REQ_ATTR_CHECK(info->extack, nest, tb, PSP_A_KEYS_SPI))
366 		return -EINVAL;
367 
368 	if (nla_len(tb[PSP_A_KEYS_KEY]) != key_sz) {
369 		NL_SET_ERR_MSG_ATTR(info->extack, tb[PSP_A_KEYS_KEY],
370 				    "incorrect key length");
371 		return -EINVAL;
372 	}
373 
374 	spi = nla_get_u32(tb[PSP_A_KEYS_SPI]);
375 	if (!(spi & PSP_SPI_KEY_ID)) {
376 		NL_SET_ERR_MSG_ATTR(info->extack, tb[PSP_A_KEYS_KEY],
377 				    "invalid SPI: lower 31b must be non-zero");
378 		return -EINVAL;
379 	}
380 
381 	key->spi = cpu_to_be32(spi);
382 	memcpy(key->key, nla_data(tb[PSP_A_KEYS_KEY]), key_sz);
383 
384 	return 0;
385 }
386 
387 static int
388 psp_nl_put_key(struct sk_buff *skb, u32 attr, u32 version,
389 	       struct psp_key_parsed *key)
390 {
391 	int key_sz = psp_key_size(version);
392 	void *nest;
393 
394 	nest = nla_nest_start(skb, attr);
395 
396 	if (nla_put_u32(skb, PSP_A_KEYS_SPI, be32_to_cpu(key->spi)) ||
397 	    nla_put(skb, PSP_A_KEYS_KEY, key_sz, key->key)) {
398 		nla_nest_cancel(skb, nest);
399 		return -EMSGSIZE;
400 	}
401 
402 	nla_nest_end(skb, nest);
403 
404 	return 0;
405 }
406 
407 int psp_nl_rx_assoc_doit(struct sk_buff *skb, struct genl_info *info)
408 {
409 	struct socket *socket = info->user_ptr[1];
410 	struct psp_dev *psd = info->user_ptr[0];
411 	struct psp_key_parsed key;
412 	struct psp_assoc *pas;
413 	struct sk_buff *rsp;
414 	u32 version;
415 	int err;
416 
417 	if (GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_VERSION))
418 		return -EINVAL;
419 
420 	version = nla_get_u32(info->attrs[PSP_A_ASSOC_VERSION]);
421 	if (!(psd->caps->versions & (1 << version))) {
422 		NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_ASSOC_VERSION]);
423 		return -EOPNOTSUPP;
424 	}
425 
426 	rsp = psp_nl_reply_new(info);
427 	if (!rsp)
428 		return -ENOMEM;
429 
430 	pas = psp_assoc_create(psd);
431 	if (!pas) {
432 		err = -ENOMEM;
433 		goto err_free_rsp;
434 	}
435 	pas->version = version;
436 
437 	err = psd->ops->rx_spi_alloc(psd, version, &key, info->extack);
438 	if (err)
439 		goto err_free_pas;
440 
441 	if (nla_put_u32(rsp, PSP_A_ASSOC_DEV_ID, psd->id) ||
442 	    psp_nl_put_key(rsp, PSP_A_ASSOC_RX_KEY, version, &key)) {
443 		err = -EMSGSIZE;
444 		goto err_free_pas;
445 	}
446 
447 	err = psp_sock_assoc_set_rx(socket->sk, pas, &key, info->extack);
448 	if (err) {
449 		NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_ASSOC_SOCK_FD]);
450 		goto err_free_pas;
451 	}
452 	psp_assoc_put(pas);
453 
454 	return psp_nl_reply_send(rsp, info);
455 
456 err_free_pas:
457 	psp_assoc_put(pas);
458 err_free_rsp:
459 	nlmsg_free(rsp);
460 	return err;
461 }
462 
463 int psp_nl_tx_assoc_doit(struct sk_buff *skb, struct genl_info *info)
464 {
465 	struct socket *socket = info->user_ptr[1];
466 	struct psp_dev *psd = info->user_ptr[0];
467 	struct psp_key_parsed key;
468 	struct sk_buff *rsp;
469 	unsigned int key_sz;
470 	u32 version;
471 	int err;
472 
473 	if (GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_VERSION) ||
474 	    GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_TX_KEY))
475 		return -EINVAL;
476 
477 	version = nla_get_u32(info->attrs[PSP_A_ASSOC_VERSION]);
478 	if (!(psd->caps->versions & (1 << version))) {
479 		NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_ASSOC_VERSION]);
480 		return -EOPNOTSUPP;
481 	}
482 
483 	key_sz = psp_key_size(version);
484 	if (!key_sz)
485 		return -EINVAL;
486 
487 	err = psp_nl_parse_key(info, PSP_A_ASSOC_TX_KEY, &key, key_sz);
488 	if (err < 0)
489 		return err;
490 
491 	rsp = psp_nl_reply_new(info);
492 	if (!rsp)
493 		return -ENOMEM;
494 
495 	err = psp_sock_assoc_set_tx(socket->sk, psd, version, &key,
496 				    info->extack);
497 	if (err)
498 		goto err_free_msg;
499 
500 	return psp_nl_reply_send(rsp, info);
501 
502 err_free_msg:
503 	nlmsg_free(rsp);
504 	return err;
505 }
506