1 // SPDX-License-Identifier: GPL-2.0-only 2 3 #include <linux/file.h> 4 #include <linux/net.h> 5 #include <linux/rcupdate.h> 6 #include <linux/tcp.h> 7 8 #include <net/ip.h> 9 #include <net/psp.h> 10 #include "psp.h" 11 12 struct psp_dev *psp_dev_get_for_sock(struct sock *sk) 13 { 14 struct psp_dev *psd = NULL; 15 struct dst_entry *dst; 16 17 rcu_read_lock(); 18 dst = __sk_dst_get(sk); 19 if (dst) { 20 psd = rcu_dereference(dst_dev_rcu(dst)->psp_dev); 21 if (psd && !psp_dev_tryget(psd)) 22 psd = NULL; 23 } 24 rcu_read_unlock(); 25 26 return psd; 27 } 28 29 static struct sk_buff * 30 psp_validate_xmit(struct sock *sk, struct net_device *dev, struct sk_buff *skb) 31 { 32 struct psp_assoc *pas; 33 bool good; 34 35 rcu_read_lock(); 36 pas = psp_skb_get_assoc_rcu(skb); 37 good = !pas || rcu_access_pointer(dev->psp_dev) == pas->psd; 38 rcu_read_unlock(); 39 if (!good) { 40 sk_skb_reason_drop(sk, skb, SKB_DROP_REASON_PSP_OUTPUT); 41 return NULL; 42 } 43 44 return skb; 45 } 46 47 struct psp_assoc *psp_assoc_create(struct psp_dev *psd) 48 { 49 struct psp_assoc *pas; 50 51 lockdep_assert_held(&psd->lock); 52 53 pas = kzalloc(struct_size(pas, drv_data, psd->caps->assoc_drv_spc), 54 GFP_KERNEL_ACCOUNT); 55 if (!pas) 56 return NULL; 57 58 pas->psd = psd; 59 pas->dev_id = psd->id; 60 pas->generation = psd->generation; 61 psp_dev_get(psd); 62 refcount_set(&pas->refcnt, 1); 63 64 list_add_tail(&pas->assocs_list, &psd->active_assocs); 65 66 return pas; 67 } 68 69 static struct psp_assoc *psp_assoc_dummy(struct psp_assoc *pas) 70 { 71 struct psp_dev *psd = pas->psd; 72 size_t sz; 73 74 lockdep_assert_held(&psd->lock); 75 76 sz = struct_size(pas, drv_data, psd->caps->assoc_drv_spc); 77 return kmemdup(pas, sz, GFP_KERNEL); 78 } 79 80 static int psp_dev_tx_key_add(struct psp_dev *psd, struct psp_assoc *pas, 81 struct netlink_ext_ack *extack) 82 { 83 return psd->ops->tx_key_add(psd, pas, extack); 84 } 85 86 void psp_dev_tx_key_del(struct psp_dev *psd, struct psp_assoc *pas) 87 { 88 if (pas->tx.spi) 89 psd->ops->tx_key_del(psd, pas); 90 list_del(&pas->assocs_list); 91 } 92 93 static void psp_assoc_free(struct work_struct *work) 94 { 95 struct psp_assoc *pas = container_of(work, struct psp_assoc, work); 96 struct psp_dev *psd = pas->psd; 97 98 mutex_lock(&psd->lock); 99 if (psd->ops) 100 psp_dev_tx_key_del(psd, pas); 101 mutex_unlock(&psd->lock); 102 psp_dev_put(psd); 103 kfree(pas); 104 } 105 106 static void psp_assoc_free_queue(struct rcu_head *head) 107 { 108 struct psp_assoc *pas = container_of(head, struct psp_assoc, rcu); 109 110 INIT_WORK(&pas->work, psp_assoc_free); 111 schedule_work(&pas->work); 112 } 113 114 /** 115 * psp_assoc_put() - release a reference on a PSP association 116 * @pas: association to release 117 */ 118 void psp_assoc_put(struct psp_assoc *pas) 119 { 120 if (pas && refcount_dec_and_test(&pas->refcnt)) 121 call_rcu(&pas->rcu, psp_assoc_free_queue); 122 } 123 124 void psp_sk_assoc_free(struct sock *sk) 125 { 126 struct psp_assoc *pas = rcu_dereference_protected(sk->psp_assoc, 1); 127 128 rcu_assign_pointer(sk->psp_assoc, NULL); 129 psp_assoc_put(pas); 130 } 131 132 int psp_sock_assoc_set_rx(struct sock *sk, struct psp_assoc *pas, 133 struct psp_key_parsed *key, 134 struct netlink_ext_ack *extack) 135 { 136 int err; 137 138 memcpy(&pas->rx, key, sizeof(*key)); 139 140 lock_sock(sk); 141 142 if (psp_sk_assoc(sk)) { 143 NL_SET_ERR_MSG(extack, "Socket already has PSP state"); 144 err = -EBUSY; 145 goto exit_unlock; 146 } 147 148 refcount_inc(&pas->refcnt); 149 rcu_assign_pointer(sk->psp_assoc, pas); 150 err = 0; 151 152 exit_unlock: 153 release_sock(sk); 154 155 return err; 156 } 157 158 static int psp_sock_recv_queue_check(struct sock *sk, struct psp_assoc *pas) 159 { 160 struct psp_skb_ext *pse; 161 struct sk_buff *skb; 162 163 skb_rbtree_walk(skb, &tcp_sk(sk)->out_of_order_queue) { 164 pse = skb_ext_find(skb, SKB_EXT_PSP); 165 if (!psp_pse_matches_pas(pse, pas)) 166 return -EBUSY; 167 } 168 169 skb_queue_walk(&sk->sk_receive_queue, skb) { 170 pse = skb_ext_find(skb, SKB_EXT_PSP); 171 if (!psp_pse_matches_pas(pse, pas)) 172 return -EBUSY; 173 } 174 return 0; 175 } 176 177 int psp_sock_assoc_set_tx(struct sock *sk, struct psp_dev *psd, 178 u32 version, struct psp_key_parsed *key, 179 struct netlink_ext_ack *extack) 180 { 181 struct inet_connection_sock *icsk; 182 struct psp_assoc *pas, *dummy; 183 int err; 184 185 lock_sock(sk); 186 187 pas = psp_sk_assoc(sk); 188 if (!pas) { 189 NL_SET_ERR_MSG(extack, "Socket has no Rx key"); 190 err = -EINVAL; 191 goto exit_unlock; 192 } 193 if (pas->psd != psd) { 194 NL_SET_ERR_MSG(extack, "Rx key from different device"); 195 err = -EINVAL; 196 goto exit_unlock; 197 } 198 if (pas->version != version) { 199 NL_SET_ERR_MSG(extack, 200 "PSP version mismatch with existing state"); 201 err = -EINVAL; 202 goto exit_unlock; 203 } 204 if (pas->tx.spi) { 205 NL_SET_ERR_MSG(extack, "Tx key already set"); 206 err = -EBUSY; 207 goto exit_unlock; 208 } 209 210 err = psp_sock_recv_queue_check(sk, pas); 211 if (err) { 212 NL_SET_ERR_MSG(extack, "Socket has incompatible segments already in the recv queue"); 213 goto exit_unlock; 214 } 215 216 /* Pass a fake association to drivers to make sure they don't 217 * try to store pointers to it. For re-keying we'll need to 218 * re-allocate the assoc structures. 219 */ 220 dummy = psp_assoc_dummy(pas); 221 if (!dummy) { 222 err = -ENOMEM; 223 goto exit_unlock; 224 } 225 226 memcpy(&dummy->tx, key, sizeof(*key)); 227 err = psp_dev_tx_key_add(psd, dummy, extack); 228 if (err) 229 goto exit_free_dummy; 230 231 memcpy(pas->drv_data, dummy->drv_data, psd->caps->assoc_drv_spc); 232 memcpy(&pas->tx, key, sizeof(*key)); 233 234 WRITE_ONCE(sk->sk_validate_xmit_skb, psp_validate_xmit); 235 tcp_write_collapse_fence(sk); 236 pas->upgrade_seq = tcp_sk(sk)->rcv_nxt; 237 238 icsk = inet_csk(sk); 239 icsk->icsk_ext_hdr_len += psp_sk_overhead(sk); 240 icsk->icsk_sync_mss(sk, icsk->icsk_pmtu_cookie); 241 242 exit_free_dummy: 243 kfree(dummy); 244 exit_unlock: 245 release_sock(sk); 246 return err; 247 } 248 249 void psp_assocs_key_rotated(struct psp_dev *psd) 250 { 251 struct psp_assoc *pas, *next; 252 253 /* Mark the stale associations as invalid, they will no longer 254 * be able to Rx any traffic. 255 */ 256 list_for_each_entry_safe(pas, next, &psd->prev_assocs, assocs_list) 257 pas->generation |= ~PSP_GEN_VALID_MASK; 258 list_splice_init(&psd->prev_assocs, &psd->stale_assocs); 259 list_splice_init(&psd->active_assocs, &psd->prev_assocs); 260 261 /* TODO: we should inform the sockets that got shut down */ 262 } 263 264 void psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk) 265 { 266 struct psp_assoc *pas = psp_sk_assoc(sk); 267 268 if (pas) 269 refcount_inc(&pas->refcnt); 270 rcu_assign_pointer(tw->psp_assoc, pas); 271 tw->tw_validate_xmit_skb = psp_validate_xmit; 272 } 273 274 void psp_twsk_assoc_free(struct inet_timewait_sock *tw) 275 { 276 struct psp_assoc *pas = rcu_dereference_protected(tw->psp_assoc, 1); 277 278 rcu_assign_pointer(tw->psp_assoc, NULL); 279 psp_assoc_put(pas); 280 } 281 282 void psp_reply_set_decrypted(struct sk_buff *skb) 283 { 284 struct psp_assoc *pas; 285 286 rcu_read_lock(); 287 pas = psp_sk_get_assoc_rcu(skb->sk); 288 if (pas && pas->tx.spi) 289 skb->decrypted = 1; 290 rcu_read_unlock(); 291 } 292 EXPORT_IPV6_MOD_GPL(psp_reply_set_decrypted); 293