xref: /linux/net/ipv4/tcp_bpf.c (revision b7d3826c2ed6c3e626e7ae796c5df2c0d2551c6a)
1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
3 
4 #include <linux/skmsg.h>
5 #include <linux/filter.h>
6 #include <linux/bpf.h>
7 #include <linux/init.h>
8 #include <linux/wait.h>
9 
10 #include <net/inet_common.h>
11 
12 static bool tcp_bpf_stream_read(const struct sock *sk)
13 {
14 	struct sk_psock *psock;
15 	bool empty = true;
16 
17 	rcu_read_lock();
18 	psock = sk_psock(sk);
19 	if (likely(psock))
20 		empty = list_empty(&psock->ingress_msg);
21 	rcu_read_unlock();
22 	return !empty;
23 }
24 
25 static int tcp_bpf_wait_data(struct sock *sk, struct sk_psock *psock,
26 			     int flags, long timeo, int *err)
27 {
28 	DEFINE_WAIT_FUNC(wait, woken_wake_function);
29 	int ret;
30 
31 	add_wait_queue(sk_sleep(sk), &wait);
32 	sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
33 	ret = sk_wait_event(sk, &timeo,
34 			    !list_empty(&psock->ingress_msg) ||
35 			    !skb_queue_empty(&sk->sk_receive_queue), &wait);
36 	sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
37 	remove_wait_queue(sk_sleep(sk), &wait);
38 	return ret;
39 }
40 
41 int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock,
42 		      struct msghdr *msg, int len)
43 {
44 	struct iov_iter *iter = &msg->msg_iter;
45 	int i, ret, copied = 0;
46 
47 	while (copied != len) {
48 		struct scatterlist *sge;
49 		struct sk_msg *msg_rx;
50 
51 		msg_rx = list_first_entry_or_null(&psock->ingress_msg,
52 						  struct sk_msg, list);
53 		if (unlikely(!msg_rx))
54 			break;
55 
56 		i = msg_rx->sg.start;
57 		do {
58 			struct page *page;
59 			int copy;
60 
61 			sge = sk_msg_elem(msg_rx, i);
62 			copy = sge->length;
63 			page = sg_page(sge);
64 			if (copied + copy > len)
65 				copy = len - copied;
66 			ret = copy_page_to_iter(page, sge->offset, copy, iter);
67 			if (ret != copy) {
68 				msg_rx->sg.start = i;
69 				return -EFAULT;
70 			}
71 
72 			copied += copy;
73 			sge->offset += copy;
74 			sge->length -= copy;
75 			sk_mem_uncharge(sk, copy);
76 			if (!sge->length) {
77 				i++;
78 				if (i == MAX_SKB_FRAGS)
79 					i = 0;
80 				if (!msg_rx->skb)
81 					put_page(page);
82 			}
83 
84 			if (copied == len)
85 				break;
86 		} while (i != msg_rx->sg.end);
87 
88 		msg_rx->sg.start = i;
89 		if (!sge->length && msg_rx->sg.start == msg_rx->sg.end) {
90 			list_del(&msg_rx->list);
91 			if (msg_rx->skb)
92 				consume_skb(msg_rx->skb);
93 			kfree(msg_rx);
94 		}
95 	}
96 
97 	return copied;
98 }
99 EXPORT_SYMBOL_GPL(__tcp_bpf_recvmsg);
100 
101 int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
102 		    int nonblock, int flags, int *addr_len)
103 {
104 	struct sk_psock *psock;
105 	int copied, ret;
106 
107 	if (unlikely(flags & MSG_ERRQUEUE))
108 		return inet_recv_error(sk, msg, len, addr_len);
109 	if (!skb_queue_empty(&sk->sk_receive_queue))
110 		return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
111 
112 	psock = sk_psock_get(sk);
113 	if (unlikely(!psock))
114 		return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
115 	lock_sock(sk);
116 msg_bytes_ready:
117 	copied = __tcp_bpf_recvmsg(sk, psock, msg, len);
118 	if (!copied) {
119 		int data, err = 0;
120 		long timeo;
121 
122 		timeo = sock_rcvtimeo(sk, nonblock);
123 		data = tcp_bpf_wait_data(sk, psock, flags, timeo, &err);
124 		if (data) {
125 			if (skb_queue_empty(&sk->sk_receive_queue))
126 				goto msg_bytes_ready;
127 			release_sock(sk);
128 			sk_psock_put(sk, psock);
129 			return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
130 		}
131 		if (err) {
132 			ret = err;
133 			goto out;
134 		}
135 	}
136 	ret = copied;
137 out:
138 	release_sock(sk);
139 	sk_psock_put(sk, psock);
140 	return ret;
141 }
142 
143 static int bpf_tcp_ingress(struct sock *sk, struct sk_psock *psock,
144 			   struct sk_msg *msg, u32 apply_bytes, int flags)
145 {
146 	bool apply = apply_bytes;
147 	struct scatterlist *sge;
148 	u32 size, copied = 0;
149 	struct sk_msg *tmp;
150 	int i, ret = 0;
151 
152 	tmp = kzalloc(sizeof(*tmp), __GFP_NOWARN | GFP_KERNEL);
153 	if (unlikely(!tmp))
154 		return -ENOMEM;
155 
156 	lock_sock(sk);
157 	tmp->sg.start = msg->sg.start;
158 	i = msg->sg.start;
159 	do {
160 		sge = sk_msg_elem(msg, i);
161 		size = (apply && apply_bytes < sge->length) ?
162 			apply_bytes : sge->length;
163 		if (!sk_wmem_schedule(sk, size)) {
164 			if (!copied)
165 				ret = -ENOMEM;
166 			break;
167 		}
168 
169 		sk_mem_charge(sk, size);
170 		sk_msg_xfer(tmp, msg, i, size);
171 		copied += size;
172 		if (sge->length)
173 			get_page(sk_msg_page(tmp, i));
174 		sk_msg_iter_var_next(i);
175 		tmp->sg.end = i;
176 		if (apply) {
177 			apply_bytes -= size;
178 			if (!apply_bytes)
179 				break;
180 		}
181 	} while (i != msg->sg.end);
182 
183 	if (!ret) {
184 		msg->sg.start = i;
185 		msg->sg.size -= apply_bytes;
186 		sk_psock_queue_msg(psock, tmp);
187 		sk->sk_data_ready(sk);
188 	} else {
189 		sk_msg_free(sk, tmp);
190 		kfree(tmp);
191 	}
192 
193 	release_sock(sk);
194 	return ret;
195 }
196 
197 static int tcp_bpf_push(struct sock *sk, struct sk_msg *msg, u32 apply_bytes,
198 			int flags, bool uncharge)
199 {
200 	bool apply = apply_bytes;
201 	struct scatterlist *sge;
202 	struct page *page;
203 	int size, ret = 0;
204 	u32 off;
205 
206 	while (1) {
207 		sge = sk_msg_elem(msg, msg->sg.start);
208 		size = (apply && apply_bytes < sge->length) ?
209 			apply_bytes : sge->length;
210 		off  = sge->offset;
211 		page = sg_page(sge);
212 
213 		tcp_rate_check_app_limited(sk);
214 retry:
215 		ret = do_tcp_sendpages(sk, page, off, size, flags);
216 		if (ret <= 0)
217 			return ret;
218 		if (apply)
219 			apply_bytes -= ret;
220 		msg->sg.size -= ret;
221 		sge->offset += ret;
222 		sge->length -= ret;
223 		if (uncharge)
224 			sk_mem_uncharge(sk, ret);
225 		if (ret != size) {
226 			size -= ret;
227 			off  += ret;
228 			goto retry;
229 		}
230 		if (!sge->length) {
231 			put_page(page);
232 			sk_msg_iter_next(msg, start);
233 			sg_init_table(sge, 1);
234 			if (msg->sg.start == msg->sg.end)
235 				break;
236 		}
237 		if (apply && !apply_bytes)
238 			break;
239 	}
240 
241 	return 0;
242 }
243 
244 static int tcp_bpf_push_locked(struct sock *sk, struct sk_msg *msg,
245 			       u32 apply_bytes, int flags, bool uncharge)
246 {
247 	int ret;
248 
249 	lock_sock(sk);
250 	ret = tcp_bpf_push(sk, msg, apply_bytes, flags, uncharge);
251 	release_sock(sk);
252 	return ret;
253 }
254 
255 int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg,
256 			  u32 bytes, int flags)
257 {
258 	bool ingress = sk_msg_to_ingress(msg);
259 	struct sk_psock *psock = sk_psock_get(sk);
260 	int ret;
261 
262 	if (unlikely(!psock)) {
263 		sk_msg_free(sk, msg);
264 		return 0;
265 	}
266 	ret = ingress ? bpf_tcp_ingress(sk, psock, msg, bytes, flags) :
267 			tcp_bpf_push_locked(sk, msg, bytes, flags, false);
268 	sk_psock_put(sk, psock);
269 	return ret;
270 }
271 EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir);
272 
273 static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock,
274 				struct sk_msg *msg, int *copied, int flags)
275 {
276 	bool cork = false, enospc = msg->sg.start == msg->sg.end;
277 	struct sock *sk_redir;
278 	u32 tosend;
279 	int ret;
280 
281 more_data:
282 	if (psock->eval == __SK_NONE)
283 		psock->eval = sk_psock_msg_verdict(sk, psock, msg);
284 
285 	if (msg->cork_bytes &&
286 	    msg->cork_bytes > msg->sg.size && !enospc) {
287 		psock->cork_bytes = msg->cork_bytes - msg->sg.size;
288 		if (!psock->cork) {
289 			psock->cork = kzalloc(sizeof(*psock->cork),
290 					      GFP_ATOMIC | __GFP_NOWARN);
291 			if (!psock->cork)
292 				return -ENOMEM;
293 		}
294 		memcpy(psock->cork, msg, sizeof(*msg));
295 		return 0;
296 	}
297 
298 	tosend = msg->sg.size;
299 	if (psock->apply_bytes && psock->apply_bytes < tosend)
300 		tosend = psock->apply_bytes;
301 
302 	switch (psock->eval) {
303 	case __SK_PASS:
304 		ret = tcp_bpf_push(sk, msg, tosend, flags, true);
305 		if (unlikely(ret)) {
306 			*copied -= sk_msg_free(sk, msg);
307 			break;
308 		}
309 		sk_msg_apply_bytes(psock, tosend);
310 		break;
311 	case __SK_REDIRECT:
312 		sk_redir = psock->sk_redir;
313 		sk_msg_apply_bytes(psock, tosend);
314 		if (psock->cork) {
315 			cork = true;
316 			psock->cork = NULL;
317 		}
318 		sk_msg_return(sk, msg, tosend);
319 		release_sock(sk);
320 		ret = tcp_bpf_sendmsg_redir(sk_redir, msg, tosend, flags);
321 		lock_sock(sk);
322 		if (unlikely(ret < 0)) {
323 			int free = sk_msg_free_nocharge(sk, msg);
324 
325 			if (!cork)
326 				*copied -= free;
327 		}
328 		if (cork) {
329 			sk_msg_free(sk, msg);
330 			kfree(msg);
331 			msg = NULL;
332 			ret = 0;
333 		}
334 		break;
335 	case __SK_DROP:
336 	default:
337 		sk_msg_free_partial(sk, msg, tosend);
338 		sk_msg_apply_bytes(psock, tosend);
339 		*copied -= tosend;
340 		return -EACCES;
341 	}
342 
343 	if (likely(!ret)) {
344 		if (!psock->apply_bytes) {
345 			psock->eval =  __SK_NONE;
346 			if (psock->sk_redir) {
347 				sock_put(psock->sk_redir);
348 				psock->sk_redir = NULL;
349 			}
350 		}
351 		if (msg &&
352 		    msg->sg.data[msg->sg.start].page_link &&
353 		    msg->sg.data[msg->sg.start].length)
354 			goto more_data;
355 	}
356 	return ret;
357 }
358 
359 static int tcp_bpf_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
360 {
361 	struct sk_msg tmp, *msg_tx = NULL;
362 	int flags = msg->msg_flags | MSG_NO_SHARED_FRAGS;
363 	int copied = 0, err = 0;
364 	struct sk_psock *psock;
365 	long timeo;
366 
367 	psock = sk_psock_get(sk);
368 	if (unlikely(!psock))
369 		return tcp_sendmsg(sk, msg, size);
370 
371 	lock_sock(sk);
372 	timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
373 	while (msg_data_left(msg)) {
374 		bool enospc = false;
375 		u32 copy, osize;
376 
377 		if (sk->sk_err) {
378 			err = -sk->sk_err;
379 			goto out_err;
380 		}
381 
382 		copy = msg_data_left(msg);
383 		if (!sk_stream_memory_free(sk))
384 			goto wait_for_sndbuf;
385 		if (psock->cork) {
386 			msg_tx = psock->cork;
387 		} else {
388 			msg_tx = &tmp;
389 			sk_msg_init(msg_tx);
390 		}
391 
392 		osize = msg_tx->sg.size;
393 		err = sk_msg_alloc(sk, msg_tx, msg_tx->sg.size + copy, msg_tx->sg.end - 1);
394 		if (err) {
395 			if (err != -ENOSPC)
396 				goto wait_for_memory;
397 			enospc = true;
398 			copy = msg_tx->sg.size - osize;
399 		}
400 
401 		err = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, msg_tx,
402 					       copy);
403 		if (err < 0) {
404 			sk_msg_trim(sk, msg_tx, osize);
405 			goto out_err;
406 		}
407 
408 		copied += copy;
409 		if (psock->cork_bytes) {
410 			if (size > psock->cork_bytes)
411 				psock->cork_bytes = 0;
412 			else
413 				psock->cork_bytes -= size;
414 			if (psock->cork_bytes && !enospc)
415 				goto out_err;
416 			/* All cork bytes are accounted, rerun the prog. */
417 			psock->eval = __SK_NONE;
418 			psock->cork_bytes = 0;
419 		}
420 
421 		err = tcp_bpf_send_verdict(sk, psock, msg_tx, &copied, flags);
422 		if (unlikely(err < 0))
423 			goto out_err;
424 		continue;
425 wait_for_sndbuf:
426 		set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
427 wait_for_memory:
428 		err = sk_stream_wait_memory(sk, &timeo);
429 		if (err) {
430 			if (msg_tx && msg_tx != psock->cork)
431 				sk_msg_free(sk, msg_tx);
432 			goto out_err;
433 		}
434 	}
435 out_err:
436 	if (err < 0)
437 		err = sk_stream_error(sk, msg->msg_flags, err);
438 	release_sock(sk);
439 	sk_psock_put(sk, psock);
440 	return copied ? copied : err;
441 }
442 
443 static int tcp_bpf_sendpage(struct sock *sk, struct page *page, int offset,
444 			    size_t size, int flags)
445 {
446 	struct sk_msg tmp, *msg = NULL;
447 	int err = 0, copied = 0;
448 	struct sk_psock *psock;
449 	bool enospc = false;
450 
451 	psock = sk_psock_get(sk);
452 	if (unlikely(!psock))
453 		return tcp_sendpage(sk, page, offset, size, flags);
454 
455 	lock_sock(sk);
456 	if (psock->cork) {
457 		msg = psock->cork;
458 	} else {
459 		msg = &tmp;
460 		sk_msg_init(msg);
461 	}
462 
463 	/* Catch case where ring is full and sendpage is stalled. */
464 	if (unlikely(sk_msg_full(msg)))
465 		goto out_err;
466 
467 	sk_msg_page_add(msg, page, size, offset);
468 	sk_mem_charge(sk, size);
469 	copied = size;
470 	if (sk_msg_full(msg))
471 		enospc = true;
472 	if (psock->cork_bytes) {
473 		if (size > psock->cork_bytes)
474 			psock->cork_bytes = 0;
475 		else
476 			psock->cork_bytes -= size;
477 		if (psock->cork_bytes && !enospc)
478 			goto out_err;
479 		/* All cork bytes are accounted, rerun the prog. */
480 		psock->eval = __SK_NONE;
481 		psock->cork_bytes = 0;
482 	}
483 
484 	err = tcp_bpf_send_verdict(sk, psock, msg, &copied, flags);
485 out_err:
486 	release_sock(sk);
487 	sk_psock_put(sk, psock);
488 	return copied ? copied : err;
489 }
490 
491 static void tcp_bpf_remove(struct sock *sk, struct sk_psock *psock)
492 {
493 	struct sk_psock_link *link;
494 
495 	sk_psock_cork_free(psock);
496 	__sk_psock_purge_ingress_msg(psock);
497 	while ((link = sk_psock_link_pop(psock))) {
498 		sk_psock_unlink(sk, link);
499 		sk_psock_free_link(link);
500 	}
501 }
502 
503 static void tcp_bpf_unhash(struct sock *sk)
504 {
505 	void (*saved_unhash)(struct sock *sk);
506 	struct sk_psock *psock;
507 
508 	rcu_read_lock();
509 	psock = sk_psock(sk);
510 	if (unlikely(!psock)) {
511 		rcu_read_unlock();
512 		if (sk->sk_prot->unhash)
513 			sk->sk_prot->unhash(sk);
514 		return;
515 	}
516 
517 	saved_unhash = psock->saved_unhash;
518 	tcp_bpf_remove(sk, psock);
519 	rcu_read_unlock();
520 	saved_unhash(sk);
521 }
522 
523 static void tcp_bpf_close(struct sock *sk, long timeout)
524 {
525 	void (*saved_close)(struct sock *sk, long timeout);
526 	struct sk_psock *psock;
527 
528 	lock_sock(sk);
529 	rcu_read_lock();
530 	psock = sk_psock(sk);
531 	if (unlikely(!psock)) {
532 		rcu_read_unlock();
533 		release_sock(sk);
534 		return sk->sk_prot->close(sk, timeout);
535 	}
536 
537 	saved_close = psock->saved_close;
538 	tcp_bpf_remove(sk, psock);
539 	rcu_read_unlock();
540 	release_sock(sk);
541 	saved_close(sk, timeout);
542 }
543 
544 enum {
545 	TCP_BPF_IPV4,
546 	TCP_BPF_IPV6,
547 	TCP_BPF_NUM_PROTS,
548 };
549 
550 enum {
551 	TCP_BPF_BASE,
552 	TCP_BPF_TX,
553 	TCP_BPF_NUM_CFGS,
554 };
555 
556 static struct proto *tcpv6_prot_saved __read_mostly;
557 static DEFINE_SPINLOCK(tcpv6_prot_lock);
558 static struct proto tcp_bpf_prots[TCP_BPF_NUM_PROTS][TCP_BPF_NUM_CFGS];
559 
560 static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS],
561 				   struct proto *base)
562 {
563 	prot[TCP_BPF_BASE]			= *base;
564 	prot[TCP_BPF_BASE].unhash		= tcp_bpf_unhash;
565 	prot[TCP_BPF_BASE].close		= tcp_bpf_close;
566 	prot[TCP_BPF_BASE].recvmsg		= tcp_bpf_recvmsg;
567 	prot[TCP_BPF_BASE].stream_memory_read	= tcp_bpf_stream_read;
568 
569 	prot[TCP_BPF_TX]			= prot[TCP_BPF_BASE];
570 	prot[TCP_BPF_TX].sendmsg		= tcp_bpf_sendmsg;
571 	prot[TCP_BPF_TX].sendpage		= tcp_bpf_sendpage;
572 }
573 
574 static void tcp_bpf_check_v6_needs_rebuild(struct sock *sk, struct proto *ops)
575 {
576 	if (sk->sk_family == AF_INET6 &&
577 	    unlikely(ops != smp_load_acquire(&tcpv6_prot_saved))) {
578 		spin_lock_bh(&tcpv6_prot_lock);
579 		if (likely(ops != tcpv6_prot_saved)) {
580 			tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV6], ops);
581 			smp_store_release(&tcpv6_prot_saved, ops);
582 		}
583 		spin_unlock_bh(&tcpv6_prot_lock);
584 	}
585 }
586 
587 static int __init tcp_bpf_v4_build_proto(void)
588 {
589 	tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV4], &tcp_prot);
590 	return 0;
591 }
592 core_initcall(tcp_bpf_v4_build_proto);
593 
594 static void tcp_bpf_update_sk_prot(struct sock *sk, struct sk_psock *psock)
595 {
596 	int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
597 	int config = psock->progs.msg_parser   ? TCP_BPF_TX   : TCP_BPF_BASE;
598 
599 	sk_psock_update_proto(sk, psock, &tcp_bpf_prots[family][config]);
600 }
601 
602 static void tcp_bpf_reinit_sk_prot(struct sock *sk, struct sk_psock *psock)
603 {
604 	int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
605 	int config = psock->progs.msg_parser   ? TCP_BPF_TX   : TCP_BPF_BASE;
606 
607 	/* Reinit occurs when program types change e.g. TCP_BPF_TX is removed
608 	 * or added requiring sk_prot hook updates. We keep original saved
609 	 * hooks in this case.
610 	 */
611 	sk->sk_prot = &tcp_bpf_prots[family][config];
612 }
613 
614 static int tcp_bpf_assert_proto_ops(struct proto *ops)
615 {
616 	/* In order to avoid retpoline, we make assumptions when we call
617 	 * into ops if e.g. a psock is not present. Make sure they are
618 	 * indeed valid assumptions.
619 	 */
620 	return ops->recvmsg  == tcp_recvmsg &&
621 	       ops->sendmsg  == tcp_sendmsg &&
622 	       ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP;
623 }
624 
625 void tcp_bpf_reinit(struct sock *sk)
626 {
627 	struct sk_psock *psock;
628 
629 	sock_owned_by_me(sk);
630 
631 	rcu_read_lock();
632 	psock = sk_psock(sk);
633 	tcp_bpf_reinit_sk_prot(sk, psock);
634 	rcu_read_unlock();
635 }
636 
637 int tcp_bpf_init(struct sock *sk)
638 {
639 	struct proto *ops = READ_ONCE(sk->sk_prot);
640 	struct sk_psock *psock;
641 
642 	sock_owned_by_me(sk);
643 
644 	rcu_read_lock();
645 	psock = sk_psock(sk);
646 	if (unlikely(!psock || psock->sk_proto ||
647 		     tcp_bpf_assert_proto_ops(ops))) {
648 		rcu_read_unlock();
649 		return -EINVAL;
650 	}
651 	tcp_bpf_check_v6_needs_rebuild(sk, ops);
652 	tcp_bpf_update_sk_prot(sk, psock);
653 	rcu_read_unlock();
654 	return 0;
655 }
656