1 // SPDX-License-Identifier: GPL-2.0-only 2 3 #include <linux/skbuff.h> 4 #include <linux/xarray.h> 5 #include <net/genetlink.h> 6 #include <net/psp.h> 7 #include <net/sock.h> 8 9 #include "psp-nl-gen.h" 10 #include "psp.h" 11 12 /* Netlink helpers */ 13 14 static struct sk_buff *psp_nl_reply_new(struct genl_info *info) 15 { 16 struct sk_buff *rsp; 17 void *hdr; 18 19 rsp = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL); 20 if (!rsp) 21 return NULL; 22 23 hdr = genlmsg_iput(rsp, info); 24 if (!hdr) { 25 nlmsg_free(rsp); 26 return NULL; 27 } 28 29 return rsp; 30 } 31 32 static int psp_nl_reply_send(struct sk_buff *rsp, struct genl_info *info) 33 { 34 /* Note that this *only* works with a single message per skb! */ 35 nlmsg_end(rsp, (struct nlmsghdr *)rsp->data); 36 37 return genlmsg_reply(rsp, info); 38 } 39 40 /* Device stuff */ 41 42 static struct psp_dev * 43 psp_device_get_and_lock(struct net *net, struct nlattr *dev_id) 44 { 45 struct psp_dev *psd; 46 int err; 47 48 mutex_lock(&psp_devs_lock); 49 psd = xa_load(&psp_devs, nla_get_u32(dev_id)); 50 if (!psd) { 51 mutex_unlock(&psp_devs_lock); 52 return ERR_PTR(-ENODEV); 53 } 54 55 mutex_lock(&psd->lock); 56 mutex_unlock(&psp_devs_lock); 57 58 err = psp_dev_check_access(psd, net); 59 if (err) { 60 mutex_unlock(&psd->lock); 61 return ERR_PTR(err); 62 } 63 64 return psd; 65 } 66 67 int psp_device_get_locked(const struct genl_split_ops *ops, 68 struct sk_buff *skb, struct genl_info *info) 69 { 70 if (GENL_REQ_ATTR_CHECK(info, PSP_A_DEV_ID)) 71 return -EINVAL; 72 73 info->user_ptr[0] = psp_device_get_and_lock(genl_info_net(info), 74 info->attrs[PSP_A_DEV_ID]); 75 return PTR_ERR_OR_ZERO(info->user_ptr[0]); 76 } 77 78 void 79 psp_device_unlock(const struct genl_split_ops *ops, struct sk_buff *skb, 80 struct genl_info *info) 81 { 82 struct socket *socket = info->user_ptr[1]; 83 struct psp_dev *psd = info->user_ptr[0]; 84 85 mutex_unlock(&psd->lock); 86 if (socket) 87 sockfd_put(socket); 88 } 89 90 static int 91 psp_nl_dev_fill(struct psp_dev *psd, struct sk_buff *rsp, 92 const struct genl_info *info) 93 { 94 void *hdr; 95 96 hdr = genlmsg_iput(rsp, info); 97 if (!hdr) 98 return -EMSGSIZE; 99 100 if (nla_put_u32(rsp, PSP_A_DEV_ID, psd->id) || 101 nla_put_u32(rsp, PSP_A_DEV_IFINDEX, psd->main_netdev->ifindex) || 102 nla_put_u32(rsp, PSP_A_DEV_PSP_VERSIONS_CAP, psd->caps->versions) || 103 nla_put_u32(rsp, PSP_A_DEV_PSP_VERSIONS_ENA, psd->config.versions)) 104 goto err_cancel_msg; 105 106 genlmsg_end(rsp, hdr); 107 return 0; 108 109 err_cancel_msg: 110 genlmsg_cancel(rsp, hdr); 111 return -EMSGSIZE; 112 } 113 114 void psp_nl_notify_dev(struct psp_dev *psd, u32 cmd) 115 { 116 struct genl_info info; 117 struct sk_buff *ntf; 118 119 if (!genl_has_listeners(&psp_nl_family, dev_net(psd->main_netdev), 120 PSP_NLGRP_MGMT)) 121 return; 122 123 ntf = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL); 124 if (!ntf) 125 return; 126 127 genl_info_init_ntf(&info, &psp_nl_family, cmd); 128 if (psp_nl_dev_fill(psd, ntf, &info)) { 129 nlmsg_free(ntf); 130 return; 131 } 132 133 genlmsg_multicast_netns(&psp_nl_family, dev_net(psd->main_netdev), ntf, 134 0, PSP_NLGRP_MGMT, GFP_KERNEL); 135 } 136 137 int psp_nl_dev_get_doit(struct sk_buff *req, struct genl_info *info) 138 { 139 struct psp_dev *psd = info->user_ptr[0]; 140 struct sk_buff *rsp; 141 int err; 142 143 rsp = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL); 144 if (!rsp) 145 return -ENOMEM; 146 147 err = psp_nl_dev_fill(psd, rsp, info); 148 if (err) 149 goto err_free_msg; 150 151 return genlmsg_reply(rsp, info); 152 153 err_free_msg: 154 nlmsg_free(rsp); 155 return err; 156 } 157 158 static int 159 psp_nl_dev_get_dumpit_one(struct sk_buff *rsp, struct netlink_callback *cb, 160 struct psp_dev *psd) 161 { 162 if (psp_dev_check_access(psd, sock_net(rsp->sk))) 163 return 0; 164 165 return psp_nl_dev_fill(psd, rsp, genl_info_dump(cb)); 166 } 167 168 int psp_nl_dev_get_dumpit(struct sk_buff *rsp, struct netlink_callback *cb) 169 { 170 struct psp_dev *psd; 171 int err = 0; 172 173 mutex_lock(&psp_devs_lock); 174 xa_for_each_start(&psp_devs, cb->args[0], psd, cb->args[0]) { 175 mutex_lock(&psd->lock); 176 err = psp_nl_dev_get_dumpit_one(rsp, cb, psd); 177 mutex_unlock(&psd->lock); 178 if (err) 179 break; 180 } 181 mutex_unlock(&psp_devs_lock); 182 183 return err; 184 } 185 186 int psp_nl_dev_set_doit(struct sk_buff *skb, struct genl_info *info) 187 { 188 struct psp_dev *psd = info->user_ptr[0]; 189 struct psp_dev_config new_config; 190 struct sk_buff *rsp; 191 int err; 192 193 memcpy(&new_config, &psd->config, sizeof(new_config)); 194 195 if (info->attrs[PSP_A_DEV_PSP_VERSIONS_ENA]) { 196 new_config.versions = 197 nla_get_u32(info->attrs[PSP_A_DEV_PSP_VERSIONS_ENA]); 198 if (new_config.versions & ~psd->caps->versions) { 199 NL_SET_ERR_MSG(info->extack, "Requested PSP versions not supported by the device"); 200 return -EINVAL; 201 } 202 } else { 203 NL_SET_ERR_MSG(info->extack, "No settings present"); 204 return -EINVAL; 205 } 206 207 rsp = psp_nl_reply_new(info); 208 if (!rsp) 209 return -ENOMEM; 210 211 if (memcmp(&new_config, &psd->config, sizeof(new_config))) { 212 err = psd->ops->set_config(psd, &new_config, info->extack); 213 if (err) 214 goto err_free_rsp; 215 216 memcpy(&psd->config, &new_config, sizeof(new_config)); 217 } 218 219 psp_nl_notify_dev(psd, PSP_CMD_DEV_CHANGE_NTF); 220 221 return psp_nl_reply_send(rsp, info); 222 223 err_free_rsp: 224 nlmsg_free(rsp); 225 return err; 226 } 227 228 int psp_nl_key_rotate_doit(struct sk_buff *skb, struct genl_info *info) 229 { 230 struct psp_dev *psd = info->user_ptr[0]; 231 struct genl_info ntf_info; 232 struct sk_buff *ntf, *rsp; 233 u8 prev_gen; 234 int err; 235 236 rsp = psp_nl_reply_new(info); 237 if (!rsp) 238 return -ENOMEM; 239 240 genl_info_init_ntf(&ntf_info, &psp_nl_family, PSP_CMD_KEY_ROTATE_NTF); 241 ntf = psp_nl_reply_new(&ntf_info); 242 if (!ntf) { 243 err = -ENOMEM; 244 goto err_free_rsp; 245 } 246 247 if (nla_put_u32(rsp, PSP_A_DEV_ID, psd->id) || 248 nla_put_u32(ntf, PSP_A_DEV_ID, psd->id)) { 249 err = -EMSGSIZE; 250 goto err_free_ntf; 251 } 252 253 /* suggest the next gen number, driver can override */ 254 prev_gen = psd->generation; 255 psd->generation = (prev_gen + 1) & PSP_GEN_VALID_MASK; 256 257 err = psd->ops->key_rotate(psd, info->extack); 258 if (err) 259 goto err_free_ntf; 260 261 WARN_ON_ONCE((psd->generation && psd->generation == prev_gen) || 262 psd->generation & ~PSP_GEN_VALID_MASK); 263 264 psp_assocs_key_rotated(psd); 265 266 nlmsg_end(ntf, (struct nlmsghdr *)ntf->data); 267 genlmsg_multicast_netns(&psp_nl_family, dev_net(psd->main_netdev), ntf, 268 0, PSP_NLGRP_USE, GFP_KERNEL); 269 return psp_nl_reply_send(rsp, info); 270 271 err_free_ntf: 272 nlmsg_free(ntf); 273 err_free_rsp: 274 nlmsg_free(rsp); 275 return err; 276 } 277 278 /* Key etc. */ 279 280 int psp_assoc_device_get_locked(const struct genl_split_ops *ops, 281 struct sk_buff *skb, struct genl_info *info) 282 { 283 struct socket *socket; 284 struct psp_dev *psd; 285 struct nlattr *id; 286 int fd, err; 287 288 if (GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_SOCK_FD)) 289 return -EINVAL; 290 291 fd = nla_get_u32(info->attrs[PSP_A_ASSOC_SOCK_FD]); 292 socket = sockfd_lookup(fd, &err); 293 if (!socket) 294 return err; 295 296 if (!sk_is_tcp(socket->sk)) { 297 NL_SET_ERR_MSG_ATTR(info->extack, 298 info->attrs[PSP_A_ASSOC_SOCK_FD], 299 "Unsupported socket family and type"); 300 err = -EOPNOTSUPP; 301 goto err_sock_put; 302 } 303 304 psd = psp_dev_get_for_sock(socket->sk); 305 if (psd) { 306 err = psp_dev_check_access(psd, genl_info_net(info)); 307 if (err) { 308 psp_dev_put(psd); 309 psd = NULL; 310 } 311 } 312 313 if (!psd && GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_DEV_ID)) { 314 err = -EINVAL; 315 goto err_sock_put; 316 } 317 318 id = info->attrs[PSP_A_ASSOC_DEV_ID]; 319 if (psd) { 320 mutex_lock(&psd->lock); 321 if (id && psd->id != nla_get_u32(id)) { 322 mutex_unlock(&psd->lock); 323 NL_SET_ERR_MSG_ATTR(info->extack, id, 324 "Device id vs socket mismatch"); 325 err = -EINVAL; 326 goto err_psd_put; 327 } 328 329 psp_dev_put(psd); 330 } else { 331 psd = psp_device_get_and_lock(genl_info_net(info), id); 332 if (IS_ERR(psd)) { 333 err = PTR_ERR(psd); 334 goto err_sock_put; 335 } 336 } 337 338 info->user_ptr[0] = psd; 339 info->user_ptr[1] = socket; 340 341 return 0; 342 343 err_psd_put: 344 psp_dev_put(psd); 345 err_sock_put: 346 sockfd_put(socket); 347 return err; 348 } 349 350 static int 351 psp_nl_parse_key(struct genl_info *info, u32 attr, struct psp_key_parsed *key, 352 unsigned int key_sz) 353 { 354 struct nlattr *nest = info->attrs[attr]; 355 struct nlattr *tb[PSP_A_KEYS_SPI + 1]; 356 u32 spi; 357 int err; 358 359 err = nla_parse_nested(tb, ARRAY_SIZE(tb) - 1, nest, 360 psp_keys_nl_policy, info->extack); 361 if (err) 362 return err; 363 364 if (NL_REQ_ATTR_CHECK(info->extack, nest, tb, PSP_A_KEYS_KEY) || 365 NL_REQ_ATTR_CHECK(info->extack, nest, tb, PSP_A_KEYS_SPI)) 366 return -EINVAL; 367 368 if (nla_len(tb[PSP_A_KEYS_KEY]) != key_sz) { 369 NL_SET_ERR_MSG_ATTR(info->extack, tb[PSP_A_KEYS_KEY], 370 "incorrect key length"); 371 return -EINVAL; 372 } 373 374 spi = nla_get_u32(tb[PSP_A_KEYS_SPI]); 375 if (!(spi & PSP_SPI_KEY_ID)) { 376 NL_SET_ERR_MSG_ATTR(info->extack, tb[PSP_A_KEYS_KEY], 377 "invalid SPI: lower 31b must be non-zero"); 378 return -EINVAL; 379 } 380 381 key->spi = cpu_to_be32(spi); 382 memcpy(key->key, nla_data(tb[PSP_A_KEYS_KEY]), key_sz); 383 384 return 0; 385 } 386 387 static int 388 psp_nl_put_key(struct sk_buff *skb, u32 attr, u32 version, 389 struct psp_key_parsed *key) 390 { 391 int key_sz = psp_key_size(version); 392 void *nest; 393 394 nest = nla_nest_start(skb, attr); 395 396 if (nla_put_u32(skb, PSP_A_KEYS_SPI, be32_to_cpu(key->spi)) || 397 nla_put(skb, PSP_A_KEYS_KEY, key_sz, key->key)) { 398 nla_nest_cancel(skb, nest); 399 return -EMSGSIZE; 400 } 401 402 nla_nest_end(skb, nest); 403 404 return 0; 405 } 406 407 int psp_nl_rx_assoc_doit(struct sk_buff *skb, struct genl_info *info) 408 { 409 struct socket *socket = info->user_ptr[1]; 410 struct psp_dev *psd = info->user_ptr[0]; 411 struct psp_key_parsed key; 412 struct psp_assoc *pas; 413 struct sk_buff *rsp; 414 u32 version; 415 int err; 416 417 if (GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_VERSION)) 418 return -EINVAL; 419 420 version = nla_get_u32(info->attrs[PSP_A_ASSOC_VERSION]); 421 if (!(psd->caps->versions & (1 << version))) { 422 NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_ASSOC_VERSION]); 423 return -EOPNOTSUPP; 424 } 425 426 rsp = psp_nl_reply_new(info); 427 if (!rsp) 428 return -ENOMEM; 429 430 pas = psp_assoc_create(psd); 431 if (!pas) { 432 err = -ENOMEM; 433 goto err_free_rsp; 434 } 435 pas->version = version; 436 437 err = psd->ops->rx_spi_alloc(psd, version, &key, info->extack); 438 if (err) 439 goto err_free_pas; 440 441 if (nla_put_u32(rsp, PSP_A_ASSOC_DEV_ID, psd->id) || 442 psp_nl_put_key(rsp, PSP_A_ASSOC_RX_KEY, version, &key)) { 443 err = -EMSGSIZE; 444 goto err_free_pas; 445 } 446 447 err = psp_sock_assoc_set_rx(socket->sk, pas, &key, info->extack); 448 if (err) { 449 NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_ASSOC_SOCK_FD]); 450 goto err_free_pas; 451 } 452 psp_assoc_put(pas); 453 454 return psp_nl_reply_send(rsp, info); 455 456 err_free_pas: 457 psp_assoc_put(pas); 458 err_free_rsp: 459 nlmsg_free(rsp); 460 return err; 461 } 462 463 int psp_nl_tx_assoc_doit(struct sk_buff *skb, struct genl_info *info) 464 { 465 struct socket *socket = info->user_ptr[1]; 466 struct psp_dev *psd = info->user_ptr[0]; 467 struct psp_key_parsed key; 468 struct sk_buff *rsp; 469 unsigned int key_sz; 470 u32 version; 471 int err; 472 473 if (GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_VERSION) || 474 GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_TX_KEY)) 475 return -EINVAL; 476 477 version = nla_get_u32(info->attrs[PSP_A_ASSOC_VERSION]); 478 if (!(psd->caps->versions & (1 << version))) { 479 NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_ASSOC_VERSION]); 480 return -EOPNOTSUPP; 481 } 482 483 key_sz = psp_key_size(version); 484 if (!key_sz) 485 return -EINVAL; 486 487 err = psp_nl_parse_key(info, PSP_A_ASSOC_TX_KEY, &key, key_sz); 488 if (err < 0) 489 return err; 490 491 rsp = psp_nl_reply_new(info); 492 if (!rsp) 493 return -ENOMEM; 494 495 err = psp_sock_assoc_set_tx(socket->sk, psd, version, &key, 496 info->extack); 497 if (err) 498 goto err_free_msg; 499 500 return psp_nl_reply_send(rsp, info); 501 502 err_free_msg: 503 nlmsg_free(rsp); 504 return err; 505 } 506