1 // SPDX-License-Identifier: GPL-2.0-only 2 3 #include <linux/bitfield.h> 4 #include <linux/list.h> 5 #include <linux/netdevice.h> 6 #include <linux/xarray.h> 7 #include <net/net_namespace.h> 8 #include <net/psp.h> 9 #include <net/udp.h> 10 11 #include "psp.h" 12 #include "psp-nl-gen.h" 13 14 DEFINE_XARRAY_ALLOC1(psp_devs); 15 struct mutex psp_devs_lock; 16 17 /** 18 * DOC: PSP locking 19 * 20 * psp_devs_lock protects the psp_devs xarray. 21 * Ordering is take the psp_devs_lock and then the instance lock. 22 * Each instance is protected by RCU, and has a refcount. 23 * When driver unregisters the instance gets flushed, but struct sticks around. 24 */ 25 26 /** 27 * psp_dev_check_access() - check if user in a given net ns can access PSP dev 28 * @psd: PSP device structure user is trying to access 29 * @net: net namespace user is in 30 * 31 * Return: 0 if PSP device should be visible in @net, errno otherwise. 32 */ 33 int psp_dev_check_access(struct psp_dev *psd, struct net *net) 34 { 35 if (dev_net(psd->main_netdev) == net) 36 return 0; 37 return -ENOENT; 38 } 39 40 /** 41 * psp_dev_create() - create and register PSP device 42 * @netdev: main netdevice 43 * @psd_ops: driver callbacks 44 * @psd_caps: device capabilities 45 * @priv_ptr: back-pointer to driver private data 46 * 47 * Return: pointer to allocated PSP device, or ERR_PTR. 48 */ 49 struct psp_dev * 50 psp_dev_create(struct net_device *netdev, 51 struct psp_dev_ops *psd_ops, struct psp_dev_caps *psd_caps, 52 void *priv_ptr) 53 { 54 struct psp_dev *psd; 55 static u32 last_id; 56 int err; 57 58 if (WARN_ON(!psd_caps->versions || 59 !psd_ops->set_config || 60 !psd_ops->key_rotate || 61 !psd_ops->rx_spi_alloc || 62 !psd_ops->tx_key_add || 63 !psd_ops->tx_key_del || 64 !psd_ops->get_stats)) 65 return ERR_PTR(-EINVAL); 66 67 psd = kzalloc_obj(*psd); 68 if (!psd) 69 return ERR_PTR(-ENOMEM); 70 71 psd->main_netdev = netdev; 72 psd->ops = psd_ops; 73 psd->caps = psd_caps; 74 psd->drv_priv = priv_ptr; 75 76 mutex_init(&psd->lock); 77 INIT_LIST_HEAD(&psd->active_assocs); 78 INIT_LIST_HEAD(&psd->prev_assocs); 79 INIT_LIST_HEAD(&psd->stale_assocs); 80 refcount_set(&psd->refcnt, 1); 81 82 mutex_lock(&psp_devs_lock); 83 err = xa_alloc_cyclic(&psp_devs, &psd->id, psd, xa_limit_16b, 84 &last_id, GFP_KERNEL); 85 if (err) { 86 mutex_unlock(&psp_devs_lock); 87 kfree(psd); 88 return ERR_PTR(err); 89 } 90 mutex_lock(&psd->lock); 91 mutex_unlock(&psp_devs_lock); 92 93 psp_nl_notify_dev(psd, PSP_CMD_DEV_ADD_NTF); 94 95 rcu_assign_pointer(netdev->psp_dev, psd); 96 97 mutex_unlock(&psd->lock); 98 99 return psd; 100 } 101 EXPORT_SYMBOL(psp_dev_create); 102 103 void psp_dev_free(struct psp_dev *psd) 104 { 105 mutex_lock(&psp_devs_lock); 106 xa_erase(&psp_devs, psd->id); 107 mutex_unlock(&psp_devs_lock); 108 109 mutex_destroy(&psd->lock); 110 kfree_rcu(psd, rcu); 111 } 112 113 /** 114 * psp_dev_unregister() - unregister PSP device 115 * @psd: PSP device structure 116 */ 117 void psp_dev_unregister(struct psp_dev *psd) 118 { 119 struct psp_assoc *pas, *next; 120 121 mutex_lock(&psp_devs_lock); 122 mutex_lock(&psd->lock); 123 124 psp_nl_notify_dev(psd, PSP_CMD_DEV_DEL_NTF); 125 126 /* Wait until psp_dev_free() to call xa_erase() to prevent a 127 * different psd from being added to the xarray with this id, while 128 * there are still references to this psd being held. 129 */ 130 xa_store(&psp_devs, psd->id, NULL, GFP_KERNEL); 131 mutex_unlock(&psp_devs_lock); 132 133 list_splice_init(&psd->active_assocs, &psd->prev_assocs); 134 list_splice_init(&psd->prev_assocs, &psd->stale_assocs); 135 list_for_each_entry_safe(pas, next, &psd->stale_assocs, assocs_list) 136 psp_dev_tx_key_del(psd, pas); 137 138 rcu_assign_pointer(psd->main_netdev->psp_dev, NULL); 139 140 psd->ops = NULL; 141 psd->drv_priv = NULL; 142 143 mutex_unlock(&psd->lock); 144 145 psp_dev_put(psd); 146 } 147 EXPORT_SYMBOL(psp_dev_unregister); 148 149 unsigned int psp_key_size(u32 version) 150 { 151 switch (version) { 152 case PSP_VERSION_HDR0_AES_GCM_128: 153 case PSP_VERSION_HDR0_AES_GMAC_128: 154 return 16; 155 case PSP_VERSION_HDR0_AES_GCM_256: 156 case PSP_VERSION_HDR0_AES_GMAC_256: 157 return 32; 158 default: 159 return 0; 160 } 161 } 162 EXPORT_SYMBOL(psp_key_size); 163 164 static void psp_write_headers(struct net *net, struct sk_buff *skb, __be32 spi, 165 u8 ver, unsigned int udp_len, __be16 sport) 166 { 167 struct udphdr *uh = udp_hdr(skb); 168 struct psphdr *psph = (struct psphdr *)(uh + 1); 169 const struct sock *sk = skb->sk; 170 171 uh->dest = htons(PSP_DEFAULT_UDP_PORT); 172 173 /* A bit of theory: Selection of the source port. 174 * 175 * We need some entropy, so that multiple flows use different 176 * source ports for better RSS spreading at the receiver. 177 * 178 * We also need that all packets belonging to one TCP flow 179 * use the same source port through their duration, 180 * so that all these packets land in the same receive queue. 181 * 182 * udp_flow_src_port() is using sk_txhash, inherited from 183 * skb_set_hash_from_sk() call in __tcp_transmit_skb(). 184 * This field is subject to reshuffling, thanks to 185 * sk_rethink_txhash() calls in various TCP functions. 186 * 187 * Instead, use sk->sk_hash which is constant through 188 * the whole flow duration. 189 */ 190 if (likely(sk)) { 191 u32 hash = sk->sk_hash; 192 int min, max; 193 194 /* These operations are cheap, no need to cache the result 195 * in another socket field. 196 */ 197 inet_get_local_port_range(net, &min, &max); 198 /* Since this is being sent on the wire obfuscate hash a bit 199 * to minimize possibility that any useful information to an 200 * attacker is leaked. Only upper 16 bits are relevant in the 201 * computation for 16 bit port value because we use a 202 * reciprocal divide. 203 */ 204 hash ^= hash << 16; 205 uh->source = htons((((u64)hash * (max - min)) >> 32) + min); 206 } else { 207 uh->source = udp_flow_src_port(net, skb, 0, 0, false); 208 } 209 uh->check = 0; 210 uh->len = htons(udp_len); 211 212 psph->nexthdr = IPPROTO_TCP; 213 psph->hdrlen = PSP_HDRLEN_NOOPT; 214 psph->crypt_offset = 0; 215 psph->verfl = FIELD_PREP(PSPHDR_VERFL_VERSION, ver) | 216 FIELD_PREP(PSPHDR_VERFL_ONE, 1); 217 psph->spi = spi; 218 memset(&psph->iv, 0, sizeof(psph->iv)); 219 } 220 221 /* Encapsulate a TCP packet with PSP by adding the UDP+PSP headers and filling 222 * them in. 223 */ 224 bool psp_dev_encapsulate(struct net *net, struct sk_buff *skb, __be32 spi, 225 u8 ver, __be16 sport) 226 { 227 u32 network_len = skb_network_header_len(skb); 228 u32 ethr_len = skb_mac_header_len(skb); 229 u32 bufflen = ethr_len + network_len; 230 231 if (skb_cow_head(skb, PSP_ENCAP_HLEN)) 232 return false; 233 234 skb_push(skb, PSP_ENCAP_HLEN); 235 skb->mac_header -= PSP_ENCAP_HLEN; 236 skb->network_header -= PSP_ENCAP_HLEN; 237 skb->transport_header -= PSP_ENCAP_HLEN; 238 memmove(skb->data, skb->data + PSP_ENCAP_HLEN, bufflen); 239 240 if (skb->protocol == htons(ETH_P_IP)) { 241 ip_hdr(skb)->protocol = IPPROTO_UDP; 242 be16_add_cpu(&ip_hdr(skb)->tot_len, PSP_ENCAP_HLEN); 243 ip_hdr(skb)->check = 0; 244 ip_hdr(skb)->check = 245 ip_fast_csum((u8 *)ip_hdr(skb), ip_hdr(skb)->ihl); 246 } else if (skb->protocol == htons(ETH_P_IPV6)) { 247 ipv6_hdr(skb)->nexthdr = IPPROTO_UDP; 248 be16_add_cpu(&ipv6_hdr(skb)->payload_len, PSP_ENCAP_HLEN); 249 } else { 250 return false; 251 } 252 253 skb_set_inner_ipproto(skb, IPPROTO_TCP); 254 skb_set_inner_transport_header(skb, skb_transport_offset(skb) + 255 PSP_ENCAP_HLEN); 256 skb->encapsulation = 1; 257 psp_write_headers(net, skb, spi, ver, 258 skb->len - skb_transport_offset(skb), sport); 259 260 return true; 261 } 262 EXPORT_SYMBOL(psp_dev_encapsulate); 263 264 /* Receive handler for PSP packets. 265 * 266 * Presently it accepts only already-authenticated packets and does not 267 * support optional fields, such as virtualization cookies. The caller should 268 * ensure that skb->data is pointing to the mac header, and that skb->mac_len 269 * is set. This function does not currently adjust skb->csum (CHECKSUM_COMPLETE 270 * is not supported). 271 */ 272 int psp_dev_rcv(struct sk_buff *skb, u16 dev_id, u8 generation, bool strip_icv) 273 { 274 int l2_hlen = 0, l3_hlen, encap; 275 struct psp_skb_ext *pse; 276 struct psphdr *psph; 277 struct ethhdr *eth; 278 struct udphdr *uh; 279 __be16 proto; 280 bool is_udp; 281 282 eth = (struct ethhdr *)skb->data; 283 proto = __vlan_get_protocol(skb, eth->h_proto, &l2_hlen); 284 if (proto == htons(ETH_P_IP)) 285 l3_hlen = sizeof(struct iphdr); 286 else if (proto == htons(ETH_P_IPV6)) 287 l3_hlen = sizeof(struct ipv6hdr); 288 else 289 return -EINVAL; 290 291 if (unlikely(!pskb_may_pull(skb, l2_hlen + l3_hlen + PSP_ENCAP_HLEN))) 292 return -EINVAL; 293 294 if (proto == htons(ETH_P_IP)) { 295 struct iphdr *iph = (struct iphdr *)(skb->data + l2_hlen); 296 297 is_udp = iph->protocol == IPPROTO_UDP; 298 l3_hlen = iph->ihl * 4; 299 if (l3_hlen != sizeof(struct iphdr) && 300 !pskb_may_pull(skb, l2_hlen + l3_hlen + PSP_ENCAP_HLEN)) 301 return -EINVAL; 302 } else { 303 struct ipv6hdr *ipv6h = (struct ipv6hdr *)(skb->data + l2_hlen); 304 305 is_udp = ipv6h->nexthdr == IPPROTO_UDP; 306 } 307 308 if (unlikely(!is_udp)) 309 return -EINVAL; 310 311 uh = (struct udphdr *)(skb->data + l2_hlen + l3_hlen); 312 if (unlikely(uh->dest != htons(PSP_DEFAULT_UDP_PORT))) 313 return -EINVAL; 314 315 pse = skb_ext_add(skb, SKB_EXT_PSP); 316 if (!pse) 317 return -EINVAL; 318 319 psph = (struct psphdr *)(skb->data + l2_hlen + l3_hlen + 320 sizeof(struct udphdr)); 321 pse->spi = psph->spi; 322 pse->dev_id = dev_id; 323 pse->generation = generation; 324 pse->version = FIELD_GET(PSPHDR_VERFL_VERSION, psph->verfl); 325 326 encap = PSP_ENCAP_HLEN; 327 encap += strip_icv ? PSP_TRL_SIZE : 0; 328 329 if (proto == htons(ETH_P_IP)) { 330 struct iphdr *iph = (struct iphdr *)(skb->data + l2_hlen); 331 332 iph->protocol = psph->nexthdr; 333 iph->tot_len = htons(ntohs(iph->tot_len) - encap); 334 iph->check = 0; 335 iph->check = ip_fast_csum((u8 *)iph, iph->ihl); 336 } else { 337 struct ipv6hdr *ipv6h = (struct ipv6hdr *)(skb->data + l2_hlen); 338 339 ipv6h->nexthdr = psph->nexthdr; 340 ipv6h->payload_len = htons(ntohs(ipv6h->payload_len) - encap); 341 } 342 343 memmove(skb->data + PSP_ENCAP_HLEN, skb->data, l2_hlen + l3_hlen); 344 skb_pull(skb, PSP_ENCAP_HLEN); 345 346 if (strip_icv) 347 pskb_trim(skb, skb->len - PSP_TRL_SIZE); 348 349 return 0; 350 } 351 EXPORT_SYMBOL(psp_dev_rcv); 352 353 static int __init psp_init(void) 354 { 355 mutex_init(&psp_devs_lock); 356 357 return genl_register_family(&psp_nl_family); 358 } 359 360 subsys_initcall(psp_init); 361