1 // SPDX-License-Identifier: GPL-2.0-only
2
3 #include <linux/file.h>
4 #include <linux/net.h>
5 #include <linux/rcupdate.h>
6 #include <linux/tcp.h>
7
8 #include <net/ip.h>
9 #include <net/psp.h>
10 #include "psp.h"
11
psp_dev_get_for_sock(struct sock * sk)12 struct psp_dev *psp_dev_get_for_sock(struct sock *sk)
13 {
14 struct psp_dev *psd = NULL;
15 struct dst_entry *dst;
16
17 rcu_read_lock();
18 dst = __sk_dst_get(sk);
19 if (dst) {
20 psd = rcu_dereference(dst_dev_rcu(dst)->psp_dev);
21 if (psd && !psp_dev_tryget(psd))
22 psd = NULL;
23 }
24 rcu_read_unlock();
25
26 return psd;
27 }
28
29 static struct sk_buff *
psp_validate_xmit(struct sock * sk,struct net_device * dev,struct sk_buff * skb)30 psp_validate_xmit(struct sock *sk, struct net_device *dev, struct sk_buff *skb)
31 {
32 struct psp_assoc *pas;
33 bool good;
34
35 rcu_read_lock();
36 pas = psp_skb_get_assoc_rcu(skb);
37 good = !pas || rcu_access_pointer(dev->psp_dev) == pas->psd;
38 rcu_read_unlock();
39 if (!good) {
40 sk_skb_reason_drop(sk, skb, SKB_DROP_REASON_PSP_OUTPUT);
41 return NULL;
42 }
43
44 return skb;
45 }
46
psp_assoc_create(struct psp_dev * psd)47 struct psp_assoc *psp_assoc_create(struct psp_dev *psd)
48 {
49 struct psp_assoc *pas;
50
51 lockdep_assert_held(&psd->lock);
52
53 pas = kzalloc(struct_size(pas, drv_data, psd->caps->assoc_drv_spc),
54 GFP_KERNEL_ACCOUNT);
55 if (!pas)
56 return NULL;
57
58 pas->psd = psd;
59 pas->dev_id = psd->id;
60 pas->generation = psd->generation;
61 psp_dev_get(psd);
62 refcount_set(&pas->refcnt, 1);
63
64 list_add_tail(&pas->assocs_list, &psd->active_assocs);
65
66 return pas;
67 }
68
psp_assoc_dummy(struct psp_assoc * pas)69 static struct psp_assoc *psp_assoc_dummy(struct psp_assoc *pas)
70 {
71 struct psp_dev *psd = pas->psd;
72 size_t sz;
73
74 lockdep_assert_held(&psd->lock);
75
76 sz = struct_size(pas, drv_data, psd->caps->assoc_drv_spc);
77 return kmemdup(pas, sz, GFP_KERNEL);
78 }
79
psp_dev_tx_key_add(struct psp_dev * psd,struct psp_assoc * pas,struct netlink_ext_ack * extack)80 static int psp_dev_tx_key_add(struct psp_dev *psd, struct psp_assoc *pas,
81 struct netlink_ext_ack *extack)
82 {
83 return psd->ops->tx_key_add(psd, pas, extack);
84 }
85
psp_dev_tx_key_del(struct psp_dev * psd,struct psp_assoc * pas)86 void psp_dev_tx_key_del(struct psp_dev *psd, struct psp_assoc *pas)
87 {
88 if (pas->tx.spi)
89 psd->ops->tx_key_del(psd, pas);
90 list_del(&pas->assocs_list);
91 }
92
psp_assoc_free(struct work_struct * work)93 static void psp_assoc_free(struct work_struct *work)
94 {
95 struct psp_assoc *pas = container_of(work, struct psp_assoc, work);
96 struct psp_dev *psd = pas->psd;
97
98 mutex_lock(&psd->lock);
99 if (psd->ops)
100 psp_dev_tx_key_del(psd, pas);
101 mutex_unlock(&psd->lock);
102 psp_dev_put(psd);
103 kfree(pas);
104 }
105
psp_assoc_free_queue(struct rcu_head * head)106 static void psp_assoc_free_queue(struct rcu_head *head)
107 {
108 struct psp_assoc *pas = container_of(head, struct psp_assoc, rcu);
109
110 INIT_WORK(&pas->work, psp_assoc_free);
111 schedule_work(&pas->work);
112 }
113
114 /**
115 * psp_assoc_put() - release a reference on a PSP association
116 * @pas: association to release
117 */
psp_assoc_put(struct psp_assoc * pas)118 void psp_assoc_put(struct psp_assoc *pas)
119 {
120 if (pas && refcount_dec_and_test(&pas->refcnt))
121 call_rcu(&pas->rcu, psp_assoc_free_queue);
122 }
123
psp_sk_assoc_free(struct sock * sk)124 void psp_sk_assoc_free(struct sock *sk)
125 {
126 struct psp_assoc *pas = rcu_dereference_protected(sk->psp_assoc, 1);
127
128 rcu_assign_pointer(sk->psp_assoc, NULL);
129 psp_assoc_put(pas);
130 }
131
psp_sock_assoc_set_rx(struct sock * sk,struct psp_assoc * pas,struct psp_key_parsed * key,struct netlink_ext_ack * extack)132 int psp_sock_assoc_set_rx(struct sock *sk, struct psp_assoc *pas,
133 struct psp_key_parsed *key,
134 struct netlink_ext_ack *extack)
135 {
136 int err;
137
138 memcpy(&pas->rx, key, sizeof(*key));
139
140 lock_sock(sk);
141
142 if (psp_sk_assoc(sk)) {
143 NL_SET_ERR_MSG(extack, "Socket already has PSP state");
144 err = -EBUSY;
145 goto exit_unlock;
146 }
147
148 refcount_inc(&pas->refcnt);
149 rcu_assign_pointer(sk->psp_assoc, pas);
150 err = 0;
151
152 exit_unlock:
153 release_sock(sk);
154
155 return err;
156 }
157
psp_sock_recv_queue_check(struct sock * sk,struct psp_assoc * pas)158 static int psp_sock_recv_queue_check(struct sock *sk, struct psp_assoc *pas)
159 {
160 struct psp_skb_ext *pse;
161 struct sk_buff *skb;
162
163 skb_rbtree_walk(skb, &tcp_sk(sk)->out_of_order_queue) {
164 pse = skb_ext_find(skb, SKB_EXT_PSP);
165 if (!psp_pse_matches_pas(pse, pas))
166 return -EBUSY;
167 }
168
169 skb_queue_walk(&sk->sk_receive_queue, skb) {
170 pse = skb_ext_find(skb, SKB_EXT_PSP);
171 if (!psp_pse_matches_pas(pse, pas))
172 return -EBUSY;
173 }
174 return 0;
175 }
176
psp_sock_assoc_set_tx(struct sock * sk,struct psp_dev * psd,u32 version,struct psp_key_parsed * key,struct netlink_ext_ack * extack)177 int psp_sock_assoc_set_tx(struct sock *sk, struct psp_dev *psd,
178 u32 version, struct psp_key_parsed *key,
179 struct netlink_ext_ack *extack)
180 {
181 struct inet_connection_sock *icsk;
182 struct psp_assoc *pas, *dummy;
183 int err;
184
185 lock_sock(sk);
186
187 pas = psp_sk_assoc(sk);
188 if (!pas) {
189 NL_SET_ERR_MSG(extack, "Socket has no Rx key");
190 err = -EINVAL;
191 goto exit_unlock;
192 }
193 if (pas->psd != psd) {
194 NL_SET_ERR_MSG(extack, "Rx key from different device");
195 err = -EINVAL;
196 goto exit_unlock;
197 }
198 if (pas->version != version) {
199 NL_SET_ERR_MSG(extack,
200 "PSP version mismatch with existing state");
201 err = -EINVAL;
202 goto exit_unlock;
203 }
204 if (pas->tx.spi) {
205 NL_SET_ERR_MSG(extack, "Tx key already set");
206 err = -EBUSY;
207 goto exit_unlock;
208 }
209
210 err = psp_sock_recv_queue_check(sk, pas);
211 if (err) {
212 NL_SET_ERR_MSG(extack, "Socket has incompatible segments already in the recv queue");
213 goto exit_unlock;
214 }
215
216 /* Pass a fake association to drivers to make sure they don't
217 * try to store pointers to it. For re-keying we'll need to
218 * re-allocate the assoc structures.
219 */
220 dummy = psp_assoc_dummy(pas);
221 if (!dummy) {
222 err = -ENOMEM;
223 goto exit_unlock;
224 }
225
226 memcpy(&dummy->tx, key, sizeof(*key));
227 err = psp_dev_tx_key_add(psd, dummy, extack);
228 if (err)
229 goto exit_free_dummy;
230
231 memcpy(pas->drv_data, dummy->drv_data, psd->caps->assoc_drv_spc);
232 memcpy(&pas->tx, key, sizeof(*key));
233
234 WRITE_ONCE(sk->sk_validate_xmit_skb, psp_validate_xmit);
235 tcp_write_collapse_fence(sk);
236 pas->upgrade_seq = tcp_sk(sk)->rcv_nxt;
237
238 icsk = inet_csk(sk);
239 icsk->icsk_ext_hdr_len += psp_sk_overhead(sk);
240 icsk->icsk_sync_mss(sk, icsk->icsk_pmtu_cookie);
241
242 exit_free_dummy:
243 kfree(dummy);
244 exit_unlock:
245 release_sock(sk);
246 return err;
247 }
248
psp_assocs_key_rotated(struct psp_dev * psd)249 void psp_assocs_key_rotated(struct psp_dev *psd)
250 {
251 struct psp_assoc *pas, *next;
252
253 /* Mark the stale associations as invalid, they will no longer
254 * be able to Rx any traffic.
255 */
256 list_for_each_entry_safe(pas, next, &psd->prev_assocs, assocs_list)
257 pas->generation |= ~PSP_GEN_VALID_MASK;
258 list_splice_init(&psd->prev_assocs, &psd->stale_assocs);
259 list_splice_init(&psd->active_assocs, &psd->prev_assocs);
260
261 /* TODO: we should inform the sockets that got shut down */
262 }
263
psp_twsk_init(struct inet_timewait_sock * tw,const struct sock * sk)264 void psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk)
265 {
266 struct psp_assoc *pas = psp_sk_assoc(sk);
267
268 if (pas)
269 refcount_inc(&pas->refcnt);
270 rcu_assign_pointer(tw->psp_assoc, pas);
271 tw->tw_validate_xmit_skb = psp_validate_xmit;
272 }
273
psp_twsk_assoc_free(struct inet_timewait_sock * tw)274 void psp_twsk_assoc_free(struct inet_timewait_sock *tw)
275 {
276 struct psp_assoc *pas = rcu_dereference_protected(tw->psp_assoc, 1);
277
278 rcu_assign_pointer(tw->psp_assoc, NULL);
279 psp_assoc_put(pas);
280 }
281
psp_reply_set_decrypted(const struct sock * sk,struct sk_buff * skb)282 void psp_reply_set_decrypted(const struct sock *sk, struct sk_buff *skb)
283 {
284 struct psp_assoc *pas;
285
286 rcu_read_lock();
287 pas = psp_sk_get_assoc_rcu(sk);
288 if (pas && pas->tx.spi)
289 skb->decrypted = 1;
290 rcu_read_unlock();
291 }
292 EXPORT_IPV6_MOD_GPL(psp_reply_set_decrypted);
293