xref: /linux/net/vmw_vsock/vsock_bpf.c (revision a634dda26186cf9a51567020fcce52bcba5e1e59)
1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2022 Bobby Eshleman <bobby.eshleman@bytedance.com>
3  *
4  * Based off of net/unix/unix_bpf.c
5  */
6 
7 #include <linux/bpf.h>
8 #include <linux/module.h>
9 #include <linux/skmsg.h>
10 #include <linux/socket.h>
11 #include <linux/wait.h>
12 #include <net/af_vsock.h>
13 #include <net/sock.h>
14 
15 #define vsock_sk_has_data(__sk, __psock)				\
16 		({	!skb_queue_empty(&(__sk)->sk_receive_queue) ||	\
17 			!skb_queue_empty(&(__psock)->ingress_skb) ||	\
18 			!list_empty(&(__psock)->ingress_msg);		\
19 		})
20 
21 static struct proto *vsock_prot_saved __read_mostly;
22 static DEFINE_SPINLOCK(vsock_prot_lock);
23 static struct proto vsock_bpf_prot;
24 
25 static bool vsock_has_data(struct sock *sk, struct sk_psock *psock)
26 {
27 	struct vsock_sock *vsk = vsock_sk(sk);
28 	s64 ret;
29 
30 	ret = vsock_connectible_has_data(vsk);
31 	if (ret > 0)
32 		return true;
33 
34 	return vsock_sk_has_data(sk, psock);
35 }
36 
37 static bool vsock_msg_wait_data(struct sock *sk, struct sk_psock *psock, long timeo)
38 {
39 	bool ret;
40 
41 	DEFINE_WAIT_FUNC(wait, woken_wake_function);
42 
43 	if (sk->sk_shutdown & RCV_SHUTDOWN)
44 		return true;
45 
46 	if (!timeo)
47 		return false;
48 
49 	add_wait_queue(sk_sleep(sk), &wait);
50 	sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
51 	ret = vsock_has_data(sk, psock);
52 	if (!ret) {
53 		wait_woken(&wait, TASK_INTERRUPTIBLE, timeo);
54 		ret = vsock_has_data(sk, psock);
55 	}
56 	sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
57 	remove_wait_queue(sk_sleep(sk), &wait);
58 	return ret;
59 }
60 
61 static int __vsock_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int flags)
62 {
63 	struct socket *sock = sk->sk_socket;
64 	int err;
65 
66 	if (sk->sk_type == SOCK_STREAM || sk->sk_type == SOCK_SEQPACKET)
67 		err = __vsock_connectible_recvmsg(sock, msg, len, flags);
68 	else if (sk->sk_type == SOCK_DGRAM)
69 		err = __vsock_dgram_recvmsg(sock, msg, len, flags);
70 	else
71 		err = -EPROTOTYPE;
72 
73 	return err;
74 }
75 
76 static int vsock_bpf_recvmsg(struct sock *sk, struct msghdr *msg,
77 			     size_t len, int flags, int *addr_len)
78 {
79 	struct sk_psock *psock;
80 	struct vsock_sock *vsk;
81 	int copied;
82 
83 	psock = sk_psock_get(sk);
84 	if (unlikely(!psock))
85 		return __vsock_recvmsg(sk, msg, len, flags);
86 
87 	lock_sock(sk);
88 	vsk = vsock_sk(sk);
89 
90 	if (!vsk->transport) {
91 		copied = -ENODEV;
92 		goto out;
93 	}
94 
95 	if (vsock_has_data(sk, psock) && sk_psock_queue_empty(psock)) {
96 		release_sock(sk);
97 		sk_psock_put(sk, psock);
98 		return __vsock_recvmsg(sk, msg, len, flags);
99 	}
100 
101 	copied = sk_msg_recvmsg(sk, psock, msg, len, flags);
102 	while (copied == 0) {
103 		long timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
104 
105 		if (!vsock_msg_wait_data(sk, psock, timeo)) {
106 			copied = -EAGAIN;
107 			break;
108 		}
109 
110 		if (sk_psock_queue_empty(psock)) {
111 			release_sock(sk);
112 			sk_psock_put(sk, psock);
113 			return __vsock_recvmsg(sk, msg, len, flags);
114 		}
115 
116 		copied = sk_msg_recvmsg(sk, psock, msg, len, flags);
117 	}
118 
119 out:
120 	release_sock(sk);
121 	sk_psock_put(sk, psock);
122 
123 	return copied;
124 }
125 
126 static void vsock_bpf_rebuild_protos(struct proto *prot, const struct proto *base)
127 {
128 	*prot        = *base;
129 	prot->close  = sock_map_close;
130 	prot->recvmsg = vsock_bpf_recvmsg;
131 	prot->sock_is_readable = sk_msg_is_readable;
132 }
133 
134 static void vsock_bpf_check_needs_rebuild(struct proto *ops)
135 {
136 	/* Paired with the smp_store_release() below. */
137 	if (unlikely(ops != smp_load_acquire(&vsock_prot_saved))) {
138 		spin_lock_bh(&vsock_prot_lock);
139 		if (likely(ops != vsock_prot_saved)) {
140 			vsock_bpf_rebuild_protos(&vsock_bpf_prot, ops);
141 			/* Make sure proto function pointers are updated before publishing the
142 			 * pointer to the struct.
143 			 */
144 			smp_store_release(&vsock_prot_saved, ops);
145 		}
146 		spin_unlock_bh(&vsock_prot_lock);
147 	}
148 }
149 
150 int vsock_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore)
151 {
152 	struct vsock_sock *vsk;
153 
154 	if (restore) {
155 		sk->sk_write_space = psock->saved_write_space;
156 		sock_replace_proto(sk, psock->sk_proto);
157 		return 0;
158 	}
159 
160 	vsk = vsock_sk(sk);
161 	if (!vsk->transport)
162 		return -ENODEV;
163 
164 	if (!vsk->transport->read_skb)
165 		return -EOPNOTSUPP;
166 
167 	vsock_bpf_check_needs_rebuild(psock->sk_proto);
168 	sock_replace_proto(sk, &vsock_bpf_prot);
169 	return 0;
170 }
171 
172 void __init vsock_bpf_build_proto(void)
173 {
174 	vsock_bpf_rebuild_protos(&vsock_bpf_prot, &vsock_proto);
175 }
176