xref: /linux/net/psp/psp_main.c (revision 0ddb69e2406eba0c2f6bee0d6084e7dd17333c2b)
1 // SPDX-License-Identifier: GPL-2.0-only
2 
3 #include <linux/bitfield.h>
4 #include <linux/list.h>
5 #include <linux/netdevice.h>
6 #include <linux/xarray.h>
7 #include <net/net_namespace.h>
8 #include <net/psp.h>
9 #include <net/udp.h>
10 
11 #include "psp.h"
12 #include "psp-nl-gen.h"
13 
14 DEFINE_XARRAY_ALLOC1(psp_devs);
15 struct mutex psp_devs_lock;
16 
17 /**
18  * DOC: PSP locking
19  *
20  * psp_devs_lock protects the psp_devs xarray.
21  * Ordering is take the psp_devs_lock and then the instance lock.
22  * Each instance is protected by RCU, and has a refcount.
23  * When driver unregisters the instance gets flushed, but struct sticks around.
24  */
25 
26 /**
27  * psp_dev_check_access() - check if user in a given net ns can access PSP dev
28  * @psd:	PSP device structure user is trying to access
29  * @net:	net namespace user is in
30  * @admin:	If true, only allow access from @psd's main device's netns,
31  *		for admin operations like config changes and key rotation.
32  *		If false, also allow access from network namespaces that have
33  *		an associated device with @psd, for read-only and association
34  *		management operations.
35  *
36  * Return: 0 if PSP device should be visible in @net, errno otherwise.
37  */
38 int psp_dev_check_access(struct psp_dev *psd, struct net *net, bool admin)
39 {
40 	if (dev_net(psd->main_netdev) == net)
41 		return 0;
42 
43 	if (!admin && psp_has_assoc_dev_in_ns(psd, net))
44 		return 0;
45 
46 	return -ENOENT;
47 }
48 
49 /**
50  * psp_dev_create() - create and register PSP device
51  * @netdev:	main netdevice
52  * @psd_ops:	driver callbacks
53  * @psd_caps:	device capabilities
54  * @priv_ptr:	back-pointer to driver private data
55  *
56  * Return: pointer to allocated PSP device, or ERR_PTR.
57  */
58 struct psp_dev *
59 psp_dev_create(struct net_device *netdev,
60 	       struct psp_dev_ops *psd_ops, struct psp_dev_caps *psd_caps,
61 	       void *priv_ptr)
62 {
63 	struct psp_dev *psd;
64 	static u32 last_id;
65 	int err;
66 
67 	if (WARN_ON(!psd_caps->versions ||
68 		    !psd_ops->set_config ||
69 		    !psd_ops->key_rotate ||
70 		    !psd_ops->rx_spi_alloc ||
71 		    !psd_ops->tx_key_add ||
72 		    !psd_ops->tx_key_del ||
73 		    !psd_ops->get_stats))
74 		return ERR_PTR(-EINVAL);
75 
76 	psd = kzalloc_obj(*psd);
77 	if (!psd)
78 		return ERR_PTR(-ENOMEM);
79 
80 	psd->main_netdev = netdev;
81 	INIT_LIST_HEAD(&psd->assoc_dev_list);
82 	psd->ops = psd_ops;
83 	psd->caps = psd_caps;
84 	psd->drv_priv = priv_ptr;
85 
86 	mutex_init(&psd->lock);
87 	INIT_LIST_HEAD(&psd->active_assocs);
88 	INIT_LIST_HEAD(&psd->prev_assocs);
89 	INIT_LIST_HEAD(&psd->stale_assocs);
90 	refcount_set(&psd->refcnt, 1);
91 
92 	mutex_lock(&psp_devs_lock);
93 	err = xa_alloc_cyclic(&psp_devs, &psd->id, psd, xa_limit_16b,
94 			      &last_id, GFP_KERNEL);
95 	if (err) {
96 		mutex_unlock(&psp_devs_lock);
97 		kfree(psd);
98 		return ERR_PTR(err);
99 	}
100 	mutex_lock(&psd->lock);
101 	mutex_unlock(&psp_devs_lock);
102 
103 	/* notify before netdev assignment
104 	 * There's no strong reason for it, but thinking is to avoid creating
105 	 * implicit expectations about the PSP dev <> netdev relationship.
106 	 */
107 	psp_nl_notify_dev(psd, PSP_CMD_DEV_ADD_NTF);
108 
109 	rcu_assign_pointer(netdev->psp_dev, psd);
110 
111 	mutex_unlock(&psd->lock);
112 
113 	return psd;
114 }
115 EXPORT_SYMBOL(psp_dev_create);
116 
117 void psp_dev_free(struct psp_dev *psd)
118 {
119 	mutex_lock(&psp_devs_lock);
120 	xa_erase(&psp_devs, psd->id);
121 	mutex_unlock(&psp_devs_lock);
122 
123 	mutex_destroy(&psd->lock);
124 	kfree_rcu(psd, rcu);
125 }
126 
127 /**
128  * psp_dev_unregister() - unregister PSP device
129  * @psd:	PSP device structure
130  */
131 void psp_dev_unregister(struct psp_dev *psd)
132 {
133 	struct psp_assoc_dev *entry, *entry_tmp;
134 	struct psp_assoc *pas, *next;
135 
136 	mutex_lock(&psp_devs_lock);
137 	mutex_lock(&psd->lock);
138 
139 	psp_nl_notify_dev(psd, PSP_CMD_DEV_DEL_NTF);
140 
141 	/* Wait until psp_dev_free() to call xa_erase() to prevent a
142 	 * different psd from being added to the xarray with this id, while
143 	 * there are still references to this psd being held.
144 	 */
145 	xa_store(&psp_devs, psd->id, NULL, GFP_KERNEL);
146 	mutex_unlock(&psp_devs_lock);
147 
148 	list_splice_init(&psd->active_assocs, &psd->prev_assocs);
149 	list_splice_init(&psd->prev_assocs, &psd->stale_assocs);
150 	list_for_each_entry_safe(pas, next, &psd->stale_assocs, assocs_list)
151 		psp_dev_tx_key_del(psd, pas);
152 
153 	list_for_each_entry_safe(entry, entry_tmp, &psd->assoc_dev_list,
154 				 dev_list) {
155 		list_del(&entry->dev_list);
156 		rcu_assign_pointer(entry->assoc_dev->psp_dev, NULL);
157 		netdev_put(entry->assoc_dev, &entry->dev_tracker);
158 		kfree(entry);
159 	}
160 	psd->assoc_dev_cnt = 0;
161 
162 	rcu_assign_pointer(psd->main_netdev->psp_dev, NULL);
163 
164 	psd->ops = NULL;
165 	psd->drv_priv = NULL;
166 
167 	mutex_unlock(&psd->lock);
168 
169 	psp_dev_put(psd);
170 }
171 EXPORT_SYMBOL(psp_dev_unregister);
172 
173 unsigned int psp_key_size(u32 version)
174 {
175 	switch (version) {
176 	case PSP_VERSION_HDR0_AES_GCM_128:
177 	case PSP_VERSION_HDR0_AES_GMAC_128:
178 		return 16;
179 	case PSP_VERSION_HDR0_AES_GCM_256:
180 	case PSP_VERSION_HDR0_AES_GMAC_256:
181 		return 32;
182 	default:
183 		return 0;
184 	}
185 }
186 EXPORT_SYMBOL(psp_key_size);
187 
188 static void psp_write_headers(struct net *net, struct sk_buff *skb, __be32 spi,
189 			      u8 ver, unsigned int udp_len, __be16 sport)
190 {
191 	struct udphdr *uh = udp_hdr(skb);
192 	struct psphdr *psph = (struct psphdr *)(uh + 1);
193 	const struct sock *sk = skb->sk;
194 
195 	uh->dest = htons(PSP_DEFAULT_UDP_PORT);
196 
197 	/* A bit of theory: Selection of the source port.
198 	 *
199 	 * We need some entropy, so that multiple flows use different
200 	 * source ports for better RSS spreading at the receiver.
201 	 *
202 	 * We also need that all packets belonging to one TCP flow
203 	 * use the same source port through their duration,
204 	 * so that all these packets land in the same receive queue.
205 	 *
206 	 * udp_flow_src_port() is using sk_txhash, inherited from
207 	 * skb_set_hash_from_sk() call in __tcp_transmit_skb().
208 	 * This field is subject to reshuffling, thanks to
209 	 * sk_rethink_txhash() calls in various TCP functions.
210 	 *
211 	 * Instead, use sk->sk_hash which is constant through
212 	 * the whole flow duration.
213 	 */
214 	if (likely(sk)) {
215 		u32 hash = sk->sk_hash;
216 		int min, max;
217 
218 		/* These operations are cheap, no need to cache the result
219 		 * in another socket field.
220 		 */
221 		inet_get_local_port_range(net, &min, &max);
222 		/* Since this is being sent on the wire obfuscate hash a bit
223 		 * to minimize possibility that any useful information to an
224 		 * attacker is leaked. Only upper 16 bits are relevant in the
225 		 * computation for 16 bit port value because we use a
226 		 * reciprocal divide.
227 		 */
228 		hash ^= hash << 16;
229 		uh->source = htons(reciprocal_scale(hash, max - min + 1) + min);
230 	} else {
231 		uh->source = udp_flow_src_port(net, skb, 0, 0, false);
232 	}
233 	uh->check = 0;
234 	uh->len = htons(udp_len);
235 
236 	psph->nexthdr = IPPROTO_TCP;
237 	psph->hdrlen = PSP_HDRLEN_NOOPT;
238 	psph->crypt_offset = 0;
239 	psph->verfl = FIELD_PREP(PSPHDR_VERFL_VERSION, ver) |
240 		      FIELD_PREP(PSPHDR_VERFL_ONE, 1);
241 	psph->spi = spi;
242 	memset(&psph->iv, 0, sizeof(psph->iv));
243 }
244 
245 /* Encapsulate a TCP packet with PSP by adding the UDP+PSP headers and filling
246  * them in.
247  */
248 bool psp_dev_encapsulate(struct net *net, struct sk_buff *skb, __be32 spi,
249 			 u8 ver, __be16 sport)
250 {
251 	u32 network_len = skb_network_header_len(skb);
252 	u32 ethr_len = skb_mac_header_len(skb);
253 	u32 bufflen = ethr_len + network_len;
254 
255 	if (skb->protocol != htons(ETH_P_IP) &&
256 	    skb->protocol != htons(ETH_P_IPV6))
257 		return false;
258 
259 	if (skb_cow_head(skb, PSP_ENCAP_HLEN))
260 		return false;
261 
262 	skb_push(skb, PSP_ENCAP_HLEN);
263 	skb->mac_header		-= PSP_ENCAP_HLEN;
264 	skb->network_header	-= PSP_ENCAP_HLEN;
265 	skb->transport_header	-= PSP_ENCAP_HLEN;
266 	memmove(skb->data, skb->data + PSP_ENCAP_HLEN, bufflen);
267 
268 	if (skb->protocol == htons(ETH_P_IP)) {
269 		ip_hdr(skb)->protocol = IPPROTO_UDP;
270 		be16_add_cpu(&ip_hdr(skb)->tot_len, PSP_ENCAP_HLEN);
271 		ip_hdr(skb)->check = 0;
272 		ip_hdr(skb)->check =
273 			ip_fast_csum((u8 *)ip_hdr(skb), ip_hdr(skb)->ihl);
274 	} else {
275 		ipv6_hdr(skb)->nexthdr = IPPROTO_UDP;
276 		be16_add_cpu(&ipv6_hdr(skb)->payload_len, PSP_ENCAP_HLEN);
277 	}
278 
279 	skb_set_inner_ipproto(skb, IPPROTO_TCP);
280 	skb_set_inner_transport_header(skb, skb_transport_offset(skb) +
281 						    PSP_ENCAP_HLEN);
282 	skb->encapsulation = 1;
283 	psp_write_headers(net, skb, spi, ver,
284 			  skb->len - skb_transport_offset(skb), sport);
285 
286 	return true;
287 }
288 EXPORT_SYMBOL(psp_dev_encapsulate);
289 
290 /* Receive handler for PSP packets.
291  *
292  * Accepts only already-authenticated packets. The full PSP header is
293  * stripped according to psph->hdrlen; any optional fields it advertises
294  * (virtualization cookies, etc.) are ignored and discarded along with the
295  * rest of the header. The caller should ensure that skb->data is pointing
296  * to the mac header, and that skb->mac_len is set. This function does not
297  * currently adjust skb->csum (CHECKSUM_COMPLETE is not supported).
298  */
299 int psp_dev_rcv(struct sk_buff *skb, u16 dev_id, u8 generation, bool strip_icv)
300 {
301 	int l2_hlen = 0, l3_hlen, encap, psp_hlen;
302 	struct psp_skb_ext *pse;
303 	struct psphdr *psph;
304 	struct ethhdr *eth;
305 	struct udphdr *uh;
306 	__be16 proto;
307 	bool is_udp;
308 
309 	eth = (struct ethhdr *)skb->data;
310 	proto = __vlan_get_protocol(skb, eth->h_proto, &l2_hlen);
311 	if (proto == htons(ETH_P_IP))
312 		l3_hlen = sizeof(struct iphdr);
313 	else if (proto == htons(ETH_P_IPV6))
314 		l3_hlen = sizeof(struct ipv6hdr);
315 	else
316 		return -EINVAL;
317 
318 	if (unlikely(!pskb_may_pull(skb, l2_hlen + l3_hlen + PSP_ENCAP_HLEN)))
319 		return -EINVAL;
320 
321 	if (proto == htons(ETH_P_IP)) {
322 		struct iphdr *iph = (struct iphdr *)(skb->data + l2_hlen);
323 
324 		if (unlikely(iph->ihl < 5))
325 			return -EINVAL;
326 
327 		is_udp = iph->protocol == IPPROTO_UDP;
328 		l3_hlen = iph->ihl * 4;
329 		if (l3_hlen != sizeof(struct iphdr) &&
330 		    !pskb_may_pull(skb, l2_hlen + l3_hlen + PSP_ENCAP_HLEN))
331 			return -EINVAL;
332 	} else {
333 		struct ipv6hdr *ipv6h = (struct ipv6hdr *)(skb->data + l2_hlen);
334 
335 		is_udp = ipv6h->nexthdr == IPPROTO_UDP;
336 	}
337 
338 	if (unlikely(!is_udp))
339 		return -EINVAL;
340 
341 	uh = (struct udphdr *)(skb->data + l2_hlen + l3_hlen);
342 	if (unlikely(uh->dest != htons(PSP_DEFAULT_UDP_PORT)))
343 		return -EINVAL;
344 
345 	psph = (struct psphdr *)(skb->data + l2_hlen + l3_hlen +
346 				 sizeof(struct udphdr));
347 
348 	/* Strip the full PSP header per psph->hdrlen; VC/options are pulled
349 	 * into the linear region only so they can be discarded with the
350 	 * rest of the header.
351 	 */
352 	psp_hlen = (psph->hdrlen + 1) * 8;
353 
354 	if (unlikely(psp_hlen < sizeof(struct psphdr)))
355 		return -EINVAL;
356 
357 	if (psp_hlen > sizeof(struct psphdr) &&
358 	    !pskb_may_pull(skb, l2_hlen + l3_hlen +
359 				sizeof(struct udphdr) + psp_hlen))
360 		return -EINVAL;
361 
362 	psph = (struct psphdr *)(skb->data + l2_hlen + l3_hlen +
363 				 sizeof(struct udphdr));
364 
365 	pse = skb_ext_add(skb, SKB_EXT_PSP);
366 	if (!pse)
367 		return -EINVAL;
368 
369 	pse->spi = psph->spi;
370 	pse->dev_id = dev_id;
371 	pse->generation = generation;
372 	pse->version = FIELD_GET(PSPHDR_VERFL_VERSION, psph->verfl);
373 
374 	encap = sizeof(struct udphdr) + psp_hlen;
375 	encap += strip_icv ? PSP_TRL_SIZE : 0;
376 
377 	if (proto == htons(ETH_P_IP)) {
378 		struct iphdr *iph = (struct iphdr *)(skb->data + l2_hlen);
379 
380 		if (unlikely(ntohs(iph->tot_len) < l3_hlen + encap))
381 			return -EINVAL;
382 
383 		iph->protocol = psph->nexthdr;
384 		iph->tot_len = htons(ntohs(iph->tot_len) - encap);
385 		iph->check = 0;
386 		iph->check = ip_fast_csum((u8 *)iph, iph->ihl);
387 	} else {
388 		struct ipv6hdr *ipv6h = (struct ipv6hdr *)(skb->data + l2_hlen);
389 
390 		if (unlikely(ntohs(ipv6h->payload_len) < encap))
391 			return -EINVAL;
392 
393 		ipv6h->nexthdr = psph->nexthdr;
394 		ipv6h->payload_len = htons(ntohs(ipv6h->payload_len) - encap);
395 	}
396 
397 	memmove(skb->data + sizeof(struct udphdr) + psp_hlen,
398 		skb->data, l2_hlen + l3_hlen);
399 	skb_pull(skb, sizeof(struct udphdr) + psp_hlen);
400 
401 	if (strip_icv)
402 		pskb_trim(skb, skb->len - PSP_TRL_SIZE);
403 
404 	return 0;
405 }
406 EXPORT_SYMBOL(psp_dev_rcv);
407 
408 static void psp_dev_disassoc_one(struct psp_dev *psd, struct net_device *dev)
409 {
410 	struct psp_assoc_dev *entry;
411 
412 	list_for_each_entry(entry, &psd->assoc_dev_list, dev_list) {
413 		if (entry->assoc_dev == dev) {
414 			list_del(&entry->dev_list);
415 			psd->assoc_dev_cnt--;
416 			rcu_assign_pointer(entry->assoc_dev->psp_dev, NULL);
417 			netdev_put(entry->assoc_dev, &entry->dev_tracker);
418 			kfree(entry);
419 			return;
420 		}
421 	}
422 }
423 
424 static int psp_netdev_event(struct notifier_block *nb, unsigned long event,
425 			    void *ptr)
426 {
427 	struct net_device *dev = netdev_notifier_info_to_dev(ptr);
428 	struct psp_dev *psd;
429 
430 	if (event != NETDEV_UNREGISTER)
431 		return NOTIFY_DONE;
432 
433 	rcu_read_lock();
434 	psd = rcu_dereference(dev->psp_dev);
435 	if (psd && psp_dev_tryget(psd)) {
436 		rcu_read_unlock();
437 		mutex_lock(&psd->lock);
438 		if (psp_dev_is_registered(psd))
439 			psp_nl_notify_dev(psd, PSP_CMD_DEV_CHANGE_NTF);
440 		psp_dev_disassoc_one(psd, dev);
441 		mutex_unlock(&psd->lock);
442 		psp_dev_put(psd);
443 	} else {
444 		rcu_read_unlock();
445 	}
446 
447 	return NOTIFY_DONE;
448 }
449 
450 static struct notifier_block psp_netdev_notifier = {
451 	.notifier_call = psp_netdev_event,
452 };
453 
454 static DEFINE_MUTEX(psp_notifier_lock);
455 static bool psp_notifier_registered;
456 
457 /* Register the netdevice notifier when the first device association
458  * is created. In many installations no associations will be created and
459  * the notifier won't be needed.
460  *
461  * Must be called without psd->lock held, due to lock ordering:
462  * rtnl_lock -> psd->lock (the notifier callback runs under rtnl_lock
463  * and takes psd->lock).
464  */
465 int psp_attach_netdev_notifier(void)
466 {
467 	int err = 0;
468 
469 	if (READ_ONCE(psp_notifier_registered))
470 		return 0;
471 
472 	mutex_lock(&psp_notifier_lock);
473 	if (!psp_notifier_registered) {
474 		err = register_netdevice_notifier(&psp_netdev_notifier);
475 		if (!err)
476 			WRITE_ONCE(psp_notifier_registered, true);
477 	}
478 	mutex_unlock(&psp_notifier_lock);
479 
480 	return err;
481 }
482 
483 static int __init psp_init(void)
484 {
485 	mutex_init(&psp_devs_lock);
486 
487 	return genl_register_family(&psp_nl_family);
488 }
489 
490 subsys_initcall(psp_init);
491