1 // SPDX-License-Identifier: GPL-2.0-only 2 3 #include <linux/ethtool.h> 4 #include <linux/skbuff.h> 5 #include <linux/xarray.h> 6 #include <net/genetlink.h> 7 #include <net/psp.h> 8 #include <net/sock.h> 9 10 #include "psp-nl-gen.h" 11 #include "psp.h" 12 13 /* Netlink helpers */ 14 15 static struct sk_buff *psp_nl_reply_new(struct genl_info *info) 16 { 17 struct sk_buff *rsp; 18 void *hdr; 19 20 rsp = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL); 21 if (!rsp) 22 return NULL; 23 24 hdr = genlmsg_iput(rsp, info); 25 if (!hdr) { 26 nlmsg_free(rsp); 27 return NULL; 28 } 29 30 return rsp; 31 } 32 33 static int psp_nl_reply_send(struct sk_buff *rsp, struct genl_info *info) 34 { 35 /* Note that this *only* works with a single message per skb! */ 36 nlmsg_end(rsp, (struct nlmsghdr *)rsp->data); 37 38 return genlmsg_reply(rsp, info); 39 } 40 41 /* Device stuff */ 42 43 static struct psp_dev * 44 psp_device_get_and_lock(struct net *net, struct nlattr *dev_id) 45 { 46 struct psp_dev *psd; 47 int err; 48 49 mutex_lock(&psp_devs_lock); 50 psd = xa_load(&psp_devs, nla_get_u32(dev_id)); 51 if (!psd) { 52 mutex_unlock(&psp_devs_lock); 53 return ERR_PTR(-ENODEV); 54 } 55 56 mutex_lock(&psd->lock); 57 mutex_unlock(&psp_devs_lock); 58 59 err = psp_dev_check_access(psd, net); 60 if (err) { 61 mutex_unlock(&psd->lock); 62 return ERR_PTR(err); 63 } 64 65 return psd; 66 } 67 68 int psp_device_get_locked(const struct genl_split_ops *ops, 69 struct sk_buff *skb, struct genl_info *info) 70 { 71 if (GENL_REQ_ATTR_CHECK(info, PSP_A_DEV_ID)) 72 return -EINVAL; 73 74 info->user_ptr[0] = psp_device_get_and_lock(genl_info_net(info), 75 info->attrs[PSP_A_DEV_ID]); 76 return PTR_ERR_OR_ZERO(info->user_ptr[0]); 77 } 78 79 void 80 psp_device_unlock(const struct genl_split_ops *ops, struct sk_buff *skb, 81 struct genl_info *info) 82 { 83 struct socket *socket = info->user_ptr[1]; 84 struct psp_dev *psd = info->user_ptr[0]; 85 86 mutex_unlock(&psd->lock); 87 if (socket) 88 sockfd_put(socket); 89 } 90 91 static int 92 psp_nl_dev_fill(struct psp_dev *psd, struct sk_buff *rsp, 93 const struct genl_info *info) 94 { 95 void *hdr; 96 97 hdr = genlmsg_iput(rsp, info); 98 if (!hdr) 99 return -EMSGSIZE; 100 101 if (nla_put_u32(rsp, PSP_A_DEV_ID, psd->id) || 102 nla_put_u32(rsp, PSP_A_DEV_IFINDEX, psd->main_netdev->ifindex) || 103 nla_put_u32(rsp, PSP_A_DEV_PSP_VERSIONS_CAP, psd->caps->versions) || 104 nla_put_u32(rsp, PSP_A_DEV_PSP_VERSIONS_ENA, psd->config.versions)) 105 goto err_cancel_msg; 106 107 genlmsg_end(rsp, hdr); 108 return 0; 109 110 err_cancel_msg: 111 genlmsg_cancel(rsp, hdr); 112 return -EMSGSIZE; 113 } 114 115 void psp_nl_notify_dev(struct psp_dev *psd, u32 cmd) 116 { 117 struct genl_info info; 118 struct sk_buff *ntf; 119 120 if (!genl_has_listeners(&psp_nl_family, dev_net(psd->main_netdev), 121 PSP_NLGRP_MGMT)) 122 return; 123 124 ntf = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL); 125 if (!ntf) 126 return; 127 128 genl_info_init_ntf(&info, &psp_nl_family, cmd); 129 if (psp_nl_dev_fill(psd, ntf, &info)) { 130 nlmsg_free(ntf); 131 return; 132 } 133 134 genlmsg_multicast_netns(&psp_nl_family, dev_net(psd->main_netdev), ntf, 135 0, PSP_NLGRP_MGMT, GFP_KERNEL); 136 } 137 138 int psp_nl_dev_get_doit(struct sk_buff *req, struct genl_info *info) 139 { 140 struct psp_dev *psd = info->user_ptr[0]; 141 struct sk_buff *rsp; 142 int err; 143 144 rsp = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL); 145 if (!rsp) 146 return -ENOMEM; 147 148 err = psp_nl_dev_fill(psd, rsp, info); 149 if (err) 150 goto err_free_msg; 151 152 return genlmsg_reply(rsp, info); 153 154 err_free_msg: 155 nlmsg_free(rsp); 156 return err; 157 } 158 159 static int 160 psp_nl_dev_get_dumpit_one(struct sk_buff *rsp, struct netlink_callback *cb, 161 struct psp_dev *psd) 162 { 163 if (psp_dev_check_access(psd, sock_net(rsp->sk))) 164 return 0; 165 166 return psp_nl_dev_fill(psd, rsp, genl_info_dump(cb)); 167 } 168 169 int psp_nl_dev_get_dumpit(struct sk_buff *rsp, struct netlink_callback *cb) 170 { 171 struct psp_dev *psd; 172 int err = 0; 173 174 mutex_lock(&psp_devs_lock); 175 xa_for_each_start(&psp_devs, cb->args[0], psd, cb->args[0]) { 176 mutex_lock(&psd->lock); 177 err = psp_nl_dev_get_dumpit_one(rsp, cb, psd); 178 mutex_unlock(&psd->lock); 179 if (err) 180 break; 181 } 182 mutex_unlock(&psp_devs_lock); 183 184 return err; 185 } 186 187 int psp_nl_dev_set_doit(struct sk_buff *skb, struct genl_info *info) 188 { 189 struct psp_dev *psd = info->user_ptr[0]; 190 struct psp_dev_config new_config; 191 struct sk_buff *rsp; 192 int err; 193 194 memcpy(&new_config, &psd->config, sizeof(new_config)); 195 196 if (info->attrs[PSP_A_DEV_PSP_VERSIONS_ENA]) { 197 new_config.versions = 198 nla_get_u32(info->attrs[PSP_A_DEV_PSP_VERSIONS_ENA]); 199 if (new_config.versions & ~psd->caps->versions) { 200 NL_SET_ERR_MSG(info->extack, "Requested PSP versions not supported by the device"); 201 return -EINVAL; 202 } 203 } else { 204 NL_SET_ERR_MSG(info->extack, "No settings present"); 205 return -EINVAL; 206 } 207 208 rsp = psp_nl_reply_new(info); 209 if (!rsp) 210 return -ENOMEM; 211 212 if (memcmp(&new_config, &psd->config, sizeof(new_config))) { 213 err = psd->ops->set_config(psd, &new_config, info->extack); 214 if (err) 215 goto err_free_rsp; 216 217 memcpy(&psd->config, &new_config, sizeof(new_config)); 218 } 219 220 psp_nl_notify_dev(psd, PSP_CMD_DEV_CHANGE_NTF); 221 222 return psp_nl_reply_send(rsp, info); 223 224 err_free_rsp: 225 nlmsg_free(rsp); 226 return err; 227 } 228 229 int psp_nl_key_rotate_doit(struct sk_buff *skb, struct genl_info *info) 230 { 231 struct psp_dev *psd = info->user_ptr[0]; 232 struct genl_info ntf_info; 233 struct sk_buff *ntf, *rsp; 234 u8 prev_gen; 235 int err; 236 237 rsp = psp_nl_reply_new(info); 238 if (!rsp) 239 return -ENOMEM; 240 241 genl_info_init_ntf(&ntf_info, &psp_nl_family, PSP_CMD_KEY_ROTATE_NTF); 242 ntf = psp_nl_reply_new(&ntf_info); 243 if (!ntf) { 244 err = -ENOMEM; 245 goto err_free_rsp; 246 } 247 248 if (nla_put_u32(rsp, PSP_A_DEV_ID, psd->id) || 249 nla_put_u32(ntf, PSP_A_DEV_ID, psd->id)) { 250 err = -EMSGSIZE; 251 goto err_free_ntf; 252 } 253 254 /* suggest the next gen number, driver can override */ 255 prev_gen = psd->generation; 256 psd->generation = (prev_gen + 1) & PSP_GEN_VALID_MASK; 257 258 err = psd->ops->key_rotate(psd, info->extack); 259 if (err) 260 goto err_free_ntf; 261 262 WARN_ON_ONCE((psd->generation && psd->generation == prev_gen) || 263 psd->generation & ~PSP_GEN_VALID_MASK); 264 265 psp_assocs_key_rotated(psd); 266 psd->stats.rotations++; 267 268 nlmsg_end(ntf, (struct nlmsghdr *)ntf->data); 269 genlmsg_multicast_netns(&psp_nl_family, dev_net(psd->main_netdev), ntf, 270 0, PSP_NLGRP_USE, GFP_KERNEL); 271 return psp_nl_reply_send(rsp, info); 272 273 err_free_ntf: 274 nlmsg_free(ntf); 275 err_free_rsp: 276 nlmsg_free(rsp); 277 return err; 278 } 279 280 /* Key etc. */ 281 282 int psp_assoc_device_get_locked(const struct genl_split_ops *ops, 283 struct sk_buff *skb, struct genl_info *info) 284 { 285 struct socket *socket; 286 struct psp_dev *psd; 287 struct nlattr *id; 288 int fd, err; 289 290 if (GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_SOCK_FD)) 291 return -EINVAL; 292 293 fd = nla_get_u32(info->attrs[PSP_A_ASSOC_SOCK_FD]); 294 socket = sockfd_lookup(fd, &err); 295 if (!socket) 296 return err; 297 298 if (!sk_is_tcp(socket->sk)) { 299 NL_SET_ERR_MSG_ATTR(info->extack, 300 info->attrs[PSP_A_ASSOC_SOCK_FD], 301 "Unsupported socket family and type"); 302 err = -EOPNOTSUPP; 303 goto err_sock_put; 304 } 305 306 psd = psp_dev_get_for_sock(socket->sk); 307 if (psd) { 308 err = psp_dev_check_access(psd, genl_info_net(info)); 309 if (err) { 310 psp_dev_put(psd); 311 psd = NULL; 312 } 313 } 314 315 if (!psd && GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_DEV_ID)) { 316 err = -EINVAL; 317 goto err_sock_put; 318 } 319 320 id = info->attrs[PSP_A_ASSOC_DEV_ID]; 321 if (psd) { 322 mutex_lock(&psd->lock); 323 if (id && psd->id != nla_get_u32(id)) { 324 mutex_unlock(&psd->lock); 325 NL_SET_ERR_MSG_ATTR(info->extack, id, 326 "Device id vs socket mismatch"); 327 err = -EINVAL; 328 goto err_psd_put; 329 } 330 331 psp_dev_put(psd); 332 } else { 333 psd = psp_device_get_and_lock(genl_info_net(info), id); 334 if (IS_ERR(psd)) { 335 err = PTR_ERR(psd); 336 goto err_sock_put; 337 } 338 } 339 340 info->user_ptr[0] = psd; 341 info->user_ptr[1] = socket; 342 343 return 0; 344 345 err_psd_put: 346 psp_dev_put(psd); 347 err_sock_put: 348 sockfd_put(socket); 349 return err; 350 } 351 352 static int 353 psp_nl_parse_key(struct genl_info *info, u32 attr, struct psp_key_parsed *key, 354 unsigned int key_sz) 355 { 356 struct nlattr *nest = info->attrs[attr]; 357 struct nlattr *tb[PSP_A_KEYS_SPI + 1]; 358 u32 spi; 359 int err; 360 361 err = nla_parse_nested(tb, ARRAY_SIZE(tb) - 1, nest, 362 psp_keys_nl_policy, info->extack); 363 if (err) 364 return err; 365 366 if (NL_REQ_ATTR_CHECK(info->extack, nest, tb, PSP_A_KEYS_KEY) || 367 NL_REQ_ATTR_CHECK(info->extack, nest, tb, PSP_A_KEYS_SPI)) 368 return -EINVAL; 369 370 if (nla_len(tb[PSP_A_KEYS_KEY]) != key_sz) { 371 NL_SET_ERR_MSG_ATTR(info->extack, tb[PSP_A_KEYS_KEY], 372 "incorrect key length"); 373 return -EINVAL; 374 } 375 376 spi = nla_get_u32(tb[PSP_A_KEYS_SPI]); 377 if (!(spi & PSP_SPI_KEY_ID)) { 378 NL_SET_ERR_MSG_ATTR(info->extack, tb[PSP_A_KEYS_KEY], 379 "invalid SPI: lower 31b must be non-zero"); 380 return -EINVAL; 381 } 382 383 key->spi = cpu_to_be32(spi); 384 memcpy(key->key, nla_data(tb[PSP_A_KEYS_KEY]), key_sz); 385 386 return 0; 387 } 388 389 static int 390 psp_nl_put_key(struct sk_buff *skb, u32 attr, u32 version, 391 struct psp_key_parsed *key) 392 { 393 int key_sz = psp_key_size(version); 394 void *nest; 395 396 nest = nla_nest_start(skb, attr); 397 398 if (nla_put_u32(skb, PSP_A_KEYS_SPI, be32_to_cpu(key->spi)) || 399 nla_put(skb, PSP_A_KEYS_KEY, key_sz, key->key)) { 400 nla_nest_cancel(skb, nest); 401 return -EMSGSIZE; 402 } 403 404 nla_nest_end(skb, nest); 405 406 return 0; 407 } 408 409 int psp_nl_rx_assoc_doit(struct sk_buff *skb, struct genl_info *info) 410 { 411 struct socket *socket = info->user_ptr[1]; 412 struct psp_dev *psd = info->user_ptr[0]; 413 struct psp_key_parsed key; 414 struct psp_assoc *pas; 415 struct sk_buff *rsp; 416 u32 version; 417 int err; 418 419 if (GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_VERSION)) 420 return -EINVAL; 421 422 version = nla_get_u32(info->attrs[PSP_A_ASSOC_VERSION]); 423 if (!(psd->caps->versions & (1 << version))) { 424 NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_ASSOC_VERSION]); 425 return -EOPNOTSUPP; 426 } 427 428 rsp = psp_nl_reply_new(info); 429 if (!rsp) 430 return -ENOMEM; 431 432 pas = psp_assoc_create(psd); 433 if (!pas) { 434 err = -ENOMEM; 435 goto err_free_rsp; 436 } 437 pas->version = version; 438 439 err = psd->ops->rx_spi_alloc(psd, version, &key, info->extack); 440 if (err) 441 goto err_free_pas; 442 443 if (nla_put_u32(rsp, PSP_A_ASSOC_DEV_ID, psd->id) || 444 psp_nl_put_key(rsp, PSP_A_ASSOC_RX_KEY, version, &key)) { 445 err = -EMSGSIZE; 446 goto err_free_pas; 447 } 448 449 err = psp_sock_assoc_set_rx(socket->sk, pas, &key, info->extack); 450 if (err) { 451 NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_ASSOC_SOCK_FD]); 452 goto err_free_pas; 453 } 454 psp_assoc_put(pas); 455 456 return psp_nl_reply_send(rsp, info); 457 458 err_free_pas: 459 psp_assoc_put(pas); 460 err_free_rsp: 461 nlmsg_free(rsp); 462 return err; 463 } 464 465 int psp_nl_tx_assoc_doit(struct sk_buff *skb, struct genl_info *info) 466 { 467 struct socket *socket = info->user_ptr[1]; 468 struct psp_dev *psd = info->user_ptr[0]; 469 struct psp_key_parsed key; 470 struct sk_buff *rsp; 471 unsigned int key_sz; 472 u32 version; 473 int err; 474 475 if (GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_VERSION) || 476 GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_TX_KEY)) 477 return -EINVAL; 478 479 version = nla_get_u32(info->attrs[PSP_A_ASSOC_VERSION]); 480 if (!(psd->caps->versions & (1 << version))) { 481 NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_ASSOC_VERSION]); 482 return -EOPNOTSUPP; 483 } 484 485 key_sz = psp_key_size(version); 486 if (!key_sz) 487 return -EINVAL; 488 489 err = psp_nl_parse_key(info, PSP_A_ASSOC_TX_KEY, &key, key_sz); 490 if (err < 0) 491 return err; 492 493 rsp = psp_nl_reply_new(info); 494 if (!rsp) 495 return -ENOMEM; 496 497 err = psp_sock_assoc_set_tx(socket->sk, psd, version, &key, 498 info->extack); 499 if (err) 500 goto err_free_msg; 501 502 return psp_nl_reply_send(rsp, info); 503 504 err_free_msg: 505 nlmsg_free(rsp); 506 return err; 507 } 508 509 static int 510 psp_nl_stats_fill(struct psp_dev *psd, struct sk_buff *rsp, 511 const struct genl_info *info) 512 { 513 unsigned int required_cnt = sizeof(struct psp_dev_stats) / sizeof(u64); 514 struct psp_dev_stats stats; 515 void *hdr; 516 int i; 517 518 memset(&stats, 0xff, sizeof(stats)); 519 psd->ops->get_stats(psd, &stats); 520 521 for (i = 0; i < required_cnt; i++) 522 if (WARN_ON_ONCE(stats.required[i] == ETHTOOL_STAT_NOT_SET)) 523 return -EOPNOTSUPP; 524 525 hdr = genlmsg_iput(rsp, info); 526 if (!hdr) 527 return -EMSGSIZE; 528 529 if (nla_put_u32(rsp, PSP_A_STATS_DEV_ID, psd->id) || 530 nla_put_uint(rsp, PSP_A_STATS_KEY_ROTATIONS, 531 psd->stats.rotations) || 532 nla_put_uint(rsp, PSP_A_STATS_STALE_EVENTS, psd->stats.stales) || 533 nla_put_uint(rsp, PSP_A_STATS_RX_PACKETS, stats.rx_packets) || 534 nla_put_uint(rsp, PSP_A_STATS_RX_BYTES, stats.rx_bytes) || 535 nla_put_uint(rsp, PSP_A_STATS_RX_AUTH_FAIL, stats.rx_auth_fail) || 536 nla_put_uint(rsp, PSP_A_STATS_RX_ERROR, stats.rx_error) || 537 nla_put_uint(rsp, PSP_A_STATS_RX_BAD, stats.rx_bad) || 538 nla_put_uint(rsp, PSP_A_STATS_TX_PACKETS, stats.tx_packets) || 539 nla_put_uint(rsp, PSP_A_STATS_TX_BYTES, stats.tx_bytes) || 540 nla_put_uint(rsp, PSP_A_STATS_TX_ERROR, stats.tx_error)) 541 goto err_cancel_msg; 542 543 genlmsg_end(rsp, hdr); 544 return 0; 545 546 err_cancel_msg: 547 genlmsg_cancel(rsp, hdr); 548 return -EMSGSIZE; 549 } 550 551 int psp_nl_get_stats_doit(struct sk_buff *skb, struct genl_info *info) 552 { 553 struct psp_dev *psd = info->user_ptr[0]; 554 struct sk_buff *rsp; 555 int err; 556 557 rsp = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL); 558 if (!rsp) 559 return -ENOMEM; 560 561 err = psp_nl_stats_fill(psd, rsp, info); 562 if (err) 563 goto err_free_msg; 564 565 return genlmsg_reply(rsp, info); 566 567 err_free_msg: 568 nlmsg_free(rsp); 569 return err; 570 } 571 572 static int 573 psp_nl_stats_get_dumpit_one(struct sk_buff *rsp, struct netlink_callback *cb, 574 struct psp_dev *psd) 575 { 576 if (psp_dev_check_access(psd, sock_net(rsp->sk))) 577 return 0; 578 579 return psp_nl_stats_fill(psd, rsp, genl_info_dump(cb)); 580 } 581 582 int psp_nl_get_stats_dumpit(struct sk_buff *rsp, struct netlink_callback *cb) 583 { 584 struct psp_dev *psd; 585 int err = 0; 586 587 mutex_lock(&psp_devs_lock); 588 xa_for_each_start(&psp_devs, cb->args[0], psd, cb->args[0]) { 589 mutex_lock(&psd->lock); 590 err = psp_nl_stats_get_dumpit_one(rsp, cb, psd); 591 mutex_unlock(&psd->lock); 592 if (err) 593 break; 594 } 595 mutex_unlock(&psp_devs_lock); 596 597 return err; 598 } 599