xref: /linux/drivers/net/ovpn/tcp.c (revision 3186a8e55ae3428ec1e06af09075e20885376e4e)
1 // SPDX-License-Identifier: GPL-2.0
2 /*  OpenVPN data channel offload
3  *
4  *  Copyright (C) 2019-2025 OpenVPN, Inc.
5  *
6  *  Author:	Antonio Quartulli <antonio@openvpn.net>
7  */
8 
9 #include <linux/skbuff.h>
10 #include <net/hotdata.h>
11 #include <net/inet_common.h>
12 #include <net/ipv6.h>
13 #include <net/tcp.h>
14 #include <net/transp_v6.h>
15 #include <net/route.h>
16 #include <trace/events/sock.h>
17 
18 #include "ovpnpriv.h"
19 #include "main.h"
20 #include "io.h"
21 #include "peer.h"
22 #include "proto.h"
23 #include "skb.h"
24 #include "tcp.h"
25 
26 #define OVPN_TCP_DEPTH_NESTING	2
27 #if OVPN_TCP_DEPTH_NESTING == SINGLE_DEPTH_NESTING
28 #error "OVPN TCP requires its own lockdep subclass"
29 #endif
30 
31 static struct proto ovpn_tcp_prot __ro_after_init;
32 static struct proto_ops ovpn_tcp_ops __ro_after_init;
33 static struct proto ovpn_tcp6_prot __ro_after_init;
34 static struct proto_ops ovpn_tcp6_ops __ro_after_init;
35 
36 static int ovpn_tcp_parse(struct strparser *strp, struct sk_buff *skb)
37 {
38 	struct strp_msg *rxm = strp_msg(skb);
39 	__be16 blen;
40 	u16 len;
41 	int err;
42 
43 	/* when packets are written to the TCP stream, they are prepended with
44 	 * two bytes indicating the actual packet size.
45 	 * Parse accordingly and return the actual size (including the size
46 	 * header)
47 	 */
48 
49 	if (skb->len < rxm->offset + 2)
50 		return 0;
51 
52 	err = skb_copy_bits(skb, rxm->offset, &blen, sizeof(blen));
53 	if (err < 0)
54 		return err;
55 
56 	len = be16_to_cpu(blen);
57 	if (len < 2)
58 		return -EINVAL;
59 
60 	return len + 2;
61 }
62 
63 /* queue skb for sending to userspace via recvmsg on the socket */
64 static void ovpn_tcp_to_userspace(struct ovpn_peer *peer, struct sock *sk,
65 				  struct sk_buff *skb)
66 {
67 	skb_set_owner_r(skb, sk);
68 	memset(skb->cb, 0, sizeof(skb->cb));
69 	skb_queue_tail(&peer->tcp.user_queue, skb);
70 	peer->tcp.sk_cb.sk_data_ready(sk);
71 }
72 
73 static void ovpn_tcp_rcv(struct strparser *strp, struct sk_buff *skb)
74 {
75 	struct ovpn_peer *peer = container_of(strp, struct ovpn_peer, tcp.strp);
76 	struct strp_msg *msg = strp_msg(skb);
77 	size_t pkt_len = msg->full_len - 2;
78 	size_t off = msg->offset + 2;
79 	u8 opcode;
80 
81 	/* ensure skb->data points to the beginning of the openvpn packet */
82 	if (!pskb_pull(skb, off)) {
83 		net_warn_ratelimited("%s: packet too small for peer %u\n",
84 				     netdev_name(peer->ovpn->dev), peer->id);
85 		goto err;
86 	}
87 
88 	/* strparser does not trim the skb for us, therefore we do it now */
89 	if (pskb_trim(skb, pkt_len) != 0) {
90 		net_warn_ratelimited("%s: trimming skb failed for peer %u\n",
91 				     netdev_name(peer->ovpn->dev), peer->id);
92 		goto err;
93 	}
94 
95 	/* we need the first 4 bytes of data to be accessible
96 	 * to extract the opcode and the key ID later on
97 	 */
98 	if (!pskb_may_pull(skb, OVPN_OPCODE_SIZE)) {
99 		net_warn_ratelimited("%s: packet too small to fetch opcode for peer %u\n",
100 				     netdev_name(peer->ovpn->dev), peer->id);
101 		goto err;
102 	}
103 
104 	/* DATA_V2 packets are handled in kernel, the rest goes to user space */
105 	opcode = ovpn_opcode_from_skb(skb, 0);
106 	if (unlikely(opcode != OVPN_DATA_V2)) {
107 		if (opcode == OVPN_DATA_V1) {
108 			net_warn_ratelimited("%s: DATA_V1 detected on the TCP stream\n",
109 					     netdev_name(peer->ovpn->dev));
110 			goto err;
111 		}
112 
113 		/* The packet size header must be there when sending the packet
114 		 * to userspace, therefore we put it back
115 		 */
116 		skb_push(skb, 2);
117 		ovpn_tcp_to_userspace(peer, strp->sk, skb);
118 		return;
119 	}
120 
121 	/* hold reference to peer as required by ovpn_recv().
122 	 *
123 	 * NOTE: in this context we should already be holding a reference to
124 	 * this peer, therefore ovpn_peer_hold() is not expected to fail
125 	 */
126 	if (WARN_ON(!ovpn_peer_hold(peer)))
127 		goto err;
128 
129 	ovpn_recv(peer, skb);
130 	return;
131 err:
132 	dev_dstats_rx_dropped(peer->ovpn->dev);
133 	kfree_skb(skb);
134 	ovpn_peer_del(peer, OVPN_DEL_PEER_REASON_TRANSPORT_ERROR);
135 }
136 
137 static int ovpn_tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
138 			    int flags, int *addr_len)
139 {
140 	int err = 0, off, copied = 0, ret;
141 	struct ovpn_socket *sock;
142 	struct ovpn_peer *peer;
143 	struct sk_buff *skb;
144 
145 	rcu_read_lock();
146 	sock = rcu_dereference_sk_user_data(sk);
147 	if (unlikely(!sock || !sock->peer || !ovpn_peer_hold(sock->peer))) {
148 		rcu_read_unlock();
149 		return -EBADF;
150 	}
151 	peer = sock->peer;
152 	rcu_read_unlock();
153 
154 	skb = __skb_recv_datagram(sk, &peer->tcp.user_queue, flags, &off, &err);
155 	if (!skb) {
156 		if (err == -EAGAIN && sk->sk_shutdown & RCV_SHUTDOWN) {
157 			ret = 0;
158 			goto out;
159 		}
160 		ret = err;
161 		goto out;
162 	}
163 
164 	copied = len;
165 	if (copied > skb->len)
166 		copied = skb->len;
167 	else if (copied < skb->len)
168 		msg->msg_flags |= MSG_TRUNC;
169 
170 	err = skb_copy_datagram_msg(skb, 0, msg, copied);
171 	if (unlikely(err)) {
172 		kfree_skb(skb);
173 		ret = err;
174 		goto out;
175 	}
176 
177 	if (flags & MSG_TRUNC)
178 		copied = skb->len;
179 	kfree_skb(skb);
180 	ret = copied;
181 out:
182 	ovpn_peer_put(peer);
183 	return ret;
184 }
185 
186 void ovpn_tcp_socket_detach(struct ovpn_socket *ovpn_sock)
187 {
188 	struct ovpn_peer *peer = ovpn_sock->peer;
189 	struct socket *sock = ovpn_sock->sock;
190 
191 	strp_stop(&peer->tcp.strp);
192 	skb_queue_purge(&peer->tcp.user_queue);
193 
194 	/* restore CBs that were saved in ovpn_sock_set_tcp_cb() */
195 	sock->sk->sk_data_ready = peer->tcp.sk_cb.sk_data_ready;
196 	sock->sk->sk_write_space = peer->tcp.sk_cb.sk_write_space;
197 	sock->sk->sk_prot = peer->tcp.sk_cb.prot;
198 	sock->sk->sk_socket->ops = peer->tcp.sk_cb.ops;
199 
200 	rcu_assign_sk_user_data(sock->sk, NULL);
201 }
202 
203 void ovpn_tcp_socket_wait_finish(struct ovpn_socket *sock)
204 {
205 	struct ovpn_peer *peer = sock->peer;
206 
207 	/* NOTE: we don't wait for peer->tcp.defer_del_work to finish:
208 	 * either the worker is not running or this function
209 	 * was invoked by that worker.
210 	 */
211 
212 	cancel_work_sync(&sock->tcp_tx_work);
213 	strp_done(&peer->tcp.strp);
214 
215 	skb_queue_purge(&peer->tcp.out_queue);
216 	kfree_skb(peer->tcp.out_msg.skb);
217 	peer->tcp.out_msg.skb = NULL;
218 }
219 
220 static void ovpn_tcp_send_sock(struct ovpn_peer *peer, struct sock *sk)
221 {
222 	struct sk_buff *skb = peer->tcp.out_msg.skb;
223 	int ret, flags;
224 
225 	if (!skb)
226 		return;
227 
228 	if (peer->tcp.tx_in_progress)
229 		return;
230 
231 	peer->tcp.tx_in_progress = true;
232 
233 	do {
234 		flags = ovpn_skb_cb(skb)->nosignal ? MSG_NOSIGNAL : 0;
235 		ret = skb_send_sock_locked_with_flags(sk, skb,
236 						      peer->tcp.out_msg.offset,
237 						      peer->tcp.out_msg.len,
238 						      flags);
239 		if (unlikely(ret < 0)) {
240 			if (ret == -EAGAIN)
241 				goto out;
242 
243 			net_warn_ratelimited("%s: TCP error to peer %u: %d\n",
244 					     netdev_name(peer->ovpn->dev),
245 					     peer->id, ret);
246 
247 			/* in case of TCP error we can't recover the VPN
248 			 * stream therefore we abort the connection
249 			 */
250 			ovpn_peer_hold(peer);
251 			schedule_work(&peer->tcp.defer_del_work);
252 
253 			/* we bail out immediately and keep tx_in_progress set
254 			 * to true. This way we prevent more TX attempts
255 			 * which would lead to more invocations of
256 			 * schedule_work()
257 			 */
258 			return;
259 		}
260 
261 		peer->tcp.out_msg.len -= ret;
262 		peer->tcp.out_msg.offset += ret;
263 	} while (peer->tcp.out_msg.len > 0);
264 
265 	if (!peer->tcp.out_msg.len) {
266 		preempt_disable();
267 		dev_dstats_tx_add(peer->ovpn->dev, skb->len);
268 		preempt_enable();
269 	}
270 
271 	kfree_skb(peer->tcp.out_msg.skb);
272 	peer->tcp.out_msg.skb = NULL;
273 	peer->tcp.out_msg.len = 0;
274 	peer->tcp.out_msg.offset = 0;
275 
276 out:
277 	peer->tcp.tx_in_progress = false;
278 }
279 
280 void ovpn_tcp_tx_work(struct work_struct *work)
281 {
282 	struct ovpn_socket *sock;
283 
284 	sock = container_of(work, struct ovpn_socket, tcp_tx_work);
285 
286 	lock_sock(sock->sock->sk);
287 	if (sock->peer)
288 		ovpn_tcp_send_sock(sock->peer, sock->sock->sk);
289 	release_sock(sock->sock->sk);
290 }
291 
292 static void ovpn_tcp_send_sock_skb(struct ovpn_peer *peer, struct sock *sk,
293 				   struct sk_buff *skb)
294 {
295 	if (peer->tcp.out_msg.skb)
296 		ovpn_tcp_send_sock(peer, sk);
297 
298 	if (peer->tcp.out_msg.skb) {
299 		dev_dstats_tx_dropped(peer->ovpn->dev);
300 		kfree_skb(skb);
301 		return;
302 	}
303 
304 	peer->tcp.out_msg.skb = skb;
305 	peer->tcp.out_msg.len = skb->len;
306 	peer->tcp.out_msg.offset = 0;
307 	ovpn_tcp_send_sock(peer, sk);
308 }
309 
310 void ovpn_tcp_send_skb(struct ovpn_peer *peer, struct socket *sock,
311 		       struct sk_buff *skb)
312 {
313 	u16 len = skb->len;
314 
315 	*(__be16 *)__skb_push(skb, sizeof(u16)) = htons(len);
316 
317 	spin_lock_nested(&sock->sk->sk_lock.slock, OVPN_TCP_DEPTH_NESTING);
318 	if (sock_owned_by_user(sock->sk)) {
319 		if (skb_queue_len(&peer->tcp.out_queue) >=
320 		    READ_ONCE(net_hotdata.max_backlog)) {
321 			dev_dstats_tx_dropped(peer->ovpn->dev);
322 			kfree_skb(skb);
323 			goto unlock;
324 		}
325 		__skb_queue_tail(&peer->tcp.out_queue, skb);
326 	} else {
327 		ovpn_tcp_send_sock_skb(peer, sock->sk, skb);
328 	}
329 unlock:
330 	spin_unlock(&sock->sk->sk_lock.slock);
331 }
332 
333 static void ovpn_tcp_release(struct sock *sk)
334 {
335 	struct sk_buff_head queue;
336 	struct ovpn_socket *sock;
337 	struct ovpn_peer *peer;
338 	struct sk_buff *skb;
339 
340 	rcu_read_lock();
341 	sock = rcu_dereference_sk_user_data(sk);
342 	if (!sock) {
343 		rcu_read_unlock();
344 		return;
345 	}
346 
347 	peer = sock->peer;
348 
349 	/* during initialization this function is called before
350 	 * assigning sock->peer
351 	 */
352 	if (unlikely(!peer || !ovpn_peer_hold(peer))) {
353 		rcu_read_unlock();
354 		return;
355 	}
356 	rcu_read_unlock();
357 
358 	__skb_queue_head_init(&queue);
359 	skb_queue_splice_init(&peer->tcp.out_queue, &queue);
360 
361 	while ((skb = __skb_dequeue(&queue)))
362 		ovpn_tcp_send_sock_skb(peer, sk, skb);
363 
364 	peer->tcp.sk_cb.prot->release_cb(sk);
365 	ovpn_peer_put(peer);
366 }
367 
368 static int ovpn_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
369 {
370 	struct ovpn_socket *sock;
371 	int ret, linear = PAGE_SIZE;
372 	struct ovpn_peer *peer;
373 	struct sk_buff *skb;
374 
375 	lock_sock(sk);
376 	rcu_read_lock();
377 	sock = rcu_dereference_sk_user_data(sk);
378 	if (unlikely(!sock || !sock->peer || !ovpn_peer_hold(sock->peer))) {
379 		rcu_read_unlock();
380 		release_sock(sk);
381 		return -EIO;
382 	}
383 	rcu_read_unlock();
384 	peer = sock->peer;
385 
386 	if (msg->msg_flags & ~(MSG_DONTWAIT | MSG_NOSIGNAL)) {
387 		ret = -EOPNOTSUPP;
388 		goto peer_free;
389 	}
390 
391 	if (peer->tcp.out_msg.skb) {
392 		ret = -EAGAIN;
393 		goto peer_free;
394 	}
395 
396 	if (size < linear)
397 		linear = size;
398 
399 	skb = sock_alloc_send_pskb(sk, linear, size - linear,
400 				   msg->msg_flags & MSG_DONTWAIT, &ret, 0);
401 	if (!skb) {
402 		net_err_ratelimited("%s: skb alloc failed: %d\n",
403 				    netdev_name(peer->ovpn->dev), ret);
404 		goto peer_free;
405 	}
406 
407 	skb_put(skb, linear);
408 	skb->len = size;
409 	skb->data_len = size - linear;
410 
411 	ret = skb_copy_datagram_from_iter(skb, 0, &msg->msg_iter, size);
412 	if (ret) {
413 		kfree_skb(skb);
414 		net_err_ratelimited("%s: skb copy from iter failed: %d\n",
415 				    netdev_name(peer->ovpn->dev), ret);
416 		goto peer_free;
417 	}
418 
419 	ovpn_skb_cb(skb)->nosignal = msg->msg_flags & MSG_NOSIGNAL;
420 	ovpn_tcp_send_sock_skb(peer, sk, skb);
421 	ret = size;
422 peer_free:
423 	release_sock(sk);
424 	ovpn_peer_put(peer);
425 	return ret;
426 }
427 
428 static int ovpn_tcp_disconnect(struct sock *sk, int flags)
429 {
430 	return -EBUSY;
431 }
432 
433 static void ovpn_tcp_data_ready(struct sock *sk)
434 {
435 	struct ovpn_socket *sock;
436 
437 	trace_sk_data_ready(sk);
438 
439 	rcu_read_lock();
440 	sock = rcu_dereference_sk_user_data(sk);
441 	if (likely(sock && sock->peer))
442 		strp_data_ready(&sock->peer->tcp.strp);
443 	rcu_read_unlock();
444 }
445 
446 static void ovpn_tcp_write_space(struct sock *sk)
447 {
448 	struct ovpn_socket *sock;
449 
450 	rcu_read_lock();
451 	sock = rcu_dereference_sk_user_data(sk);
452 	if (likely(sock && sock->peer)) {
453 		schedule_work(&sock->tcp_tx_work);
454 		sock->peer->tcp.sk_cb.sk_write_space(sk);
455 	}
456 	rcu_read_unlock();
457 }
458 
459 static void ovpn_tcp_build_protos(struct proto *new_prot,
460 				  struct proto_ops *new_ops,
461 				  const struct proto *orig_prot,
462 				  const struct proto_ops *orig_ops);
463 
464 static void ovpn_tcp_peer_del_work(struct work_struct *work)
465 {
466 	struct ovpn_peer *peer = container_of(work, struct ovpn_peer,
467 					      tcp.defer_del_work);
468 
469 	ovpn_peer_del(peer, OVPN_DEL_PEER_REASON_TRANSPORT_ERROR);
470 	ovpn_peer_put(peer);
471 }
472 
473 /* Set TCP encapsulation callbacks */
474 int ovpn_tcp_socket_attach(struct ovpn_socket *ovpn_sock,
475 			   struct ovpn_peer *peer)
476 {
477 	struct socket *sock = ovpn_sock->sock;
478 	struct strp_callbacks cb = {
479 		.rcv_msg = ovpn_tcp_rcv,
480 		.parse_msg = ovpn_tcp_parse,
481 	};
482 	int ret;
483 
484 	/* make sure no pre-existing encapsulation handler exists */
485 	if (sock->sk->sk_user_data)
486 		return -EBUSY;
487 
488 	/* only a fully connected socket is expected. Connection should be
489 	 * handled in userspace
490 	 */
491 	if (sock->sk->sk_state != TCP_ESTABLISHED) {
492 		net_err_ratelimited("%s: provided TCP socket is not in ESTABLISHED state: %d\n",
493 				    netdev_name(peer->ovpn->dev),
494 				    sock->sk->sk_state);
495 		return -EINVAL;
496 	}
497 
498 	ret = strp_init(&peer->tcp.strp, sock->sk, &cb);
499 	if (ret < 0) {
500 		DEBUG_NET_WARN_ON_ONCE(1);
501 		return ret;
502 	}
503 
504 	INIT_WORK(&peer->tcp.defer_del_work, ovpn_tcp_peer_del_work);
505 
506 	__sk_dst_reset(sock->sk);
507 	skb_queue_head_init(&peer->tcp.user_queue);
508 	skb_queue_head_init(&peer->tcp.out_queue);
509 
510 	/* save current CBs so that they can be restored upon socket release */
511 	peer->tcp.sk_cb.sk_data_ready = sock->sk->sk_data_ready;
512 	peer->tcp.sk_cb.sk_write_space = sock->sk->sk_write_space;
513 	peer->tcp.sk_cb.prot = sock->sk->sk_prot;
514 	peer->tcp.sk_cb.ops = sock->sk->sk_socket->ops;
515 
516 	/* assign our static CBs and prot/ops */
517 	sock->sk->sk_data_ready = ovpn_tcp_data_ready;
518 	sock->sk->sk_write_space = ovpn_tcp_write_space;
519 
520 	if (sock->sk->sk_family == AF_INET) {
521 		sock->sk->sk_prot = &ovpn_tcp_prot;
522 		sock->sk->sk_socket->ops = &ovpn_tcp_ops;
523 	} else {
524 		sock->sk->sk_prot = &ovpn_tcp6_prot;
525 		sock->sk->sk_socket->ops = &ovpn_tcp6_ops;
526 	}
527 
528 	/* avoid using task_frag */
529 	sock->sk->sk_allocation = GFP_ATOMIC;
530 	sock->sk->sk_use_task_frag = false;
531 
532 	/* enqueue the RX worker */
533 	strp_check_rcv(&peer->tcp.strp);
534 
535 	return 0;
536 }
537 
538 static void ovpn_tcp_close(struct sock *sk, long timeout)
539 {
540 	struct ovpn_socket *sock;
541 	struct ovpn_peer *peer;
542 
543 	rcu_read_lock();
544 	sock = rcu_dereference_sk_user_data(sk);
545 	if (!sock || !sock->peer || !ovpn_peer_hold(sock->peer)) {
546 		rcu_read_unlock();
547 		return;
548 	}
549 	peer = sock->peer;
550 	rcu_read_unlock();
551 
552 	ovpn_peer_del(sock->peer, OVPN_DEL_PEER_REASON_TRANSPORT_DISCONNECT);
553 	peer->tcp.sk_cb.prot->close(sk, timeout);
554 	ovpn_peer_put(peer);
555 }
556 
557 static __poll_t ovpn_tcp_poll(struct file *file, struct socket *sock,
558 			      poll_table *wait)
559 {
560 	__poll_t mask = datagram_poll(file, sock, wait);
561 	struct ovpn_socket *ovpn_sock;
562 
563 	rcu_read_lock();
564 	ovpn_sock = rcu_dereference_sk_user_data(sock->sk);
565 	if (ovpn_sock && ovpn_sock->peer &&
566 	    !skb_queue_empty(&ovpn_sock->peer->tcp.user_queue))
567 		mask |= EPOLLIN | EPOLLRDNORM;
568 	rcu_read_unlock();
569 
570 	return mask;
571 }
572 
573 static void ovpn_tcp_build_protos(struct proto *new_prot,
574 				  struct proto_ops *new_ops,
575 				  const struct proto *orig_prot,
576 				  const struct proto_ops *orig_ops)
577 {
578 	memcpy(new_prot, orig_prot, sizeof(*new_prot));
579 	memcpy(new_ops, orig_ops, sizeof(*new_ops));
580 	new_prot->recvmsg = ovpn_tcp_recvmsg;
581 	new_prot->sendmsg = ovpn_tcp_sendmsg;
582 	new_prot->disconnect = ovpn_tcp_disconnect;
583 	new_prot->close = ovpn_tcp_close;
584 	new_prot->release_cb = ovpn_tcp_release;
585 	new_ops->poll = ovpn_tcp_poll;
586 }
587 
588 /* Initialize TCP static objects */
589 void __init ovpn_tcp_init(void)
590 {
591 	ovpn_tcp_build_protos(&ovpn_tcp_prot, &ovpn_tcp_ops, &tcp_prot,
592 			      &inet_stream_ops);
593 
594 #if IS_ENABLED(CONFIG_IPV6)
595 	ovpn_tcp_build_protos(&ovpn_tcp6_prot, &ovpn_tcp6_ops, &tcpv6_prot,
596 			      &inet6_stream_ops);
597 #endif
598 }
599