1 // SPDX-License-Identifier: GPL-2.0-only 2 3 #include <linux/bitfield.h> 4 #include <linux/list.h> 5 #include <linux/netdevice.h> 6 #include <linux/xarray.h> 7 #include <net/net_namespace.h> 8 #include <net/psp.h> 9 #include <net/udp.h> 10 11 #include "psp.h" 12 #include "psp-nl-gen.h" 13 14 DEFINE_XARRAY_ALLOC1(psp_devs); 15 struct mutex psp_devs_lock; 16 17 /** 18 * DOC: PSP locking 19 * 20 * psp_devs_lock protects the psp_devs xarray. 21 * Ordering is take the psp_devs_lock and then the instance lock. 22 * Each instance is protected by RCU, and has a refcount. 23 * When driver unregisters the instance gets flushed, but struct sticks around. 24 */ 25 26 /** 27 * psp_dev_check_access() - check if user in a given net ns can access PSP dev 28 * @psd: PSP device structure user is trying to access 29 * @net: net namespace user is in 30 * 31 * Return: 0 if PSP device should be visible in @net, errno otherwise. 32 */ 33 int psp_dev_check_access(struct psp_dev *psd, struct net *net) 34 { 35 if (dev_net(psd->main_netdev) == net) 36 return 0; 37 return -ENOENT; 38 } 39 40 /** 41 * psp_dev_create() - create and register PSP device 42 * @netdev: main netdevice 43 * @psd_ops: driver callbacks 44 * @psd_caps: device capabilities 45 * @priv_ptr: back-pointer to driver private data 46 * 47 * Return: pointer to allocated PSP device, or ERR_PTR. 48 */ 49 struct psp_dev * 50 psp_dev_create(struct net_device *netdev, 51 struct psp_dev_ops *psd_ops, struct psp_dev_caps *psd_caps, 52 void *priv_ptr) 53 { 54 struct psp_dev *psd; 55 static u32 last_id; 56 int err; 57 58 if (WARN_ON(!psd_caps->versions || 59 !psd_ops->set_config || 60 !psd_ops->key_rotate || 61 !psd_ops->rx_spi_alloc || 62 !psd_ops->tx_key_add || 63 !psd_ops->tx_key_del)) 64 return ERR_PTR(-EINVAL); 65 66 psd = kzalloc(sizeof(*psd), GFP_KERNEL); 67 if (!psd) 68 return ERR_PTR(-ENOMEM); 69 70 psd->main_netdev = netdev; 71 psd->ops = psd_ops; 72 psd->caps = psd_caps; 73 psd->drv_priv = priv_ptr; 74 75 mutex_init(&psd->lock); 76 INIT_LIST_HEAD(&psd->active_assocs); 77 INIT_LIST_HEAD(&psd->prev_assocs); 78 INIT_LIST_HEAD(&psd->stale_assocs); 79 refcount_set(&psd->refcnt, 1); 80 81 mutex_lock(&psp_devs_lock); 82 err = xa_alloc_cyclic(&psp_devs, &psd->id, psd, xa_limit_16b, 83 &last_id, GFP_KERNEL); 84 if (err) { 85 mutex_unlock(&psp_devs_lock); 86 kfree(psd); 87 return ERR_PTR(err); 88 } 89 mutex_lock(&psd->lock); 90 mutex_unlock(&psp_devs_lock); 91 92 psp_nl_notify_dev(psd, PSP_CMD_DEV_ADD_NTF); 93 94 rcu_assign_pointer(netdev->psp_dev, psd); 95 96 mutex_unlock(&psd->lock); 97 98 return psd; 99 } 100 EXPORT_SYMBOL(psp_dev_create); 101 102 void psp_dev_destroy(struct psp_dev *psd) 103 { 104 mutex_lock(&psp_devs_lock); 105 xa_erase(&psp_devs, psd->id); 106 mutex_unlock(&psp_devs_lock); 107 108 mutex_destroy(&psd->lock); 109 kfree_rcu(psd, rcu); 110 } 111 112 /** 113 * psp_dev_unregister() - unregister PSP device 114 * @psd: PSP device structure 115 */ 116 void psp_dev_unregister(struct psp_dev *psd) 117 { 118 struct psp_assoc *pas, *next; 119 120 mutex_lock(&psp_devs_lock); 121 mutex_lock(&psd->lock); 122 123 psp_nl_notify_dev(psd, PSP_CMD_DEV_DEL_NTF); 124 125 /* Wait until psp_dev_destroy() to call xa_erase() to prevent a 126 * different psd from being added to the xarray with this id, while 127 * there are still references to this psd being held. 128 */ 129 xa_store(&psp_devs, psd->id, NULL, GFP_KERNEL); 130 mutex_unlock(&psp_devs_lock); 131 132 list_splice_init(&psd->active_assocs, &psd->prev_assocs); 133 list_splice_init(&psd->prev_assocs, &psd->stale_assocs); 134 list_for_each_entry_safe(pas, next, &psd->stale_assocs, assocs_list) 135 psp_dev_tx_key_del(psd, pas); 136 137 rcu_assign_pointer(psd->main_netdev->psp_dev, NULL); 138 139 psd->ops = NULL; 140 psd->drv_priv = NULL; 141 142 mutex_unlock(&psd->lock); 143 144 psp_dev_put(psd); 145 } 146 EXPORT_SYMBOL(psp_dev_unregister); 147 148 unsigned int psp_key_size(u32 version) 149 { 150 switch (version) { 151 case PSP_VERSION_HDR0_AES_GCM_128: 152 case PSP_VERSION_HDR0_AES_GMAC_128: 153 return 16; 154 case PSP_VERSION_HDR0_AES_GCM_256: 155 case PSP_VERSION_HDR0_AES_GMAC_256: 156 return 32; 157 default: 158 return 0; 159 } 160 } 161 EXPORT_SYMBOL(psp_key_size); 162 163 static void psp_write_headers(struct net *net, struct sk_buff *skb, __be32 spi, 164 u8 ver, unsigned int udp_len, __be16 sport) 165 { 166 struct udphdr *uh = udp_hdr(skb); 167 struct psphdr *psph = (struct psphdr *)(uh + 1); 168 169 uh->dest = htons(PSP_DEFAULT_UDP_PORT); 170 uh->source = udp_flow_src_port(net, skb, 0, 0, false); 171 uh->check = 0; 172 uh->len = htons(udp_len); 173 174 psph->nexthdr = IPPROTO_TCP; 175 psph->hdrlen = PSP_HDRLEN_NOOPT; 176 psph->crypt_offset = 0; 177 psph->verfl = FIELD_PREP(PSPHDR_VERFL_VERSION, ver) | 178 FIELD_PREP(PSPHDR_VERFL_ONE, 1); 179 psph->spi = spi; 180 memset(&psph->iv, 0, sizeof(psph->iv)); 181 } 182 183 /* Encapsulate a TCP packet with PSP by adding the UDP+PSP headers and filling 184 * them in. 185 */ 186 bool psp_dev_encapsulate(struct net *net, struct sk_buff *skb, __be32 spi, 187 u8 ver, __be16 sport) 188 { 189 u32 network_len = skb_network_header_len(skb); 190 u32 ethr_len = skb_mac_header_len(skb); 191 u32 bufflen = ethr_len + network_len; 192 193 if (skb_cow_head(skb, PSP_ENCAP_HLEN)) 194 return false; 195 196 skb_push(skb, PSP_ENCAP_HLEN); 197 skb->mac_header -= PSP_ENCAP_HLEN; 198 skb->network_header -= PSP_ENCAP_HLEN; 199 skb->transport_header -= PSP_ENCAP_HLEN; 200 memmove(skb->data, skb->data + PSP_ENCAP_HLEN, bufflen); 201 202 if (skb->protocol == htons(ETH_P_IP)) { 203 ip_hdr(skb)->protocol = IPPROTO_UDP; 204 be16_add_cpu(&ip_hdr(skb)->tot_len, PSP_ENCAP_HLEN); 205 ip_hdr(skb)->check = 0; 206 ip_hdr(skb)->check = 207 ip_fast_csum((u8 *)ip_hdr(skb), ip_hdr(skb)->ihl); 208 } else if (skb->protocol == htons(ETH_P_IPV6)) { 209 ipv6_hdr(skb)->nexthdr = IPPROTO_UDP; 210 be16_add_cpu(&ipv6_hdr(skb)->payload_len, PSP_ENCAP_HLEN); 211 } else { 212 return false; 213 } 214 215 skb_set_inner_ipproto(skb, IPPROTO_TCP); 216 skb_set_inner_transport_header(skb, skb_transport_offset(skb) + 217 PSP_ENCAP_HLEN); 218 skb->encapsulation = 1; 219 psp_write_headers(net, skb, spi, ver, 220 skb->len - skb_transport_offset(skb), sport); 221 222 return true; 223 } 224 EXPORT_SYMBOL(psp_dev_encapsulate); 225 226 static int __init psp_init(void) 227 { 228 mutex_init(&psp_devs_lock); 229 230 return genl_register_family(&psp_nl_family); 231 } 232 233 subsys_initcall(psp_init); 234