1 // SPDX-License-Identifier: GPL-2.0-only 2 3 #include <linux/bitfield.h> 4 #include <linux/list.h> 5 #include <linux/netdevice.h> 6 #include <linux/xarray.h> 7 #include <net/net_namespace.h> 8 #include <net/psp.h> 9 #include <net/udp.h> 10 11 #include "psp.h" 12 #include "psp-nl-gen.h" 13 14 DEFINE_XARRAY_ALLOC1(psp_devs); 15 struct mutex psp_devs_lock; 16 17 /** 18 * DOC: PSP locking 19 * 20 * psp_devs_lock protects the psp_devs xarray. 21 * Ordering is take the psp_devs_lock and then the instance lock. 22 * Each instance is protected by RCU, and has a refcount. 23 * When driver unregisters the instance gets flushed, but struct sticks around. 24 */ 25 26 /** 27 * psp_dev_check_access() - check if user in a given net ns can access PSP dev 28 * @psd: PSP device structure user is trying to access 29 * @net: net namespace user is in 30 * 31 * Return: 0 if PSP device should be visible in @net, errno otherwise. 32 */ 33 int psp_dev_check_access(struct psp_dev *psd, struct net *net) 34 { 35 if (dev_net(psd->main_netdev) == net) 36 return 0; 37 return -ENOENT; 38 } 39 40 /** 41 * psp_dev_create() - create and register PSP device 42 * @netdev: main netdevice 43 * @psd_ops: driver callbacks 44 * @psd_caps: device capabilities 45 * @priv_ptr: back-pointer to driver private data 46 * 47 * Return: pointer to allocated PSP device, or ERR_PTR. 48 */ 49 struct psp_dev * 50 psp_dev_create(struct net_device *netdev, 51 struct psp_dev_ops *psd_ops, struct psp_dev_caps *psd_caps, 52 void *priv_ptr) 53 { 54 struct psp_dev *psd; 55 static u32 last_id; 56 int err; 57 58 if (WARN_ON(!psd_caps->versions || 59 !psd_ops->set_config || 60 !psd_ops->key_rotate || 61 !psd_ops->rx_spi_alloc || 62 !psd_ops->tx_key_add || 63 !psd_ops->tx_key_del)) 64 return ERR_PTR(-EINVAL); 65 66 psd = kzalloc(sizeof(*psd), GFP_KERNEL); 67 if (!psd) 68 return ERR_PTR(-ENOMEM); 69 70 psd->main_netdev = netdev; 71 psd->ops = psd_ops; 72 psd->caps = psd_caps; 73 psd->drv_priv = priv_ptr; 74 75 mutex_init(&psd->lock); 76 INIT_LIST_HEAD(&psd->active_assocs); 77 INIT_LIST_HEAD(&psd->prev_assocs); 78 INIT_LIST_HEAD(&psd->stale_assocs); 79 refcount_set(&psd->refcnt, 1); 80 81 mutex_lock(&psp_devs_lock); 82 err = xa_alloc_cyclic(&psp_devs, &psd->id, psd, xa_limit_16b, 83 &last_id, GFP_KERNEL); 84 if (err) { 85 mutex_unlock(&psp_devs_lock); 86 kfree(psd); 87 return ERR_PTR(err); 88 } 89 mutex_lock(&psd->lock); 90 mutex_unlock(&psp_devs_lock); 91 92 psp_nl_notify_dev(psd, PSP_CMD_DEV_ADD_NTF); 93 94 rcu_assign_pointer(netdev->psp_dev, psd); 95 96 mutex_unlock(&psd->lock); 97 98 return psd; 99 } 100 EXPORT_SYMBOL(psp_dev_create); 101 102 void psp_dev_free(struct psp_dev *psd) 103 { 104 mutex_lock(&psp_devs_lock); 105 xa_erase(&psp_devs, psd->id); 106 mutex_unlock(&psp_devs_lock); 107 108 mutex_destroy(&psd->lock); 109 kfree_rcu(psd, rcu); 110 } 111 112 /** 113 * psp_dev_unregister() - unregister PSP device 114 * @psd: PSP device structure 115 */ 116 void psp_dev_unregister(struct psp_dev *psd) 117 { 118 struct psp_assoc *pas, *next; 119 120 mutex_lock(&psp_devs_lock); 121 mutex_lock(&psd->lock); 122 123 psp_nl_notify_dev(psd, PSP_CMD_DEV_DEL_NTF); 124 125 /* Wait until psp_dev_free() to call xa_erase() to prevent a 126 * different psd from being added to the xarray with this id, while 127 * there are still references to this psd being held. 128 */ 129 xa_store(&psp_devs, psd->id, NULL, GFP_KERNEL); 130 mutex_unlock(&psp_devs_lock); 131 132 list_splice_init(&psd->active_assocs, &psd->prev_assocs); 133 list_splice_init(&psd->prev_assocs, &psd->stale_assocs); 134 list_for_each_entry_safe(pas, next, &psd->stale_assocs, assocs_list) 135 psp_dev_tx_key_del(psd, pas); 136 137 rcu_assign_pointer(psd->main_netdev->psp_dev, NULL); 138 139 psd->ops = NULL; 140 psd->drv_priv = NULL; 141 142 mutex_unlock(&psd->lock); 143 144 psp_dev_put(psd); 145 } 146 EXPORT_SYMBOL(psp_dev_unregister); 147 148 unsigned int psp_key_size(u32 version) 149 { 150 switch (version) { 151 case PSP_VERSION_HDR0_AES_GCM_128: 152 case PSP_VERSION_HDR0_AES_GMAC_128: 153 return 16; 154 case PSP_VERSION_HDR0_AES_GCM_256: 155 case PSP_VERSION_HDR0_AES_GMAC_256: 156 return 32; 157 default: 158 return 0; 159 } 160 } 161 EXPORT_SYMBOL(psp_key_size); 162 163 static void psp_write_headers(struct net *net, struct sk_buff *skb, __be32 spi, 164 u8 ver, unsigned int udp_len, __be16 sport) 165 { 166 struct udphdr *uh = udp_hdr(skb); 167 struct psphdr *psph = (struct psphdr *)(uh + 1); 168 169 uh->dest = htons(PSP_DEFAULT_UDP_PORT); 170 uh->source = udp_flow_src_port(net, skb, 0, 0, false); 171 uh->check = 0; 172 uh->len = htons(udp_len); 173 174 psph->nexthdr = IPPROTO_TCP; 175 psph->hdrlen = PSP_HDRLEN_NOOPT; 176 psph->crypt_offset = 0; 177 psph->verfl = FIELD_PREP(PSPHDR_VERFL_VERSION, ver) | 178 FIELD_PREP(PSPHDR_VERFL_ONE, 1); 179 psph->spi = spi; 180 memset(&psph->iv, 0, sizeof(psph->iv)); 181 } 182 183 /* Encapsulate a TCP packet with PSP by adding the UDP+PSP headers and filling 184 * them in. 185 */ 186 bool psp_dev_encapsulate(struct net *net, struct sk_buff *skb, __be32 spi, 187 u8 ver, __be16 sport) 188 { 189 u32 network_len = skb_network_header_len(skb); 190 u32 ethr_len = skb_mac_header_len(skb); 191 u32 bufflen = ethr_len + network_len; 192 193 if (skb_cow_head(skb, PSP_ENCAP_HLEN)) 194 return false; 195 196 skb_push(skb, PSP_ENCAP_HLEN); 197 skb->mac_header -= PSP_ENCAP_HLEN; 198 skb->network_header -= PSP_ENCAP_HLEN; 199 skb->transport_header -= PSP_ENCAP_HLEN; 200 memmove(skb->data, skb->data + PSP_ENCAP_HLEN, bufflen); 201 202 if (skb->protocol == htons(ETH_P_IP)) { 203 ip_hdr(skb)->protocol = IPPROTO_UDP; 204 be16_add_cpu(&ip_hdr(skb)->tot_len, PSP_ENCAP_HLEN); 205 ip_hdr(skb)->check = 0; 206 ip_hdr(skb)->check = 207 ip_fast_csum((u8 *)ip_hdr(skb), ip_hdr(skb)->ihl); 208 } else if (skb->protocol == htons(ETH_P_IPV6)) { 209 ipv6_hdr(skb)->nexthdr = IPPROTO_UDP; 210 be16_add_cpu(&ipv6_hdr(skb)->payload_len, PSP_ENCAP_HLEN); 211 } else { 212 return false; 213 } 214 215 skb_set_inner_ipproto(skb, IPPROTO_TCP); 216 skb_set_inner_transport_header(skb, skb_transport_offset(skb) + 217 PSP_ENCAP_HLEN); 218 skb->encapsulation = 1; 219 psp_write_headers(net, skb, spi, ver, 220 skb->len - skb_transport_offset(skb), sport); 221 222 return true; 223 } 224 EXPORT_SYMBOL(psp_dev_encapsulate); 225 226 /* Receive handler for PSP packets. 227 * 228 * Presently it accepts only already-authenticated packets and does not 229 * support optional fields, such as virtualization cookies. The caller should 230 * ensure that skb->data is pointing to the mac header, and that skb->mac_len 231 * is set. 232 */ 233 int psp_dev_rcv(struct sk_buff *skb, u16 dev_id, u8 generation, bool strip_icv) 234 { 235 int l2_hlen = 0, l3_hlen, encap; 236 struct psp_skb_ext *pse; 237 struct psphdr *psph; 238 struct ethhdr *eth; 239 struct udphdr *uh; 240 __be16 proto; 241 bool is_udp; 242 243 eth = (struct ethhdr *)skb->data; 244 proto = __vlan_get_protocol(skb, eth->h_proto, &l2_hlen); 245 if (proto == htons(ETH_P_IP)) 246 l3_hlen = sizeof(struct iphdr); 247 else if (proto == htons(ETH_P_IPV6)) 248 l3_hlen = sizeof(struct ipv6hdr); 249 else 250 return -EINVAL; 251 252 if (unlikely(!pskb_may_pull(skb, l2_hlen + l3_hlen + PSP_ENCAP_HLEN))) 253 return -EINVAL; 254 255 if (proto == htons(ETH_P_IP)) { 256 struct iphdr *iph = (struct iphdr *)(skb->data + l2_hlen); 257 258 is_udp = iph->protocol == IPPROTO_UDP; 259 l3_hlen = iph->ihl * 4; 260 if (l3_hlen != sizeof(struct iphdr) && 261 !pskb_may_pull(skb, l2_hlen + l3_hlen + PSP_ENCAP_HLEN)) 262 return -EINVAL; 263 } else { 264 struct ipv6hdr *ipv6h = (struct ipv6hdr *)(skb->data + l2_hlen); 265 266 is_udp = ipv6h->nexthdr == IPPROTO_UDP; 267 } 268 269 if (unlikely(!is_udp)) 270 return -EINVAL; 271 272 uh = (struct udphdr *)(skb->data + l2_hlen + l3_hlen); 273 if (unlikely(uh->dest != htons(PSP_DEFAULT_UDP_PORT))) 274 return -EINVAL; 275 276 pse = skb_ext_add(skb, SKB_EXT_PSP); 277 if (!pse) 278 return -EINVAL; 279 280 psph = (struct psphdr *)(skb->data + l2_hlen + l3_hlen + 281 sizeof(struct udphdr)); 282 pse->spi = psph->spi; 283 pse->dev_id = dev_id; 284 pse->generation = generation; 285 pse->version = FIELD_GET(PSPHDR_VERFL_VERSION, psph->verfl); 286 287 encap = PSP_ENCAP_HLEN; 288 encap += strip_icv ? PSP_TRL_SIZE : 0; 289 290 if (proto == htons(ETH_P_IP)) { 291 struct iphdr *iph = (struct iphdr *)(skb->data + l2_hlen); 292 293 iph->protocol = psph->nexthdr; 294 iph->tot_len = htons(ntohs(iph->tot_len) - encap); 295 iph->check = 0; 296 iph->check = ip_fast_csum((u8 *)iph, iph->ihl); 297 } else { 298 struct ipv6hdr *ipv6h = (struct ipv6hdr *)(skb->data + l2_hlen); 299 300 ipv6h->nexthdr = psph->nexthdr; 301 ipv6h->payload_len = htons(ntohs(ipv6h->payload_len) - encap); 302 } 303 304 memmove(skb->data + PSP_ENCAP_HLEN, skb->data, l2_hlen + l3_hlen); 305 skb_pull(skb, PSP_ENCAP_HLEN); 306 307 if (strip_icv) 308 pskb_trim(skb, skb->len - PSP_TRL_SIZE); 309 310 return 0; 311 } 312 EXPORT_SYMBOL(psp_dev_rcv); 313 314 static int __init psp_init(void) 315 { 316 mutex_init(&psp_devs_lock); 317 318 return genl_register_family(&psp_nl_family); 319 } 320 321 subsys_initcall(psp_init); 322