xref: /linux/net/vmw_vsock/vsock_bpf.c (revision 566ab427f827b0256d3e8ce0235d088e6a9c28bd)
1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2022 Bobby Eshleman <bobby.eshleman@bytedance.com>
3  *
4  * Based off of net/unix/unix_bpf.c
5  */
6 
7 #include <linux/bpf.h>
8 #include <linux/module.h>
9 #include <linux/skmsg.h>
10 #include <linux/socket.h>
11 #include <linux/wait.h>
12 #include <net/af_vsock.h>
13 #include <net/sock.h>
14 
15 #define vsock_sk_has_data(__sk, __psock)				\
16 		({	!skb_queue_empty(&(__sk)->sk_receive_queue) ||	\
17 			!skb_queue_empty(&(__psock)->ingress_skb) ||	\
18 			!list_empty(&(__psock)->ingress_msg);		\
19 		})
20 
21 static struct proto *vsock_prot_saved __read_mostly;
22 static DEFINE_SPINLOCK(vsock_prot_lock);
23 static struct proto vsock_bpf_prot;
24 
25 static bool vsock_has_data(struct sock *sk, struct sk_psock *psock)
26 {
27 	struct vsock_sock *vsk = vsock_sk(sk);
28 	s64 ret;
29 
30 	ret = vsock_connectible_has_data(vsk);
31 	if (ret > 0)
32 		return true;
33 
34 	return vsock_sk_has_data(sk, psock);
35 }
36 
37 static bool vsock_msg_wait_data(struct sock *sk, struct sk_psock *psock, long timeo)
38 {
39 	bool ret;
40 
41 	DEFINE_WAIT_FUNC(wait, woken_wake_function);
42 
43 	if (sk->sk_shutdown & RCV_SHUTDOWN)
44 		return true;
45 
46 	if (!timeo)
47 		return false;
48 
49 	add_wait_queue(sk_sleep(sk), &wait);
50 	sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
51 	ret = vsock_has_data(sk, psock);
52 	if (!ret) {
53 		wait_woken(&wait, TASK_INTERRUPTIBLE, timeo);
54 		ret = vsock_has_data(sk, psock);
55 	}
56 	sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
57 	remove_wait_queue(sk_sleep(sk), &wait);
58 	return ret;
59 }
60 
61 static int __vsock_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int flags)
62 {
63 	struct socket *sock = sk->sk_socket;
64 	int err;
65 
66 	if (sk->sk_type == SOCK_STREAM || sk->sk_type == SOCK_SEQPACKET)
67 		err = __vsock_connectible_recvmsg(sock, msg, len, flags);
68 	else if (sk->sk_type == SOCK_DGRAM)
69 		err = __vsock_dgram_recvmsg(sock, msg, len, flags);
70 	else
71 		err = -EPROTOTYPE;
72 
73 	return err;
74 }
75 
76 static int vsock_bpf_recvmsg(struct sock *sk, struct msghdr *msg,
77 			     size_t len, int flags, int *addr_len)
78 {
79 	struct sk_psock *psock;
80 	int copied;
81 
82 	psock = sk_psock_get(sk);
83 	if (unlikely(!psock))
84 		return __vsock_recvmsg(sk, msg, len, flags);
85 
86 	lock_sock(sk);
87 	if (vsock_has_data(sk, psock) && sk_psock_queue_empty(psock)) {
88 		release_sock(sk);
89 		sk_psock_put(sk, psock);
90 		return __vsock_recvmsg(sk, msg, len, flags);
91 	}
92 
93 	copied = sk_msg_recvmsg(sk, psock, msg, len, flags);
94 	while (copied == 0) {
95 		long timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
96 
97 		if (!vsock_msg_wait_data(sk, psock, timeo)) {
98 			copied = -EAGAIN;
99 			break;
100 		}
101 
102 		if (sk_psock_queue_empty(psock)) {
103 			release_sock(sk);
104 			sk_psock_put(sk, psock);
105 			return __vsock_recvmsg(sk, msg, len, flags);
106 		}
107 
108 		copied = sk_msg_recvmsg(sk, psock, msg, len, flags);
109 	}
110 
111 	release_sock(sk);
112 	sk_psock_put(sk, psock);
113 
114 	return copied;
115 }
116 
117 /* Copy of original proto with updated sock_map methods */
118 static struct proto vsock_bpf_prot = {
119 	.close = sock_map_close,
120 	.recvmsg = vsock_bpf_recvmsg,
121 	.sock_is_readable = sk_msg_is_readable,
122 	.unhash = sock_map_unhash,
123 };
124 
125 static void vsock_bpf_rebuild_protos(struct proto *prot, const struct proto *base)
126 {
127 	*prot        = *base;
128 	prot->close  = sock_map_close;
129 	prot->recvmsg = vsock_bpf_recvmsg;
130 	prot->sock_is_readable = sk_msg_is_readable;
131 }
132 
133 static void vsock_bpf_check_needs_rebuild(struct proto *ops)
134 {
135 	/* Paired with the smp_store_release() below. */
136 	if (unlikely(ops != smp_load_acquire(&vsock_prot_saved))) {
137 		spin_lock_bh(&vsock_prot_lock);
138 		if (likely(ops != vsock_prot_saved)) {
139 			vsock_bpf_rebuild_protos(&vsock_bpf_prot, ops);
140 			/* Make sure proto function pointers are updated before publishing the
141 			 * pointer to the struct.
142 			 */
143 			smp_store_release(&vsock_prot_saved, ops);
144 		}
145 		spin_unlock_bh(&vsock_prot_lock);
146 	}
147 }
148 
149 int vsock_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore)
150 {
151 	struct vsock_sock *vsk;
152 
153 	if (restore) {
154 		sk->sk_write_space = psock->saved_write_space;
155 		sock_replace_proto(sk, psock->sk_proto);
156 		return 0;
157 	}
158 
159 	vsk = vsock_sk(sk);
160 	if (!vsk->transport)
161 		return -ENODEV;
162 
163 	if (!vsk->transport->read_skb)
164 		return -EOPNOTSUPP;
165 
166 	vsock_bpf_check_needs_rebuild(psock->sk_proto);
167 	sock_replace_proto(sk, &vsock_bpf_prot);
168 	return 0;
169 }
170 
171 void __init vsock_bpf_build_proto(void)
172 {
173 	vsock_bpf_rebuild_protos(&vsock_bpf_prot, &vsock_proto);
174 }
175