xref: /linux/net/psp/psp_nl.c (revision d30c1683aaecb93d2ab95685dc4300a33d3cea7a)
1 // SPDX-License-Identifier: GPL-2.0-only
2 
3 #include <linux/ethtool.h>
4 #include <linux/skbuff.h>
5 #include <linux/xarray.h>
6 #include <net/genetlink.h>
7 #include <net/psp.h>
8 #include <net/sock.h>
9 
10 #include "psp-nl-gen.h"
11 #include "psp.h"
12 
13 /* Netlink helpers */
14 
15 static struct sk_buff *psp_nl_reply_new(struct genl_info *info)
16 {
17 	struct sk_buff *rsp;
18 	void *hdr;
19 
20 	rsp = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL);
21 	if (!rsp)
22 		return NULL;
23 
24 	hdr = genlmsg_iput(rsp, info);
25 	if (!hdr) {
26 		nlmsg_free(rsp);
27 		return NULL;
28 	}
29 
30 	return rsp;
31 }
32 
33 static int psp_nl_reply_send(struct sk_buff *rsp, struct genl_info *info)
34 {
35 	/* Note that this *only* works with a single message per skb! */
36 	nlmsg_end(rsp, (struct nlmsghdr *)rsp->data);
37 
38 	return genlmsg_reply(rsp, info);
39 }
40 
41 /* Device stuff */
42 
43 static struct psp_dev *
44 psp_device_get_and_lock(struct net *net, struct nlattr *dev_id)
45 {
46 	struct psp_dev *psd;
47 	int err;
48 
49 	mutex_lock(&psp_devs_lock);
50 	psd = xa_load(&psp_devs, nla_get_u32(dev_id));
51 	if (!psd) {
52 		mutex_unlock(&psp_devs_lock);
53 		return ERR_PTR(-ENODEV);
54 	}
55 
56 	mutex_lock(&psd->lock);
57 	mutex_unlock(&psp_devs_lock);
58 
59 	err = psp_dev_check_access(psd, net);
60 	if (err) {
61 		mutex_unlock(&psd->lock);
62 		return ERR_PTR(err);
63 	}
64 
65 	return psd;
66 }
67 
68 int psp_device_get_locked(const struct genl_split_ops *ops,
69 			  struct sk_buff *skb, struct genl_info *info)
70 {
71 	if (GENL_REQ_ATTR_CHECK(info, PSP_A_DEV_ID))
72 		return -EINVAL;
73 
74 	info->user_ptr[0] = psp_device_get_and_lock(genl_info_net(info),
75 						    info->attrs[PSP_A_DEV_ID]);
76 	return PTR_ERR_OR_ZERO(info->user_ptr[0]);
77 }
78 
79 void
80 psp_device_unlock(const struct genl_split_ops *ops, struct sk_buff *skb,
81 		  struct genl_info *info)
82 {
83 	struct socket *socket = info->user_ptr[1];
84 	struct psp_dev *psd = info->user_ptr[0];
85 
86 	mutex_unlock(&psd->lock);
87 	if (socket)
88 		sockfd_put(socket);
89 }
90 
91 static int
92 psp_nl_dev_fill(struct psp_dev *psd, struct sk_buff *rsp,
93 		const struct genl_info *info)
94 {
95 	void *hdr;
96 
97 	hdr = genlmsg_iput(rsp, info);
98 	if (!hdr)
99 		return -EMSGSIZE;
100 
101 	if (nla_put_u32(rsp, PSP_A_DEV_ID, psd->id) ||
102 	    nla_put_u32(rsp, PSP_A_DEV_IFINDEX, psd->main_netdev->ifindex) ||
103 	    nla_put_u32(rsp, PSP_A_DEV_PSP_VERSIONS_CAP, psd->caps->versions) ||
104 	    nla_put_u32(rsp, PSP_A_DEV_PSP_VERSIONS_ENA, psd->config.versions))
105 		goto err_cancel_msg;
106 
107 	genlmsg_end(rsp, hdr);
108 	return 0;
109 
110 err_cancel_msg:
111 	genlmsg_cancel(rsp, hdr);
112 	return -EMSGSIZE;
113 }
114 
115 void psp_nl_notify_dev(struct psp_dev *psd, u32 cmd)
116 {
117 	struct genl_info info;
118 	struct sk_buff *ntf;
119 
120 	if (!genl_has_listeners(&psp_nl_family, dev_net(psd->main_netdev),
121 				PSP_NLGRP_MGMT))
122 		return;
123 
124 	ntf = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL);
125 	if (!ntf)
126 		return;
127 
128 	genl_info_init_ntf(&info, &psp_nl_family, cmd);
129 	if (psp_nl_dev_fill(psd, ntf, &info)) {
130 		nlmsg_free(ntf);
131 		return;
132 	}
133 
134 	genlmsg_multicast_netns(&psp_nl_family, dev_net(psd->main_netdev), ntf,
135 				0, PSP_NLGRP_MGMT, GFP_KERNEL);
136 }
137 
138 int psp_nl_dev_get_doit(struct sk_buff *req, struct genl_info *info)
139 {
140 	struct psp_dev *psd = info->user_ptr[0];
141 	struct sk_buff *rsp;
142 	int err;
143 
144 	rsp = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL);
145 	if (!rsp)
146 		return -ENOMEM;
147 
148 	err = psp_nl_dev_fill(psd, rsp, info);
149 	if (err)
150 		goto err_free_msg;
151 
152 	return genlmsg_reply(rsp, info);
153 
154 err_free_msg:
155 	nlmsg_free(rsp);
156 	return err;
157 }
158 
159 static int
160 psp_nl_dev_get_dumpit_one(struct sk_buff *rsp, struct netlink_callback *cb,
161 			  struct psp_dev *psd)
162 {
163 	if (psp_dev_check_access(psd, sock_net(rsp->sk)))
164 		return 0;
165 
166 	return psp_nl_dev_fill(psd, rsp, genl_info_dump(cb));
167 }
168 
169 int psp_nl_dev_get_dumpit(struct sk_buff *rsp, struct netlink_callback *cb)
170 {
171 	struct psp_dev *psd;
172 	int err = 0;
173 
174 	mutex_lock(&psp_devs_lock);
175 	xa_for_each_start(&psp_devs, cb->args[0], psd, cb->args[0]) {
176 		mutex_lock(&psd->lock);
177 		err = psp_nl_dev_get_dumpit_one(rsp, cb, psd);
178 		mutex_unlock(&psd->lock);
179 		if (err)
180 			break;
181 	}
182 	mutex_unlock(&psp_devs_lock);
183 
184 	return err;
185 }
186 
187 int psp_nl_dev_set_doit(struct sk_buff *skb, struct genl_info *info)
188 {
189 	struct psp_dev *psd = info->user_ptr[0];
190 	struct psp_dev_config new_config;
191 	struct sk_buff *rsp;
192 	int err;
193 
194 	memcpy(&new_config, &psd->config, sizeof(new_config));
195 
196 	if (info->attrs[PSP_A_DEV_PSP_VERSIONS_ENA]) {
197 		new_config.versions =
198 			nla_get_u32(info->attrs[PSP_A_DEV_PSP_VERSIONS_ENA]);
199 		if (new_config.versions & ~psd->caps->versions) {
200 			NL_SET_ERR_MSG(info->extack, "Requested PSP versions not supported by the device");
201 			return -EINVAL;
202 		}
203 	} else {
204 		NL_SET_ERR_MSG(info->extack, "No settings present");
205 		return -EINVAL;
206 	}
207 
208 	rsp = psp_nl_reply_new(info);
209 	if (!rsp)
210 		return -ENOMEM;
211 
212 	if (memcmp(&new_config, &psd->config, sizeof(new_config))) {
213 		err = psd->ops->set_config(psd, &new_config, info->extack);
214 		if (err)
215 			goto err_free_rsp;
216 
217 		memcpy(&psd->config, &new_config, sizeof(new_config));
218 	}
219 
220 	psp_nl_notify_dev(psd, PSP_CMD_DEV_CHANGE_NTF);
221 
222 	return psp_nl_reply_send(rsp, info);
223 
224 err_free_rsp:
225 	nlmsg_free(rsp);
226 	return err;
227 }
228 
229 int psp_nl_key_rotate_doit(struct sk_buff *skb, struct genl_info *info)
230 {
231 	struct psp_dev *psd = info->user_ptr[0];
232 	struct genl_info ntf_info;
233 	struct sk_buff *ntf, *rsp;
234 	u8 prev_gen;
235 	int err;
236 
237 	rsp = psp_nl_reply_new(info);
238 	if (!rsp)
239 		return -ENOMEM;
240 
241 	genl_info_init_ntf(&ntf_info, &psp_nl_family, PSP_CMD_KEY_ROTATE_NTF);
242 	ntf = psp_nl_reply_new(&ntf_info);
243 	if (!ntf) {
244 		err = -ENOMEM;
245 		goto err_free_rsp;
246 	}
247 
248 	if (nla_put_u32(rsp, PSP_A_DEV_ID, psd->id) ||
249 	    nla_put_u32(ntf, PSP_A_DEV_ID, psd->id)) {
250 		err = -EMSGSIZE;
251 		goto err_free_ntf;
252 	}
253 
254 	/* suggest the next gen number, driver can override */
255 	prev_gen = psd->generation;
256 	psd->generation = (prev_gen + 1) & PSP_GEN_VALID_MASK;
257 
258 	err = psd->ops->key_rotate(psd, info->extack);
259 	if (err)
260 		goto err_free_ntf;
261 
262 	WARN_ON_ONCE((psd->generation && psd->generation == prev_gen) ||
263 		     psd->generation & ~PSP_GEN_VALID_MASK);
264 
265 	psp_assocs_key_rotated(psd);
266 	psd->stats.rotations++;
267 
268 	nlmsg_end(ntf, (struct nlmsghdr *)ntf->data);
269 	genlmsg_multicast_netns(&psp_nl_family, dev_net(psd->main_netdev), ntf,
270 				0, PSP_NLGRP_USE, GFP_KERNEL);
271 	return psp_nl_reply_send(rsp, info);
272 
273 err_free_ntf:
274 	nlmsg_free(ntf);
275 err_free_rsp:
276 	nlmsg_free(rsp);
277 	return err;
278 }
279 
280 /* Key etc. */
281 
282 int psp_assoc_device_get_locked(const struct genl_split_ops *ops,
283 				struct sk_buff *skb, struct genl_info *info)
284 {
285 	struct socket *socket;
286 	struct psp_dev *psd;
287 	struct nlattr *id;
288 	int fd, err;
289 
290 	if (GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_SOCK_FD))
291 		return -EINVAL;
292 
293 	fd = nla_get_u32(info->attrs[PSP_A_ASSOC_SOCK_FD]);
294 	socket = sockfd_lookup(fd, &err);
295 	if (!socket)
296 		return err;
297 
298 	if (!sk_is_tcp(socket->sk)) {
299 		NL_SET_ERR_MSG_ATTR(info->extack,
300 				    info->attrs[PSP_A_ASSOC_SOCK_FD],
301 				    "Unsupported socket family and type");
302 		err = -EOPNOTSUPP;
303 		goto err_sock_put;
304 	}
305 
306 	psd = psp_dev_get_for_sock(socket->sk);
307 	if (psd) {
308 		err = psp_dev_check_access(psd, genl_info_net(info));
309 		if (err) {
310 			psp_dev_put(psd);
311 			psd = NULL;
312 		}
313 	}
314 
315 	if (!psd && GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_DEV_ID)) {
316 		err = -EINVAL;
317 		goto err_sock_put;
318 	}
319 
320 	id = info->attrs[PSP_A_ASSOC_DEV_ID];
321 	if (psd) {
322 		mutex_lock(&psd->lock);
323 		if (id && psd->id != nla_get_u32(id)) {
324 			mutex_unlock(&psd->lock);
325 			NL_SET_ERR_MSG_ATTR(info->extack, id,
326 					    "Device id vs socket mismatch");
327 			err = -EINVAL;
328 			goto err_psd_put;
329 		}
330 
331 		psp_dev_put(psd);
332 	} else {
333 		psd = psp_device_get_and_lock(genl_info_net(info), id);
334 		if (IS_ERR(psd)) {
335 			err = PTR_ERR(psd);
336 			goto err_sock_put;
337 		}
338 	}
339 
340 	info->user_ptr[0] = psd;
341 	info->user_ptr[1] = socket;
342 
343 	return 0;
344 
345 err_psd_put:
346 	psp_dev_put(psd);
347 err_sock_put:
348 	sockfd_put(socket);
349 	return err;
350 }
351 
352 static int
353 psp_nl_parse_key(struct genl_info *info, u32 attr, struct psp_key_parsed *key,
354 		 unsigned int key_sz)
355 {
356 	struct nlattr *nest = info->attrs[attr];
357 	struct nlattr *tb[PSP_A_KEYS_SPI + 1];
358 	u32 spi;
359 	int err;
360 
361 	err = nla_parse_nested(tb, ARRAY_SIZE(tb) - 1, nest,
362 			       psp_keys_nl_policy, info->extack);
363 	if (err)
364 		return err;
365 
366 	if (NL_REQ_ATTR_CHECK(info->extack, nest, tb, PSP_A_KEYS_KEY) ||
367 	    NL_REQ_ATTR_CHECK(info->extack, nest, tb, PSP_A_KEYS_SPI))
368 		return -EINVAL;
369 
370 	if (nla_len(tb[PSP_A_KEYS_KEY]) != key_sz) {
371 		NL_SET_ERR_MSG_ATTR(info->extack, tb[PSP_A_KEYS_KEY],
372 				    "incorrect key length");
373 		return -EINVAL;
374 	}
375 
376 	spi = nla_get_u32(tb[PSP_A_KEYS_SPI]);
377 	if (!(spi & PSP_SPI_KEY_ID)) {
378 		NL_SET_ERR_MSG_ATTR(info->extack, tb[PSP_A_KEYS_KEY],
379 				    "invalid SPI: lower 31b must be non-zero");
380 		return -EINVAL;
381 	}
382 
383 	key->spi = cpu_to_be32(spi);
384 	memcpy(key->key, nla_data(tb[PSP_A_KEYS_KEY]), key_sz);
385 
386 	return 0;
387 }
388 
389 static int
390 psp_nl_put_key(struct sk_buff *skb, u32 attr, u32 version,
391 	       struct psp_key_parsed *key)
392 {
393 	int key_sz = psp_key_size(version);
394 	void *nest;
395 
396 	nest = nla_nest_start(skb, attr);
397 
398 	if (nla_put_u32(skb, PSP_A_KEYS_SPI, be32_to_cpu(key->spi)) ||
399 	    nla_put(skb, PSP_A_KEYS_KEY, key_sz, key->key)) {
400 		nla_nest_cancel(skb, nest);
401 		return -EMSGSIZE;
402 	}
403 
404 	nla_nest_end(skb, nest);
405 
406 	return 0;
407 }
408 
409 int psp_nl_rx_assoc_doit(struct sk_buff *skb, struct genl_info *info)
410 {
411 	struct socket *socket = info->user_ptr[1];
412 	struct psp_dev *psd = info->user_ptr[0];
413 	struct psp_key_parsed key;
414 	struct psp_assoc *pas;
415 	struct sk_buff *rsp;
416 	u32 version;
417 	int err;
418 
419 	if (GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_VERSION))
420 		return -EINVAL;
421 
422 	version = nla_get_u32(info->attrs[PSP_A_ASSOC_VERSION]);
423 	if (!(psd->caps->versions & (1 << version))) {
424 		NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_ASSOC_VERSION]);
425 		return -EOPNOTSUPP;
426 	}
427 
428 	rsp = psp_nl_reply_new(info);
429 	if (!rsp)
430 		return -ENOMEM;
431 
432 	pas = psp_assoc_create(psd);
433 	if (!pas) {
434 		err = -ENOMEM;
435 		goto err_free_rsp;
436 	}
437 	pas->version = version;
438 
439 	err = psd->ops->rx_spi_alloc(psd, version, &key, info->extack);
440 	if (err)
441 		goto err_free_pas;
442 
443 	if (nla_put_u32(rsp, PSP_A_ASSOC_DEV_ID, psd->id) ||
444 	    psp_nl_put_key(rsp, PSP_A_ASSOC_RX_KEY, version, &key)) {
445 		err = -EMSGSIZE;
446 		goto err_free_pas;
447 	}
448 
449 	err = psp_sock_assoc_set_rx(socket->sk, pas, &key, info->extack);
450 	if (err) {
451 		NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_ASSOC_SOCK_FD]);
452 		goto err_free_pas;
453 	}
454 	psp_assoc_put(pas);
455 
456 	return psp_nl_reply_send(rsp, info);
457 
458 err_free_pas:
459 	psp_assoc_put(pas);
460 err_free_rsp:
461 	nlmsg_free(rsp);
462 	return err;
463 }
464 
465 int psp_nl_tx_assoc_doit(struct sk_buff *skb, struct genl_info *info)
466 {
467 	struct socket *socket = info->user_ptr[1];
468 	struct psp_dev *psd = info->user_ptr[0];
469 	struct psp_key_parsed key;
470 	struct sk_buff *rsp;
471 	unsigned int key_sz;
472 	u32 version;
473 	int err;
474 
475 	if (GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_VERSION) ||
476 	    GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_TX_KEY))
477 		return -EINVAL;
478 
479 	version = nla_get_u32(info->attrs[PSP_A_ASSOC_VERSION]);
480 	if (!(psd->caps->versions & (1 << version))) {
481 		NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_ASSOC_VERSION]);
482 		return -EOPNOTSUPP;
483 	}
484 
485 	key_sz = psp_key_size(version);
486 	if (!key_sz)
487 		return -EINVAL;
488 
489 	err = psp_nl_parse_key(info, PSP_A_ASSOC_TX_KEY, &key, key_sz);
490 	if (err < 0)
491 		return err;
492 
493 	rsp = psp_nl_reply_new(info);
494 	if (!rsp)
495 		return -ENOMEM;
496 
497 	err = psp_sock_assoc_set_tx(socket->sk, psd, version, &key,
498 				    info->extack);
499 	if (err)
500 		goto err_free_msg;
501 
502 	return psp_nl_reply_send(rsp, info);
503 
504 err_free_msg:
505 	nlmsg_free(rsp);
506 	return err;
507 }
508 
509 static int
510 psp_nl_stats_fill(struct psp_dev *psd, struct sk_buff *rsp,
511 		  const struct genl_info *info)
512 {
513 	unsigned int required_cnt = sizeof(struct psp_dev_stats) / sizeof(u64);
514 	struct psp_dev_stats stats;
515 	void *hdr;
516 	int i;
517 
518 	memset(&stats, 0xff, sizeof(stats));
519 	psd->ops->get_stats(psd, &stats);
520 
521 	for (i = 0; i < required_cnt; i++)
522 		if (WARN_ON_ONCE(stats.required[i] == ETHTOOL_STAT_NOT_SET))
523 			return -EOPNOTSUPP;
524 
525 	hdr = genlmsg_iput(rsp, info);
526 	if (!hdr)
527 		return -EMSGSIZE;
528 
529 	if (nla_put_u32(rsp, PSP_A_STATS_DEV_ID, psd->id) ||
530 	    nla_put_uint(rsp, PSP_A_STATS_KEY_ROTATIONS,
531 			 psd->stats.rotations) ||
532 	    nla_put_uint(rsp, PSP_A_STATS_STALE_EVENTS, psd->stats.stales) ||
533 	    nla_put_uint(rsp, PSP_A_STATS_RX_PACKETS, stats.rx_packets) ||
534 	    nla_put_uint(rsp, PSP_A_STATS_RX_BYTES, stats.rx_bytes) ||
535 	    nla_put_uint(rsp, PSP_A_STATS_RX_AUTH_FAIL, stats.rx_auth_fail) ||
536 	    nla_put_uint(rsp, PSP_A_STATS_RX_ERROR, stats.rx_error) ||
537 	    nla_put_uint(rsp, PSP_A_STATS_RX_BAD, stats.rx_bad) ||
538 	    nla_put_uint(rsp, PSP_A_STATS_TX_PACKETS, stats.tx_packets) ||
539 	    nla_put_uint(rsp, PSP_A_STATS_TX_BYTES, stats.tx_bytes) ||
540 	    nla_put_uint(rsp, PSP_A_STATS_TX_ERROR, stats.tx_error))
541 		goto err_cancel_msg;
542 
543 	genlmsg_end(rsp, hdr);
544 	return 0;
545 
546 err_cancel_msg:
547 	genlmsg_cancel(rsp, hdr);
548 	return -EMSGSIZE;
549 }
550 
551 int psp_nl_get_stats_doit(struct sk_buff *skb, struct genl_info *info)
552 {
553 	struct psp_dev *psd = info->user_ptr[0];
554 	struct sk_buff *rsp;
555 	int err;
556 
557 	rsp = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL);
558 	if (!rsp)
559 		return -ENOMEM;
560 
561 	err = psp_nl_stats_fill(psd, rsp, info);
562 	if (err)
563 		goto err_free_msg;
564 
565 	return genlmsg_reply(rsp, info);
566 
567 err_free_msg:
568 	nlmsg_free(rsp);
569 	return err;
570 }
571 
572 static int
573 psp_nl_stats_get_dumpit_one(struct sk_buff *rsp, struct netlink_callback *cb,
574 			    struct psp_dev *psd)
575 {
576 	if (psp_dev_check_access(psd, sock_net(rsp->sk)))
577 		return 0;
578 
579 	return psp_nl_stats_fill(psd, rsp, genl_info_dump(cb));
580 }
581 
582 int psp_nl_get_stats_dumpit(struct sk_buff *rsp, struct netlink_callback *cb)
583 {
584 	struct psp_dev *psd;
585 	int err = 0;
586 
587 	mutex_lock(&psp_devs_lock);
588 	xa_for_each_start(&psp_devs, cb->args[0], psd, cb->args[0]) {
589 		mutex_lock(&psd->lock);
590 		err = psp_nl_stats_get_dumpit_one(rsp, cb, psd);
591 		mutex_unlock(&psd->lock);
592 		if (err)
593 			break;
594 	}
595 	mutex_unlock(&psp_devs_lock);
596 
597 	return err;
598 }
599