1 // SPDX-License-Identifier: GPL-2.0-only 2 3 #include <linux/bitfield.h> 4 #include <linux/list.h> 5 #include <linux/netdevice.h> 6 #include <linux/xarray.h> 7 #include <net/net_namespace.h> 8 #include <net/psp.h> 9 #include <net/udp.h> 10 11 #include "psp.h" 12 #include "psp-nl-gen.h" 13 14 DEFINE_XARRAY_ALLOC1(psp_devs); 15 struct mutex psp_devs_lock; 16 17 /** 18 * DOC: PSP locking 19 * 20 * psp_devs_lock protects the psp_devs xarray. 21 * Ordering is take the psp_devs_lock and then the instance lock. 22 * Each instance is protected by RCU, and has a refcount. 23 * When driver unregisters the instance gets flushed, but struct sticks around. 24 */ 25 26 /** 27 * psp_dev_check_access() - check if user in a given net ns can access PSP dev 28 * @psd: PSP device structure user is trying to access 29 * @net: net namespace user is in 30 * 31 * Return: 0 if PSP device should be visible in @net, errno otherwise. 32 */ 33 int psp_dev_check_access(struct psp_dev *psd, struct net *net) 34 { 35 if (dev_net(psd->main_netdev) == net) 36 return 0; 37 return -ENOENT; 38 } 39 40 /** 41 * psp_dev_create() - create and register PSP device 42 * @netdev: main netdevice 43 * @psd_ops: driver callbacks 44 * @psd_caps: device capabilities 45 * @priv_ptr: back-pointer to driver private data 46 * 47 * Return: pointer to allocated PSP device, or ERR_PTR. 48 */ 49 struct psp_dev * 50 psp_dev_create(struct net_device *netdev, 51 struct psp_dev_ops *psd_ops, struct psp_dev_caps *psd_caps, 52 void *priv_ptr) 53 { 54 struct psp_dev *psd; 55 static u32 last_id; 56 int err; 57 58 if (WARN_ON(!psd_caps->versions || 59 !psd_ops->set_config || 60 !psd_ops->key_rotate || 61 !psd_ops->rx_spi_alloc || 62 !psd_ops->tx_key_add || 63 !psd_ops->tx_key_del || 64 !psd_ops->get_stats)) 65 return ERR_PTR(-EINVAL); 66 67 psd = kzalloc(sizeof(*psd), GFP_KERNEL); 68 if (!psd) 69 return ERR_PTR(-ENOMEM); 70 71 psd->main_netdev = netdev; 72 psd->ops = psd_ops; 73 psd->caps = psd_caps; 74 psd->drv_priv = priv_ptr; 75 76 mutex_init(&psd->lock); 77 INIT_LIST_HEAD(&psd->active_assocs); 78 INIT_LIST_HEAD(&psd->prev_assocs); 79 INIT_LIST_HEAD(&psd->stale_assocs); 80 refcount_set(&psd->refcnt, 1); 81 82 mutex_lock(&psp_devs_lock); 83 err = xa_alloc_cyclic(&psp_devs, &psd->id, psd, xa_limit_16b, 84 &last_id, GFP_KERNEL); 85 if (err) { 86 mutex_unlock(&psp_devs_lock); 87 kfree(psd); 88 return ERR_PTR(err); 89 } 90 mutex_lock(&psd->lock); 91 mutex_unlock(&psp_devs_lock); 92 93 psp_nl_notify_dev(psd, PSP_CMD_DEV_ADD_NTF); 94 95 rcu_assign_pointer(netdev->psp_dev, psd); 96 97 mutex_unlock(&psd->lock); 98 99 return psd; 100 } 101 EXPORT_SYMBOL(psp_dev_create); 102 103 void psp_dev_free(struct psp_dev *psd) 104 { 105 mutex_lock(&psp_devs_lock); 106 xa_erase(&psp_devs, psd->id); 107 mutex_unlock(&psp_devs_lock); 108 109 mutex_destroy(&psd->lock); 110 kfree_rcu(psd, rcu); 111 } 112 113 /** 114 * psp_dev_unregister() - unregister PSP device 115 * @psd: PSP device structure 116 */ 117 void psp_dev_unregister(struct psp_dev *psd) 118 { 119 struct psp_assoc *pas, *next; 120 121 mutex_lock(&psp_devs_lock); 122 mutex_lock(&psd->lock); 123 124 psp_nl_notify_dev(psd, PSP_CMD_DEV_DEL_NTF); 125 126 /* Wait until psp_dev_free() to call xa_erase() to prevent a 127 * different psd from being added to the xarray with this id, while 128 * there are still references to this psd being held. 129 */ 130 xa_store(&psp_devs, psd->id, NULL, GFP_KERNEL); 131 mutex_unlock(&psp_devs_lock); 132 133 list_splice_init(&psd->active_assocs, &psd->prev_assocs); 134 list_splice_init(&psd->prev_assocs, &psd->stale_assocs); 135 list_for_each_entry_safe(pas, next, &psd->stale_assocs, assocs_list) 136 psp_dev_tx_key_del(psd, pas); 137 138 rcu_assign_pointer(psd->main_netdev->psp_dev, NULL); 139 140 psd->ops = NULL; 141 psd->drv_priv = NULL; 142 143 mutex_unlock(&psd->lock); 144 145 psp_dev_put(psd); 146 } 147 EXPORT_SYMBOL(psp_dev_unregister); 148 149 unsigned int psp_key_size(u32 version) 150 { 151 switch (version) { 152 case PSP_VERSION_HDR0_AES_GCM_128: 153 case PSP_VERSION_HDR0_AES_GMAC_128: 154 return 16; 155 case PSP_VERSION_HDR0_AES_GCM_256: 156 case PSP_VERSION_HDR0_AES_GMAC_256: 157 return 32; 158 default: 159 return 0; 160 } 161 } 162 EXPORT_SYMBOL(psp_key_size); 163 164 static void psp_write_headers(struct net *net, struct sk_buff *skb, __be32 spi, 165 u8 ver, unsigned int udp_len, __be16 sport) 166 { 167 struct udphdr *uh = udp_hdr(skb); 168 struct psphdr *psph = (struct psphdr *)(uh + 1); 169 170 uh->dest = htons(PSP_DEFAULT_UDP_PORT); 171 uh->source = udp_flow_src_port(net, skb, 0, 0, false); 172 uh->check = 0; 173 uh->len = htons(udp_len); 174 175 psph->nexthdr = IPPROTO_TCP; 176 psph->hdrlen = PSP_HDRLEN_NOOPT; 177 psph->crypt_offset = 0; 178 psph->verfl = FIELD_PREP(PSPHDR_VERFL_VERSION, ver) | 179 FIELD_PREP(PSPHDR_VERFL_ONE, 1); 180 psph->spi = spi; 181 memset(&psph->iv, 0, sizeof(psph->iv)); 182 } 183 184 /* Encapsulate a TCP packet with PSP by adding the UDP+PSP headers and filling 185 * them in. 186 */ 187 bool psp_dev_encapsulate(struct net *net, struct sk_buff *skb, __be32 spi, 188 u8 ver, __be16 sport) 189 { 190 u32 network_len = skb_network_header_len(skb); 191 u32 ethr_len = skb_mac_header_len(skb); 192 u32 bufflen = ethr_len + network_len; 193 194 if (skb_cow_head(skb, PSP_ENCAP_HLEN)) 195 return false; 196 197 skb_push(skb, PSP_ENCAP_HLEN); 198 skb->mac_header -= PSP_ENCAP_HLEN; 199 skb->network_header -= PSP_ENCAP_HLEN; 200 skb->transport_header -= PSP_ENCAP_HLEN; 201 memmove(skb->data, skb->data + PSP_ENCAP_HLEN, bufflen); 202 203 if (skb->protocol == htons(ETH_P_IP)) { 204 ip_hdr(skb)->protocol = IPPROTO_UDP; 205 be16_add_cpu(&ip_hdr(skb)->tot_len, PSP_ENCAP_HLEN); 206 ip_hdr(skb)->check = 0; 207 ip_hdr(skb)->check = 208 ip_fast_csum((u8 *)ip_hdr(skb), ip_hdr(skb)->ihl); 209 } else if (skb->protocol == htons(ETH_P_IPV6)) { 210 ipv6_hdr(skb)->nexthdr = IPPROTO_UDP; 211 be16_add_cpu(&ipv6_hdr(skb)->payload_len, PSP_ENCAP_HLEN); 212 } else { 213 return false; 214 } 215 216 skb_set_inner_ipproto(skb, IPPROTO_TCP); 217 skb_set_inner_transport_header(skb, skb_transport_offset(skb) + 218 PSP_ENCAP_HLEN); 219 skb->encapsulation = 1; 220 psp_write_headers(net, skb, spi, ver, 221 skb->len - skb_transport_offset(skb), sport); 222 223 return true; 224 } 225 EXPORT_SYMBOL(psp_dev_encapsulate); 226 227 /* Receive handler for PSP packets. 228 * 229 * Presently it accepts only already-authenticated packets and does not 230 * support optional fields, such as virtualization cookies. The caller should 231 * ensure that skb->data is pointing to the mac header, and that skb->mac_len 232 * is set. This function does not currently adjust skb->csum (CHECKSUM_COMPLETE 233 * is not supported). 234 */ 235 int psp_dev_rcv(struct sk_buff *skb, u16 dev_id, u8 generation, bool strip_icv) 236 { 237 int l2_hlen = 0, l3_hlen, encap; 238 struct psp_skb_ext *pse; 239 struct psphdr *psph; 240 struct ethhdr *eth; 241 struct udphdr *uh; 242 __be16 proto; 243 bool is_udp; 244 245 eth = (struct ethhdr *)skb->data; 246 proto = __vlan_get_protocol(skb, eth->h_proto, &l2_hlen); 247 if (proto == htons(ETH_P_IP)) 248 l3_hlen = sizeof(struct iphdr); 249 else if (proto == htons(ETH_P_IPV6)) 250 l3_hlen = sizeof(struct ipv6hdr); 251 else 252 return -EINVAL; 253 254 if (unlikely(!pskb_may_pull(skb, l2_hlen + l3_hlen + PSP_ENCAP_HLEN))) 255 return -EINVAL; 256 257 if (proto == htons(ETH_P_IP)) { 258 struct iphdr *iph = (struct iphdr *)(skb->data + l2_hlen); 259 260 is_udp = iph->protocol == IPPROTO_UDP; 261 l3_hlen = iph->ihl * 4; 262 if (l3_hlen != sizeof(struct iphdr) && 263 !pskb_may_pull(skb, l2_hlen + l3_hlen + PSP_ENCAP_HLEN)) 264 return -EINVAL; 265 } else { 266 struct ipv6hdr *ipv6h = (struct ipv6hdr *)(skb->data + l2_hlen); 267 268 is_udp = ipv6h->nexthdr == IPPROTO_UDP; 269 } 270 271 if (unlikely(!is_udp)) 272 return -EINVAL; 273 274 uh = (struct udphdr *)(skb->data + l2_hlen + l3_hlen); 275 if (unlikely(uh->dest != htons(PSP_DEFAULT_UDP_PORT))) 276 return -EINVAL; 277 278 pse = skb_ext_add(skb, SKB_EXT_PSP); 279 if (!pse) 280 return -EINVAL; 281 282 psph = (struct psphdr *)(skb->data + l2_hlen + l3_hlen + 283 sizeof(struct udphdr)); 284 pse->spi = psph->spi; 285 pse->dev_id = dev_id; 286 pse->generation = generation; 287 pse->version = FIELD_GET(PSPHDR_VERFL_VERSION, psph->verfl); 288 289 encap = PSP_ENCAP_HLEN; 290 encap += strip_icv ? PSP_TRL_SIZE : 0; 291 292 if (proto == htons(ETH_P_IP)) { 293 struct iphdr *iph = (struct iphdr *)(skb->data + l2_hlen); 294 295 iph->protocol = psph->nexthdr; 296 iph->tot_len = htons(ntohs(iph->tot_len) - encap); 297 iph->check = 0; 298 iph->check = ip_fast_csum((u8 *)iph, iph->ihl); 299 } else { 300 struct ipv6hdr *ipv6h = (struct ipv6hdr *)(skb->data + l2_hlen); 301 302 ipv6h->nexthdr = psph->nexthdr; 303 ipv6h->payload_len = htons(ntohs(ipv6h->payload_len) - encap); 304 } 305 306 memmove(skb->data + PSP_ENCAP_HLEN, skb->data, l2_hlen + l3_hlen); 307 skb_pull(skb, PSP_ENCAP_HLEN); 308 309 if (strip_icv) 310 pskb_trim(skb, skb->len - PSP_TRL_SIZE); 311 312 return 0; 313 } 314 EXPORT_SYMBOL(psp_dev_rcv); 315 316 static int __init psp_init(void) 317 { 318 mutex_init(&psp_devs_lock); 319 320 return genl_register_family(&psp_nl_family); 321 } 322 323 subsys_initcall(psp_init); 324