16b46ca26SJakub Kicinski // SPDX-License-Identifier: GPL-2.0-only 26b46ca26SJakub Kicinski 36b46ca26SJakub Kicinski #include <linux/file.h> 46b46ca26SJakub Kicinski #include <linux/net.h> 56b46ca26SJakub Kicinski #include <linux/rcupdate.h> 66b46ca26SJakub Kicinski #include <linux/tcp.h> 76b46ca26SJakub Kicinski 86b46ca26SJakub Kicinski #include <net/ip.h> 96b46ca26SJakub Kicinski #include <net/psp.h> 106b46ca26SJakub Kicinski #include "psp.h" 116b46ca26SJakub Kicinski 126b46ca26SJakub Kicinski struct psp_dev *psp_dev_get_for_sock(struct sock *sk) 136b46ca26SJakub Kicinski { 1417f1b771SEric Dumazet struct psp_dev *psd = NULL; 156b46ca26SJakub Kicinski struct dst_entry *dst; 166b46ca26SJakub Kicinski 176b46ca26SJakub Kicinski rcu_read_lock(); 1817f1b771SEric Dumazet dst = __sk_dst_get(sk); 1917f1b771SEric Dumazet if (dst) { 2017f1b771SEric Dumazet psd = rcu_dereference(dst_dev_rcu(dst)->psp_dev); 216b46ca26SJakub Kicinski if (psd && !psp_dev_tryget(psd)) 226b46ca26SJakub Kicinski psd = NULL; 2317f1b771SEric Dumazet } 246b46ca26SJakub Kicinski rcu_read_unlock(); 256b46ca26SJakub Kicinski 266b46ca26SJakub Kicinski return psd; 276b46ca26SJakub Kicinski } 286b46ca26SJakub Kicinski 296b46ca26SJakub Kicinski static struct sk_buff * 306b46ca26SJakub Kicinski psp_validate_xmit(struct sock *sk, struct net_device *dev, struct sk_buff *skb) 316b46ca26SJakub Kicinski { 326b46ca26SJakub Kicinski struct psp_assoc *pas; 336b46ca26SJakub Kicinski bool good; 346b46ca26SJakub Kicinski 356b46ca26SJakub Kicinski rcu_read_lock(); 366b46ca26SJakub Kicinski pas = psp_skb_get_assoc_rcu(skb); 376b46ca26SJakub Kicinski good = !pas || rcu_access_pointer(dev->psp_dev) == pas->psd; 386b46ca26SJakub Kicinski rcu_read_unlock(); 396b46ca26SJakub Kicinski if (!good) { 40*b02c1230SEric Dumazet sk_skb_reason_drop(sk, skb, SKB_DROP_REASON_PSP_OUTPUT); 416b46ca26SJakub Kicinski return NULL; 426b46ca26SJakub Kicinski } 436b46ca26SJakub Kicinski 446b46ca26SJakub Kicinski return skb; 456b46ca26SJakub Kicinski } 466b46ca26SJakub Kicinski 476b46ca26SJakub Kicinski struct psp_assoc *psp_assoc_create(struct psp_dev *psd) 486b46ca26SJakub Kicinski { 496b46ca26SJakub Kicinski struct psp_assoc *pas; 506b46ca26SJakub Kicinski 516b46ca26SJakub Kicinski lockdep_assert_held(&psd->lock); 526b46ca26SJakub Kicinski 536b46ca26SJakub Kicinski pas = kzalloc(struct_size(pas, drv_data, psd->caps->assoc_drv_spc), 546b46ca26SJakub Kicinski GFP_KERNEL_ACCOUNT); 556b46ca26SJakub Kicinski if (!pas) 566b46ca26SJakub Kicinski return NULL; 576b46ca26SJakub Kicinski 586b46ca26SJakub Kicinski pas->psd = psd; 596b46ca26SJakub Kicinski pas->dev_id = psd->id; 60e7885105SJakub Kicinski pas->generation = psd->generation; 616b46ca26SJakub Kicinski psp_dev_get(psd); 626b46ca26SJakub Kicinski refcount_set(&pas->refcnt, 1); 636b46ca26SJakub Kicinski 646b46ca26SJakub Kicinski list_add_tail(&pas->assocs_list, &psd->active_assocs); 656b46ca26SJakub Kicinski 666b46ca26SJakub Kicinski return pas; 676b46ca26SJakub Kicinski } 686b46ca26SJakub Kicinski 696b46ca26SJakub Kicinski static struct psp_assoc *psp_assoc_dummy(struct psp_assoc *pas) 706b46ca26SJakub Kicinski { 716b46ca26SJakub Kicinski struct psp_dev *psd = pas->psd; 726b46ca26SJakub Kicinski size_t sz; 736b46ca26SJakub Kicinski 746b46ca26SJakub Kicinski lockdep_assert_held(&psd->lock); 756b46ca26SJakub Kicinski 766b46ca26SJakub Kicinski sz = struct_size(pas, drv_data, psd->caps->assoc_drv_spc); 776b46ca26SJakub Kicinski return kmemdup(pas, sz, GFP_KERNEL); 786b46ca26SJakub Kicinski } 796b46ca26SJakub Kicinski 806b46ca26SJakub Kicinski static int psp_dev_tx_key_add(struct psp_dev *psd, struct psp_assoc *pas, 816b46ca26SJakub Kicinski struct netlink_ext_ack *extack) 826b46ca26SJakub Kicinski { 836b46ca26SJakub Kicinski return psd->ops->tx_key_add(psd, pas, extack); 846b46ca26SJakub Kicinski } 856b46ca26SJakub Kicinski 866b46ca26SJakub Kicinski void psp_dev_tx_key_del(struct psp_dev *psd, struct psp_assoc *pas) 876b46ca26SJakub Kicinski { 886b46ca26SJakub Kicinski if (pas->tx.spi) 896b46ca26SJakub Kicinski psd->ops->tx_key_del(psd, pas); 906b46ca26SJakub Kicinski list_del(&pas->assocs_list); 916b46ca26SJakub Kicinski } 926b46ca26SJakub Kicinski 936b46ca26SJakub Kicinski static void psp_assoc_free(struct work_struct *work) 946b46ca26SJakub Kicinski { 956b46ca26SJakub Kicinski struct psp_assoc *pas = container_of(work, struct psp_assoc, work); 966b46ca26SJakub Kicinski struct psp_dev *psd = pas->psd; 976b46ca26SJakub Kicinski 986b46ca26SJakub Kicinski mutex_lock(&psd->lock); 996b46ca26SJakub Kicinski if (psd->ops) 1006b46ca26SJakub Kicinski psp_dev_tx_key_del(psd, pas); 1016b46ca26SJakub Kicinski mutex_unlock(&psd->lock); 1026b46ca26SJakub Kicinski psp_dev_put(psd); 1036b46ca26SJakub Kicinski kfree(pas); 1046b46ca26SJakub Kicinski } 1056b46ca26SJakub Kicinski 1066b46ca26SJakub Kicinski static void psp_assoc_free_queue(struct rcu_head *head) 1076b46ca26SJakub Kicinski { 1086b46ca26SJakub Kicinski struct psp_assoc *pas = container_of(head, struct psp_assoc, rcu); 1096b46ca26SJakub Kicinski 1106b46ca26SJakub Kicinski INIT_WORK(&pas->work, psp_assoc_free); 1116b46ca26SJakub Kicinski schedule_work(&pas->work); 1126b46ca26SJakub Kicinski } 1136b46ca26SJakub Kicinski 1146b46ca26SJakub Kicinski /** 1156b46ca26SJakub Kicinski * psp_assoc_put() - release a reference on a PSP association 1166b46ca26SJakub Kicinski * @pas: association to release 1176b46ca26SJakub Kicinski */ 1186b46ca26SJakub Kicinski void psp_assoc_put(struct psp_assoc *pas) 1196b46ca26SJakub Kicinski { 1206b46ca26SJakub Kicinski if (pas && refcount_dec_and_test(&pas->refcnt)) 1216b46ca26SJakub Kicinski call_rcu(&pas->rcu, psp_assoc_free_queue); 1226b46ca26SJakub Kicinski } 1236b46ca26SJakub Kicinski 1246b46ca26SJakub Kicinski void psp_sk_assoc_free(struct sock *sk) 1256b46ca26SJakub Kicinski { 1266b46ca26SJakub Kicinski struct psp_assoc *pas = rcu_dereference_protected(sk->psp_assoc, 1); 1276b46ca26SJakub Kicinski 1286b46ca26SJakub Kicinski rcu_assign_pointer(sk->psp_assoc, NULL); 1296b46ca26SJakub Kicinski psp_assoc_put(pas); 1306b46ca26SJakub Kicinski } 1316b46ca26SJakub Kicinski 1326b46ca26SJakub Kicinski int psp_sock_assoc_set_rx(struct sock *sk, struct psp_assoc *pas, 1336b46ca26SJakub Kicinski struct psp_key_parsed *key, 1346b46ca26SJakub Kicinski struct netlink_ext_ack *extack) 1356b46ca26SJakub Kicinski { 1366b46ca26SJakub Kicinski int err; 1376b46ca26SJakub Kicinski 1386b46ca26SJakub Kicinski memcpy(&pas->rx, key, sizeof(*key)); 1396b46ca26SJakub Kicinski 1406b46ca26SJakub Kicinski lock_sock(sk); 1416b46ca26SJakub Kicinski 1426b46ca26SJakub Kicinski if (psp_sk_assoc(sk)) { 1436b46ca26SJakub Kicinski NL_SET_ERR_MSG(extack, "Socket already has PSP state"); 1446b46ca26SJakub Kicinski err = -EBUSY; 1456b46ca26SJakub Kicinski goto exit_unlock; 1466b46ca26SJakub Kicinski } 1476b46ca26SJakub Kicinski 1486b46ca26SJakub Kicinski refcount_inc(&pas->refcnt); 1496b46ca26SJakub Kicinski rcu_assign_pointer(sk->psp_assoc, pas); 1506b46ca26SJakub Kicinski err = 0; 1516b46ca26SJakub Kicinski 1526b46ca26SJakub Kicinski exit_unlock: 1536b46ca26SJakub Kicinski release_sock(sk); 1546b46ca26SJakub Kicinski 1556b46ca26SJakub Kicinski return err; 1566b46ca26SJakub Kicinski } 1576b46ca26SJakub Kicinski 1586b46ca26SJakub Kicinski static int psp_sock_recv_queue_check(struct sock *sk, struct psp_assoc *pas) 1596b46ca26SJakub Kicinski { 1606b46ca26SJakub Kicinski struct psp_skb_ext *pse; 1616b46ca26SJakub Kicinski struct sk_buff *skb; 1626b46ca26SJakub Kicinski 1636b46ca26SJakub Kicinski skb_rbtree_walk(skb, &tcp_sk(sk)->out_of_order_queue) { 1646b46ca26SJakub Kicinski pse = skb_ext_find(skb, SKB_EXT_PSP); 1656b46ca26SJakub Kicinski if (!psp_pse_matches_pas(pse, pas)) 1666b46ca26SJakub Kicinski return -EBUSY; 1676b46ca26SJakub Kicinski } 1686b46ca26SJakub Kicinski 1696b46ca26SJakub Kicinski skb_queue_walk(&sk->sk_receive_queue, skb) { 1706b46ca26SJakub Kicinski pse = skb_ext_find(skb, SKB_EXT_PSP); 1716b46ca26SJakub Kicinski if (!psp_pse_matches_pas(pse, pas)) 1726b46ca26SJakub Kicinski return -EBUSY; 1736b46ca26SJakub Kicinski } 1746b46ca26SJakub Kicinski return 0; 1756b46ca26SJakub Kicinski } 1766b46ca26SJakub Kicinski 1776b46ca26SJakub Kicinski int psp_sock_assoc_set_tx(struct sock *sk, struct psp_dev *psd, 1786b46ca26SJakub Kicinski u32 version, struct psp_key_parsed *key, 1796b46ca26SJakub Kicinski struct netlink_ext_ack *extack) 1806b46ca26SJakub Kicinski { 181e9726925SJakub Kicinski struct inet_connection_sock *icsk; 1826b46ca26SJakub Kicinski struct psp_assoc *pas, *dummy; 1836b46ca26SJakub Kicinski int err; 1846b46ca26SJakub Kicinski 1856b46ca26SJakub Kicinski lock_sock(sk); 1866b46ca26SJakub Kicinski 1876b46ca26SJakub Kicinski pas = psp_sk_assoc(sk); 1886b46ca26SJakub Kicinski if (!pas) { 1896b46ca26SJakub Kicinski NL_SET_ERR_MSG(extack, "Socket has no Rx key"); 1906b46ca26SJakub Kicinski err = -EINVAL; 1916b46ca26SJakub Kicinski goto exit_unlock; 1926b46ca26SJakub Kicinski } 1936b46ca26SJakub Kicinski if (pas->psd != psd) { 1946b46ca26SJakub Kicinski NL_SET_ERR_MSG(extack, "Rx key from different device"); 1956b46ca26SJakub Kicinski err = -EINVAL; 1966b46ca26SJakub Kicinski goto exit_unlock; 1976b46ca26SJakub Kicinski } 1986b46ca26SJakub Kicinski if (pas->version != version) { 1996b46ca26SJakub Kicinski NL_SET_ERR_MSG(extack, 2006b46ca26SJakub Kicinski "PSP version mismatch with existing state"); 2016b46ca26SJakub Kicinski err = -EINVAL; 2026b46ca26SJakub Kicinski goto exit_unlock; 2036b46ca26SJakub Kicinski } 2046b46ca26SJakub Kicinski if (pas->tx.spi) { 2056b46ca26SJakub Kicinski NL_SET_ERR_MSG(extack, "Tx key already set"); 2066b46ca26SJakub Kicinski err = -EBUSY; 2076b46ca26SJakub Kicinski goto exit_unlock; 2086b46ca26SJakub Kicinski } 2096b46ca26SJakub Kicinski 2106b46ca26SJakub Kicinski err = psp_sock_recv_queue_check(sk, pas); 2116b46ca26SJakub Kicinski if (err) { 2126b46ca26SJakub Kicinski NL_SET_ERR_MSG(extack, "Socket has incompatible segments already in the recv queue"); 2136b46ca26SJakub Kicinski goto exit_unlock; 2146b46ca26SJakub Kicinski } 2156b46ca26SJakub Kicinski 2166b46ca26SJakub Kicinski /* Pass a fake association to drivers to make sure they don't 2176b46ca26SJakub Kicinski * try to store pointers to it. For re-keying we'll need to 2186b46ca26SJakub Kicinski * re-allocate the assoc structures. 2196b46ca26SJakub Kicinski */ 2206b46ca26SJakub Kicinski dummy = psp_assoc_dummy(pas); 2216b46ca26SJakub Kicinski if (!dummy) { 2226b46ca26SJakub Kicinski err = -ENOMEM; 2236b46ca26SJakub Kicinski goto exit_unlock; 2246b46ca26SJakub Kicinski } 2256b46ca26SJakub Kicinski 2266b46ca26SJakub Kicinski memcpy(&dummy->tx, key, sizeof(*key)); 2276b46ca26SJakub Kicinski err = psp_dev_tx_key_add(psd, dummy, extack); 2286b46ca26SJakub Kicinski if (err) 2296b46ca26SJakub Kicinski goto exit_free_dummy; 2306b46ca26SJakub Kicinski 2316b46ca26SJakub Kicinski memcpy(pas->drv_data, dummy->drv_data, psd->caps->assoc_drv_spc); 2326b46ca26SJakub Kicinski memcpy(&pas->tx, key, sizeof(*key)); 2336b46ca26SJakub Kicinski 2346b46ca26SJakub Kicinski WRITE_ONCE(sk->sk_validate_xmit_skb, psp_validate_xmit); 2356b46ca26SJakub Kicinski tcp_write_collapse_fence(sk); 2366b46ca26SJakub Kicinski pas->upgrade_seq = tcp_sk(sk)->rcv_nxt; 2376b46ca26SJakub Kicinski 238e9726925SJakub Kicinski icsk = inet_csk(sk); 239e9726925SJakub Kicinski icsk->icsk_ext_hdr_len += psp_sk_overhead(sk); 240e9726925SJakub Kicinski icsk->icsk_sync_mss(sk, icsk->icsk_pmtu_cookie); 241e9726925SJakub Kicinski 2426b46ca26SJakub Kicinski exit_free_dummy: 2436b46ca26SJakub Kicinski kfree(dummy); 2446b46ca26SJakub Kicinski exit_unlock: 2456b46ca26SJakub Kicinski release_sock(sk); 2466b46ca26SJakub Kicinski return err; 2476b46ca26SJakub Kicinski } 2486b46ca26SJakub Kicinski 249e7885105SJakub Kicinski void psp_assocs_key_rotated(struct psp_dev *psd) 250e7885105SJakub Kicinski { 251e7885105SJakub Kicinski struct psp_assoc *pas, *next; 252e7885105SJakub Kicinski 253e7885105SJakub Kicinski /* Mark the stale associations as invalid, they will no longer 254e7885105SJakub Kicinski * be able to Rx any traffic. 255e7885105SJakub Kicinski */ 256e7885105SJakub Kicinski list_for_each_entry_safe(pas, next, &psd->prev_assocs, assocs_list) 257e7885105SJakub Kicinski pas->generation |= ~PSP_GEN_VALID_MASK; 258e7885105SJakub Kicinski list_splice_init(&psd->prev_assocs, &psd->stale_assocs); 259e7885105SJakub Kicinski list_splice_init(&psd->active_assocs, &psd->prev_assocs); 260e7885105SJakub Kicinski 261e7885105SJakub Kicinski /* TODO: we should inform the sockets that got shut down */ 262e7885105SJakub Kicinski } 263e7885105SJakub Kicinski 2646b46ca26SJakub Kicinski void psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk) 2656b46ca26SJakub Kicinski { 2666b46ca26SJakub Kicinski struct psp_assoc *pas = psp_sk_assoc(sk); 2676b46ca26SJakub Kicinski 2686b46ca26SJakub Kicinski if (pas) 2696b46ca26SJakub Kicinski refcount_inc(&pas->refcnt); 2706b46ca26SJakub Kicinski rcu_assign_pointer(tw->psp_assoc, pas); 2716b46ca26SJakub Kicinski tw->tw_validate_xmit_skb = psp_validate_xmit; 2726b46ca26SJakub Kicinski } 2736b46ca26SJakub Kicinski 2746b46ca26SJakub Kicinski void psp_twsk_assoc_free(struct inet_timewait_sock *tw) 2756b46ca26SJakub Kicinski { 2766b46ca26SJakub Kicinski struct psp_assoc *pas = rcu_dereference_protected(tw->psp_assoc, 1); 2776b46ca26SJakub Kicinski 2786b46ca26SJakub Kicinski rcu_assign_pointer(tw->psp_assoc, NULL); 2796b46ca26SJakub Kicinski psp_assoc_put(pas); 2806b46ca26SJakub Kicinski } 2816b46ca26SJakub Kicinski 2826b46ca26SJakub Kicinski void psp_reply_set_decrypted(struct sk_buff *skb) 2836b46ca26SJakub Kicinski { 2846b46ca26SJakub Kicinski struct psp_assoc *pas; 2856b46ca26SJakub Kicinski 2866b46ca26SJakub Kicinski rcu_read_lock(); 2876b46ca26SJakub Kicinski pas = psp_sk_get_assoc_rcu(skb->sk); 2886b46ca26SJakub Kicinski if (pas && pas->tx.spi) 2896b46ca26SJakub Kicinski skb->decrypted = 1; 2906b46ca26SJakub Kicinski rcu_read_unlock(); 2916b46ca26SJakub Kicinski } 2926b46ca26SJakub Kicinski EXPORT_IPV6_MOD_GPL(psp_reply_set_decrypted); 293