1 // SPDX-License-Identifier: GPL-2.0-only 2 3 #include <linux/bitfield.h> 4 #include <linux/list.h> 5 #include <linux/netdevice.h> 6 #include <linux/xarray.h> 7 #include <net/net_namespace.h> 8 #include <net/psp.h> 9 #include <net/udp.h> 10 11 #include "psp.h" 12 #include "psp-nl-gen.h" 13 14 DEFINE_XARRAY_ALLOC1(psp_devs); 15 struct mutex psp_devs_lock; 16 17 /** 18 * DOC: PSP locking 19 * 20 * psp_devs_lock protects the psp_devs xarray. 21 * Ordering is take the psp_devs_lock and then the instance lock. 22 * Each instance is protected by RCU, and has a refcount. 23 * When driver unregisters the instance gets flushed, but struct sticks around. 24 */ 25 26 /** 27 * psp_dev_check_access() - check if user in a given net ns can access PSP dev 28 * @psd: PSP device structure user is trying to access 29 * @net: net namespace user is in 30 * @admin: If true, only allow access from @psd's main device's netns, 31 * for admin operations like config changes and key rotation. 32 * If false, also allow access from network namespaces that have 33 * an associated device with @psd, for read-only and association 34 * management operations. 35 * 36 * Return: 0 if PSP device should be visible in @net, errno otherwise. 37 */ 38 int psp_dev_check_access(struct psp_dev *psd, struct net *net, bool admin) 39 { 40 if (dev_net(psd->main_netdev) == net) 41 return 0; 42 43 if (!admin && psp_has_assoc_dev_in_ns(psd, net)) 44 return 0; 45 46 return -ENOENT; 47 } 48 49 /** 50 * psp_dev_create() - create and register PSP device 51 * @netdev: main netdevice 52 * @psd_ops: driver callbacks 53 * @psd_caps: device capabilities 54 * @priv_ptr: back-pointer to driver private data 55 * 56 * Return: pointer to allocated PSP device, or ERR_PTR. 57 */ 58 struct psp_dev * 59 psp_dev_create(struct net_device *netdev, 60 struct psp_dev_ops *psd_ops, struct psp_dev_caps *psd_caps, 61 void *priv_ptr) 62 { 63 struct psp_dev *psd; 64 static u32 last_id; 65 int err; 66 67 if (WARN_ON(!psd_caps->versions || 68 !psd_ops->set_config || 69 !psd_ops->key_rotate || 70 !psd_ops->rx_spi_alloc || 71 !psd_ops->tx_key_add || 72 !psd_ops->tx_key_del || 73 !psd_ops->get_stats)) 74 return ERR_PTR(-EINVAL); 75 76 psd = kzalloc_obj(*psd); 77 if (!psd) 78 return ERR_PTR(-ENOMEM); 79 80 psd->main_netdev = netdev; 81 INIT_LIST_HEAD(&psd->assoc_dev_list); 82 psd->ops = psd_ops; 83 psd->caps = psd_caps; 84 psd->drv_priv = priv_ptr; 85 86 mutex_init(&psd->lock); 87 INIT_LIST_HEAD(&psd->active_assocs); 88 INIT_LIST_HEAD(&psd->prev_assocs); 89 INIT_LIST_HEAD(&psd->stale_assocs); 90 refcount_set(&psd->refcnt, 1); 91 92 mutex_lock(&psp_devs_lock); 93 err = xa_alloc_cyclic(&psp_devs, &psd->id, psd, xa_limit_16b, 94 &last_id, GFP_KERNEL); 95 if (err) { 96 mutex_unlock(&psp_devs_lock); 97 kfree(psd); 98 return ERR_PTR(err); 99 } 100 mutex_lock(&psd->lock); 101 mutex_unlock(&psp_devs_lock); 102 103 /* notify before netdev assignment 104 * There's no strong reason for it, but thinking is to avoid creating 105 * implicit expectations about the PSP dev <> netdev relationship. 106 */ 107 psp_nl_notify_dev(psd, PSP_CMD_DEV_ADD_NTF); 108 109 rcu_assign_pointer(netdev->psp_dev, psd); 110 111 mutex_unlock(&psd->lock); 112 113 return psd; 114 } 115 EXPORT_SYMBOL(psp_dev_create); 116 117 void psp_dev_free(struct psp_dev *psd) 118 { 119 mutex_lock(&psp_devs_lock); 120 xa_erase(&psp_devs, psd->id); 121 mutex_unlock(&psp_devs_lock); 122 123 mutex_destroy(&psd->lock); 124 kfree_rcu(psd, rcu); 125 } 126 127 /** 128 * psp_dev_unregister() - unregister PSP device 129 * @psd: PSP device structure 130 */ 131 void psp_dev_unregister(struct psp_dev *psd) 132 { 133 struct psp_assoc_dev *entry, *entry_tmp; 134 struct psp_assoc *pas, *next; 135 136 mutex_lock(&psp_devs_lock); 137 mutex_lock(&psd->lock); 138 139 psp_nl_notify_dev(psd, PSP_CMD_DEV_DEL_NTF); 140 141 /* Wait until psp_dev_free() to call xa_erase() to prevent a 142 * different psd from being added to the xarray with this id, while 143 * there are still references to this psd being held. 144 */ 145 xa_store(&psp_devs, psd->id, NULL, GFP_KERNEL); 146 mutex_unlock(&psp_devs_lock); 147 148 list_splice_init(&psd->active_assocs, &psd->prev_assocs); 149 list_splice_init(&psd->prev_assocs, &psd->stale_assocs); 150 list_for_each_entry_safe(pas, next, &psd->stale_assocs, assocs_list) 151 psp_dev_tx_key_del(psd, pas); 152 153 list_for_each_entry_safe(entry, entry_tmp, &psd->assoc_dev_list, 154 dev_list) { 155 list_del(&entry->dev_list); 156 rcu_assign_pointer(entry->assoc_dev->psp_dev, NULL); 157 netdev_put(entry->assoc_dev, &entry->dev_tracker); 158 kfree(entry); 159 } 160 psd->assoc_dev_cnt = 0; 161 162 rcu_assign_pointer(psd->main_netdev->psp_dev, NULL); 163 164 psd->ops = NULL; 165 psd->drv_priv = NULL; 166 167 mutex_unlock(&psd->lock); 168 169 psp_dev_put(psd); 170 } 171 EXPORT_SYMBOL(psp_dev_unregister); 172 173 unsigned int psp_key_size(u32 version) 174 { 175 switch (version) { 176 case PSP_VERSION_HDR0_AES_GCM_128: 177 case PSP_VERSION_HDR0_AES_GMAC_128: 178 return 16; 179 case PSP_VERSION_HDR0_AES_GCM_256: 180 case PSP_VERSION_HDR0_AES_GMAC_256: 181 return 32; 182 default: 183 return 0; 184 } 185 } 186 EXPORT_SYMBOL(psp_key_size); 187 188 static void psp_write_headers(struct net *net, struct sk_buff *skb, __be32 spi, 189 u8 ver, unsigned int udp_len, __be16 sport) 190 { 191 struct udphdr *uh = udp_hdr(skb); 192 struct psphdr *psph = (struct psphdr *)(uh + 1); 193 const struct sock *sk = skb->sk; 194 195 uh->dest = htons(PSP_DEFAULT_UDP_PORT); 196 197 /* A bit of theory: Selection of the source port. 198 * 199 * We need some entropy, so that multiple flows use different 200 * source ports for better RSS spreading at the receiver. 201 * 202 * We also need that all packets belonging to one TCP flow 203 * use the same source port through their duration, 204 * so that all these packets land in the same receive queue. 205 * 206 * udp_flow_src_port() is using sk_txhash, inherited from 207 * skb_set_hash_from_sk() call in __tcp_transmit_skb(). 208 * This field is subject to reshuffling, thanks to 209 * sk_rethink_txhash() calls in various TCP functions. 210 * 211 * Instead, use sk->sk_hash which is constant through 212 * the whole flow duration. 213 */ 214 if (likely(sk)) { 215 u32 hash = sk->sk_hash; 216 int min, max; 217 218 /* These operations are cheap, no need to cache the result 219 * in another socket field. 220 */ 221 inet_get_local_port_range(net, &min, &max); 222 /* Since this is being sent on the wire obfuscate hash a bit 223 * to minimize possibility that any useful information to an 224 * attacker is leaked. Only upper 16 bits are relevant in the 225 * computation for 16 bit port value because we use a 226 * reciprocal divide. 227 */ 228 hash ^= hash << 16; 229 uh->source = htons(reciprocal_scale(hash, max - min + 1) + min); 230 } else { 231 uh->source = udp_flow_src_port(net, skb, 0, 0, false); 232 } 233 uh->check = 0; 234 uh->len = htons(udp_len); 235 236 psph->nexthdr = IPPROTO_TCP; 237 psph->hdrlen = PSP_HDRLEN_NOOPT; 238 psph->crypt_offset = 0; 239 psph->verfl = FIELD_PREP(PSPHDR_VERFL_VERSION, ver) | 240 FIELD_PREP(PSPHDR_VERFL_ONE, 1); 241 psph->spi = spi; 242 memset(&psph->iv, 0, sizeof(psph->iv)); 243 } 244 245 /* Encapsulate a TCP packet with PSP by adding the UDP+PSP headers and filling 246 * them in. 247 */ 248 bool psp_dev_encapsulate(struct net *net, struct sk_buff *skb, __be32 spi, 249 u8 ver, __be16 sport) 250 { 251 u32 network_len = skb_network_header_len(skb); 252 u32 ethr_len = skb_mac_header_len(skb); 253 u32 bufflen = ethr_len + network_len; 254 255 if (skb->protocol != htons(ETH_P_IP) && 256 skb->protocol != htons(ETH_P_IPV6)) 257 return false; 258 259 if (skb_cow_head(skb, PSP_ENCAP_HLEN)) 260 return false; 261 262 skb_push(skb, PSP_ENCAP_HLEN); 263 skb->mac_header -= PSP_ENCAP_HLEN; 264 skb->network_header -= PSP_ENCAP_HLEN; 265 skb->transport_header -= PSP_ENCAP_HLEN; 266 memmove(skb->data, skb->data + PSP_ENCAP_HLEN, bufflen); 267 268 if (skb->protocol == htons(ETH_P_IP)) { 269 ip_hdr(skb)->protocol = IPPROTO_UDP; 270 be16_add_cpu(&ip_hdr(skb)->tot_len, PSP_ENCAP_HLEN); 271 ip_hdr(skb)->check = 0; 272 ip_hdr(skb)->check = 273 ip_fast_csum((u8 *)ip_hdr(skb), ip_hdr(skb)->ihl); 274 } else { 275 ipv6_hdr(skb)->nexthdr = IPPROTO_UDP; 276 be16_add_cpu(&ipv6_hdr(skb)->payload_len, PSP_ENCAP_HLEN); 277 } 278 279 skb_set_inner_ipproto(skb, IPPROTO_TCP); 280 skb_set_inner_transport_header(skb, skb_transport_offset(skb) + 281 PSP_ENCAP_HLEN); 282 skb->encapsulation = 1; 283 psp_write_headers(net, skb, spi, ver, 284 skb->len - skb_transport_offset(skb), sport); 285 286 return true; 287 } 288 EXPORT_SYMBOL(psp_dev_encapsulate); 289 290 /* Receive handler for PSP packets. 291 * 292 * Accepts only already-authenticated packets. The full PSP header is 293 * stripped according to psph->hdrlen; any optional fields it advertises 294 * (virtualization cookies, etc.) are ignored and discarded along with the 295 * rest of the header. The caller should ensure that skb->data is pointing 296 * to the mac header, and that skb->mac_len is set. This function does not 297 * currently adjust skb->csum (CHECKSUM_COMPLETE is not supported). 298 */ 299 int psp_dev_rcv(struct sk_buff *skb, u16 dev_id, u8 generation, bool strip_icv) 300 { 301 int l2_hlen = 0, l3_hlen, encap, psp_hlen; 302 struct psp_skb_ext *pse; 303 struct psphdr *psph; 304 struct ethhdr *eth; 305 struct udphdr *uh; 306 __be16 proto; 307 bool is_udp; 308 309 eth = (struct ethhdr *)skb->data; 310 proto = __vlan_get_protocol(skb, eth->h_proto, &l2_hlen); 311 if (proto == htons(ETH_P_IP)) 312 l3_hlen = sizeof(struct iphdr); 313 else if (proto == htons(ETH_P_IPV6)) 314 l3_hlen = sizeof(struct ipv6hdr); 315 else 316 return -EINVAL; 317 318 if (unlikely(!pskb_may_pull(skb, l2_hlen + l3_hlen + PSP_ENCAP_HLEN))) 319 return -EINVAL; 320 321 if (proto == htons(ETH_P_IP)) { 322 struct iphdr *iph = (struct iphdr *)(skb->data + l2_hlen); 323 324 if (unlikely(iph->ihl < 5)) 325 return -EINVAL; 326 327 is_udp = iph->protocol == IPPROTO_UDP; 328 l3_hlen = iph->ihl * 4; 329 if (l3_hlen != sizeof(struct iphdr) && 330 !pskb_may_pull(skb, l2_hlen + l3_hlen + PSP_ENCAP_HLEN)) 331 return -EINVAL; 332 } else { 333 struct ipv6hdr *ipv6h = (struct ipv6hdr *)(skb->data + l2_hlen); 334 335 is_udp = ipv6h->nexthdr == IPPROTO_UDP; 336 } 337 338 if (unlikely(!is_udp)) 339 return -EINVAL; 340 341 uh = (struct udphdr *)(skb->data + l2_hlen + l3_hlen); 342 if (unlikely(uh->dest != htons(PSP_DEFAULT_UDP_PORT))) 343 return -EINVAL; 344 345 psph = (struct psphdr *)(skb->data + l2_hlen + l3_hlen + 346 sizeof(struct udphdr)); 347 348 /* Strip the full PSP header per psph->hdrlen; VC/options are pulled 349 * into the linear region only so they can be discarded with the 350 * rest of the header. 351 */ 352 psp_hlen = (psph->hdrlen + 1) * 8; 353 354 if (unlikely(psp_hlen < sizeof(struct psphdr))) 355 return -EINVAL; 356 357 if (psp_hlen > sizeof(struct psphdr) && 358 !pskb_may_pull(skb, l2_hlen + l3_hlen + 359 sizeof(struct udphdr) + psp_hlen)) 360 return -EINVAL; 361 362 psph = (struct psphdr *)(skb->data + l2_hlen + l3_hlen + 363 sizeof(struct udphdr)); 364 365 pse = skb_ext_add(skb, SKB_EXT_PSP); 366 if (!pse) 367 return -EINVAL; 368 369 pse->spi = psph->spi; 370 pse->dev_id = dev_id; 371 pse->generation = generation; 372 pse->version = FIELD_GET(PSPHDR_VERFL_VERSION, psph->verfl); 373 374 encap = sizeof(struct udphdr) + psp_hlen; 375 encap += strip_icv ? PSP_TRL_SIZE : 0; 376 377 if (proto == htons(ETH_P_IP)) { 378 struct iphdr *iph = (struct iphdr *)(skb->data + l2_hlen); 379 380 if (unlikely(ntohs(iph->tot_len) < l3_hlen + encap)) 381 return -EINVAL; 382 383 iph->protocol = psph->nexthdr; 384 iph->tot_len = htons(ntohs(iph->tot_len) - encap); 385 iph->check = 0; 386 iph->check = ip_fast_csum((u8 *)iph, iph->ihl); 387 } else { 388 struct ipv6hdr *ipv6h = (struct ipv6hdr *)(skb->data + l2_hlen); 389 390 if (unlikely(ntohs(ipv6h->payload_len) < encap)) 391 return -EINVAL; 392 393 ipv6h->nexthdr = psph->nexthdr; 394 ipv6h->payload_len = htons(ntohs(ipv6h->payload_len) - encap); 395 } 396 397 memmove(skb->data + sizeof(struct udphdr) + psp_hlen, 398 skb->data, l2_hlen + l3_hlen); 399 skb_pull(skb, sizeof(struct udphdr) + psp_hlen); 400 401 if (strip_icv) 402 pskb_trim(skb, skb->len - PSP_TRL_SIZE); 403 404 return 0; 405 } 406 EXPORT_SYMBOL(psp_dev_rcv); 407 408 static void psp_dev_disassoc_one(struct psp_dev *psd, struct net_device *dev) 409 { 410 struct psp_assoc_dev *entry; 411 412 list_for_each_entry(entry, &psd->assoc_dev_list, dev_list) { 413 if (entry->assoc_dev == dev) { 414 list_del(&entry->dev_list); 415 psd->assoc_dev_cnt--; 416 rcu_assign_pointer(entry->assoc_dev->psp_dev, NULL); 417 netdev_put(entry->assoc_dev, &entry->dev_tracker); 418 kfree(entry); 419 return; 420 } 421 } 422 } 423 424 static int psp_netdev_event(struct notifier_block *nb, unsigned long event, 425 void *ptr) 426 { 427 struct net_device *dev = netdev_notifier_info_to_dev(ptr); 428 struct psp_dev *psd; 429 430 if (event != NETDEV_UNREGISTER) 431 return NOTIFY_DONE; 432 433 rcu_read_lock(); 434 psd = rcu_dereference(dev->psp_dev); 435 if (psd && psp_dev_tryget(psd)) { 436 rcu_read_unlock(); 437 mutex_lock(&psd->lock); 438 if (psp_dev_is_registered(psd)) 439 psp_nl_notify_dev(psd, PSP_CMD_DEV_CHANGE_NTF); 440 psp_dev_disassoc_one(psd, dev); 441 mutex_unlock(&psd->lock); 442 psp_dev_put(psd); 443 } else { 444 rcu_read_unlock(); 445 } 446 447 return NOTIFY_DONE; 448 } 449 450 static struct notifier_block psp_netdev_notifier = { 451 .notifier_call = psp_netdev_event, 452 }; 453 454 static DEFINE_MUTEX(psp_notifier_lock); 455 static bool psp_notifier_registered; 456 457 /* Register the netdevice notifier when the first device association 458 * is created. In many installations no associations will be created and 459 * the notifier won't be needed. 460 * 461 * Must be called without psd->lock held, due to lock ordering: 462 * rtnl_lock -> psd->lock (the notifier callback runs under rtnl_lock 463 * and takes psd->lock). 464 */ 465 int psp_attach_netdev_notifier(void) 466 { 467 int err = 0; 468 469 if (READ_ONCE(psp_notifier_registered)) 470 return 0; 471 472 mutex_lock(&psp_notifier_lock); 473 if (!psp_notifier_registered) { 474 err = register_netdevice_notifier(&psp_netdev_notifier); 475 if (!err) 476 WRITE_ONCE(psp_notifier_registered, true); 477 } 478 mutex_unlock(&psp_notifier_lock); 479 480 return err; 481 } 482 483 static int __init psp_init(void) 484 { 485 mutex_init(&psp_devs_lock); 486 487 return genl_register_family(&psp_nl_family); 488 } 489 490 subsys_initcall(psp_init); 491