1 // SPDX-License-Identifier: GPL-2.0-only 2 3 #include <linux/bitfield.h> 4 #include <linux/list.h> 5 #include <linux/netdevice.h> 6 #include <linux/xarray.h> 7 #include <net/net_namespace.h> 8 #include <net/psp.h> 9 #include <net/udp.h> 10 11 #include "psp.h" 12 #include "psp-nl-gen.h" 13 14 DEFINE_XARRAY_ALLOC1(psp_devs); 15 struct mutex psp_devs_lock; 16 17 /** 18 * DOC: PSP locking 19 * 20 * psp_devs_lock protects the psp_devs xarray. 21 * Ordering is take the psp_devs_lock and then the instance lock. 22 * Each instance is protected by RCU, and has a refcount. 23 * When driver unregisters the instance gets flushed, but struct sticks around. 24 */ 25 26 /** 27 * psp_dev_check_access() - check if user in a given net ns can access PSP dev 28 * @psd: PSP device structure user is trying to access 29 * @net: net namespace user is in 30 * 31 * Return: 0 if PSP device should be visible in @net, errno otherwise. 32 */ 33 int psp_dev_check_access(struct psp_dev *psd, struct net *net) 34 { 35 if (dev_net(psd->main_netdev) == net) 36 return 0; 37 return -ENOENT; 38 } 39 40 /** 41 * psp_dev_create() - create and register PSP device 42 * @netdev: main netdevice 43 * @psd_ops: driver callbacks 44 * @psd_caps: device capabilities 45 * @priv_ptr: back-pointer to driver private data 46 * 47 * Return: pointer to allocated PSP device, or ERR_PTR. 48 */ 49 struct psp_dev * 50 psp_dev_create(struct net_device *netdev, 51 struct psp_dev_ops *psd_ops, struct psp_dev_caps *psd_caps, 52 void *priv_ptr) 53 { 54 struct psp_dev *psd; 55 static u32 last_id; 56 int err; 57 58 if (WARN_ON(!psd_caps->versions || 59 !psd_ops->set_config || 60 !psd_ops->key_rotate || 61 !psd_ops->rx_spi_alloc || 62 !psd_ops->tx_key_add || 63 !psd_ops->tx_key_del || 64 !psd_ops->get_stats)) 65 return ERR_PTR(-EINVAL); 66 67 psd = kzalloc_obj(*psd); 68 if (!psd) 69 return ERR_PTR(-ENOMEM); 70 71 psd->main_netdev = netdev; 72 psd->ops = psd_ops; 73 psd->caps = psd_caps; 74 psd->drv_priv = priv_ptr; 75 76 mutex_init(&psd->lock); 77 INIT_LIST_HEAD(&psd->active_assocs); 78 INIT_LIST_HEAD(&psd->prev_assocs); 79 INIT_LIST_HEAD(&psd->stale_assocs); 80 refcount_set(&psd->refcnt, 1); 81 82 mutex_lock(&psp_devs_lock); 83 err = xa_alloc_cyclic(&psp_devs, &psd->id, psd, xa_limit_16b, 84 &last_id, GFP_KERNEL); 85 if (err) { 86 mutex_unlock(&psp_devs_lock); 87 kfree(psd); 88 return ERR_PTR(err); 89 } 90 mutex_lock(&psd->lock); 91 mutex_unlock(&psp_devs_lock); 92 93 /* notify before netdev assignment 94 * There's no strong reason for it, but thinking is to avoid creating 95 * implicit expectations about the PSP dev <> netdev relationship. 96 */ 97 psp_nl_notify_dev(psd, PSP_CMD_DEV_ADD_NTF); 98 99 rcu_assign_pointer(netdev->psp_dev, psd); 100 101 mutex_unlock(&psd->lock); 102 103 return psd; 104 } 105 EXPORT_SYMBOL(psp_dev_create); 106 107 void psp_dev_free(struct psp_dev *psd) 108 { 109 mutex_lock(&psp_devs_lock); 110 xa_erase(&psp_devs, psd->id); 111 mutex_unlock(&psp_devs_lock); 112 113 mutex_destroy(&psd->lock); 114 kfree_rcu(psd, rcu); 115 } 116 117 /** 118 * psp_dev_unregister() - unregister PSP device 119 * @psd: PSP device structure 120 */ 121 void psp_dev_unregister(struct psp_dev *psd) 122 { 123 struct psp_assoc *pas, *next; 124 125 mutex_lock(&psp_devs_lock); 126 mutex_lock(&psd->lock); 127 128 psp_nl_notify_dev(psd, PSP_CMD_DEV_DEL_NTF); 129 130 /* Wait until psp_dev_free() to call xa_erase() to prevent a 131 * different psd from being added to the xarray with this id, while 132 * there are still references to this psd being held. 133 */ 134 xa_store(&psp_devs, psd->id, NULL, GFP_KERNEL); 135 mutex_unlock(&psp_devs_lock); 136 137 list_splice_init(&psd->active_assocs, &psd->prev_assocs); 138 list_splice_init(&psd->prev_assocs, &psd->stale_assocs); 139 list_for_each_entry_safe(pas, next, &psd->stale_assocs, assocs_list) 140 psp_dev_tx_key_del(psd, pas); 141 142 rcu_assign_pointer(psd->main_netdev->psp_dev, NULL); 143 144 psd->ops = NULL; 145 psd->drv_priv = NULL; 146 147 mutex_unlock(&psd->lock); 148 149 psp_dev_put(psd); 150 } 151 EXPORT_SYMBOL(psp_dev_unregister); 152 153 unsigned int psp_key_size(u32 version) 154 { 155 switch (version) { 156 case PSP_VERSION_HDR0_AES_GCM_128: 157 case PSP_VERSION_HDR0_AES_GMAC_128: 158 return 16; 159 case PSP_VERSION_HDR0_AES_GCM_256: 160 case PSP_VERSION_HDR0_AES_GMAC_256: 161 return 32; 162 default: 163 return 0; 164 } 165 } 166 EXPORT_SYMBOL(psp_key_size); 167 168 static void psp_write_headers(struct net *net, struct sk_buff *skb, __be32 spi, 169 u8 ver, unsigned int udp_len, __be16 sport) 170 { 171 struct udphdr *uh = udp_hdr(skb); 172 struct psphdr *psph = (struct psphdr *)(uh + 1); 173 const struct sock *sk = skb->sk; 174 175 uh->dest = htons(PSP_DEFAULT_UDP_PORT); 176 177 /* A bit of theory: Selection of the source port. 178 * 179 * We need some entropy, so that multiple flows use different 180 * source ports for better RSS spreading at the receiver. 181 * 182 * We also need that all packets belonging to one TCP flow 183 * use the same source port through their duration, 184 * so that all these packets land in the same receive queue. 185 * 186 * udp_flow_src_port() is using sk_txhash, inherited from 187 * skb_set_hash_from_sk() call in __tcp_transmit_skb(). 188 * This field is subject to reshuffling, thanks to 189 * sk_rethink_txhash() calls in various TCP functions. 190 * 191 * Instead, use sk->sk_hash which is constant through 192 * the whole flow duration. 193 */ 194 if (likely(sk)) { 195 u32 hash = sk->sk_hash; 196 int min, max; 197 198 /* These operations are cheap, no need to cache the result 199 * in another socket field. 200 */ 201 inet_get_local_port_range(net, &min, &max); 202 /* Since this is being sent on the wire obfuscate hash a bit 203 * to minimize possibility that any useful information to an 204 * attacker is leaked. Only upper 16 bits are relevant in the 205 * computation for 16 bit port value because we use a 206 * reciprocal divide. 207 */ 208 hash ^= hash << 16; 209 uh->source = htons(reciprocal_scale(hash, max - min + 1) + min); 210 } else { 211 uh->source = udp_flow_src_port(net, skb, 0, 0, false); 212 } 213 uh->check = 0; 214 uh->len = htons(udp_len); 215 216 psph->nexthdr = IPPROTO_TCP; 217 psph->hdrlen = PSP_HDRLEN_NOOPT; 218 psph->crypt_offset = 0; 219 psph->verfl = FIELD_PREP(PSPHDR_VERFL_VERSION, ver) | 220 FIELD_PREP(PSPHDR_VERFL_ONE, 1); 221 psph->spi = spi; 222 memset(&psph->iv, 0, sizeof(psph->iv)); 223 } 224 225 /* Encapsulate a TCP packet with PSP by adding the UDP+PSP headers and filling 226 * them in. 227 */ 228 bool psp_dev_encapsulate(struct net *net, struct sk_buff *skb, __be32 spi, 229 u8 ver, __be16 sport) 230 { 231 u32 network_len = skb_network_header_len(skb); 232 u32 ethr_len = skb_mac_header_len(skb); 233 u32 bufflen = ethr_len + network_len; 234 235 if (skb->protocol != htons(ETH_P_IP) && 236 skb->protocol != htons(ETH_P_IPV6)) 237 return false; 238 239 if (skb_cow_head(skb, PSP_ENCAP_HLEN)) 240 return false; 241 242 skb_push(skb, PSP_ENCAP_HLEN); 243 skb->mac_header -= PSP_ENCAP_HLEN; 244 skb->network_header -= PSP_ENCAP_HLEN; 245 skb->transport_header -= PSP_ENCAP_HLEN; 246 memmove(skb->data, skb->data + PSP_ENCAP_HLEN, bufflen); 247 248 if (skb->protocol == htons(ETH_P_IP)) { 249 ip_hdr(skb)->protocol = IPPROTO_UDP; 250 be16_add_cpu(&ip_hdr(skb)->tot_len, PSP_ENCAP_HLEN); 251 ip_hdr(skb)->check = 0; 252 ip_hdr(skb)->check = 253 ip_fast_csum((u8 *)ip_hdr(skb), ip_hdr(skb)->ihl); 254 } else { 255 ipv6_hdr(skb)->nexthdr = IPPROTO_UDP; 256 be16_add_cpu(&ipv6_hdr(skb)->payload_len, PSP_ENCAP_HLEN); 257 } 258 259 skb_set_inner_ipproto(skb, IPPROTO_TCP); 260 skb_set_inner_transport_header(skb, skb_transport_offset(skb) + 261 PSP_ENCAP_HLEN); 262 skb->encapsulation = 1; 263 psp_write_headers(net, skb, spi, ver, 264 skb->len - skb_transport_offset(skb), sport); 265 266 return true; 267 } 268 EXPORT_SYMBOL(psp_dev_encapsulate); 269 270 /* Receive handler for PSP packets. 271 * 272 * Accepts only already-authenticated packets. The full PSP header is 273 * stripped according to psph->hdrlen; any optional fields it advertises 274 * (virtualization cookies, etc.) are ignored and discarded along with the 275 * rest of the header. The caller should ensure that skb->data is pointing 276 * to the mac header, and that skb->mac_len is set. This function does not 277 * currently adjust skb->csum (CHECKSUM_COMPLETE is not supported). 278 */ 279 int psp_dev_rcv(struct sk_buff *skb, u16 dev_id, u8 generation, bool strip_icv) 280 { 281 int l2_hlen = 0, l3_hlen, encap, psp_hlen; 282 struct psp_skb_ext *pse; 283 struct psphdr *psph; 284 struct ethhdr *eth; 285 struct udphdr *uh; 286 __be16 proto; 287 bool is_udp; 288 289 eth = (struct ethhdr *)skb->data; 290 proto = __vlan_get_protocol(skb, eth->h_proto, &l2_hlen); 291 if (proto == htons(ETH_P_IP)) 292 l3_hlen = sizeof(struct iphdr); 293 else if (proto == htons(ETH_P_IPV6)) 294 l3_hlen = sizeof(struct ipv6hdr); 295 else 296 return -EINVAL; 297 298 if (unlikely(!pskb_may_pull(skb, l2_hlen + l3_hlen + PSP_ENCAP_HLEN))) 299 return -EINVAL; 300 301 if (proto == htons(ETH_P_IP)) { 302 struct iphdr *iph = (struct iphdr *)(skb->data + l2_hlen); 303 304 if (unlikely(iph->ihl < 5)) 305 return -EINVAL; 306 307 is_udp = iph->protocol == IPPROTO_UDP; 308 l3_hlen = iph->ihl * 4; 309 if (l3_hlen != sizeof(struct iphdr) && 310 !pskb_may_pull(skb, l2_hlen + l3_hlen + PSP_ENCAP_HLEN)) 311 return -EINVAL; 312 } else { 313 struct ipv6hdr *ipv6h = (struct ipv6hdr *)(skb->data + l2_hlen); 314 315 is_udp = ipv6h->nexthdr == IPPROTO_UDP; 316 } 317 318 if (unlikely(!is_udp)) 319 return -EINVAL; 320 321 uh = (struct udphdr *)(skb->data + l2_hlen + l3_hlen); 322 if (unlikely(uh->dest != htons(PSP_DEFAULT_UDP_PORT))) 323 return -EINVAL; 324 325 psph = (struct psphdr *)(skb->data + l2_hlen + l3_hlen + 326 sizeof(struct udphdr)); 327 328 /* Strip the full PSP header per psph->hdrlen; VC/options are pulled 329 * into the linear region only so they can be discarded with the 330 * rest of the header. 331 */ 332 psp_hlen = (psph->hdrlen + 1) * 8; 333 334 if (unlikely(psp_hlen < sizeof(struct psphdr))) 335 return -EINVAL; 336 337 if (psp_hlen > sizeof(struct psphdr) && 338 !pskb_may_pull(skb, l2_hlen + l3_hlen + 339 sizeof(struct udphdr) + psp_hlen)) 340 return -EINVAL; 341 342 psph = (struct psphdr *)(skb->data + l2_hlen + l3_hlen + 343 sizeof(struct udphdr)); 344 345 pse = skb_ext_add(skb, SKB_EXT_PSP); 346 if (!pse) 347 return -EINVAL; 348 349 pse->spi = psph->spi; 350 pse->dev_id = dev_id; 351 pse->generation = generation; 352 pse->version = FIELD_GET(PSPHDR_VERFL_VERSION, psph->verfl); 353 354 encap = sizeof(struct udphdr) + psp_hlen; 355 encap += strip_icv ? PSP_TRL_SIZE : 0; 356 357 if (proto == htons(ETH_P_IP)) { 358 struct iphdr *iph = (struct iphdr *)(skb->data + l2_hlen); 359 360 if (unlikely(ntohs(iph->tot_len) < l3_hlen + encap)) 361 return -EINVAL; 362 363 iph->protocol = psph->nexthdr; 364 iph->tot_len = htons(ntohs(iph->tot_len) - encap); 365 iph->check = 0; 366 iph->check = ip_fast_csum((u8 *)iph, iph->ihl); 367 } else { 368 struct ipv6hdr *ipv6h = (struct ipv6hdr *)(skb->data + l2_hlen); 369 370 if (unlikely(ntohs(ipv6h->payload_len) < encap)) 371 return -EINVAL; 372 373 ipv6h->nexthdr = psph->nexthdr; 374 ipv6h->payload_len = htons(ntohs(ipv6h->payload_len) - encap); 375 } 376 377 memmove(skb->data + sizeof(struct udphdr) + psp_hlen, 378 skb->data, l2_hlen + l3_hlen); 379 skb_pull(skb, sizeof(struct udphdr) + psp_hlen); 380 381 if (strip_icv) 382 pskb_trim(skb, skb->len - PSP_TRL_SIZE); 383 384 return 0; 385 } 386 EXPORT_SYMBOL(psp_dev_rcv); 387 388 static int __init psp_init(void) 389 { 390 mutex_init(&psp_devs_lock); 391 392 return genl_register_family(&psp_nl_family); 393 } 394 395 subsys_initcall(psp_init); 396