100c94ca2SJakub Kicinski // SPDX-License-Identifier: GPL-2.0-only 200c94ca2SJakub Kicinski 3*fc724515SRaed Salem #include <linux/bitfield.h> 400c94ca2SJakub Kicinski #include <linux/list.h> 500c94ca2SJakub Kicinski #include <linux/netdevice.h> 600c94ca2SJakub Kicinski #include <linux/xarray.h> 700c94ca2SJakub Kicinski #include <net/net_namespace.h> 800c94ca2SJakub Kicinski #include <net/psp.h> 9*fc724515SRaed Salem #include <net/udp.h> 1000c94ca2SJakub Kicinski 1100c94ca2SJakub Kicinski #include "psp.h" 1200c94ca2SJakub Kicinski #include "psp-nl-gen.h" 1300c94ca2SJakub Kicinski 1400c94ca2SJakub Kicinski DEFINE_XARRAY_ALLOC1(psp_devs); 1500c94ca2SJakub Kicinski struct mutex psp_devs_lock; 1600c94ca2SJakub Kicinski 1700c94ca2SJakub Kicinski /** 1800c94ca2SJakub Kicinski * DOC: PSP locking 1900c94ca2SJakub Kicinski * 2000c94ca2SJakub Kicinski * psp_devs_lock protects the psp_devs xarray. 2100c94ca2SJakub Kicinski * Ordering is take the psp_devs_lock and then the instance lock. 2200c94ca2SJakub Kicinski * Each instance is protected by RCU, and has a refcount. 2300c94ca2SJakub Kicinski * When driver unregisters the instance gets flushed, but struct sticks around. 2400c94ca2SJakub Kicinski */ 2500c94ca2SJakub Kicinski 2600c94ca2SJakub Kicinski /** 2700c94ca2SJakub Kicinski * psp_dev_check_access() - check if user in a given net ns can access PSP dev 2800c94ca2SJakub Kicinski * @psd: PSP device structure user is trying to access 2900c94ca2SJakub Kicinski * @net: net namespace user is in 3000c94ca2SJakub Kicinski * 3100c94ca2SJakub Kicinski * Return: 0 if PSP device should be visible in @net, errno otherwise. 3200c94ca2SJakub Kicinski */ 3300c94ca2SJakub Kicinski int psp_dev_check_access(struct psp_dev *psd, struct net *net) 3400c94ca2SJakub Kicinski { 3500c94ca2SJakub Kicinski if (dev_net(psd->main_netdev) == net) 3600c94ca2SJakub Kicinski return 0; 3700c94ca2SJakub Kicinski return -ENOENT; 3800c94ca2SJakub Kicinski } 3900c94ca2SJakub Kicinski 4000c94ca2SJakub Kicinski /** 4100c94ca2SJakub Kicinski * psp_dev_create() - create and register PSP device 4200c94ca2SJakub Kicinski * @netdev: main netdevice 4300c94ca2SJakub Kicinski * @psd_ops: driver callbacks 4400c94ca2SJakub Kicinski * @psd_caps: device capabilities 4500c94ca2SJakub Kicinski * @priv_ptr: back-pointer to driver private data 4600c94ca2SJakub Kicinski * 4700c94ca2SJakub Kicinski * Return: pointer to allocated PSP device, or ERR_PTR. 4800c94ca2SJakub Kicinski */ 4900c94ca2SJakub Kicinski struct psp_dev * 5000c94ca2SJakub Kicinski psp_dev_create(struct net_device *netdev, 5100c94ca2SJakub Kicinski struct psp_dev_ops *psd_ops, struct psp_dev_caps *psd_caps, 5200c94ca2SJakub Kicinski void *priv_ptr) 5300c94ca2SJakub Kicinski { 5400c94ca2SJakub Kicinski struct psp_dev *psd; 5500c94ca2SJakub Kicinski static u32 last_id; 5600c94ca2SJakub Kicinski int err; 5700c94ca2SJakub Kicinski 5800c94ca2SJakub Kicinski if (WARN_ON(!psd_caps->versions || 59117f02a4SJakub Kicinski !psd_ops->set_config || 606b46ca26SJakub Kicinski !psd_ops->key_rotate || 616b46ca26SJakub Kicinski !psd_ops->rx_spi_alloc || 626b46ca26SJakub Kicinski !psd_ops->tx_key_add || 636b46ca26SJakub Kicinski !psd_ops->tx_key_del)) 6400c94ca2SJakub Kicinski return ERR_PTR(-EINVAL); 6500c94ca2SJakub Kicinski 6600c94ca2SJakub Kicinski psd = kzalloc(sizeof(*psd), GFP_KERNEL); 6700c94ca2SJakub Kicinski if (!psd) 6800c94ca2SJakub Kicinski return ERR_PTR(-ENOMEM); 6900c94ca2SJakub Kicinski 7000c94ca2SJakub Kicinski psd->main_netdev = netdev; 7100c94ca2SJakub Kicinski psd->ops = psd_ops; 7200c94ca2SJakub Kicinski psd->caps = psd_caps; 7300c94ca2SJakub Kicinski psd->drv_priv = priv_ptr; 7400c94ca2SJakub Kicinski 7500c94ca2SJakub Kicinski mutex_init(&psd->lock); 766b46ca26SJakub Kicinski INIT_LIST_HEAD(&psd->active_assocs); 77e7885105SJakub Kicinski INIT_LIST_HEAD(&psd->prev_assocs); 78e7885105SJakub Kicinski INIT_LIST_HEAD(&psd->stale_assocs); 7900c94ca2SJakub Kicinski refcount_set(&psd->refcnt, 1); 8000c94ca2SJakub Kicinski 8100c94ca2SJakub Kicinski mutex_lock(&psp_devs_lock); 8200c94ca2SJakub Kicinski err = xa_alloc_cyclic(&psp_devs, &psd->id, psd, xa_limit_16b, 8300c94ca2SJakub Kicinski &last_id, GFP_KERNEL); 8400c94ca2SJakub Kicinski if (err) { 8500c94ca2SJakub Kicinski mutex_unlock(&psp_devs_lock); 8600c94ca2SJakub Kicinski kfree(psd); 8700c94ca2SJakub Kicinski return ERR_PTR(err); 8800c94ca2SJakub Kicinski } 8900c94ca2SJakub Kicinski mutex_lock(&psd->lock); 9000c94ca2SJakub Kicinski mutex_unlock(&psp_devs_lock); 9100c94ca2SJakub Kicinski 9200c94ca2SJakub Kicinski psp_nl_notify_dev(psd, PSP_CMD_DEV_ADD_NTF); 9300c94ca2SJakub Kicinski 9400c94ca2SJakub Kicinski rcu_assign_pointer(netdev->psp_dev, psd); 9500c94ca2SJakub Kicinski 9600c94ca2SJakub Kicinski mutex_unlock(&psd->lock); 9700c94ca2SJakub Kicinski 9800c94ca2SJakub Kicinski return psd; 9900c94ca2SJakub Kicinski } 10000c94ca2SJakub Kicinski EXPORT_SYMBOL(psp_dev_create); 10100c94ca2SJakub Kicinski 10200c94ca2SJakub Kicinski void psp_dev_destroy(struct psp_dev *psd) 10300c94ca2SJakub Kicinski { 10400c94ca2SJakub Kicinski mutex_lock(&psp_devs_lock); 10500c94ca2SJakub Kicinski xa_erase(&psp_devs, psd->id); 10600c94ca2SJakub Kicinski mutex_unlock(&psp_devs_lock); 10700c94ca2SJakub Kicinski 10800c94ca2SJakub Kicinski mutex_destroy(&psd->lock); 10900c94ca2SJakub Kicinski kfree_rcu(psd, rcu); 11000c94ca2SJakub Kicinski } 11100c94ca2SJakub Kicinski 11200c94ca2SJakub Kicinski /** 11300c94ca2SJakub Kicinski * psp_dev_unregister() - unregister PSP device 11400c94ca2SJakub Kicinski * @psd: PSP device structure 11500c94ca2SJakub Kicinski */ 11600c94ca2SJakub Kicinski void psp_dev_unregister(struct psp_dev *psd) 11700c94ca2SJakub Kicinski { 1186b46ca26SJakub Kicinski struct psp_assoc *pas, *next; 1196b46ca26SJakub Kicinski 12000c94ca2SJakub Kicinski mutex_lock(&psp_devs_lock); 12100c94ca2SJakub Kicinski mutex_lock(&psd->lock); 12200c94ca2SJakub Kicinski 12300c94ca2SJakub Kicinski psp_nl_notify_dev(psd, PSP_CMD_DEV_DEL_NTF); 12400c94ca2SJakub Kicinski 12500c94ca2SJakub Kicinski /* Wait until psp_dev_destroy() to call xa_erase() to prevent a 12600c94ca2SJakub Kicinski * different psd from being added to the xarray with this id, while 12700c94ca2SJakub Kicinski * there are still references to this psd being held. 12800c94ca2SJakub Kicinski */ 12900c94ca2SJakub Kicinski xa_store(&psp_devs, psd->id, NULL, GFP_KERNEL); 13000c94ca2SJakub Kicinski mutex_unlock(&psp_devs_lock); 13100c94ca2SJakub Kicinski 132e7885105SJakub Kicinski list_splice_init(&psd->active_assocs, &psd->prev_assocs); 133e7885105SJakub Kicinski list_splice_init(&psd->prev_assocs, &psd->stale_assocs); 134e7885105SJakub Kicinski list_for_each_entry_safe(pas, next, &psd->stale_assocs, assocs_list) 1356b46ca26SJakub Kicinski psp_dev_tx_key_del(psd, pas); 1366b46ca26SJakub Kicinski 13700c94ca2SJakub Kicinski rcu_assign_pointer(psd->main_netdev->psp_dev, NULL); 13800c94ca2SJakub Kicinski 13900c94ca2SJakub Kicinski psd->ops = NULL; 14000c94ca2SJakub Kicinski psd->drv_priv = NULL; 14100c94ca2SJakub Kicinski 14200c94ca2SJakub Kicinski mutex_unlock(&psd->lock); 14300c94ca2SJakub Kicinski 14400c94ca2SJakub Kicinski psp_dev_put(psd); 14500c94ca2SJakub Kicinski } 14600c94ca2SJakub Kicinski EXPORT_SYMBOL(psp_dev_unregister); 14700c94ca2SJakub Kicinski 1486b46ca26SJakub Kicinski unsigned int psp_key_size(u32 version) 1496b46ca26SJakub Kicinski { 1506b46ca26SJakub Kicinski switch (version) { 1516b46ca26SJakub Kicinski case PSP_VERSION_HDR0_AES_GCM_128: 1526b46ca26SJakub Kicinski case PSP_VERSION_HDR0_AES_GMAC_128: 1536b46ca26SJakub Kicinski return 16; 1546b46ca26SJakub Kicinski case PSP_VERSION_HDR0_AES_GCM_256: 1556b46ca26SJakub Kicinski case PSP_VERSION_HDR0_AES_GMAC_256: 1566b46ca26SJakub Kicinski return 32; 1576b46ca26SJakub Kicinski default: 1586b46ca26SJakub Kicinski return 0; 1596b46ca26SJakub Kicinski } 1606b46ca26SJakub Kicinski } 1616b46ca26SJakub Kicinski EXPORT_SYMBOL(psp_key_size); 1626b46ca26SJakub Kicinski 163*fc724515SRaed Salem static void psp_write_headers(struct net *net, struct sk_buff *skb, __be32 spi, 164*fc724515SRaed Salem u8 ver, unsigned int udp_len, __be16 sport) 165*fc724515SRaed Salem { 166*fc724515SRaed Salem struct udphdr *uh = udp_hdr(skb); 167*fc724515SRaed Salem struct psphdr *psph = (struct psphdr *)(uh + 1); 168*fc724515SRaed Salem 169*fc724515SRaed Salem uh->dest = htons(PSP_DEFAULT_UDP_PORT); 170*fc724515SRaed Salem uh->source = udp_flow_src_port(net, skb, 0, 0, false); 171*fc724515SRaed Salem uh->check = 0; 172*fc724515SRaed Salem uh->len = htons(udp_len); 173*fc724515SRaed Salem 174*fc724515SRaed Salem psph->nexthdr = IPPROTO_TCP; 175*fc724515SRaed Salem psph->hdrlen = PSP_HDRLEN_NOOPT; 176*fc724515SRaed Salem psph->crypt_offset = 0; 177*fc724515SRaed Salem psph->verfl = FIELD_PREP(PSPHDR_VERFL_VERSION, ver) | 178*fc724515SRaed Salem FIELD_PREP(PSPHDR_VERFL_ONE, 1); 179*fc724515SRaed Salem psph->spi = spi; 180*fc724515SRaed Salem memset(&psph->iv, 0, sizeof(psph->iv)); 181*fc724515SRaed Salem } 182*fc724515SRaed Salem 183*fc724515SRaed Salem /* Encapsulate a TCP packet with PSP by adding the UDP+PSP headers and filling 184*fc724515SRaed Salem * them in. 185*fc724515SRaed Salem */ 186*fc724515SRaed Salem bool psp_dev_encapsulate(struct net *net, struct sk_buff *skb, __be32 spi, 187*fc724515SRaed Salem u8 ver, __be16 sport) 188*fc724515SRaed Salem { 189*fc724515SRaed Salem u32 network_len = skb_network_header_len(skb); 190*fc724515SRaed Salem u32 ethr_len = skb_mac_header_len(skb); 191*fc724515SRaed Salem u32 bufflen = ethr_len + network_len; 192*fc724515SRaed Salem 193*fc724515SRaed Salem if (skb_cow_head(skb, PSP_ENCAP_HLEN)) 194*fc724515SRaed Salem return false; 195*fc724515SRaed Salem 196*fc724515SRaed Salem skb_push(skb, PSP_ENCAP_HLEN); 197*fc724515SRaed Salem skb->mac_header -= PSP_ENCAP_HLEN; 198*fc724515SRaed Salem skb->network_header -= PSP_ENCAP_HLEN; 199*fc724515SRaed Salem skb->transport_header -= PSP_ENCAP_HLEN; 200*fc724515SRaed Salem memmove(skb->data, skb->data + PSP_ENCAP_HLEN, bufflen); 201*fc724515SRaed Salem 202*fc724515SRaed Salem if (skb->protocol == htons(ETH_P_IP)) { 203*fc724515SRaed Salem ip_hdr(skb)->protocol = IPPROTO_UDP; 204*fc724515SRaed Salem be16_add_cpu(&ip_hdr(skb)->tot_len, PSP_ENCAP_HLEN); 205*fc724515SRaed Salem ip_hdr(skb)->check = 0; 206*fc724515SRaed Salem ip_hdr(skb)->check = 207*fc724515SRaed Salem ip_fast_csum((u8 *)ip_hdr(skb), ip_hdr(skb)->ihl); 208*fc724515SRaed Salem } else if (skb->protocol == htons(ETH_P_IPV6)) { 209*fc724515SRaed Salem ipv6_hdr(skb)->nexthdr = IPPROTO_UDP; 210*fc724515SRaed Salem be16_add_cpu(&ipv6_hdr(skb)->payload_len, PSP_ENCAP_HLEN); 211*fc724515SRaed Salem } else { 212*fc724515SRaed Salem return false; 213*fc724515SRaed Salem } 214*fc724515SRaed Salem 215*fc724515SRaed Salem skb_set_inner_ipproto(skb, IPPROTO_TCP); 216*fc724515SRaed Salem skb_set_inner_transport_header(skb, skb_transport_offset(skb) + 217*fc724515SRaed Salem PSP_ENCAP_HLEN); 218*fc724515SRaed Salem skb->encapsulation = 1; 219*fc724515SRaed Salem psp_write_headers(net, skb, spi, ver, 220*fc724515SRaed Salem skb->len - skb_transport_offset(skb), sport); 221*fc724515SRaed Salem 222*fc724515SRaed Salem return true; 223*fc724515SRaed Salem } 224*fc724515SRaed Salem EXPORT_SYMBOL(psp_dev_encapsulate); 225*fc724515SRaed Salem 22600c94ca2SJakub Kicinski static int __init psp_init(void) 22700c94ca2SJakub Kicinski { 22800c94ca2SJakub Kicinski mutex_init(&psp_devs_lock); 22900c94ca2SJakub Kicinski 23000c94ca2SJakub Kicinski return genl_register_family(&psp_nl_family); 23100c94ca2SJakub Kicinski } 23200c94ca2SJakub Kicinski 23300c94ca2SJakub Kicinski subsys_initcall(psp_init); 234