1 // SPDX-License-Identifier: GPL-2.0-only 2 3 #include <linux/file.h> 4 #include <linux/net.h> 5 #include <linux/rcupdate.h> 6 #include <linux/tcp.h> 7 8 #include <net/ip.h> 9 #include <net/psp.h> 10 #include "psp.h" 11 12 struct psp_dev *psp_dev_get_for_sock(struct sock *sk) 13 { 14 struct dst_entry *dst; 15 struct psp_dev *psd; 16 17 dst = sk_dst_get(sk); 18 if (!dst) 19 return NULL; 20 21 rcu_read_lock(); 22 psd = rcu_dereference(dst->dev->psp_dev); 23 if (psd && !psp_dev_tryget(psd)) 24 psd = NULL; 25 rcu_read_unlock(); 26 27 dst_release(dst); 28 29 return psd; 30 } 31 32 static struct sk_buff * 33 psp_validate_xmit(struct sock *sk, struct net_device *dev, struct sk_buff *skb) 34 { 35 struct psp_assoc *pas; 36 bool good; 37 38 rcu_read_lock(); 39 pas = psp_skb_get_assoc_rcu(skb); 40 good = !pas || rcu_access_pointer(dev->psp_dev) == pas->psd; 41 rcu_read_unlock(); 42 if (!good) { 43 kfree_skb_reason(skb, SKB_DROP_REASON_PSP_OUTPUT); 44 return NULL; 45 } 46 47 return skb; 48 } 49 50 struct psp_assoc *psp_assoc_create(struct psp_dev *psd) 51 { 52 struct psp_assoc *pas; 53 54 lockdep_assert_held(&psd->lock); 55 56 pas = kzalloc(struct_size(pas, drv_data, psd->caps->assoc_drv_spc), 57 GFP_KERNEL_ACCOUNT); 58 if (!pas) 59 return NULL; 60 61 pas->psd = psd; 62 pas->dev_id = psd->id; 63 psp_dev_get(psd); 64 refcount_set(&pas->refcnt, 1); 65 66 list_add_tail(&pas->assocs_list, &psd->active_assocs); 67 68 return pas; 69 } 70 71 static struct psp_assoc *psp_assoc_dummy(struct psp_assoc *pas) 72 { 73 struct psp_dev *psd = pas->psd; 74 size_t sz; 75 76 lockdep_assert_held(&psd->lock); 77 78 sz = struct_size(pas, drv_data, psd->caps->assoc_drv_spc); 79 return kmemdup(pas, sz, GFP_KERNEL); 80 } 81 82 static int psp_dev_tx_key_add(struct psp_dev *psd, struct psp_assoc *pas, 83 struct netlink_ext_ack *extack) 84 { 85 return psd->ops->tx_key_add(psd, pas, extack); 86 } 87 88 void psp_dev_tx_key_del(struct psp_dev *psd, struct psp_assoc *pas) 89 { 90 if (pas->tx.spi) 91 psd->ops->tx_key_del(psd, pas); 92 list_del(&pas->assocs_list); 93 } 94 95 static void psp_assoc_free(struct work_struct *work) 96 { 97 struct psp_assoc *pas = container_of(work, struct psp_assoc, work); 98 struct psp_dev *psd = pas->psd; 99 100 mutex_lock(&psd->lock); 101 if (psd->ops) 102 psp_dev_tx_key_del(psd, pas); 103 mutex_unlock(&psd->lock); 104 psp_dev_put(psd); 105 kfree(pas); 106 } 107 108 static void psp_assoc_free_queue(struct rcu_head *head) 109 { 110 struct psp_assoc *pas = container_of(head, struct psp_assoc, rcu); 111 112 INIT_WORK(&pas->work, psp_assoc_free); 113 schedule_work(&pas->work); 114 } 115 116 /** 117 * psp_assoc_put() - release a reference on a PSP association 118 * @pas: association to release 119 */ 120 void psp_assoc_put(struct psp_assoc *pas) 121 { 122 if (pas && refcount_dec_and_test(&pas->refcnt)) 123 call_rcu(&pas->rcu, psp_assoc_free_queue); 124 } 125 126 void psp_sk_assoc_free(struct sock *sk) 127 { 128 struct psp_assoc *pas = rcu_dereference_protected(sk->psp_assoc, 1); 129 130 rcu_assign_pointer(sk->psp_assoc, NULL); 131 psp_assoc_put(pas); 132 } 133 134 int psp_sock_assoc_set_rx(struct sock *sk, struct psp_assoc *pas, 135 struct psp_key_parsed *key, 136 struct netlink_ext_ack *extack) 137 { 138 int err; 139 140 memcpy(&pas->rx, key, sizeof(*key)); 141 142 lock_sock(sk); 143 144 if (psp_sk_assoc(sk)) { 145 NL_SET_ERR_MSG(extack, "Socket already has PSP state"); 146 err = -EBUSY; 147 goto exit_unlock; 148 } 149 150 refcount_inc(&pas->refcnt); 151 rcu_assign_pointer(sk->psp_assoc, pas); 152 err = 0; 153 154 exit_unlock: 155 release_sock(sk); 156 157 return err; 158 } 159 160 static int psp_sock_recv_queue_check(struct sock *sk, struct psp_assoc *pas) 161 { 162 struct psp_skb_ext *pse; 163 struct sk_buff *skb; 164 165 skb_rbtree_walk(skb, &tcp_sk(sk)->out_of_order_queue) { 166 pse = skb_ext_find(skb, SKB_EXT_PSP); 167 if (!psp_pse_matches_pas(pse, pas)) 168 return -EBUSY; 169 } 170 171 skb_queue_walk(&sk->sk_receive_queue, skb) { 172 pse = skb_ext_find(skb, SKB_EXT_PSP); 173 if (!psp_pse_matches_pas(pse, pas)) 174 return -EBUSY; 175 } 176 return 0; 177 } 178 179 int psp_sock_assoc_set_tx(struct sock *sk, struct psp_dev *psd, 180 u32 version, struct psp_key_parsed *key, 181 struct netlink_ext_ack *extack) 182 { 183 struct psp_assoc *pas, *dummy; 184 int err; 185 186 lock_sock(sk); 187 188 pas = psp_sk_assoc(sk); 189 if (!pas) { 190 NL_SET_ERR_MSG(extack, "Socket has no Rx key"); 191 err = -EINVAL; 192 goto exit_unlock; 193 } 194 if (pas->psd != psd) { 195 NL_SET_ERR_MSG(extack, "Rx key from different device"); 196 err = -EINVAL; 197 goto exit_unlock; 198 } 199 if (pas->version != version) { 200 NL_SET_ERR_MSG(extack, 201 "PSP version mismatch with existing state"); 202 err = -EINVAL; 203 goto exit_unlock; 204 } 205 if (pas->tx.spi) { 206 NL_SET_ERR_MSG(extack, "Tx key already set"); 207 err = -EBUSY; 208 goto exit_unlock; 209 } 210 211 err = psp_sock_recv_queue_check(sk, pas); 212 if (err) { 213 NL_SET_ERR_MSG(extack, "Socket has incompatible segments already in the recv queue"); 214 goto exit_unlock; 215 } 216 217 /* Pass a fake association to drivers to make sure they don't 218 * try to store pointers to it. For re-keying we'll need to 219 * re-allocate the assoc structures. 220 */ 221 dummy = psp_assoc_dummy(pas); 222 if (!dummy) { 223 err = -ENOMEM; 224 goto exit_unlock; 225 } 226 227 memcpy(&dummy->tx, key, sizeof(*key)); 228 err = psp_dev_tx_key_add(psd, dummy, extack); 229 if (err) 230 goto exit_free_dummy; 231 232 memcpy(pas->drv_data, dummy->drv_data, psd->caps->assoc_drv_spc); 233 memcpy(&pas->tx, key, sizeof(*key)); 234 235 WRITE_ONCE(sk->sk_validate_xmit_skb, psp_validate_xmit); 236 tcp_write_collapse_fence(sk); 237 pas->upgrade_seq = tcp_sk(sk)->rcv_nxt; 238 239 exit_free_dummy: 240 kfree(dummy); 241 exit_unlock: 242 release_sock(sk); 243 return err; 244 } 245 246 void psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk) 247 { 248 struct psp_assoc *pas = psp_sk_assoc(sk); 249 250 if (pas) 251 refcount_inc(&pas->refcnt); 252 rcu_assign_pointer(tw->psp_assoc, pas); 253 tw->tw_validate_xmit_skb = psp_validate_xmit; 254 } 255 256 void psp_twsk_assoc_free(struct inet_timewait_sock *tw) 257 { 258 struct psp_assoc *pas = rcu_dereference_protected(tw->psp_assoc, 1); 259 260 rcu_assign_pointer(tw->psp_assoc, NULL); 261 psp_assoc_put(pas); 262 } 263 264 void psp_reply_set_decrypted(struct sk_buff *skb) 265 { 266 struct psp_assoc *pas; 267 268 rcu_read_lock(); 269 pas = psp_sk_get_assoc_rcu(skb->sk); 270 if (pas && pas->tx.spi) 271 skb->decrypted = 1; 272 rcu_read_unlock(); 273 } 274 EXPORT_IPV6_MOD_GPL(psp_reply_set_decrypted); 275