xref: /linux/net/psp/psp_sock.c (revision 07fdad3a93756b872da7b53647715c48d0f4a2d0)
16b46ca26SJakub Kicinski // SPDX-License-Identifier: GPL-2.0-only
26b46ca26SJakub Kicinski 
36b46ca26SJakub Kicinski #include <linux/file.h>
46b46ca26SJakub Kicinski #include <linux/net.h>
56b46ca26SJakub Kicinski #include <linux/rcupdate.h>
66b46ca26SJakub Kicinski #include <linux/tcp.h>
76b46ca26SJakub Kicinski 
86b46ca26SJakub Kicinski #include <net/ip.h>
96b46ca26SJakub Kicinski #include <net/psp.h>
106b46ca26SJakub Kicinski #include "psp.h"
116b46ca26SJakub Kicinski 
126b46ca26SJakub Kicinski struct psp_dev *psp_dev_get_for_sock(struct sock *sk)
136b46ca26SJakub Kicinski {
1417f1b771SEric Dumazet 	struct psp_dev *psd = NULL;
156b46ca26SJakub Kicinski 	struct dst_entry *dst;
166b46ca26SJakub Kicinski 
176b46ca26SJakub Kicinski 	rcu_read_lock();
1817f1b771SEric Dumazet 	dst = __sk_dst_get(sk);
1917f1b771SEric Dumazet 	if (dst) {
2017f1b771SEric Dumazet 		psd = rcu_dereference(dst_dev_rcu(dst)->psp_dev);
216b46ca26SJakub Kicinski 		if (psd && !psp_dev_tryget(psd))
226b46ca26SJakub Kicinski 			psd = NULL;
2317f1b771SEric Dumazet 	}
246b46ca26SJakub Kicinski 	rcu_read_unlock();
256b46ca26SJakub Kicinski 
266b46ca26SJakub Kicinski 	return psd;
276b46ca26SJakub Kicinski }
286b46ca26SJakub Kicinski 
296b46ca26SJakub Kicinski static struct sk_buff *
306b46ca26SJakub Kicinski psp_validate_xmit(struct sock *sk, struct net_device *dev, struct sk_buff *skb)
316b46ca26SJakub Kicinski {
326b46ca26SJakub Kicinski 	struct psp_assoc *pas;
336b46ca26SJakub Kicinski 	bool good;
346b46ca26SJakub Kicinski 
356b46ca26SJakub Kicinski 	rcu_read_lock();
366b46ca26SJakub Kicinski 	pas = psp_skb_get_assoc_rcu(skb);
376b46ca26SJakub Kicinski 	good = !pas || rcu_access_pointer(dev->psp_dev) == pas->psd;
386b46ca26SJakub Kicinski 	rcu_read_unlock();
396b46ca26SJakub Kicinski 	if (!good) {
40*b02c1230SEric Dumazet 		sk_skb_reason_drop(sk, skb, SKB_DROP_REASON_PSP_OUTPUT);
416b46ca26SJakub Kicinski 		return NULL;
426b46ca26SJakub Kicinski 	}
436b46ca26SJakub Kicinski 
446b46ca26SJakub Kicinski 	return skb;
456b46ca26SJakub Kicinski }
466b46ca26SJakub Kicinski 
476b46ca26SJakub Kicinski struct psp_assoc *psp_assoc_create(struct psp_dev *psd)
486b46ca26SJakub Kicinski {
496b46ca26SJakub Kicinski 	struct psp_assoc *pas;
506b46ca26SJakub Kicinski 
516b46ca26SJakub Kicinski 	lockdep_assert_held(&psd->lock);
526b46ca26SJakub Kicinski 
536b46ca26SJakub Kicinski 	pas = kzalloc(struct_size(pas, drv_data, psd->caps->assoc_drv_spc),
546b46ca26SJakub Kicinski 		      GFP_KERNEL_ACCOUNT);
556b46ca26SJakub Kicinski 	if (!pas)
566b46ca26SJakub Kicinski 		return NULL;
576b46ca26SJakub Kicinski 
586b46ca26SJakub Kicinski 	pas->psd = psd;
596b46ca26SJakub Kicinski 	pas->dev_id = psd->id;
60e7885105SJakub Kicinski 	pas->generation = psd->generation;
616b46ca26SJakub Kicinski 	psp_dev_get(psd);
626b46ca26SJakub Kicinski 	refcount_set(&pas->refcnt, 1);
636b46ca26SJakub Kicinski 
646b46ca26SJakub Kicinski 	list_add_tail(&pas->assocs_list, &psd->active_assocs);
656b46ca26SJakub Kicinski 
666b46ca26SJakub Kicinski 	return pas;
676b46ca26SJakub Kicinski }
686b46ca26SJakub Kicinski 
696b46ca26SJakub Kicinski static struct psp_assoc *psp_assoc_dummy(struct psp_assoc *pas)
706b46ca26SJakub Kicinski {
716b46ca26SJakub Kicinski 	struct psp_dev *psd = pas->psd;
726b46ca26SJakub Kicinski 	size_t sz;
736b46ca26SJakub Kicinski 
746b46ca26SJakub Kicinski 	lockdep_assert_held(&psd->lock);
756b46ca26SJakub Kicinski 
766b46ca26SJakub Kicinski 	sz = struct_size(pas, drv_data, psd->caps->assoc_drv_spc);
776b46ca26SJakub Kicinski 	return kmemdup(pas, sz, GFP_KERNEL);
786b46ca26SJakub Kicinski }
796b46ca26SJakub Kicinski 
806b46ca26SJakub Kicinski static int psp_dev_tx_key_add(struct psp_dev *psd, struct psp_assoc *pas,
816b46ca26SJakub Kicinski 			      struct netlink_ext_ack *extack)
826b46ca26SJakub Kicinski {
836b46ca26SJakub Kicinski 	return psd->ops->tx_key_add(psd, pas, extack);
846b46ca26SJakub Kicinski }
856b46ca26SJakub Kicinski 
866b46ca26SJakub Kicinski void psp_dev_tx_key_del(struct psp_dev *psd, struct psp_assoc *pas)
876b46ca26SJakub Kicinski {
886b46ca26SJakub Kicinski 	if (pas->tx.spi)
896b46ca26SJakub Kicinski 		psd->ops->tx_key_del(psd, pas);
906b46ca26SJakub Kicinski 	list_del(&pas->assocs_list);
916b46ca26SJakub Kicinski }
926b46ca26SJakub Kicinski 
936b46ca26SJakub Kicinski static void psp_assoc_free(struct work_struct *work)
946b46ca26SJakub Kicinski {
956b46ca26SJakub Kicinski 	struct psp_assoc *pas = container_of(work, struct psp_assoc, work);
966b46ca26SJakub Kicinski 	struct psp_dev *psd = pas->psd;
976b46ca26SJakub Kicinski 
986b46ca26SJakub Kicinski 	mutex_lock(&psd->lock);
996b46ca26SJakub Kicinski 	if (psd->ops)
1006b46ca26SJakub Kicinski 		psp_dev_tx_key_del(psd, pas);
1016b46ca26SJakub Kicinski 	mutex_unlock(&psd->lock);
1026b46ca26SJakub Kicinski 	psp_dev_put(psd);
1036b46ca26SJakub Kicinski 	kfree(pas);
1046b46ca26SJakub Kicinski }
1056b46ca26SJakub Kicinski 
1066b46ca26SJakub Kicinski static void psp_assoc_free_queue(struct rcu_head *head)
1076b46ca26SJakub Kicinski {
1086b46ca26SJakub Kicinski 	struct psp_assoc *pas = container_of(head, struct psp_assoc, rcu);
1096b46ca26SJakub Kicinski 
1106b46ca26SJakub Kicinski 	INIT_WORK(&pas->work, psp_assoc_free);
1116b46ca26SJakub Kicinski 	schedule_work(&pas->work);
1126b46ca26SJakub Kicinski }
1136b46ca26SJakub Kicinski 
1146b46ca26SJakub Kicinski /**
1156b46ca26SJakub Kicinski  * psp_assoc_put() - release a reference on a PSP association
1166b46ca26SJakub Kicinski  * @pas: association to release
1176b46ca26SJakub Kicinski  */
1186b46ca26SJakub Kicinski void psp_assoc_put(struct psp_assoc *pas)
1196b46ca26SJakub Kicinski {
1206b46ca26SJakub Kicinski 	if (pas && refcount_dec_and_test(&pas->refcnt))
1216b46ca26SJakub Kicinski 		call_rcu(&pas->rcu, psp_assoc_free_queue);
1226b46ca26SJakub Kicinski }
1236b46ca26SJakub Kicinski 
1246b46ca26SJakub Kicinski void psp_sk_assoc_free(struct sock *sk)
1256b46ca26SJakub Kicinski {
1266b46ca26SJakub Kicinski 	struct psp_assoc *pas = rcu_dereference_protected(sk->psp_assoc, 1);
1276b46ca26SJakub Kicinski 
1286b46ca26SJakub Kicinski 	rcu_assign_pointer(sk->psp_assoc, NULL);
1296b46ca26SJakub Kicinski 	psp_assoc_put(pas);
1306b46ca26SJakub Kicinski }
1316b46ca26SJakub Kicinski 
1326b46ca26SJakub Kicinski int psp_sock_assoc_set_rx(struct sock *sk, struct psp_assoc *pas,
1336b46ca26SJakub Kicinski 			  struct psp_key_parsed *key,
1346b46ca26SJakub Kicinski 			  struct netlink_ext_ack *extack)
1356b46ca26SJakub Kicinski {
1366b46ca26SJakub Kicinski 	int err;
1376b46ca26SJakub Kicinski 
1386b46ca26SJakub Kicinski 	memcpy(&pas->rx, key, sizeof(*key));
1396b46ca26SJakub Kicinski 
1406b46ca26SJakub Kicinski 	lock_sock(sk);
1416b46ca26SJakub Kicinski 
1426b46ca26SJakub Kicinski 	if (psp_sk_assoc(sk)) {
1436b46ca26SJakub Kicinski 		NL_SET_ERR_MSG(extack, "Socket already has PSP state");
1446b46ca26SJakub Kicinski 		err = -EBUSY;
1456b46ca26SJakub Kicinski 		goto exit_unlock;
1466b46ca26SJakub Kicinski 	}
1476b46ca26SJakub Kicinski 
1486b46ca26SJakub Kicinski 	refcount_inc(&pas->refcnt);
1496b46ca26SJakub Kicinski 	rcu_assign_pointer(sk->psp_assoc, pas);
1506b46ca26SJakub Kicinski 	err = 0;
1516b46ca26SJakub Kicinski 
1526b46ca26SJakub Kicinski exit_unlock:
1536b46ca26SJakub Kicinski 	release_sock(sk);
1546b46ca26SJakub Kicinski 
1556b46ca26SJakub Kicinski 	return err;
1566b46ca26SJakub Kicinski }
1576b46ca26SJakub Kicinski 
1586b46ca26SJakub Kicinski static int psp_sock_recv_queue_check(struct sock *sk, struct psp_assoc *pas)
1596b46ca26SJakub Kicinski {
1606b46ca26SJakub Kicinski 	struct psp_skb_ext *pse;
1616b46ca26SJakub Kicinski 	struct sk_buff *skb;
1626b46ca26SJakub Kicinski 
1636b46ca26SJakub Kicinski 	skb_rbtree_walk(skb, &tcp_sk(sk)->out_of_order_queue) {
1646b46ca26SJakub Kicinski 		pse = skb_ext_find(skb, SKB_EXT_PSP);
1656b46ca26SJakub Kicinski 		if (!psp_pse_matches_pas(pse, pas))
1666b46ca26SJakub Kicinski 			return -EBUSY;
1676b46ca26SJakub Kicinski 	}
1686b46ca26SJakub Kicinski 
1696b46ca26SJakub Kicinski 	skb_queue_walk(&sk->sk_receive_queue, skb) {
1706b46ca26SJakub Kicinski 		pse = skb_ext_find(skb, SKB_EXT_PSP);
1716b46ca26SJakub Kicinski 		if (!psp_pse_matches_pas(pse, pas))
1726b46ca26SJakub Kicinski 			return -EBUSY;
1736b46ca26SJakub Kicinski 	}
1746b46ca26SJakub Kicinski 	return 0;
1756b46ca26SJakub Kicinski }
1766b46ca26SJakub Kicinski 
1776b46ca26SJakub Kicinski int psp_sock_assoc_set_tx(struct sock *sk, struct psp_dev *psd,
1786b46ca26SJakub Kicinski 			  u32 version, struct psp_key_parsed *key,
1796b46ca26SJakub Kicinski 			  struct netlink_ext_ack *extack)
1806b46ca26SJakub Kicinski {
181e9726925SJakub Kicinski 	struct inet_connection_sock *icsk;
1826b46ca26SJakub Kicinski 	struct psp_assoc *pas, *dummy;
1836b46ca26SJakub Kicinski 	int err;
1846b46ca26SJakub Kicinski 
1856b46ca26SJakub Kicinski 	lock_sock(sk);
1866b46ca26SJakub Kicinski 
1876b46ca26SJakub Kicinski 	pas = psp_sk_assoc(sk);
1886b46ca26SJakub Kicinski 	if (!pas) {
1896b46ca26SJakub Kicinski 		NL_SET_ERR_MSG(extack, "Socket has no Rx key");
1906b46ca26SJakub Kicinski 		err = -EINVAL;
1916b46ca26SJakub Kicinski 		goto exit_unlock;
1926b46ca26SJakub Kicinski 	}
1936b46ca26SJakub Kicinski 	if (pas->psd != psd) {
1946b46ca26SJakub Kicinski 		NL_SET_ERR_MSG(extack, "Rx key from different device");
1956b46ca26SJakub Kicinski 		err = -EINVAL;
1966b46ca26SJakub Kicinski 		goto exit_unlock;
1976b46ca26SJakub Kicinski 	}
1986b46ca26SJakub Kicinski 	if (pas->version != version) {
1996b46ca26SJakub Kicinski 		NL_SET_ERR_MSG(extack,
2006b46ca26SJakub Kicinski 			       "PSP version mismatch with existing state");
2016b46ca26SJakub Kicinski 		err = -EINVAL;
2026b46ca26SJakub Kicinski 		goto exit_unlock;
2036b46ca26SJakub Kicinski 	}
2046b46ca26SJakub Kicinski 	if (pas->tx.spi) {
2056b46ca26SJakub Kicinski 		NL_SET_ERR_MSG(extack, "Tx key already set");
2066b46ca26SJakub Kicinski 		err = -EBUSY;
2076b46ca26SJakub Kicinski 		goto exit_unlock;
2086b46ca26SJakub Kicinski 	}
2096b46ca26SJakub Kicinski 
2106b46ca26SJakub Kicinski 	err = psp_sock_recv_queue_check(sk, pas);
2116b46ca26SJakub Kicinski 	if (err) {
2126b46ca26SJakub Kicinski 		NL_SET_ERR_MSG(extack, "Socket has incompatible segments already in the recv queue");
2136b46ca26SJakub Kicinski 		goto exit_unlock;
2146b46ca26SJakub Kicinski 	}
2156b46ca26SJakub Kicinski 
2166b46ca26SJakub Kicinski 	/* Pass a fake association to drivers to make sure they don't
2176b46ca26SJakub Kicinski 	 * try to store pointers to it. For re-keying we'll need to
2186b46ca26SJakub Kicinski 	 * re-allocate the assoc structures.
2196b46ca26SJakub Kicinski 	 */
2206b46ca26SJakub Kicinski 	dummy = psp_assoc_dummy(pas);
2216b46ca26SJakub Kicinski 	if (!dummy) {
2226b46ca26SJakub Kicinski 		err = -ENOMEM;
2236b46ca26SJakub Kicinski 		goto exit_unlock;
2246b46ca26SJakub Kicinski 	}
2256b46ca26SJakub Kicinski 
2266b46ca26SJakub Kicinski 	memcpy(&dummy->tx, key, sizeof(*key));
2276b46ca26SJakub Kicinski 	err = psp_dev_tx_key_add(psd, dummy, extack);
2286b46ca26SJakub Kicinski 	if (err)
2296b46ca26SJakub Kicinski 		goto exit_free_dummy;
2306b46ca26SJakub Kicinski 
2316b46ca26SJakub Kicinski 	memcpy(pas->drv_data, dummy->drv_data, psd->caps->assoc_drv_spc);
2326b46ca26SJakub Kicinski 	memcpy(&pas->tx, key, sizeof(*key));
2336b46ca26SJakub Kicinski 
2346b46ca26SJakub Kicinski 	WRITE_ONCE(sk->sk_validate_xmit_skb, psp_validate_xmit);
2356b46ca26SJakub Kicinski 	tcp_write_collapse_fence(sk);
2366b46ca26SJakub Kicinski 	pas->upgrade_seq = tcp_sk(sk)->rcv_nxt;
2376b46ca26SJakub Kicinski 
238e9726925SJakub Kicinski 	icsk = inet_csk(sk);
239e9726925SJakub Kicinski 	icsk->icsk_ext_hdr_len += psp_sk_overhead(sk);
240e9726925SJakub Kicinski 	icsk->icsk_sync_mss(sk, icsk->icsk_pmtu_cookie);
241e9726925SJakub Kicinski 
2426b46ca26SJakub Kicinski exit_free_dummy:
2436b46ca26SJakub Kicinski 	kfree(dummy);
2446b46ca26SJakub Kicinski exit_unlock:
2456b46ca26SJakub Kicinski 	release_sock(sk);
2466b46ca26SJakub Kicinski 	return err;
2476b46ca26SJakub Kicinski }
2486b46ca26SJakub Kicinski 
249e7885105SJakub Kicinski void psp_assocs_key_rotated(struct psp_dev *psd)
250e7885105SJakub Kicinski {
251e7885105SJakub Kicinski 	struct psp_assoc *pas, *next;
252e7885105SJakub Kicinski 
253e7885105SJakub Kicinski 	/* Mark the stale associations as invalid, they will no longer
254e7885105SJakub Kicinski 	 * be able to Rx any traffic.
255e7885105SJakub Kicinski 	 */
256e7885105SJakub Kicinski 	list_for_each_entry_safe(pas, next, &psd->prev_assocs, assocs_list)
257e7885105SJakub Kicinski 		pas->generation |= ~PSP_GEN_VALID_MASK;
258e7885105SJakub Kicinski 	list_splice_init(&psd->prev_assocs, &psd->stale_assocs);
259e7885105SJakub Kicinski 	list_splice_init(&psd->active_assocs, &psd->prev_assocs);
260e7885105SJakub Kicinski 
261e7885105SJakub Kicinski 	/* TODO: we should inform the sockets that got shut down */
262e7885105SJakub Kicinski }
263e7885105SJakub Kicinski 
2646b46ca26SJakub Kicinski void psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk)
2656b46ca26SJakub Kicinski {
2666b46ca26SJakub Kicinski 	struct psp_assoc *pas = psp_sk_assoc(sk);
2676b46ca26SJakub Kicinski 
2686b46ca26SJakub Kicinski 	if (pas)
2696b46ca26SJakub Kicinski 		refcount_inc(&pas->refcnt);
2706b46ca26SJakub Kicinski 	rcu_assign_pointer(tw->psp_assoc, pas);
2716b46ca26SJakub Kicinski 	tw->tw_validate_xmit_skb = psp_validate_xmit;
2726b46ca26SJakub Kicinski }
2736b46ca26SJakub Kicinski 
2746b46ca26SJakub Kicinski void psp_twsk_assoc_free(struct inet_timewait_sock *tw)
2756b46ca26SJakub Kicinski {
2766b46ca26SJakub Kicinski 	struct psp_assoc *pas = rcu_dereference_protected(tw->psp_assoc, 1);
2776b46ca26SJakub Kicinski 
2786b46ca26SJakub Kicinski 	rcu_assign_pointer(tw->psp_assoc, NULL);
2796b46ca26SJakub Kicinski 	psp_assoc_put(pas);
2806b46ca26SJakub Kicinski }
2816b46ca26SJakub Kicinski 
2826b46ca26SJakub Kicinski void psp_reply_set_decrypted(struct sk_buff *skb)
2836b46ca26SJakub Kicinski {
2846b46ca26SJakub Kicinski 	struct psp_assoc *pas;
2856b46ca26SJakub Kicinski 
2866b46ca26SJakub Kicinski 	rcu_read_lock();
2876b46ca26SJakub Kicinski 	pas = psp_sk_get_assoc_rcu(skb->sk);
2886b46ca26SJakub Kicinski 	if (pas && pas->tx.spi)
2896b46ca26SJakub Kicinski 		skb->decrypted = 1;
2906b46ca26SJakub Kicinski 	rcu_read_unlock();
2916b46ca26SJakub Kicinski }
2926b46ca26SJakub Kicinski EXPORT_IPV6_MOD_GPL(psp_reply_set_decrypted);
293