xref: /linux/net/psp/psp_main.c (revision 07fdad3a93756b872da7b53647715c48d0f4a2d0)
100c94ca2SJakub Kicinski // SPDX-License-Identifier: GPL-2.0-only
200c94ca2SJakub Kicinski 
3fc724515SRaed Salem #include <linux/bitfield.h>
400c94ca2SJakub Kicinski #include <linux/list.h>
500c94ca2SJakub Kicinski #include <linux/netdevice.h>
600c94ca2SJakub Kicinski #include <linux/xarray.h>
700c94ca2SJakub Kicinski #include <net/net_namespace.h>
800c94ca2SJakub Kicinski #include <net/psp.h>
9fc724515SRaed Salem #include <net/udp.h>
1000c94ca2SJakub Kicinski 
1100c94ca2SJakub Kicinski #include "psp.h"
1200c94ca2SJakub Kicinski #include "psp-nl-gen.h"
1300c94ca2SJakub Kicinski 
1400c94ca2SJakub Kicinski DEFINE_XARRAY_ALLOC1(psp_devs);
1500c94ca2SJakub Kicinski struct mutex psp_devs_lock;
1600c94ca2SJakub Kicinski 
1700c94ca2SJakub Kicinski /**
1800c94ca2SJakub Kicinski  * DOC: PSP locking
1900c94ca2SJakub Kicinski  *
2000c94ca2SJakub Kicinski  * psp_devs_lock protects the psp_devs xarray.
2100c94ca2SJakub Kicinski  * Ordering is take the psp_devs_lock and then the instance lock.
2200c94ca2SJakub Kicinski  * Each instance is protected by RCU, and has a refcount.
2300c94ca2SJakub Kicinski  * When driver unregisters the instance gets flushed, but struct sticks around.
2400c94ca2SJakub Kicinski  */
2500c94ca2SJakub Kicinski 
2600c94ca2SJakub Kicinski /**
2700c94ca2SJakub Kicinski  * psp_dev_check_access() - check if user in a given net ns can access PSP dev
2800c94ca2SJakub Kicinski  * @psd:	PSP device structure user is trying to access
2900c94ca2SJakub Kicinski  * @net:	net namespace user is in
3000c94ca2SJakub Kicinski  *
3100c94ca2SJakub Kicinski  * Return: 0 if PSP device should be visible in @net, errno otherwise.
3200c94ca2SJakub Kicinski  */
3300c94ca2SJakub Kicinski int psp_dev_check_access(struct psp_dev *psd, struct net *net)
3400c94ca2SJakub Kicinski {
3500c94ca2SJakub Kicinski 	if (dev_net(psd->main_netdev) == net)
3600c94ca2SJakub Kicinski 		return 0;
3700c94ca2SJakub Kicinski 	return -ENOENT;
3800c94ca2SJakub Kicinski }
3900c94ca2SJakub Kicinski 
4000c94ca2SJakub Kicinski /**
4100c94ca2SJakub Kicinski  * psp_dev_create() - create and register PSP device
4200c94ca2SJakub Kicinski  * @netdev:	main netdevice
4300c94ca2SJakub Kicinski  * @psd_ops:	driver callbacks
4400c94ca2SJakub Kicinski  * @psd_caps:	device capabilities
4500c94ca2SJakub Kicinski  * @priv_ptr:	back-pointer to driver private data
4600c94ca2SJakub Kicinski  *
4700c94ca2SJakub Kicinski  * Return: pointer to allocated PSP device, or ERR_PTR.
4800c94ca2SJakub Kicinski  */
4900c94ca2SJakub Kicinski struct psp_dev *
5000c94ca2SJakub Kicinski psp_dev_create(struct net_device *netdev,
5100c94ca2SJakub Kicinski 	       struct psp_dev_ops *psd_ops, struct psp_dev_caps *psd_caps,
5200c94ca2SJakub Kicinski 	       void *priv_ptr)
5300c94ca2SJakub Kicinski {
5400c94ca2SJakub Kicinski 	struct psp_dev *psd;
5500c94ca2SJakub Kicinski 	static u32 last_id;
5600c94ca2SJakub Kicinski 	int err;
5700c94ca2SJakub Kicinski 
5800c94ca2SJakub Kicinski 	if (WARN_ON(!psd_caps->versions ||
59117f02a4SJakub Kicinski 		    !psd_ops->set_config ||
606b46ca26SJakub Kicinski 		    !psd_ops->key_rotate ||
616b46ca26SJakub Kicinski 		    !psd_ops->rx_spi_alloc ||
626b46ca26SJakub Kicinski 		    !psd_ops->tx_key_add ||
636b46ca26SJakub Kicinski 		    !psd_ops->tx_key_del))
6400c94ca2SJakub Kicinski 		return ERR_PTR(-EINVAL);
6500c94ca2SJakub Kicinski 
6600c94ca2SJakub Kicinski 	psd = kzalloc(sizeof(*psd), GFP_KERNEL);
6700c94ca2SJakub Kicinski 	if (!psd)
6800c94ca2SJakub Kicinski 		return ERR_PTR(-ENOMEM);
6900c94ca2SJakub Kicinski 
7000c94ca2SJakub Kicinski 	psd->main_netdev = netdev;
7100c94ca2SJakub Kicinski 	psd->ops = psd_ops;
7200c94ca2SJakub Kicinski 	psd->caps = psd_caps;
7300c94ca2SJakub Kicinski 	psd->drv_priv = priv_ptr;
7400c94ca2SJakub Kicinski 
7500c94ca2SJakub Kicinski 	mutex_init(&psd->lock);
766b46ca26SJakub Kicinski 	INIT_LIST_HEAD(&psd->active_assocs);
77e7885105SJakub Kicinski 	INIT_LIST_HEAD(&psd->prev_assocs);
78e7885105SJakub Kicinski 	INIT_LIST_HEAD(&psd->stale_assocs);
7900c94ca2SJakub Kicinski 	refcount_set(&psd->refcnt, 1);
8000c94ca2SJakub Kicinski 
8100c94ca2SJakub Kicinski 	mutex_lock(&psp_devs_lock);
8200c94ca2SJakub Kicinski 	err = xa_alloc_cyclic(&psp_devs, &psd->id, psd, xa_limit_16b,
8300c94ca2SJakub Kicinski 			      &last_id, GFP_KERNEL);
8400c94ca2SJakub Kicinski 	if (err) {
8500c94ca2SJakub Kicinski 		mutex_unlock(&psp_devs_lock);
8600c94ca2SJakub Kicinski 		kfree(psd);
8700c94ca2SJakub Kicinski 		return ERR_PTR(err);
8800c94ca2SJakub Kicinski 	}
8900c94ca2SJakub Kicinski 	mutex_lock(&psd->lock);
9000c94ca2SJakub Kicinski 	mutex_unlock(&psp_devs_lock);
9100c94ca2SJakub Kicinski 
9200c94ca2SJakub Kicinski 	psp_nl_notify_dev(psd, PSP_CMD_DEV_ADD_NTF);
9300c94ca2SJakub Kicinski 
9400c94ca2SJakub Kicinski 	rcu_assign_pointer(netdev->psp_dev, psd);
9500c94ca2SJakub Kicinski 
9600c94ca2SJakub Kicinski 	mutex_unlock(&psd->lock);
9700c94ca2SJakub Kicinski 
9800c94ca2SJakub Kicinski 	return psd;
9900c94ca2SJakub Kicinski }
10000c94ca2SJakub Kicinski EXPORT_SYMBOL(psp_dev_create);
10100c94ca2SJakub Kicinski 
102672beab0SEric Dumazet void psp_dev_free(struct psp_dev *psd)
10300c94ca2SJakub Kicinski {
10400c94ca2SJakub Kicinski 	mutex_lock(&psp_devs_lock);
10500c94ca2SJakub Kicinski 	xa_erase(&psp_devs, psd->id);
10600c94ca2SJakub Kicinski 	mutex_unlock(&psp_devs_lock);
10700c94ca2SJakub Kicinski 
10800c94ca2SJakub Kicinski 	mutex_destroy(&psd->lock);
10900c94ca2SJakub Kicinski 	kfree_rcu(psd, rcu);
11000c94ca2SJakub Kicinski }
11100c94ca2SJakub Kicinski 
11200c94ca2SJakub Kicinski /**
11300c94ca2SJakub Kicinski  * psp_dev_unregister() - unregister PSP device
11400c94ca2SJakub Kicinski  * @psd:	PSP device structure
11500c94ca2SJakub Kicinski  */
11600c94ca2SJakub Kicinski void psp_dev_unregister(struct psp_dev *psd)
11700c94ca2SJakub Kicinski {
1186b46ca26SJakub Kicinski 	struct psp_assoc *pas, *next;
1196b46ca26SJakub Kicinski 
12000c94ca2SJakub Kicinski 	mutex_lock(&psp_devs_lock);
12100c94ca2SJakub Kicinski 	mutex_lock(&psd->lock);
12200c94ca2SJakub Kicinski 
12300c94ca2SJakub Kicinski 	psp_nl_notify_dev(psd, PSP_CMD_DEV_DEL_NTF);
12400c94ca2SJakub Kicinski 
125672beab0SEric Dumazet 	/* Wait until psp_dev_free() to call xa_erase() to prevent a
12600c94ca2SJakub Kicinski 	 * different psd from being added to the xarray with this id, while
12700c94ca2SJakub Kicinski 	 * there are still references to this psd being held.
12800c94ca2SJakub Kicinski 	 */
12900c94ca2SJakub Kicinski 	xa_store(&psp_devs, psd->id, NULL, GFP_KERNEL);
13000c94ca2SJakub Kicinski 	mutex_unlock(&psp_devs_lock);
13100c94ca2SJakub Kicinski 
132e7885105SJakub Kicinski 	list_splice_init(&psd->active_assocs, &psd->prev_assocs);
133e7885105SJakub Kicinski 	list_splice_init(&psd->prev_assocs, &psd->stale_assocs);
134e7885105SJakub Kicinski 	list_for_each_entry_safe(pas, next, &psd->stale_assocs, assocs_list)
1356b46ca26SJakub Kicinski 		psp_dev_tx_key_del(psd, pas);
1366b46ca26SJakub Kicinski 
13700c94ca2SJakub Kicinski 	rcu_assign_pointer(psd->main_netdev->psp_dev, NULL);
13800c94ca2SJakub Kicinski 
13900c94ca2SJakub Kicinski 	psd->ops = NULL;
14000c94ca2SJakub Kicinski 	psd->drv_priv = NULL;
14100c94ca2SJakub Kicinski 
14200c94ca2SJakub Kicinski 	mutex_unlock(&psd->lock);
14300c94ca2SJakub Kicinski 
14400c94ca2SJakub Kicinski 	psp_dev_put(psd);
14500c94ca2SJakub Kicinski }
14600c94ca2SJakub Kicinski EXPORT_SYMBOL(psp_dev_unregister);
14700c94ca2SJakub Kicinski 
1486b46ca26SJakub Kicinski unsigned int psp_key_size(u32 version)
1496b46ca26SJakub Kicinski {
1506b46ca26SJakub Kicinski 	switch (version) {
1516b46ca26SJakub Kicinski 	case PSP_VERSION_HDR0_AES_GCM_128:
1526b46ca26SJakub Kicinski 	case PSP_VERSION_HDR0_AES_GMAC_128:
1536b46ca26SJakub Kicinski 		return 16;
1546b46ca26SJakub Kicinski 	case PSP_VERSION_HDR0_AES_GCM_256:
1556b46ca26SJakub Kicinski 	case PSP_VERSION_HDR0_AES_GMAC_256:
1566b46ca26SJakub Kicinski 		return 32;
1576b46ca26SJakub Kicinski 	default:
1586b46ca26SJakub Kicinski 		return 0;
1596b46ca26SJakub Kicinski 	}
1606b46ca26SJakub Kicinski }
1616b46ca26SJakub Kicinski EXPORT_SYMBOL(psp_key_size);
1626b46ca26SJakub Kicinski 
163fc724515SRaed Salem static void psp_write_headers(struct net *net, struct sk_buff *skb, __be32 spi,
164fc724515SRaed Salem 			      u8 ver, unsigned int udp_len, __be16 sport)
165fc724515SRaed Salem {
166fc724515SRaed Salem 	struct udphdr *uh = udp_hdr(skb);
167fc724515SRaed Salem 	struct psphdr *psph = (struct psphdr *)(uh + 1);
168fc724515SRaed Salem 
169fc724515SRaed Salem 	uh->dest = htons(PSP_DEFAULT_UDP_PORT);
170fc724515SRaed Salem 	uh->source = udp_flow_src_port(net, skb, 0, 0, false);
171fc724515SRaed Salem 	uh->check = 0;
172fc724515SRaed Salem 	uh->len = htons(udp_len);
173fc724515SRaed Salem 
174fc724515SRaed Salem 	psph->nexthdr = IPPROTO_TCP;
175fc724515SRaed Salem 	psph->hdrlen = PSP_HDRLEN_NOOPT;
176fc724515SRaed Salem 	psph->crypt_offset = 0;
177fc724515SRaed Salem 	psph->verfl = FIELD_PREP(PSPHDR_VERFL_VERSION, ver) |
178fc724515SRaed Salem 		      FIELD_PREP(PSPHDR_VERFL_ONE, 1);
179fc724515SRaed Salem 	psph->spi = spi;
180fc724515SRaed Salem 	memset(&psph->iv, 0, sizeof(psph->iv));
181fc724515SRaed Salem }
182fc724515SRaed Salem 
183fc724515SRaed Salem /* Encapsulate a TCP packet with PSP by adding the UDP+PSP headers and filling
184fc724515SRaed Salem  * them in.
185fc724515SRaed Salem  */
186fc724515SRaed Salem bool psp_dev_encapsulate(struct net *net, struct sk_buff *skb, __be32 spi,
187fc724515SRaed Salem 			 u8 ver, __be16 sport)
188fc724515SRaed Salem {
189fc724515SRaed Salem 	u32 network_len = skb_network_header_len(skb);
190fc724515SRaed Salem 	u32 ethr_len = skb_mac_header_len(skb);
191fc724515SRaed Salem 	u32 bufflen = ethr_len + network_len;
192fc724515SRaed Salem 
193fc724515SRaed Salem 	if (skb_cow_head(skb, PSP_ENCAP_HLEN))
194fc724515SRaed Salem 		return false;
195fc724515SRaed Salem 
196fc724515SRaed Salem 	skb_push(skb, PSP_ENCAP_HLEN);
197fc724515SRaed Salem 	skb->mac_header		-= PSP_ENCAP_HLEN;
198fc724515SRaed Salem 	skb->network_header	-= PSP_ENCAP_HLEN;
199fc724515SRaed Salem 	skb->transport_header	-= PSP_ENCAP_HLEN;
200fc724515SRaed Salem 	memmove(skb->data, skb->data + PSP_ENCAP_HLEN, bufflen);
201fc724515SRaed Salem 
202fc724515SRaed Salem 	if (skb->protocol == htons(ETH_P_IP)) {
203fc724515SRaed Salem 		ip_hdr(skb)->protocol = IPPROTO_UDP;
204fc724515SRaed Salem 		be16_add_cpu(&ip_hdr(skb)->tot_len, PSP_ENCAP_HLEN);
205fc724515SRaed Salem 		ip_hdr(skb)->check = 0;
206fc724515SRaed Salem 		ip_hdr(skb)->check =
207fc724515SRaed Salem 			ip_fast_csum((u8 *)ip_hdr(skb), ip_hdr(skb)->ihl);
208fc724515SRaed Salem 	} else if (skb->protocol == htons(ETH_P_IPV6)) {
209fc724515SRaed Salem 		ipv6_hdr(skb)->nexthdr = IPPROTO_UDP;
210fc724515SRaed Salem 		be16_add_cpu(&ipv6_hdr(skb)->payload_len, PSP_ENCAP_HLEN);
211fc724515SRaed Salem 	} else {
212fc724515SRaed Salem 		return false;
213fc724515SRaed Salem 	}
214fc724515SRaed Salem 
215fc724515SRaed Salem 	skb_set_inner_ipproto(skb, IPPROTO_TCP);
216fc724515SRaed Salem 	skb_set_inner_transport_header(skb, skb_transport_offset(skb) +
217fc724515SRaed Salem 						    PSP_ENCAP_HLEN);
218fc724515SRaed Salem 	skb->encapsulation = 1;
219fc724515SRaed Salem 	psp_write_headers(net, skb, spi, ver,
220fc724515SRaed Salem 			  skb->len - skb_transport_offset(skb), sport);
221fc724515SRaed Salem 
222fc724515SRaed Salem 	return true;
223fc724515SRaed Salem }
224fc724515SRaed Salem EXPORT_SYMBOL(psp_dev_encapsulate);
225fc724515SRaed Salem 
2260eddb802SRaed Salem /* Receive handler for PSP packets.
2270eddb802SRaed Salem  *
2280eddb802SRaed Salem  * Presently it accepts only already-authenticated packets and does not
2290eddb802SRaed Salem  * support optional fields, such as virtualization cookies. The caller should
2300eddb802SRaed Salem  * ensure that skb->data is pointing to the mac header, and that skb->mac_len
231*85c7333cSDaniel Zahka  * is set. This function does not currently adjust skb->csum (CHECKSUM_COMPLETE
232*85c7333cSDaniel Zahka  * is not supported).
2330eddb802SRaed Salem  */
2340eddb802SRaed Salem int psp_dev_rcv(struct sk_buff *skb, u16 dev_id, u8 generation, bool strip_icv)
2350eddb802SRaed Salem {
2360eddb802SRaed Salem 	int l2_hlen = 0, l3_hlen, encap;
2370eddb802SRaed Salem 	struct psp_skb_ext *pse;
2380eddb802SRaed Salem 	struct psphdr *psph;
2390eddb802SRaed Salem 	struct ethhdr *eth;
2400eddb802SRaed Salem 	struct udphdr *uh;
2410eddb802SRaed Salem 	__be16 proto;
2420eddb802SRaed Salem 	bool is_udp;
2430eddb802SRaed Salem 
2440eddb802SRaed Salem 	eth = (struct ethhdr *)skb->data;
2450eddb802SRaed Salem 	proto = __vlan_get_protocol(skb, eth->h_proto, &l2_hlen);
2460eddb802SRaed Salem 	if (proto == htons(ETH_P_IP))
2470eddb802SRaed Salem 		l3_hlen = sizeof(struct iphdr);
2480eddb802SRaed Salem 	else if (proto == htons(ETH_P_IPV6))
2490eddb802SRaed Salem 		l3_hlen = sizeof(struct ipv6hdr);
2500eddb802SRaed Salem 	else
2510eddb802SRaed Salem 		return -EINVAL;
2520eddb802SRaed Salem 
2530eddb802SRaed Salem 	if (unlikely(!pskb_may_pull(skb, l2_hlen + l3_hlen + PSP_ENCAP_HLEN)))
2540eddb802SRaed Salem 		return -EINVAL;
2550eddb802SRaed Salem 
2560eddb802SRaed Salem 	if (proto == htons(ETH_P_IP)) {
2570eddb802SRaed Salem 		struct iphdr *iph = (struct iphdr *)(skb->data + l2_hlen);
2580eddb802SRaed Salem 
2590eddb802SRaed Salem 		is_udp = iph->protocol == IPPROTO_UDP;
2600eddb802SRaed Salem 		l3_hlen = iph->ihl * 4;
2610eddb802SRaed Salem 		if (l3_hlen != sizeof(struct iphdr) &&
2620eddb802SRaed Salem 		    !pskb_may_pull(skb, l2_hlen + l3_hlen + PSP_ENCAP_HLEN))
2630eddb802SRaed Salem 			return -EINVAL;
2640eddb802SRaed Salem 	} else {
2650eddb802SRaed Salem 		struct ipv6hdr *ipv6h = (struct ipv6hdr *)(skb->data + l2_hlen);
2660eddb802SRaed Salem 
2670eddb802SRaed Salem 		is_udp = ipv6h->nexthdr == IPPROTO_UDP;
2680eddb802SRaed Salem 	}
2690eddb802SRaed Salem 
2700eddb802SRaed Salem 	if (unlikely(!is_udp))
2710eddb802SRaed Salem 		return -EINVAL;
2720eddb802SRaed Salem 
2730eddb802SRaed Salem 	uh = (struct udphdr *)(skb->data + l2_hlen + l3_hlen);
2740eddb802SRaed Salem 	if (unlikely(uh->dest != htons(PSP_DEFAULT_UDP_PORT)))
2750eddb802SRaed Salem 		return -EINVAL;
2760eddb802SRaed Salem 
2770eddb802SRaed Salem 	pse = skb_ext_add(skb, SKB_EXT_PSP);
2780eddb802SRaed Salem 	if (!pse)
2790eddb802SRaed Salem 		return -EINVAL;
2800eddb802SRaed Salem 
2810eddb802SRaed Salem 	psph = (struct psphdr *)(skb->data + l2_hlen + l3_hlen +
2820eddb802SRaed Salem 				 sizeof(struct udphdr));
2830eddb802SRaed Salem 	pse->spi = psph->spi;
2840eddb802SRaed Salem 	pse->dev_id = dev_id;
2850eddb802SRaed Salem 	pse->generation = generation;
2860eddb802SRaed Salem 	pse->version = FIELD_GET(PSPHDR_VERFL_VERSION, psph->verfl);
2870eddb802SRaed Salem 
2880eddb802SRaed Salem 	encap = PSP_ENCAP_HLEN;
2890eddb802SRaed Salem 	encap += strip_icv ? PSP_TRL_SIZE : 0;
2900eddb802SRaed Salem 
2910eddb802SRaed Salem 	if (proto == htons(ETH_P_IP)) {
2920eddb802SRaed Salem 		struct iphdr *iph = (struct iphdr *)(skb->data + l2_hlen);
2930eddb802SRaed Salem 
2940eddb802SRaed Salem 		iph->protocol = psph->nexthdr;
2950eddb802SRaed Salem 		iph->tot_len = htons(ntohs(iph->tot_len) - encap);
2960eddb802SRaed Salem 		iph->check = 0;
2970eddb802SRaed Salem 		iph->check = ip_fast_csum((u8 *)iph, iph->ihl);
2980eddb802SRaed Salem 	} else {
2990eddb802SRaed Salem 		struct ipv6hdr *ipv6h = (struct ipv6hdr *)(skb->data + l2_hlen);
3000eddb802SRaed Salem 
3010eddb802SRaed Salem 		ipv6h->nexthdr = psph->nexthdr;
3020eddb802SRaed Salem 		ipv6h->payload_len = htons(ntohs(ipv6h->payload_len) - encap);
3030eddb802SRaed Salem 	}
3040eddb802SRaed Salem 
3050eddb802SRaed Salem 	memmove(skb->data + PSP_ENCAP_HLEN, skb->data, l2_hlen + l3_hlen);
3060eddb802SRaed Salem 	skb_pull(skb, PSP_ENCAP_HLEN);
3070eddb802SRaed Salem 
3080eddb802SRaed Salem 	if (strip_icv)
3090eddb802SRaed Salem 		pskb_trim(skb, skb->len - PSP_TRL_SIZE);
3100eddb802SRaed Salem 
3110eddb802SRaed Salem 	return 0;
3120eddb802SRaed Salem }
3130eddb802SRaed Salem EXPORT_SYMBOL(psp_dev_rcv);
3140eddb802SRaed Salem 
31500c94ca2SJakub Kicinski static int __init psp_init(void)
31600c94ca2SJakub Kicinski {
31700c94ca2SJakub Kicinski 	mutex_init(&psp_devs_lock);
31800c94ca2SJakub Kicinski 
31900c94ca2SJakub Kicinski 	return genl_register_family(&psp_nl_family);
32000c94ca2SJakub Kicinski }
32100c94ca2SJakub Kicinski 
32200c94ca2SJakub Kicinski subsys_initcall(psp_init);
323