100c94ca2SJakub Kicinski // SPDX-License-Identifier: GPL-2.0-only 200c94ca2SJakub Kicinski 3fc724515SRaed Salem #include <linux/bitfield.h> 400c94ca2SJakub Kicinski #include <linux/list.h> 500c94ca2SJakub Kicinski #include <linux/netdevice.h> 600c94ca2SJakub Kicinski #include <linux/xarray.h> 700c94ca2SJakub Kicinski #include <net/net_namespace.h> 800c94ca2SJakub Kicinski #include <net/psp.h> 9fc724515SRaed Salem #include <net/udp.h> 1000c94ca2SJakub Kicinski 1100c94ca2SJakub Kicinski #include "psp.h" 1200c94ca2SJakub Kicinski #include "psp-nl-gen.h" 1300c94ca2SJakub Kicinski 1400c94ca2SJakub Kicinski DEFINE_XARRAY_ALLOC1(psp_devs); 1500c94ca2SJakub Kicinski struct mutex psp_devs_lock; 1600c94ca2SJakub Kicinski 1700c94ca2SJakub Kicinski /** 1800c94ca2SJakub Kicinski * DOC: PSP locking 1900c94ca2SJakub Kicinski * 2000c94ca2SJakub Kicinski * psp_devs_lock protects the psp_devs xarray. 2100c94ca2SJakub Kicinski * Ordering is take the psp_devs_lock and then the instance lock. 2200c94ca2SJakub Kicinski * Each instance is protected by RCU, and has a refcount. 2300c94ca2SJakub Kicinski * When driver unregisters the instance gets flushed, but struct sticks around. 2400c94ca2SJakub Kicinski */ 2500c94ca2SJakub Kicinski 2600c94ca2SJakub Kicinski /** 2700c94ca2SJakub Kicinski * psp_dev_check_access() - check if user in a given net ns can access PSP dev 2800c94ca2SJakub Kicinski * @psd: PSP device structure user is trying to access 2900c94ca2SJakub Kicinski * @net: net namespace user is in 3000c94ca2SJakub Kicinski * 3100c94ca2SJakub Kicinski * Return: 0 if PSP device should be visible in @net, errno otherwise. 3200c94ca2SJakub Kicinski */ 3300c94ca2SJakub Kicinski int psp_dev_check_access(struct psp_dev *psd, struct net *net) 3400c94ca2SJakub Kicinski { 3500c94ca2SJakub Kicinski if (dev_net(psd->main_netdev) == net) 3600c94ca2SJakub Kicinski return 0; 3700c94ca2SJakub Kicinski return -ENOENT; 3800c94ca2SJakub Kicinski } 3900c94ca2SJakub Kicinski 4000c94ca2SJakub Kicinski /** 4100c94ca2SJakub Kicinski * psp_dev_create() - create and register PSP device 4200c94ca2SJakub Kicinski * @netdev: main netdevice 4300c94ca2SJakub Kicinski * @psd_ops: driver callbacks 4400c94ca2SJakub Kicinski * @psd_caps: device capabilities 4500c94ca2SJakub Kicinski * @priv_ptr: back-pointer to driver private data 4600c94ca2SJakub Kicinski * 4700c94ca2SJakub Kicinski * Return: pointer to allocated PSP device, or ERR_PTR. 4800c94ca2SJakub Kicinski */ 4900c94ca2SJakub Kicinski struct psp_dev * 5000c94ca2SJakub Kicinski psp_dev_create(struct net_device *netdev, 5100c94ca2SJakub Kicinski struct psp_dev_ops *psd_ops, struct psp_dev_caps *psd_caps, 5200c94ca2SJakub Kicinski void *priv_ptr) 5300c94ca2SJakub Kicinski { 5400c94ca2SJakub Kicinski struct psp_dev *psd; 5500c94ca2SJakub Kicinski static u32 last_id; 5600c94ca2SJakub Kicinski int err; 5700c94ca2SJakub Kicinski 5800c94ca2SJakub Kicinski if (WARN_ON(!psd_caps->versions || 59117f02a4SJakub Kicinski !psd_ops->set_config || 606b46ca26SJakub Kicinski !psd_ops->key_rotate || 616b46ca26SJakub Kicinski !psd_ops->rx_spi_alloc || 626b46ca26SJakub Kicinski !psd_ops->tx_key_add || 636b46ca26SJakub Kicinski !psd_ops->tx_key_del)) 6400c94ca2SJakub Kicinski return ERR_PTR(-EINVAL); 6500c94ca2SJakub Kicinski 6600c94ca2SJakub Kicinski psd = kzalloc(sizeof(*psd), GFP_KERNEL); 6700c94ca2SJakub Kicinski if (!psd) 6800c94ca2SJakub Kicinski return ERR_PTR(-ENOMEM); 6900c94ca2SJakub Kicinski 7000c94ca2SJakub Kicinski psd->main_netdev = netdev; 7100c94ca2SJakub Kicinski psd->ops = psd_ops; 7200c94ca2SJakub Kicinski psd->caps = psd_caps; 7300c94ca2SJakub Kicinski psd->drv_priv = priv_ptr; 7400c94ca2SJakub Kicinski 7500c94ca2SJakub Kicinski mutex_init(&psd->lock); 766b46ca26SJakub Kicinski INIT_LIST_HEAD(&psd->active_assocs); 77e7885105SJakub Kicinski INIT_LIST_HEAD(&psd->prev_assocs); 78e7885105SJakub Kicinski INIT_LIST_HEAD(&psd->stale_assocs); 7900c94ca2SJakub Kicinski refcount_set(&psd->refcnt, 1); 8000c94ca2SJakub Kicinski 8100c94ca2SJakub Kicinski mutex_lock(&psp_devs_lock); 8200c94ca2SJakub Kicinski err = xa_alloc_cyclic(&psp_devs, &psd->id, psd, xa_limit_16b, 8300c94ca2SJakub Kicinski &last_id, GFP_KERNEL); 8400c94ca2SJakub Kicinski if (err) { 8500c94ca2SJakub Kicinski mutex_unlock(&psp_devs_lock); 8600c94ca2SJakub Kicinski kfree(psd); 8700c94ca2SJakub Kicinski return ERR_PTR(err); 8800c94ca2SJakub Kicinski } 8900c94ca2SJakub Kicinski mutex_lock(&psd->lock); 9000c94ca2SJakub Kicinski mutex_unlock(&psp_devs_lock); 9100c94ca2SJakub Kicinski 9200c94ca2SJakub Kicinski psp_nl_notify_dev(psd, PSP_CMD_DEV_ADD_NTF); 9300c94ca2SJakub Kicinski 9400c94ca2SJakub Kicinski rcu_assign_pointer(netdev->psp_dev, psd); 9500c94ca2SJakub Kicinski 9600c94ca2SJakub Kicinski mutex_unlock(&psd->lock); 9700c94ca2SJakub Kicinski 9800c94ca2SJakub Kicinski return psd; 9900c94ca2SJakub Kicinski } 10000c94ca2SJakub Kicinski EXPORT_SYMBOL(psp_dev_create); 10100c94ca2SJakub Kicinski 102672beab0SEric Dumazet void psp_dev_free(struct psp_dev *psd) 10300c94ca2SJakub Kicinski { 10400c94ca2SJakub Kicinski mutex_lock(&psp_devs_lock); 10500c94ca2SJakub Kicinski xa_erase(&psp_devs, psd->id); 10600c94ca2SJakub Kicinski mutex_unlock(&psp_devs_lock); 10700c94ca2SJakub Kicinski 10800c94ca2SJakub Kicinski mutex_destroy(&psd->lock); 10900c94ca2SJakub Kicinski kfree_rcu(psd, rcu); 11000c94ca2SJakub Kicinski } 11100c94ca2SJakub Kicinski 11200c94ca2SJakub Kicinski /** 11300c94ca2SJakub Kicinski * psp_dev_unregister() - unregister PSP device 11400c94ca2SJakub Kicinski * @psd: PSP device structure 11500c94ca2SJakub Kicinski */ 11600c94ca2SJakub Kicinski void psp_dev_unregister(struct psp_dev *psd) 11700c94ca2SJakub Kicinski { 1186b46ca26SJakub Kicinski struct psp_assoc *pas, *next; 1196b46ca26SJakub Kicinski 12000c94ca2SJakub Kicinski mutex_lock(&psp_devs_lock); 12100c94ca2SJakub Kicinski mutex_lock(&psd->lock); 12200c94ca2SJakub Kicinski 12300c94ca2SJakub Kicinski psp_nl_notify_dev(psd, PSP_CMD_DEV_DEL_NTF); 12400c94ca2SJakub Kicinski 125672beab0SEric Dumazet /* Wait until psp_dev_free() to call xa_erase() to prevent a 12600c94ca2SJakub Kicinski * different psd from being added to the xarray with this id, while 12700c94ca2SJakub Kicinski * there are still references to this psd being held. 12800c94ca2SJakub Kicinski */ 12900c94ca2SJakub Kicinski xa_store(&psp_devs, psd->id, NULL, GFP_KERNEL); 13000c94ca2SJakub Kicinski mutex_unlock(&psp_devs_lock); 13100c94ca2SJakub Kicinski 132e7885105SJakub Kicinski list_splice_init(&psd->active_assocs, &psd->prev_assocs); 133e7885105SJakub Kicinski list_splice_init(&psd->prev_assocs, &psd->stale_assocs); 134e7885105SJakub Kicinski list_for_each_entry_safe(pas, next, &psd->stale_assocs, assocs_list) 1356b46ca26SJakub Kicinski psp_dev_tx_key_del(psd, pas); 1366b46ca26SJakub Kicinski 13700c94ca2SJakub Kicinski rcu_assign_pointer(psd->main_netdev->psp_dev, NULL); 13800c94ca2SJakub Kicinski 13900c94ca2SJakub Kicinski psd->ops = NULL; 14000c94ca2SJakub Kicinski psd->drv_priv = NULL; 14100c94ca2SJakub Kicinski 14200c94ca2SJakub Kicinski mutex_unlock(&psd->lock); 14300c94ca2SJakub Kicinski 14400c94ca2SJakub Kicinski psp_dev_put(psd); 14500c94ca2SJakub Kicinski } 14600c94ca2SJakub Kicinski EXPORT_SYMBOL(psp_dev_unregister); 14700c94ca2SJakub Kicinski 1486b46ca26SJakub Kicinski unsigned int psp_key_size(u32 version) 1496b46ca26SJakub Kicinski { 1506b46ca26SJakub Kicinski switch (version) { 1516b46ca26SJakub Kicinski case PSP_VERSION_HDR0_AES_GCM_128: 1526b46ca26SJakub Kicinski case PSP_VERSION_HDR0_AES_GMAC_128: 1536b46ca26SJakub Kicinski return 16; 1546b46ca26SJakub Kicinski case PSP_VERSION_HDR0_AES_GCM_256: 1556b46ca26SJakub Kicinski case PSP_VERSION_HDR0_AES_GMAC_256: 1566b46ca26SJakub Kicinski return 32; 1576b46ca26SJakub Kicinski default: 1586b46ca26SJakub Kicinski return 0; 1596b46ca26SJakub Kicinski } 1606b46ca26SJakub Kicinski } 1616b46ca26SJakub Kicinski EXPORT_SYMBOL(psp_key_size); 1626b46ca26SJakub Kicinski 163fc724515SRaed Salem static void psp_write_headers(struct net *net, struct sk_buff *skb, __be32 spi, 164fc724515SRaed Salem u8 ver, unsigned int udp_len, __be16 sport) 165fc724515SRaed Salem { 166fc724515SRaed Salem struct udphdr *uh = udp_hdr(skb); 167fc724515SRaed Salem struct psphdr *psph = (struct psphdr *)(uh + 1); 168fc724515SRaed Salem 169fc724515SRaed Salem uh->dest = htons(PSP_DEFAULT_UDP_PORT); 170fc724515SRaed Salem uh->source = udp_flow_src_port(net, skb, 0, 0, false); 171fc724515SRaed Salem uh->check = 0; 172fc724515SRaed Salem uh->len = htons(udp_len); 173fc724515SRaed Salem 174fc724515SRaed Salem psph->nexthdr = IPPROTO_TCP; 175fc724515SRaed Salem psph->hdrlen = PSP_HDRLEN_NOOPT; 176fc724515SRaed Salem psph->crypt_offset = 0; 177fc724515SRaed Salem psph->verfl = FIELD_PREP(PSPHDR_VERFL_VERSION, ver) | 178fc724515SRaed Salem FIELD_PREP(PSPHDR_VERFL_ONE, 1); 179fc724515SRaed Salem psph->spi = spi; 180fc724515SRaed Salem memset(&psph->iv, 0, sizeof(psph->iv)); 181fc724515SRaed Salem } 182fc724515SRaed Salem 183fc724515SRaed Salem /* Encapsulate a TCP packet with PSP by adding the UDP+PSP headers and filling 184fc724515SRaed Salem * them in. 185fc724515SRaed Salem */ 186fc724515SRaed Salem bool psp_dev_encapsulate(struct net *net, struct sk_buff *skb, __be32 spi, 187fc724515SRaed Salem u8 ver, __be16 sport) 188fc724515SRaed Salem { 189fc724515SRaed Salem u32 network_len = skb_network_header_len(skb); 190fc724515SRaed Salem u32 ethr_len = skb_mac_header_len(skb); 191fc724515SRaed Salem u32 bufflen = ethr_len + network_len; 192fc724515SRaed Salem 193fc724515SRaed Salem if (skb_cow_head(skb, PSP_ENCAP_HLEN)) 194fc724515SRaed Salem return false; 195fc724515SRaed Salem 196fc724515SRaed Salem skb_push(skb, PSP_ENCAP_HLEN); 197fc724515SRaed Salem skb->mac_header -= PSP_ENCAP_HLEN; 198fc724515SRaed Salem skb->network_header -= PSP_ENCAP_HLEN; 199fc724515SRaed Salem skb->transport_header -= PSP_ENCAP_HLEN; 200fc724515SRaed Salem memmove(skb->data, skb->data + PSP_ENCAP_HLEN, bufflen); 201fc724515SRaed Salem 202fc724515SRaed Salem if (skb->protocol == htons(ETH_P_IP)) { 203fc724515SRaed Salem ip_hdr(skb)->protocol = IPPROTO_UDP; 204fc724515SRaed Salem be16_add_cpu(&ip_hdr(skb)->tot_len, PSP_ENCAP_HLEN); 205fc724515SRaed Salem ip_hdr(skb)->check = 0; 206fc724515SRaed Salem ip_hdr(skb)->check = 207fc724515SRaed Salem ip_fast_csum((u8 *)ip_hdr(skb), ip_hdr(skb)->ihl); 208fc724515SRaed Salem } else if (skb->protocol == htons(ETH_P_IPV6)) { 209fc724515SRaed Salem ipv6_hdr(skb)->nexthdr = IPPROTO_UDP; 210fc724515SRaed Salem be16_add_cpu(&ipv6_hdr(skb)->payload_len, PSP_ENCAP_HLEN); 211fc724515SRaed Salem } else { 212fc724515SRaed Salem return false; 213fc724515SRaed Salem } 214fc724515SRaed Salem 215fc724515SRaed Salem skb_set_inner_ipproto(skb, IPPROTO_TCP); 216fc724515SRaed Salem skb_set_inner_transport_header(skb, skb_transport_offset(skb) + 217fc724515SRaed Salem PSP_ENCAP_HLEN); 218fc724515SRaed Salem skb->encapsulation = 1; 219fc724515SRaed Salem psp_write_headers(net, skb, spi, ver, 220fc724515SRaed Salem skb->len - skb_transport_offset(skb), sport); 221fc724515SRaed Salem 222fc724515SRaed Salem return true; 223fc724515SRaed Salem } 224fc724515SRaed Salem EXPORT_SYMBOL(psp_dev_encapsulate); 225fc724515SRaed Salem 2260eddb802SRaed Salem /* Receive handler for PSP packets. 2270eddb802SRaed Salem * 2280eddb802SRaed Salem * Presently it accepts only already-authenticated packets and does not 2290eddb802SRaed Salem * support optional fields, such as virtualization cookies. The caller should 2300eddb802SRaed Salem * ensure that skb->data is pointing to the mac header, and that skb->mac_len 231*85c7333cSDaniel Zahka * is set. This function does not currently adjust skb->csum (CHECKSUM_COMPLETE 232*85c7333cSDaniel Zahka * is not supported). 2330eddb802SRaed Salem */ 2340eddb802SRaed Salem int psp_dev_rcv(struct sk_buff *skb, u16 dev_id, u8 generation, bool strip_icv) 2350eddb802SRaed Salem { 2360eddb802SRaed Salem int l2_hlen = 0, l3_hlen, encap; 2370eddb802SRaed Salem struct psp_skb_ext *pse; 2380eddb802SRaed Salem struct psphdr *psph; 2390eddb802SRaed Salem struct ethhdr *eth; 2400eddb802SRaed Salem struct udphdr *uh; 2410eddb802SRaed Salem __be16 proto; 2420eddb802SRaed Salem bool is_udp; 2430eddb802SRaed Salem 2440eddb802SRaed Salem eth = (struct ethhdr *)skb->data; 2450eddb802SRaed Salem proto = __vlan_get_protocol(skb, eth->h_proto, &l2_hlen); 2460eddb802SRaed Salem if (proto == htons(ETH_P_IP)) 2470eddb802SRaed Salem l3_hlen = sizeof(struct iphdr); 2480eddb802SRaed Salem else if (proto == htons(ETH_P_IPV6)) 2490eddb802SRaed Salem l3_hlen = sizeof(struct ipv6hdr); 2500eddb802SRaed Salem else 2510eddb802SRaed Salem return -EINVAL; 2520eddb802SRaed Salem 2530eddb802SRaed Salem if (unlikely(!pskb_may_pull(skb, l2_hlen + l3_hlen + PSP_ENCAP_HLEN))) 2540eddb802SRaed Salem return -EINVAL; 2550eddb802SRaed Salem 2560eddb802SRaed Salem if (proto == htons(ETH_P_IP)) { 2570eddb802SRaed Salem struct iphdr *iph = (struct iphdr *)(skb->data + l2_hlen); 2580eddb802SRaed Salem 2590eddb802SRaed Salem is_udp = iph->protocol == IPPROTO_UDP; 2600eddb802SRaed Salem l3_hlen = iph->ihl * 4; 2610eddb802SRaed Salem if (l3_hlen != sizeof(struct iphdr) && 2620eddb802SRaed Salem !pskb_may_pull(skb, l2_hlen + l3_hlen + PSP_ENCAP_HLEN)) 2630eddb802SRaed Salem return -EINVAL; 2640eddb802SRaed Salem } else { 2650eddb802SRaed Salem struct ipv6hdr *ipv6h = (struct ipv6hdr *)(skb->data + l2_hlen); 2660eddb802SRaed Salem 2670eddb802SRaed Salem is_udp = ipv6h->nexthdr == IPPROTO_UDP; 2680eddb802SRaed Salem } 2690eddb802SRaed Salem 2700eddb802SRaed Salem if (unlikely(!is_udp)) 2710eddb802SRaed Salem return -EINVAL; 2720eddb802SRaed Salem 2730eddb802SRaed Salem uh = (struct udphdr *)(skb->data + l2_hlen + l3_hlen); 2740eddb802SRaed Salem if (unlikely(uh->dest != htons(PSP_DEFAULT_UDP_PORT))) 2750eddb802SRaed Salem return -EINVAL; 2760eddb802SRaed Salem 2770eddb802SRaed Salem pse = skb_ext_add(skb, SKB_EXT_PSP); 2780eddb802SRaed Salem if (!pse) 2790eddb802SRaed Salem return -EINVAL; 2800eddb802SRaed Salem 2810eddb802SRaed Salem psph = (struct psphdr *)(skb->data + l2_hlen + l3_hlen + 2820eddb802SRaed Salem sizeof(struct udphdr)); 2830eddb802SRaed Salem pse->spi = psph->spi; 2840eddb802SRaed Salem pse->dev_id = dev_id; 2850eddb802SRaed Salem pse->generation = generation; 2860eddb802SRaed Salem pse->version = FIELD_GET(PSPHDR_VERFL_VERSION, psph->verfl); 2870eddb802SRaed Salem 2880eddb802SRaed Salem encap = PSP_ENCAP_HLEN; 2890eddb802SRaed Salem encap += strip_icv ? PSP_TRL_SIZE : 0; 2900eddb802SRaed Salem 2910eddb802SRaed Salem if (proto == htons(ETH_P_IP)) { 2920eddb802SRaed Salem struct iphdr *iph = (struct iphdr *)(skb->data + l2_hlen); 2930eddb802SRaed Salem 2940eddb802SRaed Salem iph->protocol = psph->nexthdr; 2950eddb802SRaed Salem iph->tot_len = htons(ntohs(iph->tot_len) - encap); 2960eddb802SRaed Salem iph->check = 0; 2970eddb802SRaed Salem iph->check = ip_fast_csum((u8 *)iph, iph->ihl); 2980eddb802SRaed Salem } else { 2990eddb802SRaed Salem struct ipv6hdr *ipv6h = (struct ipv6hdr *)(skb->data + l2_hlen); 3000eddb802SRaed Salem 3010eddb802SRaed Salem ipv6h->nexthdr = psph->nexthdr; 3020eddb802SRaed Salem ipv6h->payload_len = htons(ntohs(ipv6h->payload_len) - encap); 3030eddb802SRaed Salem } 3040eddb802SRaed Salem 3050eddb802SRaed Salem memmove(skb->data + PSP_ENCAP_HLEN, skb->data, l2_hlen + l3_hlen); 3060eddb802SRaed Salem skb_pull(skb, PSP_ENCAP_HLEN); 3070eddb802SRaed Salem 3080eddb802SRaed Salem if (strip_icv) 3090eddb802SRaed Salem pskb_trim(skb, skb->len - PSP_TRL_SIZE); 3100eddb802SRaed Salem 3110eddb802SRaed Salem return 0; 3120eddb802SRaed Salem } 3130eddb802SRaed Salem EXPORT_SYMBOL(psp_dev_rcv); 3140eddb802SRaed Salem 31500c94ca2SJakub Kicinski static int __init psp_init(void) 31600c94ca2SJakub Kicinski { 31700c94ca2SJakub Kicinski mutex_init(&psp_devs_lock); 31800c94ca2SJakub Kicinski 31900c94ca2SJakub Kicinski return genl_register_family(&psp_nl_family); 32000c94ca2SJakub Kicinski } 32100c94ca2SJakub Kicinski 32200c94ca2SJakub Kicinski subsys_initcall(psp_init); 323