xref: /linux/net/tls/tls_strp.c (revision 393fc2f5948fd340d016a9557eea6e1ac2f6c60c)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /* Copyright (c) 2016 Tom Herbert <tom@herbertland.com> */
3 
4 #include <linux/skbuff.h>
5 #include <linux/workqueue.h>
6 #include <net/strparser.h>
7 #include <net/tcp.h>
8 #include <net/sock.h>
9 #include <net/tls.h>
10 
11 #include "tls.h"
12 
13 static struct workqueue_struct *tls_strp_wq;
14 
15 static void tls_strp_abort_strp(struct tls_strparser *strp, int err)
16 {
17 	if (strp->stopped)
18 		return;
19 
20 	strp->stopped = 1;
21 
22 	/* Report an error on the lower socket */
23 	strp->sk->sk_err = -err;
24 	sk_error_report(strp->sk);
25 }
26 
27 static void tls_strp_anchor_free(struct tls_strparser *strp)
28 {
29 	struct skb_shared_info *shinfo = skb_shinfo(strp->anchor);
30 
31 	DEBUG_NET_WARN_ON_ONCE(atomic_read(&shinfo->dataref) != 1);
32 	shinfo->frag_list = NULL;
33 	consume_skb(strp->anchor);
34 	strp->anchor = NULL;
35 }
36 
37 /* Create a new skb with the contents of input copied to its page frags */
38 static struct sk_buff *tls_strp_msg_make_copy(struct tls_strparser *strp)
39 {
40 	struct strp_msg *rxm;
41 	struct sk_buff *skb;
42 	int i, err, offset;
43 
44 	skb = alloc_skb_with_frags(0, strp->stm.full_len, TLS_PAGE_ORDER,
45 				   &err, strp->sk->sk_allocation);
46 	if (!skb)
47 		return NULL;
48 
49 	offset = strp->stm.offset;
50 	for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) {
51 		skb_frag_t *frag = &skb_shinfo(skb)->frags[i];
52 
53 		WARN_ON_ONCE(skb_copy_bits(strp->anchor, offset,
54 					   skb_frag_address(frag),
55 					   skb_frag_size(frag)));
56 		offset += skb_frag_size(frag);
57 	}
58 
59 	skb_copy_header(skb, strp->anchor);
60 	rxm = strp_msg(skb);
61 	rxm->offset = 0;
62 	return skb;
63 }
64 
65 /* Steal the input skb, input msg is invalid after calling this function */
66 struct sk_buff *tls_strp_msg_detach(struct tls_sw_context_rx *ctx)
67 {
68 	struct tls_strparser *strp = &ctx->strp;
69 
70 #ifdef CONFIG_TLS_DEVICE
71 	DEBUG_NET_WARN_ON_ONCE(!strp->anchor->decrypted);
72 #else
73 	/* This function turns an input into an output,
74 	 * that can only happen if we have offload.
75 	 */
76 	WARN_ON(1);
77 #endif
78 
79 	if (strp->copy_mode) {
80 		struct sk_buff *skb;
81 
82 		/* Replace anchor with an empty skb, this is a little
83 		 * dangerous but __tls_cur_msg() warns on empty skbs
84 		 * so hopefully we'll catch abuses.
85 		 */
86 		skb = alloc_skb(0, strp->sk->sk_allocation);
87 		if (!skb)
88 			return NULL;
89 
90 		swap(strp->anchor, skb);
91 		return skb;
92 	}
93 
94 	return tls_strp_msg_make_copy(strp);
95 }
96 
97 /* Force the input skb to be in copy mode. The data ownership remains
98  * with the input skb itself (meaning unpause will wipe it) but it can
99  * be modified.
100  */
101 int tls_strp_msg_cow(struct tls_sw_context_rx *ctx)
102 {
103 	struct tls_strparser *strp = &ctx->strp;
104 	struct sk_buff *skb;
105 
106 	if (strp->copy_mode)
107 		return 0;
108 
109 	skb = tls_strp_msg_make_copy(strp);
110 	if (!skb)
111 		return -ENOMEM;
112 
113 	tls_strp_anchor_free(strp);
114 	strp->anchor = skb;
115 
116 	tcp_read_done(strp->sk, strp->stm.full_len);
117 	strp->copy_mode = 1;
118 
119 	return 0;
120 }
121 
122 /* Make a clone (in the skb sense) of the input msg to keep a reference
123  * to the underlying data. The reference-holding skbs get placed on
124  * @dst.
125  */
126 int tls_strp_msg_hold(struct tls_strparser *strp, struct sk_buff_head *dst)
127 {
128 	struct skb_shared_info *shinfo = skb_shinfo(strp->anchor);
129 
130 	if (strp->copy_mode) {
131 		struct sk_buff *skb;
132 
133 		WARN_ON_ONCE(!shinfo->nr_frags);
134 
135 		/* We can't skb_clone() the anchor, it gets wiped by unpause */
136 		skb = alloc_skb(0, strp->sk->sk_allocation);
137 		if (!skb)
138 			return -ENOMEM;
139 
140 		__skb_queue_tail(dst, strp->anchor);
141 		strp->anchor = skb;
142 	} else {
143 		struct sk_buff *iter, *clone;
144 		int chunk, len, offset;
145 
146 		offset = strp->stm.offset;
147 		len = strp->stm.full_len;
148 		iter = shinfo->frag_list;
149 
150 		while (len > 0) {
151 			if (iter->len <= offset) {
152 				offset -= iter->len;
153 				goto next;
154 			}
155 
156 			chunk = iter->len - offset;
157 			offset = 0;
158 
159 			clone = skb_clone(iter, strp->sk->sk_allocation);
160 			if (!clone)
161 				return -ENOMEM;
162 			__skb_queue_tail(dst, clone);
163 
164 			len -= chunk;
165 next:
166 			iter = iter->next;
167 		}
168 	}
169 
170 	return 0;
171 }
172 
173 static void tls_strp_flush_anchor_copy(struct tls_strparser *strp)
174 {
175 	struct skb_shared_info *shinfo = skb_shinfo(strp->anchor);
176 	int i;
177 
178 	DEBUG_NET_WARN_ON_ONCE(atomic_read(&shinfo->dataref) != 1);
179 
180 	for (i = 0; i < shinfo->nr_frags; i++)
181 		__skb_frag_unref(&shinfo->frags[i], false);
182 	shinfo->nr_frags = 0;
183 	strp->copy_mode = 0;
184 }
185 
186 static int tls_strp_copyin(read_descriptor_t *desc, struct sk_buff *in_skb,
187 			   unsigned int offset, size_t in_len)
188 {
189 	struct tls_strparser *strp = (struct tls_strparser *)desc->arg.data;
190 	struct sk_buff *skb;
191 	skb_frag_t *frag;
192 	size_t len, chunk;
193 	int sz;
194 
195 	if (strp->msg_ready)
196 		return 0;
197 
198 	skb = strp->anchor;
199 	frag = &skb_shinfo(skb)->frags[skb->len / PAGE_SIZE];
200 
201 	len = in_len;
202 	/* First make sure we got the header */
203 	if (!strp->stm.full_len) {
204 		/* Assume one page is more than enough for headers */
205 		chunk =	min_t(size_t, len, PAGE_SIZE - skb_frag_size(frag));
206 		WARN_ON_ONCE(skb_copy_bits(in_skb, offset,
207 					   skb_frag_address(frag) +
208 					   skb_frag_size(frag),
209 					   chunk));
210 
211 		sz = tls_rx_msg_size(strp, strp->anchor);
212 		if (sz < 0) {
213 			desc->error = sz;
214 			return 0;
215 		}
216 
217 		/* We may have over-read, sz == 0 is guaranteed under-read */
218 		if (sz > 0)
219 			chunk =	min_t(size_t, chunk, sz - skb->len);
220 
221 		skb->len += chunk;
222 		skb->data_len += chunk;
223 		skb_frag_size_add(frag, chunk);
224 		frag++;
225 		len -= chunk;
226 		offset += chunk;
227 
228 		strp->stm.full_len = sz;
229 		if (!strp->stm.full_len)
230 			goto read_done;
231 	}
232 
233 	/* Load up more data */
234 	while (len && strp->stm.full_len > skb->len) {
235 		chunk =	min_t(size_t, len, strp->stm.full_len - skb->len);
236 		chunk = min_t(size_t, chunk, PAGE_SIZE - skb_frag_size(frag));
237 		WARN_ON_ONCE(skb_copy_bits(in_skb, offset,
238 					   skb_frag_address(frag) +
239 					   skb_frag_size(frag),
240 					   chunk));
241 
242 		skb->len += chunk;
243 		skb->data_len += chunk;
244 		skb_frag_size_add(frag, chunk);
245 		frag++;
246 		len -= chunk;
247 		offset += chunk;
248 	}
249 
250 	if (strp->stm.full_len == skb->len) {
251 		desc->count = 0;
252 
253 		strp->msg_ready = 1;
254 		tls_rx_msg_ready(strp);
255 	}
256 
257 read_done:
258 	return in_len - len;
259 }
260 
261 static int tls_strp_read_copyin(struct tls_strparser *strp)
262 {
263 	struct socket *sock = strp->sk->sk_socket;
264 	read_descriptor_t desc;
265 
266 	desc.arg.data = strp;
267 	desc.error = 0;
268 	desc.count = 1; /* give more than one skb per call */
269 
270 	/* sk should be locked here, so okay to do read_sock */
271 	sock->ops->read_sock(strp->sk, &desc, tls_strp_copyin);
272 
273 	return desc.error;
274 }
275 
276 static int tls_strp_read_short(struct tls_strparser *strp)
277 {
278 	struct skb_shared_info *shinfo;
279 	struct page *page;
280 	int need_spc, len;
281 
282 	/* If the rbuf is small or rcv window has collapsed to 0 we need
283 	 * to read the data out. Otherwise the connection will stall.
284 	 * Without pressure threshold of INT_MAX will never be ready.
285 	 */
286 	if (likely(!tcp_epollin_ready(strp->sk, INT_MAX)))
287 		return 0;
288 
289 	shinfo = skb_shinfo(strp->anchor);
290 	shinfo->frag_list = NULL;
291 
292 	/* If we don't know the length go max plus page for cipher overhead */
293 	need_spc = strp->stm.full_len ?: TLS_MAX_PAYLOAD_SIZE + PAGE_SIZE;
294 
295 	for (len = need_spc; len > 0; len -= PAGE_SIZE) {
296 		page = alloc_page(strp->sk->sk_allocation);
297 		if (!page) {
298 			tls_strp_flush_anchor_copy(strp);
299 			return -ENOMEM;
300 		}
301 
302 		skb_fill_page_desc(strp->anchor, shinfo->nr_frags++,
303 				   page, 0, 0);
304 	}
305 
306 	strp->copy_mode = 1;
307 	strp->stm.offset = 0;
308 
309 	strp->anchor->len = 0;
310 	strp->anchor->data_len = 0;
311 	strp->anchor->truesize = round_up(need_spc, PAGE_SIZE);
312 
313 	tls_strp_read_copyin(strp);
314 
315 	return 0;
316 }
317 
318 static void tls_strp_load_anchor_with_queue(struct tls_strparser *strp, int len)
319 {
320 	struct tcp_sock *tp = tcp_sk(strp->sk);
321 	struct sk_buff *first;
322 	u32 offset;
323 
324 	first = tcp_recv_skb(strp->sk, tp->copied_seq, &offset);
325 	if (WARN_ON_ONCE(!first))
326 		return;
327 
328 	/* Bestow the state onto the anchor */
329 	strp->anchor->len = offset + len;
330 	strp->anchor->data_len = offset + len;
331 	strp->anchor->truesize = offset + len;
332 
333 	skb_shinfo(strp->anchor)->frag_list = first;
334 
335 	skb_copy_header(strp->anchor, first);
336 	strp->anchor->destructor = NULL;
337 
338 	strp->stm.offset = offset;
339 }
340 
341 void tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh)
342 {
343 	struct strp_msg *rxm;
344 	struct tls_msg *tlm;
345 
346 	DEBUG_NET_WARN_ON_ONCE(!strp->msg_ready);
347 	DEBUG_NET_WARN_ON_ONCE(!strp->stm.full_len);
348 
349 	if (!strp->copy_mode && force_refresh) {
350 		if (WARN_ON(tcp_inq(strp->sk) < strp->stm.full_len))
351 			return;
352 
353 		tls_strp_load_anchor_with_queue(strp, strp->stm.full_len);
354 	}
355 
356 	rxm = strp_msg(strp->anchor);
357 	rxm->full_len	= strp->stm.full_len;
358 	rxm->offset	= strp->stm.offset;
359 	tlm = tls_msg(strp->anchor);
360 	tlm->control	= strp->mark;
361 }
362 
363 /* Called with lock held on lower socket */
364 static int tls_strp_read_sock(struct tls_strparser *strp)
365 {
366 	int sz, inq;
367 
368 	inq = tcp_inq(strp->sk);
369 	if (inq < 1)
370 		return 0;
371 
372 	if (unlikely(strp->copy_mode))
373 		return tls_strp_read_copyin(strp);
374 
375 	if (inq < strp->stm.full_len)
376 		return tls_strp_read_short(strp);
377 
378 	if (!strp->stm.full_len) {
379 		tls_strp_load_anchor_with_queue(strp, inq);
380 
381 		sz = tls_rx_msg_size(strp, strp->anchor);
382 		if (sz < 0) {
383 			tls_strp_abort_strp(strp, sz);
384 			return sz;
385 		}
386 
387 		strp->stm.full_len = sz;
388 
389 		if (!strp->stm.full_len || inq < strp->stm.full_len)
390 			return tls_strp_read_short(strp);
391 	}
392 
393 	strp->msg_ready = 1;
394 	tls_rx_msg_ready(strp);
395 
396 	return 0;
397 }
398 
399 void tls_strp_check_rcv(struct tls_strparser *strp)
400 {
401 	if (unlikely(strp->stopped) || strp->msg_ready)
402 		return;
403 
404 	if (tls_strp_read_sock(strp) == -ENOMEM)
405 		queue_work(tls_strp_wq, &strp->work);
406 }
407 
408 /* Lower sock lock held */
409 void tls_strp_data_ready(struct tls_strparser *strp)
410 {
411 	/* This check is needed to synchronize with do_tls_strp_work.
412 	 * do_tls_strp_work acquires a process lock (lock_sock) whereas
413 	 * the lock held here is bh_lock_sock. The two locks can be
414 	 * held by different threads at the same time, but bh_lock_sock
415 	 * allows a thread in BH context to safely check if the process
416 	 * lock is held. In this case, if the lock is held, queue work.
417 	 */
418 	if (sock_owned_by_user_nocheck(strp->sk)) {
419 		queue_work(tls_strp_wq, &strp->work);
420 		return;
421 	}
422 
423 	tls_strp_check_rcv(strp);
424 }
425 
426 static void tls_strp_work(struct work_struct *w)
427 {
428 	struct tls_strparser *strp =
429 		container_of(w, struct tls_strparser, work);
430 
431 	lock_sock(strp->sk);
432 	tls_strp_check_rcv(strp);
433 	release_sock(strp->sk);
434 }
435 
436 void tls_strp_msg_done(struct tls_strparser *strp)
437 {
438 	WARN_ON(!strp->stm.full_len);
439 
440 	if (likely(!strp->copy_mode))
441 		tcp_read_done(strp->sk, strp->stm.full_len);
442 	else
443 		tls_strp_flush_anchor_copy(strp);
444 
445 	strp->msg_ready = 0;
446 	memset(&strp->stm, 0, sizeof(strp->stm));
447 
448 	tls_strp_check_rcv(strp);
449 }
450 
451 void tls_strp_stop(struct tls_strparser *strp)
452 {
453 	strp->stopped = 1;
454 }
455 
456 int tls_strp_init(struct tls_strparser *strp, struct sock *sk)
457 {
458 	memset(strp, 0, sizeof(*strp));
459 
460 	strp->sk = sk;
461 
462 	strp->anchor = alloc_skb(0, GFP_KERNEL);
463 	if (!strp->anchor)
464 		return -ENOMEM;
465 
466 	INIT_WORK(&strp->work, tls_strp_work);
467 
468 	return 0;
469 }
470 
471 /* strp must already be stopped so that tls_strp_recv will no longer be called.
472  * Note that tls_strp_done is not called with the lower socket held.
473  */
474 void tls_strp_done(struct tls_strparser *strp)
475 {
476 	WARN_ON(!strp->stopped);
477 
478 	cancel_work_sync(&strp->work);
479 	tls_strp_anchor_free(strp);
480 }
481 
482 int __init tls_strp_dev_init(void)
483 {
484 	tls_strp_wq = create_workqueue("tls-strp");
485 	if (unlikely(!tls_strp_wq))
486 		return -ENOMEM;
487 
488 	return 0;
489 }
490 
491 void tls_strp_dev_exit(void)
492 {
493 	destroy_workqueue(tls_strp_wq);
494 }
495