1 // SPDX-License-Identifier: GPL-2.0-only
2
3 #include <linux/ethtool.h>
4 #include <linux/skbuff.h>
5 #include <linux/xarray.h>
6 #include <net/genetlink.h>
7 #include <net/psp.h>
8 #include <net/sock.h>
9
10 #include "psp-nl-gen.h"
11 #include "psp.h"
12
13 /* Netlink helpers */
14
psp_nl_reply_new(struct genl_info * info)15 static struct sk_buff *psp_nl_reply_new(struct genl_info *info)
16 {
17 struct sk_buff *rsp;
18 void *hdr;
19
20 rsp = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL);
21 if (!rsp)
22 return NULL;
23
24 hdr = genlmsg_iput(rsp, info);
25 if (!hdr) {
26 nlmsg_free(rsp);
27 return NULL;
28 }
29
30 return rsp;
31 }
32
psp_nl_reply_send(struct sk_buff * rsp,struct genl_info * info)33 static int psp_nl_reply_send(struct sk_buff *rsp, struct genl_info *info)
34 {
35 /* Note that this *only* works with a single message per skb! */
36 nlmsg_end(rsp, (struct nlmsghdr *)rsp->data);
37
38 return genlmsg_reply(rsp, info);
39 }
40
41 /* Device stuff */
42
43 static struct psp_dev *
psp_device_get_and_lock(struct net * net,struct nlattr * dev_id)44 psp_device_get_and_lock(struct net *net, struct nlattr *dev_id)
45 {
46 struct psp_dev *psd;
47 int err;
48
49 mutex_lock(&psp_devs_lock);
50 psd = xa_load(&psp_devs, nla_get_u32(dev_id));
51 if (!psd) {
52 mutex_unlock(&psp_devs_lock);
53 return ERR_PTR(-ENODEV);
54 }
55
56 mutex_lock(&psd->lock);
57 mutex_unlock(&psp_devs_lock);
58
59 err = psp_dev_check_access(psd, net);
60 if (err) {
61 mutex_unlock(&psd->lock);
62 return ERR_PTR(err);
63 }
64
65 return psd;
66 }
67
psp_device_get_locked(const struct genl_split_ops * ops,struct sk_buff * skb,struct genl_info * info)68 int psp_device_get_locked(const struct genl_split_ops *ops,
69 struct sk_buff *skb, struct genl_info *info)
70 {
71 if (GENL_REQ_ATTR_CHECK(info, PSP_A_DEV_ID))
72 return -EINVAL;
73
74 info->user_ptr[0] = psp_device_get_and_lock(genl_info_net(info),
75 info->attrs[PSP_A_DEV_ID]);
76 return PTR_ERR_OR_ZERO(info->user_ptr[0]);
77 }
78
79 void
psp_device_unlock(const struct genl_split_ops * ops,struct sk_buff * skb,struct genl_info * info)80 psp_device_unlock(const struct genl_split_ops *ops, struct sk_buff *skb,
81 struct genl_info *info)
82 {
83 struct socket *socket = info->user_ptr[1];
84 struct psp_dev *psd = info->user_ptr[0];
85
86 mutex_unlock(&psd->lock);
87 if (socket)
88 sockfd_put(socket);
89 }
90
91 static int
psp_nl_dev_fill(struct psp_dev * psd,struct sk_buff * rsp,const struct genl_info * info)92 psp_nl_dev_fill(struct psp_dev *psd, struct sk_buff *rsp,
93 const struct genl_info *info)
94 {
95 void *hdr;
96
97 hdr = genlmsg_iput(rsp, info);
98 if (!hdr)
99 return -EMSGSIZE;
100
101 if (nla_put_u32(rsp, PSP_A_DEV_ID, psd->id) ||
102 nla_put_u32(rsp, PSP_A_DEV_IFINDEX, psd->main_netdev->ifindex) ||
103 nla_put_u32(rsp, PSP_A_DEV_PSP_VERSIONS_CAP, psd->caps->versions) ||
104 nla_put_u32(rsp, PSP_A_DEV_PSP_VERSIONS_ENA, psd->config.versions))
105 goto err_cancel_msg;
106
107 genlmsg_end(rsp, hdr);
108 return 0;
109
110 err_cancel_msg:
111 genlmsg_cancel(rsp, hdr);
112 return -EMSGSIZE;
113 }
114
psp_nl_notify_dev(struct psp_dev * psd,u32 cmd)115 void psp_nl_notify_dev(struct psp_dev *psd, u32 cmd)
116 {
117 struct genl_info info;
118 struct sk_buff *ntf;
119
120 if (!genl_has_listeners(&psp_nl_family, dev_net(psd->main_netdev),
121 PSP_NLGRP_MGMT))
122 return;
123
124 ntf = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL);
125 if (!ntf)
126 return;
127
128 genl_info_init_ntf(&info, &psp_nl_family, cmd);
129 if (psp_nl_dev_fill(psd, ntf, &info)) {
130 nlmsg_free(ntf);
131 return;
132 }
133
134 genlmsg_multicast_netns(&psp_nl_family, dev_net(psd->main_netdev), ntf,
135 0, PSP_NLGRP_MGMT, GFP_KERNEL);
136 }
137
psp_nl_dev_get_doit(struct sk_buff * req,struct genl_info * info)138 int psp_nl_dev_get_doit(struct sk_buff *req, struct genl_info *info)
139 {
140 struct psp_dev *psd = info->user_ptr[0];
141 struct sk_buff *rsp;
142 int err;
143
144 rsp = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL);
145 if (!rsp)
146 return -ENOMEM;
147
148 err = psp_nl_dev_fill(psd, rsp, info);
149 if (err)
150 goto err_free_msg;
151
152 return genlmsg_reply(rsp, info);
153
154 err_free_msg:
155 nlmsg_free(rsp);
156 return err;
157 }
158
159 static int
psp_nl_dev_get_dumpit_one(struct sk_buff * rsp,struct netlink_callback * cb,struct psp_dev * psd)160 psp_nl_dev_get_dumpit_one(struct sk_buff *rsp, struct netlink_callback *cb,
161 struct psp_dev *psd)
162 {
163 if (psp_dev_check_access(psd, sock_net(rsp->sk)))
164 return 0;
165
166 return psp_nl_dev_fill(psd, rsp, genl_info_dump(cb));
167 }
168
psp_nl_dev_get_dumpit(struct sk_buff * rsp,struct netlink_callback * cb)169 int psp_nl_dev_get_dumpit(struct sk_buff *rsp, struct netlink_callback *cb)
170 {
171 struct psp_dev *psd;
172 int err = 0;
173
174 mutex_lock(&psp_devs_lock);
175 xa_for_each_start(&psp_devs, cb->args[0], psd, cb->args[0]) {
176 mutex_lock(&psd->lock);
177 err = psp_nl_dev_get_dumpit_one(rsp, cb, psd);
178 mutex_unlock(&psd->lock);
179 if (err)
180 break;
181 }
182 mutex_unlock(&psp_devs_lock);
183
184 return err;
185 }
186
psp_nl_dev_set_doit(struct sk_buff * skb,struct genl_info * info)187 int psp_nl_dev_set_doit(struct sk_buff *skb, struct genl_info *info)
188 {
189 struct psp_dev *psd = info->user_ptr[0];
190 struct psp_dev_config new_config;
191 struct sk_buff *rsp;
192 int err;
193
194 memcpy(&new_config, &psd->config, sizeof(new_config));
195
196 if (info->attrs[PSP_A_DEV_PSP_VERSIONS_ENA]) {
197 new_config.versions =
198 nla_get_u32(info->attrs[PSP_A_DEV_PSP_VERSIONS_ENA]);
199 if (new_config.versions & ~psd->caps->versions) {
200 NL_SET_ERR_MSG(info->extack, "Requested PSP versions not supported by the device");
201 return -EINVAL;
202 }
203 } else {
204 NL_SET_ERR_MSG(info->extack, "No settings present");
205 return -EINVAL;
206 }
207
208 rsp = psp_nl_reply_new(info);
209 if (!rsp)
210 return -ENOMEM;
211
212 if (memcmp(&new_config, &psd->config, sizeof(new_config))) {
213 err = psd->ops->set_config(psd, &new_config, info->extack);
214 if (err)
215 goto err_free_rsp;
216
217 memcpy(&psd->config, &new_config, sizeof(new_config));
218 }
219
220 psp_nl_notify_dev(psd, PSP_CMD_DEV_CHANGE_NTF);
221
222 return psp_nl_reply_send(rsp, info);
223
224 err_free_rsp:
225 nlmsg_free(rsp);
226 return err;
227 }
228
psp_nl_key_rotate_doit(struct sk_buff * skb,struct genl_info * info)229 int psp_nl_key_rotate_doit(struct sk_buff *skb, struct genl_info *info)
230 {
231 struct psp_dev *psd = info->user_ptr[0];
232 struct genl_info ntf_info;
233 struct sk_buff *ntf, *rsp;
234 u8 prev_gen;
235 int err;
236
237 rsp = psp_nl_reply_new(info);
238 if (!rsp)
239 return -ENOMEM;
240
241 genl_info_init_ntf(&ntf_info, &psp_nl_family, PSP_CMD_KEY_ROTATE_NTF);
242 ntf = psp_nl_reply_new(&ntf_info);
243 if (!ntf) {
244 err = -ENOMEM;
245 goto err_free_rsp;
246 }
247
248 if (nla_put_u32(rsp, PSP_A_DEV_ID, psd->id) ||
249 nla_put_u32(ntf, PSP_A_DEV_ID, psd->id)) {
250 err = -EMSGSIZE;
251 goto err_free_ntf;
252 }
253
254 /* suggest the next gen number, driver can override */
255 prev_gen = psd->generation;
256 psd->generation = (prev_gen + 1) & PSP_GEN_VALID_MASK;
257
258 err = psd->ops->key_rotate(psd, info->extack);
259 if (err)
260 goto err_free_ntf;
261
262 WARN_ON_ONCE((psd->generation && psd->generation == prev_gen) ||
263 psd->generation & ~PSP_GEN_VALID_MASK);
264
265 psp_assocs_key_rotated(psd);
266 psd->stats.rotations++;
267
268 nlmsg_end(ntf, (struct nlmsghdr *)ntf->data);
269 genlmsg_multicast_netns(&psp_nl_family, dev_net(psd->main_netdev), ntf,
270 0, PSP_NLGRP_USE, GFP_KERNEL);
271 return psp_nl_reply_send(rsp, info);
272
273 err_free_ntf:
274 nlmsg_free(ntf);
275 err_free_rsp:
276 nlmsg_free(rsp);
277 return err;
278 }
279
280 /* Key etc. */
281
psp_assoc_device_get_locked(const struct genl_split_ops * ops,struct sk_buff * skb,struct genl_info * info)282 int psp_assoc_device_get_locked(const struct genl_split_ops *ops,
283 struct sk_buff *skb, struct genl_info *info)
284 {
285 struct socket *socket;
286 struct psp_dev *psd;
287 struct nlattr *id;
288 int fd, err;
289
290 if (GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_SOCK_FD))
291 return -EINVAL;
292
293 fd = nla_get_u32(info->attrs[PSP_A_ASSOC_SOCK_FD]);
294 socket = sockfd_lookup(fd, &err);
295 if (!socket)
296 return err;
297
298 if (!sk_is_tcp(socket->sk)) {
299 NL_SET_ERR_MSG_ATTR(info->extack,
300 info->attrs[PSP_A_ASSOC_SOCK_FD],
301 "Unsupported socket family and type");
302 err = -EOPNOTSUPP;
303 goto err_sock_put;
304 }
305
306 psd = psp_dev_get_for_sock(socket->sk);
307 if (psd) {
308 /* Extra care needed here, psp_dev_get_for_sock() only gives
309 * us access to struct psp_dev's memory, which is quite weak.
310 */
311 mutex_lock(&psd->lock);
312 if (!psp_dev_is_registered(psd) ||
313 psp_dev_check_access(psd, genl_info_net(info))) {
314 mutex_unlock(&psd->lock);
315 psp_dev_put(psd);
316 psd = NULL;
317 }
318 }
319
320 if (!psd && GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_DEV_ID)) {
321 err = -EINVAL;
322 goto err_sock_put;
323 }
324
325 id = info->attrs[PSP_A_ASSOC_DEV_ID];
326 if (psd) {
327 if (id && psd->id != nla_get_u32(id)) {
328 mutex_unlock(&psd->lock);
329 NL_SET_ERR_MSG_ATTR(info->extack, id,
330 "Device id vs socket mismatch");
331 err = -EINVAL;
332 goto err_psd_put;
333 }
334
335 psp_dev_put(psd);
336 } else {
337 psd = psp_device_get_and_lock(genl_info_net(info), id);
338 if (IS_ERR(psd)) {
339 err = PTR_ERR(psd);
340 goto err_sock_put;
341 }
342 }
343
344 info->user_ptr[0] = psd;
345 info->user_ptr[1] = socket;
346
347 return 0;
348
349 err_psd_put:
350 psp_dev_put(psd);
351 err_sock_put:
352 sockfd_put(socket);
353 return err;
354 }
355
356 static int
psp_nl_parse_key(struct genl_info * info,u32 attr,struct psp_key_parsed * key,unsigned int key_sz)357 psp_nl_parse_key(struct genl_info *info, u32 attr, struct psp_key_parsed *key,
358 unsigned int key_sz)
359 {
360 struct nlattr *nest = info->attrs[attr];
361 struct nlattr *tb[PSP_A_KEYS_SPI + 1];
362 u32 spi;
363 int err;
364
365 err = nla_parse_nested(tb, ARRAY_SIZE(tb) - 1, nest,
366 psp_keys_nl_policy, info->extack);
367 if (err)
368 return err;
369
370 if (NL_REQ_ATTR_CHECK(info->extack, nest, tb, PSP_A_KEYS_KEY) ||
371 NL_REQ_ATTR_CHECK(info->extack, nest, tb, PSP_A_KEYS_SPI))
372 return -EINVAL;
373
374 if (nla_len(tb[PSP_A_KEYS_KEY]) != key_sz) {
375 NL_SET_ERR_MSG_ATTR(info->extack, tb[PSP_A_KEYS_KEY],
376 "incorrect key length");
377 return -EINVAL;
378 }
379
380 spi = nla_get_u32(tb[PSP_A_KEYS_SPI]);
381 if (!(spi & PSP_SPI_KEY_ID)) {
382 NL_SET_ERR_MSG_ATTR(info->extack, tb[PSP_A_KEYS_KEY],
383 "invalid SPI: lower 31b must be non-zero");
384 return -EINVAL;
385 }
386
387 key->spi = cpu_to_be32(spi);
388 memcpy(key->key, nla_data(tb[PSP_A_KEYS_KEY]), key_sz);
389
390 return 0;
391 }
392
393 static int
psp_nl_put_key(struct sk_buff * skb,u32 attr,u32 version,struct psp_key_parsed * key)394 psp_nl_put_key(struct sk_buff *skb, u32 attr, u32 version,
395 struct psp_key_parsed *key)
396 {
397 int key_sz = psp_key_size(version);
398 void *nest;
399
400 nest = nla_nest_start(skb, attr);
401
402 if (nla_put_u32(skb, PSP_A_KEYS_SPI, be32_to_cpu(key->spi)) ||
403 nla_put(skb, PSP_A_KEYS_KEY, key_sz, key->key)) {
404 nla_nest_cancel(skb, nest);
405 return -EMSGSIZE;
406 }
407
408 nla_nest_end(skb, nest);
409
410 return 0;
411 }
412
psp_nl_rx_assoc_doit(struct sk_buff * skb,struct genl_info * info)413 int psp_nl_rx_assoc_doit(struct sk_buff *skb, struct genl_info *info)
414 {
415 struct socket *socket = info->user_ptr[1];
416 struct psp_dev *psd = info->user_ptr[0];
417 struct psp_key_parsed key;
418 struct psp_assoc *pas;
419 struct sk_buff *rsp;
420 u32 version;
421 int err;
422
423 if (GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_VERSION))
424 return -EINVAL;
425
426 version = nla_get_u32(info->attrs[PSP_A_ASSOC_VERSION]);
427 if (!(psd->caps->versions & (1 << version))) {
428 NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_ASSOC_VERSION]);
429 return -EOPNOTSUPP;
430 }
431
432 rsp = psp_nl_reply_new(info);
433 if (!rsp)
434 return -ENOMEM;
435
436 pas = psp_assoc_create(psd);
437 if (!pas) {
438 err = -ENOMEM;
439 goto err_free_rsp;
440 }
441 pas->version = version;
442
443 err = psd->ops->rx_spi_alloc(psd, version, &key, info->extack);
444 if (err)
445 goto err_free_pas;
446
447 if (nla_put_u32(rsp, PSP_A_ASSOC_DEV_ID, psd->id) ||
448 psp_nl_put_key(rsp, PSP_A_ASSOC_RX_KEY, version, &key)) {
449 err = -EMSGSIZE;
450 goto err_free_pas;
451 }
452
453 err = psp_sock_assoc_set_rx(socket->sk, pas, &key, info->extack);
454 if (err) {
455 NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_ASSOC_SOCK_FD]);
456 goto err_free_pas;
457 }
458 psp_assoc_put(pas);
459
460 return psp_nl_reply_send(rsp, info);
461
462 err_free_pas:
463 psp_assoc_put(pas);
464 err_free_rsp:
465 nlmsg_free(rsp);
466 return err;
467 }
468
psp_nl_tx_assoc_doit(struct sk_buff * skb,struct genl_info * info)469 int psp_nl_tx_assoc_doit(struct sk_buff *skb, struct genl_info *info)
470 {
471 struct socket *socket = info->user_ptr[1];
472 struct psp_dev *psd = info->user_ptr[0];
473 struct psp_key_parsed key;
474 struct sk_buff *rsp;
475 unsigned int key_sz;
476 u32 version;
477 int err;
478
479 if (GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_VERSION) ||
480 GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_TX_KEY))
481 return -EINVAL;
482
483 version = nla_get_u32(info->attrs[PSP_A_ASSOC_VERSION]);
484 if (!(psd->caps->versions & (1 << version))) {
485 NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_ASSOC_VERSION]);
486 return -EOPNOTSUPP;
487 }
488
489 key_sz = psp_key_size(version);
490 if (!key_sz)
491 return -EINVAL;
492
493 err = psp_nl_parse_key(info, PSP_A_ASSOC_TX_KEY, &key, key_sz);
494 if (err < 0)
495 return err;
496
497 rsp = psp_nl_reply_new(info);
498 if (!rsp)
499 return -ENOMEM;
500
501 err = psp_sock_assoc_set_tx(socket->sk, psd, version, &key,
502 info->extack);
503 if (err)
504 goto err_free_msg;
505
506 return psp_nl_reply_send(rsp, info);
507
508 err_free_msg:
509 nlmsg_free(rsp);
510 return err;
511 }
512
513 static int
psp_nl_stats_fill(struct psp_dev * psd,struct sk_buff * rsp,const struct genl_info * info)514 psp_nl_stats_fill(struct psp_dev *psd, struct sk_buff *rsp,
515 const struct genl_info *info)
516 {
517 unsigned int required_cnt = sizeof(struct psp_dev_stats) / sizeof(u64);
518 struct psp_dev_stats stats;
519 void *hdr;
520 int i;
521
522 memset(&stats, 0xff, sizeof(stats));
523 psd->ops->get_stats(psd, &stats);
524
525 for (i = 0; i < required_cnt; i++)
526 if (WARN_ON_ONCE(stats.required[i] == ETHTOOL_STAT_NOT_SET))
527 return -EOPNOTSUPP;
528
529 hdr = genlmsg_iput(rsp, info);
530 if (!hdr)
531 return -EMSGSIZE;
532
533 if (nla_put_u32(rsp, PSP_A_STATS_DEV_ID, psd->id) ||
534 nla_put_uint(rsp, PSP_A_STATS_KEY_ROTATIONS,
535 psd->stats.rotations) ||
536 nla_put_uint(rsp, PSP_A_STATS_STALE_EVENTS, psd->stats.stales) ||
537 nla_put_uint(rsp, PSP_A_STATS_RX_PACKETS, stats.rx_packets) ||
538 nla_put_uint(rsp, PSP_A_STATS_RX_BYTES, stats.rx_bytes) ||
539 nla_put_uint(rsp, PSP_A_STATS_RX_AUTH_FAIL, stats.rx_auth_fail) ||
540 nla_put_uint(rsp, PSP_A_STATS_RX_ERROR, stats.rx_error) ||
541 nla_put_uint(rsp, PSP_A_STATS_RX_BAD, stats.rx_bad) ||
542 nla_put_uint(rsp, PSP_A_STATS_TX_PACKETS, stats.tx_packets) ||
543 nla_put_uint(rsp, PSP_A_STATS_TX_BYTES, stats.tx_bytes) ||
544 nla_put_uint(rsp, PSP_A_STATS_TX_ERROR, stats.tx_error))
545 goto err_cancel_msg;
546
547 genlmsg_end(rsp, hdr);
548 return 0;
549
550 err_cancel_msg:
551 genlmsg_cancel(rsp, hdr);
552 return -EMSGSIZE;
553 }
554
psp_nl_get_stats_doit(struct sk_buff * skb,struct genl_info * info)555 int psp_nl_get_stats_doit(struct sk_buff *skb, struct genl_info *info)
556 {
557 struct psp_dev *psd = info->user_ptr[0];
558 struct sk_buff *rsp;
559 int err;
560
561 rsp = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL);
562 if (!rsp)
563 return -ENOMEM;
564
565 err = psp_nl_stats_fill(psd, rsp, info);
566 if (err)
567 goto err_free_msg;
568
569 return genlmsg_reply(rsp, info);
570
571 err_free_msg:
572 nlmsg_free(rsp);
573 return err;
574 }
575
576 static int
psp_nl_stats_get_dumpit_one(struct sk_buff * rsp,struct netlink_callback * cb,struct psp_dev * psd)577 psp_nl_stats_get_dumpit_one(struct sk_buff *rsp, struct netlink_callback *cb,
578 struct psp_dev *psd)
579 {
580 if (psp_dev_check_access(psd, sock_net(rsp->sk)))
581 return 0;
582
583 return psp_nl_stats_fill(psd, rsp, genl_info_dump(cb));
584 }
585
psp_nl_get_stats_dumpit(struct sk_buff * rsp,struct netlink_callback * cb)586 int psp_nl_get_stats_dumpit(struct sk_buff *rsp, struct netlink_callback *cb)
587 {
588 struct psp_dev *psd;
589 int err = 0;
590
591 mutex_lock(&psp_devs_lock);
592 xa_for_each_start(&psp_devs, cb->args[0], psd, cb->args[0]) {
593 mutex_lock(&psd->lock);
594 err = psp_nl_stats_get_dumpit_one(rsp, cb, psd);
595 mutex_unlock(&psd->lock);
596 if (err)
597 break;
598 }
599 mutex_unlock(&psp_devs_lock);
600
601 return err;
602 }
603