xref: /linux/net/vmw_vsock/af_vsock.c (revision eafb64f40ca49c79f0769aab25d0fae5c9d3becb)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * VMware vSockets Driver
4  *
5  * Copyright (C) 2007-2013 VMware, Inc. All rights reserved.
6  */
7 
8 /* Implementation notes:
9  *
10  * - There are two kinds of sockets: those created by user action (such as
11  * calling socket(2)) and those created by incoming connection request packets.
12  *
13  * - There are two "global" tables, one for bound sockets (sockets that have
14  * specified an address that they are responsible for) and one for connected
15  * sockets (sockets that have established a connection with another socket).
16  * These tables are "global" in that all sockets on the system are placed
17  * within them. - Note, though, that the bound table contains an extra entry
18  * for a list of unbound sockets and SOCK_DGRAM sockets will always remain in
19  * that list. The bound table is used solely for lookup of sockets when packets
20  * are received and that's not necessary for SOCK_DGRAM sockets since we create
21  * a datagram handle for each and need not perform a lookup.  Keeping SOCK_DGRAM
22  * sockets out of the bound hash buckets will reduce the chance of collisions
23  * when looking for SOCK_STREAM sockets and prevents us from having to check the
24  * socket type in the hash table lookups.
25  *
26  * - Sockets created by user action will either be "client" sockets that
27  * initiate a connection or "server" sockets that listen for connections; we do
28  * not support simultaneous connects (two "client" sockets connecting).
29  *
30  * - "Server" sockets are referred to as listener sockets throughout this
31  * implementation because they are in the TCP_LISTEN state.  When a
32  * connection request is received (the second kind of socket mentioned above),
33  * we create a new socket and refer to it as a pending socket.  These pending
34  * sockets are placed on the pending connection list of the listener socket.
35  * When future packets are received for the address the listener socket is
36  * bound to, we check if the source of the packet is from one that has an
37  * existing pending connection.  If it does, we process the packet for the
38  * pending socket.  When that socket reaches the connected state, it is removed
39  * from the listener socket's pending list and enqueued in the listener
40  * socket's accept queue.  Callers of accept(2) will accept connected sockets
41  * from the listener socket's accept queue.  If the socket cannot be accepted
42  * for some reason then it is marked rejected.  Once the connection is
43  * accepted, it is owned by the user process and the responsibility for cleanup
44  * falls with that user process.
45  *
46  * - It is possible that these pending sockets will never reach the connected
47  * state; in fact, we may never receive another packet after the connection
48  * request.  Because of this, we must schedule a cleanup function to run in the
49  * future, after some amount of time passes where a connection should have been
50  * established.  This function ensures that the socket is off all lists so it
51  * cannot be retrieved, then drops all references to the socket so it is cleaned
52  * up (sock_put() -> sk_free() -> our sk_destruct implementation).  Note this
53  * function will also cleanup rejected sockets, those that reach the connected
54  * state but leave it before they have been accepted.
55  *
56  * - Lock ordering for pending or accept queue sockets is:
57  *
58  *     lock_sock(listener);
59  *     lock_sock_nested(pending, SINGLE_DEPTH_NESTING);
60  *
61  * Using explicit nested locking keeps lockdep happy since normally only one
62  * lock of a given class may be taken at a time.
63  *
64  * - Sockets created by user action will be cleaned up when the user process
65  * calls close(2), causing our release implementation to be called. Our release
66  * implementation will perform some cleanup then drop the last reference so our
67  * sk_destruct implementation is invoked.  Our sk_destruct implementation will
68  * perform additional cleanup that's common for both types of sockets.
69  *
70  * - A socket's reference count is what ensures that the structure won't be
71  * freed.  Each entry in a list (such as the "global" bound and connected tables
72  * and the listener socket's pending list and connected queue) ensures a
73  * reference.  When we defer work until process context and pass a socket as our
74  * argument, we must ensure the reference count is increased to ensure the
75  * socket isn't freed before the function is run; the deferred function will
76  * then drop the reference.
77  *
78  * - sk->sk_state uses the TCP state constants because they are widely used by
79  * other address families and exposed to userspace tools like ss(8):
80  *
81  *   TCP_CLOSE - unconnected
82  *   TCP_SYN_SENT - connecting
83  *   TCP_ESTABLISHED - connected
84  *   TCP_CLOSING - disconnecting
85  *   TCP_LISTEN - listening
86  *
87  * - Namespaces in vsock support two different modes: "local" and "global".
88  *   Each mode defines how the namespace interacts with CIDs.
89  *   Each namespace exposes two sysctl files:
90  *
91  *   - /proc/sys/net/vsock/ns_mode (read-only) reports the current namespace's
92  *     mode, which is set at namespace creation and immutable thereafter.
93  *   - /proc/sys/net/vsock/child_ns_mode (writable) controls what mode future
94  *     child namespaces will inherit when created. The default is "global".
95  *
96  *   Changing child_ns_mode only affects newly created namespaces, not the
97  *   current namespace or existing children. At namespace creation, ns_mode
98  *   is inherited from the parent's child_ns_mode.
99  *
100  *   The init_net mode is "global" and cannot be modified.
101  *
102  *   The modes affect the allocation and accessibility of CIDs as follows:
103  *
104  *   - global - access and allocation are all system-wide
105  *      - all CID allocation from global namespaces draw from the same
106  *        system-wide pool.
107  *      - if one global namespace has already allocated some CID, another
108  *        global namespace will not be able to allocate the same CID.
109  *      - global mode AF_VSOCK sockets can reach any VM or socket in any global
110  *        namespace, they are not contained to only their own namespace.
111  *      - AF_VSOCK sockets in a global mode namespace cannot reach VMs or
112  *        sockets in any local mode namespace.
113  *   - local - access and allocation are contained within the namespace
114  *     - CID allocation draws only from a private pool local only to the
115  *       namespace, and does not affect the CIDs available for allocation in any
116  *       other namespace (global or local).
117  *     - VMs in a local namespace do not collide with CIDs in any other local
118  *       namespace or any global namespace. For example, if a VM in a local mode
119  *       namespace is given CID 10, then CID 10 is still available for
120  *       allocation in any other namespace, but not in the same namespace.
121  *     - AF_VSOCK sockets in a local mode namespace can connect only to VMs or
122  *       other sockets within their own namespace.
123  *     - sockets bound to VMADDR_CID_ANY in local namespaces will never resolve
124  *       to any transport that is not compatible with local mode. There is no
125  *       error that propagates to the user (as there is for connection attempts)
126  *       because it is possible for some packet to reach this socket from
127  *       a different transport that *does* support local mode. For
128  *       example, virtio-vsock may not support local mode, but the socket
129  *       may still accept a connection from vhost-vsock which does.
130  */
131 
132 #include <linux/compat.h>
133 #include <linux/types.h>
134 #include <linux/bitops.h>
135 #include <linux/cred.h>
136 #include <linux/errqueue.h>
137 #include <linux/init.h>
138 #include <linux/io.h>
139 #include <linux/kernel.h>
140 #include <linux/sched/signal.h>
141 #include <linux/kmod.h>
142 #include <linux/list.h>
143 #include <linux/miscdevice.h>
144 #include <linux/module.h>
145 #include <linux/mutex.h>
146 #include <linux/net.h>
147 #include <linux/proc_fs.h>
148 #include <linux/poll.h>
149 #include <linux/random.h>
150 #include <linux/skbuff.h>
151 #include <linux/smp.h>
152 #include <linux/socket.h>
153 #include <linux/stddef.h>
154 #include <linux/sysctl.h>
155 #include <linux/unistd.h>
156 #include <linux/wait.h>
157 #include <linux/workqueue.h>
158 #include <net/sock.h>
159 #include <net/af_vsock.h>
160 #include <net/netns/vsock.h>
161 #include <uapi/linux/vm_sockets.h>
162 #include <uapi/asm-generic/ioctls.h>
163 
164 #define VSOCK_NET_MODE_STR_GLOBAL "global"
165 #define VSOCK_NET_MODE_STR_LOCAL "local"
166 
167 /* 6 chars for "global", 1 for null-terminator, and 1 more for '\n'.
168  * The newline is added by proc_dostring() for read operations.
169  */
170 #define VSOCK_NET_MODE_STR_MAX 8
171 
172 static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr);
173 static void vsock_sk_destruct(struct sock *sk);
174 static int vsock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb);
175 static void vsock_close(struct sock *sk, long timeout);
176 
177 /* Protocol family. */
178 struct proto vsock_proto = {
179 	.name = "AF_VSOCK",
180 	.owner = THIS_MODULE,
181 	.obj_size = sizeof(struct vsock_sock),
182 	.close = vsock_close,
183 #ifdef CONFIG_BPF_SYSCALL
184 	.psock_update_sk_prot = vsock_bpf_update_proto,
185 #endif
186 };
187 
188 /* The default peer timeout indicates how long we will wait for a peer response
189  * to a control message.
190  */
191 #define VSOCK_DEFAULT_CONNECT_TIMEOUT (2 * HZ)
192 
193 #define VSOCK_DEFAULT_BUFFER_SIZE     (1024 * 256)
194 #define VSOCK_DEFAULT_BUFFER_MAX_SIZE (1024 * 256)
195 #define VSOCK_DEFAULT_BUFFER_MIN_SIZE 128
196 
197 /* Transport used for host->guest communication */
198 static const struct vsock_transport *transport_h2g;
199 /* Transport used for guest->host communication */
200 static const struct vsock_transport *transport_g2h;
201 /* Transport used for DGRAM communication */
202 static const struct vsock_transport *transport_dgram;
203 /* Transport used for local communication */
204 static const struct vsock_transport *transport_local;
205 static DEFINE_MUTEX(vsock_register_mutex);
206 
207 /**** UTILS ****/
208 
209 /* Each bound VSocket is stored in the bind hash table and each connected
210  * VSocket is stored in the connected hash table.
211  *
212  * Unbound sockets are all put on the same list attached to the end of the hash
213  * table (vsock_unbound_sockets).  Bound sockets are added to the hash table in
214  * the bucket that their local address hashes to (vsock_bound_sockets(addr)
215  * represents the list that addr hashes to).
216  *
217  * Specifically, we initialize the vsock_bind_table array to a size of
218  * VSOCK_HASH_SIZE + 1 so that vsock_bind_table[0] through
219  * vsock_bind_table[VSOCK_HASH_SIZE - 1] are for bound sockets and
220  * vsock_bind_table[VSOCK_HASH_SIZE] is for unbound sockets.  The hash function
221  * mods with VSOCK_HASH_SIZE to ensure this.
222  */
223 #define MAX_PORT_RETRIES        24
224 
225 #define VSOCK_HASH(addr)        ((addr)->svm_port % VSOCK_HASH_SIZE)
226 #define vsock_bound_sockets(addr) (&vsock_bind_table[VSOCK_HASH(addr)])
227 #define vsock_unbound_sockets     (&vsock_bind_table[VSOCK_HASH_SIZE])
228 
229 /* XXX This can probably be implemented in a better way. */
230 #define VSOCK_CONN_HASH(src, dst)				\
231 	(((src)->svm_cid ^ (dst)->svm_port) % VSOCK_HASH_SIZE)
232 #define vsock_connected_sockets(src, dst)		\
233 	(&vsock_connected_table[VSOCK_CONN_HASH(src, dst)])
234 #define vsock_connected_sockets_vsk(vsk)				\
235 	vsock_connected_sockets(&(vsk)->remote_addr, &(vsk)->local_addr)
236 
237 struct list_head vsock_bind_table[VSOCK_HASH_SIZE + 1];
238 EXPORT_SYMBOL_GPL(vsock_bind_table);
239 struct list_head vsock_connected_table[VSOCK_HASH_SIZE];
240 EXPORT_SYMBOL_GPL(vsock_connected_table);
241 DEFINE_SPINLOCK(vsock_table_lock);
242 EXPORT_SYMBOL_GPL(vsock_table_lock);
243 
244 /* Autobind this socket to the local address if necessary. */
245 static int vsock_auto_bind(struct vsock_sock *vsk)
246 {
247 	struct sock *sk = sk_vsock(vsk);
248 	struct sockaddr_vm local_addr;
249 
250 	if (vsock_addr_bound(&vsk->local_addr))
251 		return 0;
252 	vsock_addr_init(&local_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY);
253 	return __vsock_bind(sk, &local_addr);
254 }
255 
256 static void vsock_init_tables(void)
257 {
258 	int i;
259 
260 	for (i = 0; i < ARRAY_SIZE(vsock_bind_table); i++)
261 		INIT_LIST_HEAD(&vsock_bind_table[i]);
262 
263 	for (i = 0; i < ARRAY_SIZE(vsock_connected_table); i++)
264 		INIT_LIST_HEAD(&vsock_connected_table[i]);
265 }
266 
267 static void __vsock_insert_bound(struct list_head *list,
268 				 struct vsock_sock *vsk)
269 {
270 	sock_hold(&vsk->sk);
271 	list_add(&vsk->bound_table, list);
272 }
273 
274 static void __vsock_insert_connected(struct list_head *list,
275 				     struct vsock_sock *vsk)
276 {
277 	sock_hold(&vsk->sk);
278 	list_add(&vsk->connected_table, list);
279 }
280 
281 static void __vsock_remove_bound(struct vsock_sock *vsk)
282 {
283 	list_del_init(&vsk->bound_table);
284 	sock_put(&vsk->sk);
285 }
286 
287 static void __vsock_remove_connected(struct vsock_sock *vsk)
288 {
289 	list_del_init(&vsk->connected_table);
290 	sock_put(&vsk->sk);
291 }
292 
293 static struct sock *__vsock_find_bound_socket_net(struct sockaddr_vm *addr,
294 						  struct net *net)
295 {
296 	struct vsock_sock *vsk;
297 
298 	list_for_each_entry(vsk, vsock_bound_sockets(addr), bound_table) {
299 		struct sock *sk = sk_vsock(vsk);
300 
301 		if (vsock_addr_equals_addr(addr, &vsk->local_addr) &&
302 		    vsock_net_check_mode(sock_net(sk), net))
303 			return sk;
304 
305 		if (addr->svm_port == vsk->local_addr.svm_port &&
306 		    (vsk->local_addr.svm_cid == VMADDR_CID_ANY ||
307 		     addr->svm_cid == VMADDR_CID_ANY) &&
308 		     vsock_net_check_mode(sock_net(sk), net))
309 			return sk;
310 	}
311 
312 	return NULL;
313 }
314 
315 static struct sock *
316 __vsock_find_connected_socket_net(struct sockaddr_vm *src,
317 				  struct sockaddr_vm *dst, struct net *net)
318 {
319 	struct vsock_sock *vsk;
320 
321 	list_for_each_entry(vsk, vsock_connected_sockets(src, dst),
322 			    connected_table) {
323 		struct sock *sk = sk_vsock(vsk);
324 
325 		if (vsock_addr_equals_addr(src, &vsk->remote_addr) &&
326 		    dst->svm_port == vsk->local_addr.svm_port &&
327 		    vsock_net_check_mode(sock_net(sk), net)) {
328 			return sk;
329 		}
330 	}
331 
332 	return NULL;
333 }
334 
335 static void vsock_insert_unbound(struct vsock_sock *vsk)
336 {
337 	spin_lock_bh(&vsock_table_lock);
338 	__vsock_insert_bound(vsock_unbound_sockets, vsk);
339 	spin_unlock_bh(&vsock_table_lock);
340 }
341 
342 void vsock_insert_connected(struct vsock_sock *vsk)
343 {
344 	struct list_head *list = vsock_connected_sockets(
345 		&vsk->remote_addr, &vsk->local_addr);
346 
347 	spin_lock_bh(&vsock_table_lock);
348 	__vsock_insert_connected(list, vsk);
349 	spin_unlock_bh(&vsock_table_lock);
350 }
351 EXPORT_SYMBOL_GPL(vsock_insert_connected);
352 
353 void vsock_remove_bound(struct vsock_sock *vsk)
354 {
355 	spin_lock_bh(&vsock_table_lock);
356 	if (__vsock_in_bound_table(vsk))
357 		__vsock_remove_bound(vsk);
358 	spin_unlock_bh(&vsock_table_lock);
359 }
360 EXPORT_SYMBOL_GPL(vsock_remove_bound);
361 
362 void vsock_remove_connected(struct vsock_sock *vsk)
363 {
364 	spin_lock_bh(&vsock_table_lock);
365 	if (__vsock_in_connected_table(vsk))
366 		__vsock_remove_connected(vsk);
367 	spin_unlock_bh(&vsock_table_lock);
368 }
369 EXPORT_SYMBOL_GPL(vsock_remove_connected);
370 
371 /* Find a bound socket, filtering by namespace and namespace mode.
372  *
373  * Use this in transports that are namespace-aware and can provide the
374  * network namespace context.
375  */
376 struct sock *vsock_find_bound_socket_net(struct sockaddr_vm *addr,
377 					 struct net *net)
378 {
379 	struct sock *sk;
380 
381 	spin_lock_bh(&vsock_table_lock);
382 	sk = __vsock_find_bound_socket_net(addr, net);
383 	if (sk)
384 		sock_hold(sk);
385 
386 	spin_unlock_bh(&vsock_table_lock);
387 
388 	return sk;
389 }
390 EXPORT_SYMBOL_GPL(vsock_find_bound_socket_net);
391 
392 /* Find a bound socket without namespace filtering.
393  *
394  * Use this in transports that lack namespace context. All sockets are
395  * treated as if in global mode.
396  */
397 struct sock *vsock_find_bound_socket(struct sockaddr_vm *addr)
398 {
399 	return vsock_find_bound_socket_net(addr, NULL);
400 }
401 EXPORT_SYMBOL_GPL(vsock_find_bound_socket);
402 
403 /* Find a connected socket, filtering by namespace and namespace mode.
404  *
405  * Use this in transports that are namespace-aware and can provide the
406  * network namespace context.
407  */
408 struct sock *vsock_find_connected_socket_net(struct sockaddr_vm *src,
409 					     struct sockaddr_vm *dst,
410 					     struct net *net)
411 {
412 	struct sock *sk;
413 
414 	spin_lock_bh(&vsock_table_lock);
415 	sk = __vsock_find_connected_socket_net(src, dst, net);
416 	if (sk)
417 		sock_hold(sk);
418 
419 	spin_unlock_bh(&vsock_table_lock);
420 
421 	return sk;
422 }
423 EXPORT_SYMBOL_GPL(vsock_find_connected_socket_net);
424 
425 /* Find a connected socket without namespace filtering.
426  *
427  * Use this in transports that lack namespace context. All sockets are
428  * treated as if in global mode.
429  */
430 struct sock *vsock_find_connected_socket(struct sockaddr_vm *src,
431 					 struct sockaddr_vm *dst)
432 {
433 	return vsock_find_connected_socket_net(src, dst, NULL);
434 }
435 EXPORT_SYMBOL_GPL(vsock_find_connected_socket);
436 
437 void vsock_remove_sock(struct vsock_sock *vsk)
438 {
439 	/* Transport reassignment must not remove the binding. */
440 	if (sock_flag(sk_vsock(vsk), SOCK_DEAD))
441 		vsock_remove_bound(vsk);
442 
443 	vsock_remove_connected(vsk);
444 }
445 EXPORT_SYMBOL_GPL(vsock_remove_sock);
446 
447 void vsock_for_each_connected_socket(struct vsock_transport *transport,
448 				     void (*fn)(struct sock *sk))
449 {
450 	int i;
451 
452 	spin_lock_bh(&vsock_table_lock);
453 
454 	for (i = 0; i < ARRAY_SIZE(vsock_connected_table); i++) {
455 		struct vsock_sock *vsk;
456 		list_for_each_entry(vsk, &vsock_connected_table[i],
457 				    connected_table) {
458 			if (vsk->transport != transport)
459 				continue;
460 
461 			fn(sk_vsock(vsk));
462 		}
463 	}
464 
465 	spin_unlock_bh(&vsock_table_lock);
466 }
467 EXPORT_SYMBOL_GPL(vsock_for_each_connected_socket);
468 
469 void vsock_add_pending(struct sock *listener, struct sock *pending)
470 {
471 	struct vsock_sock *vlistener;
472 	struct vsock_sock *vpending;
473 
474 	vlistener = vsock_sk(listener);
475 	vpending = vsock_sk(pending);
476 
477 	sock_hold(pending);
478 	sock_hold(listener);
479 	list_add_tail(&vpending->pending_links, &vlistener->pending_links);
480 }
481 EXPORT_SYMBOL_GPL(vsock_add_pending);
482 
483 void vsock_remove_pending(struct sock *listener, struct sock *pending)
484 {
485 	struct vsock_sock *vpending = vsock_sk(pending);
486 
487 	list_del_init(&vpending->pending_links);
488 	sock_put(listener);
489 	sock_put(pending);
490 }
491 EXPORT_SYMBOL_GPL(vsock_remove_pending);
492 
493 void vsock_enqueue_accept(struct sock *listener, struct sock *connected)
494 {
495 	struct vsock_sock *vlistener;
496 	struct vsock_sock *vconnected;
497 
498 	vlistener = vsock_sk(listener);
499 	vconnected = vsock_sk(connected);
500 
501 	sock_hold(connected);
502 	sock_hold(listener);
503 	list_add_tail(&vconnected->accept_queue, &vlistener->accept_queue);
504 }
505 EXPORT_SYMBOL_GPL(vsock_enqueue_accept);
506 
507 static bool vsock_use_local_transport(unsigned int remote_cid)
508 {
509 	lockdep_assert_held(&vsock_register_mutex);
510 
511 	if (!transport_local)
512 		return false;
513 
514 	if (remote_cid == VMADDR_CID_LOCAL)
515 		return true;
516 
517 	if (transport_g2h) {
518 		return remote_cid == transport_g2h->get_local_cid();
519 	} else {
520 		return remote_cid == VMADDR_CID_HOST;
521 	}
522 }
523 
524 static void vsock_deassign_transport(struct vsock_sock *vsk)
525 {
526 	if (!vsk->transport)
527 		return;
528 
529 	vsk->transport->destruct(vsk);
530 	module_put(vsk->transport->module);
531 	vsk->transport = NULL;
532 }
533 
534 /* Assign a transport to a socket and call the .init transport callback.
535  *
536  * Note: for connection oriented socket this must be called when vsk->remote_addr
537  * is set (e.g. during the connect() or when a connection request on a listener
538  * socket is received).
539  * The vsk->remote_addr is used to decide which transport to use:
540  *  - remote CID == VMADDR_CID_LOCAL or g2h->local_cid or VMADDR_CID_HOST if
541  *    g2h is not loaded, will use local transport;
542  *  - remote CID <= VMADDR_CID_HOST or h2g is not loaded or remote flags field
543  *    includes VMADDR_FLAG_TO_HOST flag value, will use guest->host transport;
544  *  - remote CID > VMADDR_CID_HOST will use host->guest transport;
545  */
546 int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
547 {
548 	const struct vsock_transport *new_transport;
549 	struct sock *sk = sk_vsock(vsk);
550 	unsigned int remote_cid = vsk->remote_addr.svm_cid;
551 	__u8 remote_flags;
552 	int ret;
553 
554 	/* If the packet is coming with the source and destination CIDs higher
555 	 * than VMADDR_CID_HOST, then a vsock channel where all the packets are
556 	 * forwarded to the host should be established. Then the host will
557 	 * need to forward the packets to the guest.
558 	 *
559 	 * The flag is set on the (listen) receive path (psk is not NULL). On
560 	 * the connect path the flag can be set by the user space application.
561 	 */
562 	if (psk && vsk->local_addr.svm_cid > VMADDR_CID_HOST &&
563 	    vsk->remote_addr.svm_cid > VMADDR_CID_HOST)
564 		vsk->remote_addr.svm_flags |= VMADDR_FLAG_TO_HOST;
565 
566 	remote_flags = vsk->remote_addr.svm_flags;
567 
568 	mutex_lock(&vsock_register_mutex);
569 
570 	switch (sk->sk_type) {
571 	case SOCK_DGRAM:
572 		new_transport = transport_dgram;
573 		break;
574 	case SOCK_STREAM:
575 	case SOCK_SEQPACKET:
576 		if (vsock_use_local_transport(remote_cid))
577 			new_transport = transport_local;
578 		else if (remote_cid <= VMADDR_CID_HOST || !transport_h2g ||
579 			 (remote_flags & VMADDR_FLAG_TO_HOST))
580 			new_transport = transport_g2h;
581 		else
582 			new_transport = transport_h2g;
583 		break;
584 	default:
585 		ret = -ESOCKTNOSUPPORT;
586 		goto err;
587 	}
588 
589 	if (vsk->transport && vsk->transport == new_transport) {
590 		ret = 0;
591 		goto err;
592 	}
593 
594 	/* We increase the module refcnt to prevent the transport unloading
595 	 * while there are open sockets assigned to it.
596 	 */
597 	if (!new_transport || !try_module_get(new_transport->module)) {
598 		ret = -ENODEV;
599 		goto err;
600 	}
601 
602 	/* It's safe to release the mutex after a successful try_module_get().
603 	 * Whichever transport `new_transport` points at, it won't go away until
604 	 * the last module_put() below or in vsock_deassign_transport().
605 	 */
606 	mutex_unlock(&vsock_register_mutex);
607 
608 	if (vsk->transport) {
609 		/* transport->release() must be called with sock lock acquired.
610 		 * This path can only be taken during vsock_connect(), where we
611 		 * have already held the sock lock. In the other cases, this
612 		 * function is called on a new socket which is not assigned to
613 		 * any transport.
614 		 */
615 		vsk->transport->release(vsk);
616 		vsock_deassign_transport(vsk);
617 
618 		/* transport's release() and destruct() can touch some socket
619 		 * state, since we are reassigning the socket to a new transport
620 		 * during vsock_connect(), let's reset these fields to have a
621 		 * clean state.
622 		 */
623 		sock_reset_flag(sk, SOCK_DONE);
624 		sk->sk_state = TCP_CLOSE;
625 		vsk->peer_shutdown = 0;
626 	}
627 
628 	if (sk->sk_type == SOCK_SEQPACKET) {
629 		if (!new_transport->seqpacket_allow ||
630 		    !new_transport->seqpacket_allow(vsk, remote_cid)) {
631 			module_put(new_transport->module);
632 			return -ESOCKTNOSUPPORT;
633 		}
634 	}
635 
636 	ret = new_transport->init(vsk, psk);
637 	if (ret) {
638 		module_put(new_transport->module);
639 		return ret;
640 	}
641 
642 	vsk->transport = new_transport;
643 
644 	return 0;
645 err:
646 	mutex_unlock(&vsock_register_mutex);
647 	return ret;
648 }
649 EXPORT_SYMBOL_GPL(vsock_assign_transport);
650 
651 /*
652  * Provide safe access to static transport_{h2g,g2h,dgram,local} callbacks.
653  * Otherwise we may race with module removal. Do not use on `vsk->transport`.
654  */
655 static u32 vsock_registered_transport_cid(const struct vsock_transport **transport)
656 {
657 	u32 cid = VMADDR_CID_ANY;
658 
659 	mutex_lock(&vsock_register_mutex);
660 	if (*transport)
661 		cid = (*transport)->get_local_cid();
662 	mutex_unlock(&vsock_register_mutex);
663 
664 	return cid;
665 }
666 
667 bool vsock_find_cid(unsigned int cid)
668 {
669 	if (cid == vsock_registered_transport_cid(&transport_g2h))
670 		return true;
671 
672 	if (transport_h2g && cid == VMADDR_CID_HOST)
673 		return true;
674 
675 	if (transport_local && cid == VMADDR_CID_LOCAL)
676 		return true;
677 
678 	return false;
679 }
680 EXPORT_SYMBOL_GPL(vsock_find_cid);
681 
682 static struct sock *vsock_dequeue_accept(struct sock *listener)
683 {
684 	struct vsock_sock *vlistener;
685 	struct vsock_sock *vconnected;
686 
687 	vlistener = vsock_sk(listener);
688 
689 	if (list_empty(&vlistener->accept_queue))
690 		return NULL;
691 
692 	vconnected = list_entry(vlistener->accept_queue.next,
693 				struct vsock_sock, accept_queue);
694 
695 	list_del_init(&vconnected->accept_queue);
696 	sock_put(listener);
697 	/* The caller will need a reference on the connected socket so we let
698 	 * it call sock_put().
699 	 */
700 
701 	return sk_vsock(vconnected);
702 }
703 
704 static bool vsock_is_accept_queue_empty(struct sock *sk)
705 {
706 	struct vsock_sock *vsk = vsock_sk(sk);
707 	return list_empty(&vsk->accept_queue);
708 }
709 
710 static bool vsock_is_pending(struct sock *sk)
711 {
712 	struct vsock_sock *vsk = vsock_sk(sk);
713 	return !list_empty(&vsk->pending_links);
714 }
715 
716 static int vsock_send_shutdown(struct sock *sk, int mode)
717 {
718 	struct vsock_sock *vsk = vsock_sk(sk);
719 
720 	if (!vsk->transport)
721 		return -ENODEV;
722 
723 	return vsk->transport->shutdown(vsk, mode);
724 }
725 
726 static void vsock_pending_work(struct work_struct *work)
727 {
728 	struct sock *sk;
729 	struct sock *listener;
730 	struct vsock_sock *vsk;
731 	bool cleanup;
732 
733 	vsk = container_of(work, struct vsock_sock, pending_work.work);
734 	sk = sk_vsock(vsk);
735 	listener = vsk->listener;
736 	cleanup = true;
737 
738 	lock_sock(listener);
739 	lock_sock_nested(sk, SINGLE_DEPTH_NESTING);
740 
741 	if (vsock_is_pending(sk)) {
742 		vsock_remove_pending(listener, sk);
743 
744 		sk_acceptq_removed(listener);
745 	} else if (!vsk->rejected) {
746 		/* We are not on the pending list and accept() did not reject
747 		 * us, so we must have been accepted by our user process.  We
748 		 * just need to drop our references to the sockets and be on
749 		 * our way.
750 		 */
751 		cleanup = false;
752 		goto out;
753 	}
754 
755 	/* We need to remove ourself from the global connected sockets list so
756 	 * incoming packets can't find this socket, and to reduce the reference
757 	 * count.
758 	 */
759 	vsock_remove_connected(vsk);
760 
761 	sk->sk_state = TCP_CLOSE;
762 
763 out:
764 	release_sock(sk);
765 	release_sock(listener);
766 	if (cleanup)
767 		sock_put(sk);
768 
769 	sock_put(sk);
770 	sock_put(listener);
771 }
772 
773 /**** SOCKET OPERATIONS ****/
774 
775 static int __vsock_bind_connectible(struct vsock_sock *vsk,
776 				    struct sockaddr_vm *addr)
777 {
778 	struct net *net = sock_net(sk_vsock(vsk));
779 	struct sockaddr_vm new_addr;
780 
781 	if (!net->vsock.port)
782 		net->vsock.port = get_random_u32_above(LAST_RESERVED_PORT);
783 
784 	vsock_addr_init(&new_addr, addr->svm_cid, addr->svm_port);
785 
786 	if (addr->svm_port == VMADDR_PORT_ANY) {
787 		bool found = false;
788 		unsigned int i;
789 
790 		for (i = 0; i < MAX_PORT_RETRIES; i++) {
791 			if (net->vsock.port == VMADDR_PORT_ANY ||
792 			    net->vsock.port <= LAST_RESERVED_PORT)
793 				net->vsock.port = LAST_RESERVED_PORT + 1;
794 
795 			new_addr.svm_port = net->vsock.port++;
796 
797 			if (!__vsock_find_bound_socket_net(&new_addr, net)) {
798 				found = true;
799 				break;
800 			}
801 		}
802 
803 		if (!found)
804 			return -EADDRNOTAVAIL;
805 	} else {
806 		/* If port is in reserved range, ensure caller
807 		 * has necessary privileges.
808 		 */
809 		if (addr->svm_port <= LAST_RESERVED_PORT &&
810 		    !capable(CAP_NET_BIND_SERVICE)) {
811 			return -EACCES;
812 		}
813 
814 		if (__vsock_find_bound_socket_net(&new_addr, net))
815 			return -EADDRINUSE;
816 	}
817 
818 	vsock_addr_init(&vsk->local_addr, new_addr.svm_cid, new_addr.svm_port);
819 
820 	/* Remove connection oriented sockets from the unbound list and add them
821 	 * to the hash table for easy lookup by its address.  The unbound list
822 	 * is simply an extra entry at the end of the hash table, a trick used
823 	 * by AF_UNIX.
824 	 */
825 	__vsock_remove_bound(vsk);
826 	__vsock_insert_bound(vsock_bound_sockets(&vsk->local_addr), vsk);
827 
828 	return 0;
829 }
830 
831 static int __vsock_bind_dgram(struct vsock_sock *vsk,
832 			      struct sockaddr_vm *addr)
833 {
834 	return vsk->transport->dgram_bind(vsk, addr);
835 }
836 
837 static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr)
838 {
839 	struct vsock_sock *vsk = vsock_sk(sk);
840 	int retval;
841 
842 	/* First ensure this socket isn't already bound. */
843 	if (vsock_addr_bound(&vsk->local_addr))
844 		return -EINVAL;
845 
846 	/* Now bind to the provided address or select appropriate values if
847 	 * none are provided (VMADDR_CID_ANY and VMADDR_PORT_ANY).  Note that
848 	 * like AF_INET prevents binding to a non-local IP address (in most
849 	 * cases), we only allow binding to a local CID.
850 	 */
851 	if (addr->svm_cid != VMADDR_CID_ANY && !vsock_find_cid(addr->svm_cid))
852 		return -EADDRNOTAVAIL;
853 
854 	switch (sk->sk_socket->type) {
855 	case SOCK_STREAM:
856 	case SOCK_SEQPACKET:
857 		spin_lock_bh(&vsock_table_lock);
858 		retval = __vsock_bind_connectible(vsk, addr);
859 		spin_unlock_bh(&vsock_table_lock);
860 		break;
861 
862 	case SOCK_DGRAM:
863 		retval = __vsock_bind_dgram(vsk, addr);
864 		break;
865 
866 	default:
867 		retval = -EINVAL;
868 		break;
869 	}
870 
871 	return retval;
872 }
873 
874 static void vsock_connect_timeout(struct work_struct *work);
875 
876 static struct sock *__vsock_create(struct net *net,
877 				   struct socket *sock,
878 				   struct sock *parent,
879 				   gfp_t priority,
880 				   unsigned short type,
881 				   int kern)
882 {
883 	struct sock *sk;
884 	struct vsock_sock *psk;
885 	struct vsock_sock *vsk;
886 
887 	sk = sk_alloc(net, AF_VSOCK, priority, &vsock_proto, kern);
888 	if (!sk)
889 		return NULL;
890 
891 	sock_init_data(sock, sk);
892 
893 	/* sk->sk_type is normally set in sock_init_data, but only if sock is
894 	 * non-NULL. We make sure that our sockets always have a type by
895 	 * setting it here if needed.
896 	 */
897 	if (!sock)
898 		sk->sk_type = type;
899 
900 	vsk = vsock_sk(sk);
901 	vsock_addr_init(&vsk->local_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY);
902 	vsock_addr_init(&vsk->remote_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY);
903 
904 	sk->sk_destruct = vsock_sk_destruct;
905 	sk->sk_backlog_rcv = vsock_queue_rcv_skb;
906 	sock_reset_flag(sk, SOCK_DONE);
907 
908 	INIT_LIST_HEAD(&vsk->bound_table);
909 	INIT_LIST_HEAD(&vsk->connected_table);
910 	vsk->listener = NULL;
911 	INIT_LIST_HEAD(&vsk->pending_links);
912 	INIT_LIST_HEAD(&vsk->accept_queue);
913 	vsk->rejected = false;
914 	vsk->sent_request = false;
915 	vsk->ignore_connecting_rst = false;
916 	vsk->peer_shutdown = 0;
917 	INIT_DELAYED_WORK(&vsk->connect_work, vsock_connect_timeout);
918 	INIT_DELAYED_WORK(&vsk->pending_work, vsock_pending_work);
919 
920 	psk = parent ? vsock_sk(parent) : NULL;
921 	if (parent) {
922 		vsk->trusted = psk->trusted;
923 		vsk->owner = get_cred(psk->owner);
924 		vsk->connect_timeout = psk->connect_timeout;
925 		vsk->buffer_size = psk->buffer_size;
926 		vsk->buffer_min_size = psk->buffer_min_size;
927 		vsk->buffer_max_size = psk->buffer_max_size;
928 		security_sk_clone(parent, sk);
929 	} else {
930 		vsk->trusted = ns_capable_noaudit(&init_user_ns, CAP_NET_ADMIN);
931 		vsk->owner = get_current_cred();
932 		vsk->connect_timeout = VSOCK_DEFAULT_CONNECT_TIMEOUT;
933 		vsk->buffer_size = VSOCK_DEFAULT_BUFFER_SIZE;
934 		vsk->buffer_min_size = VSOCK_DEFAULT_BUFFER_MIN_SIZE;
935 		vsk->buffer_max_size = VSOCK_DEFAULT_BUFFER_MAX_SIZE;
936 	}
937 
938 	return sk;
939 }
940 
941 static bool sock_type_connectible(u16 type)
942 {
943 	return (type == SOCK_STREAM) || (type == SOCK_SEQPACKET);
944 }
945 
946 static void __vsock_release(struct sock *sk, int level)
947 {
948 	struct vsock_sock *vsk;
949 	struct sock *pending;
950 
951 	vsk = vsock_sk(sk);
952 	pending = NULL;	/* Compiler warning. */
953 
954 	/* When "level" is SINGLE_DEPTH_NESTING, use the nested
955 	 * version to avoid the warning "possible recursive locking
956 	 * detected". When "level" is 0, lock_sock_nested(sk, level)
957 	 * is the same as lock_sock(sk).
958 	 */
959 	lock_sock_nested(sk, level);
960 
961 	/* Indicate to vsock_remove_sock() that the socket is being released and
962 	 * can be removed from the bound_table. Unlike transport reassignment
963 	 * case, where the socket must remain bound despite vsock_remove_sock()
964 	 * being called from the transport release() callback.
965 	 */
966 	sock_set_flag(sk, SOCK_DEAD);
967 
968 	if (vsk->transport)
969 		vsk->transport->release(vsk);
970 	else if (sock_type_connectible(sk->sk_type))
971 		vsock_remove_sock(vsk);
972 
973 	sock_orphan(sk);
974 	sk->sk_shutdown = SHUTDOWN_MASK;
975 
976 	skb_queue_purge(&sk->sk_receive_queue);
977 
978 	/* Clean up any sockets that never were accepted. */
979 	while ((pending = vsock_dequeue_accept(sk)) != NULL) {
980 		__vsock_release(pending, SINGLE_DEPTH_NESTING);
981 		sock_put(pending);
982 	}
983 
984 	release_sock(sk);
985 	sock_put(sk);
986 }
987 
988 static void vsock_sk_destruct(struct sock *sk)
989 {
990 	struct vsock_sock *vsk = vsock_sk(sk);
991 
992 	/* Flush MSG_ZEROCOPY leftovers. */
993 	__skb_queue_purge(&sk->sk_error_queue);
994 
995 	vsock_deassign_transport(vsk);
996 
997 	/* When clearing these addresses, there's no need to set the family and
998 	 * possibly register the address family with the kernel.
999 	 */
1000 	vsock_addr_init(&vsk->local_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY);
1001 	vsock_addr_init(&vsk->remote_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY);
1002 
1003 	put_cred(vsk->owner);
1004 }
1005 
1006 static int vsock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb)
1007 {
1008 	int err;
1009 
1010 	err = sock_queue_rcv_skb(sk, skb);
1011 	if (err)
1012 		kfree_skb(skb);
1013 
1014 	return err;
1015 }
1016 
1017 struct sock *vsock_create_connected(struct sock *parent)
1018 {
1019 	return __vsock_create(sock_net(parent), NULL, parent, GFP_KERNEL,
1020 			      parent->sk_type, 0);
1021 }
1022 EXPORT_SYMBOL_GPL(vsock_create_connected);
1023 
1024 s64 vsock_stream_has_data(struct vsock_sock *vsk)
1025 {
1026 	if (WARN_ON(!vsk->transport))
1027 		return 0;
1028 
1029 	return vsk->transport->stream_has_data(vsk);
1030 }
1031 EXPORT_SYMBOL_GPL(vsock_stream_has_data);
1032 
1033 s64 vsock_connectible_has_data(struct vsock_sock *vsk)
1034 {
1035 	struct sock *sk = sk_vsock(vsk);
1036 
1037 	if (WARN_ON(!vsk->transport))
1038 		return 0;
1039 
1040 	if (sk->sk_type == SOCK_SEQPACKET)
1041 		return vsk->transport->seqpacket_has_data(vsk);
1042 	else
1043 		return vsock_stream_has_data(vsk);
1044 }
1045 EXPORT_SYMBOL_GPL(vsock_connectible_has_data);
1046 
1047 s64 vsock_stream_has_space(struct vsock_sock *vsk)
1048 {
1049 	if (WARN_ON(!vsk->transport))
1050 		return 0;
1051 
1052 	return vsk->transport->stream_has_space(vsk);
1053 }
1054 EXPORT_SYMBOL_GPL(vsock_stream_has_space);
1055 
1056 void vsock_data_ready(struct sock *sk)
1057 {
1058 	struct vsock_sock *vsk = vsock_sk(sk);
1059 
1060 	if (vsock_stream_has_data(vsk) >= sk->sk_rcvlowat ||
1061 	    sock_flag(sk, SOCK_DONE))
1062 		sk->sk_data_ready(sk);
1063 }
1064 EXPORT_SYMBOL_GPL(vsock_data_ready);
1065 
1066 /* Dummy callback required by sockmap.
1067  * See unconditional call of saved_close() in sock_map_close().
1068  */
1069 static void vsock_close(struct sock *sk, long timeout)
1070 {
1071 }
1072 
1073 static int vsock_release(struct socket *sock)
1074 {
1075 	struct sock *sk = sock->sk;
1076 
1077 	if (!sk)
1078 		return 0;
1079 
1080 	sk->sk_prot->close(sk, 0);
1081 	__vsock_release(sk, 0);
1082 	sock->sk = NULL;
1083 	sock->state = SS_FREE;
1084 
1085 	return 0;
1086 }
1087 
1088 static int
1089 vsock_bind(struct socket *sock, struct sockaddr_unsized *addr, int addr_len)
1090 {
1091 	int err;
1092 	struct sock *sk;
1093 	struct sockaddr_vm *vm_addr;
1094 
1095 	sk = sock->sk;
1096 
1097 	if (vsock_addr_cast(addr, addr_len, &vm_addr) != 0)
1098 		return -EINVAL;
1099 
1100 	lock_sock(sk);
1101 	err = __vsock_bind(sk, vm_addr);
1102 	release_sock(sk);
1103 
1104 	return err;
1105 }
1106 
1107 static int vsock_getname(struct socket *sock,
1108 			 struct sockaddr *addr, int peer)
1109 {
1110 	int err;
1111 	struct sock *sk;
1112 	struct vsock_sock *vsk;
1113 	struct sockaddr_vm *vm_addr;
1114 
1115 	sk = sock->sk;
1116 	vsk = vsock_sk(sk);
1117 	err = 0;
1118 
1119 	lock_sock(sk);
1120 
1121 	if (peer) {
1122 		if (sock->state != SS_CONNECTED) {
1123 			err = -ENOTCONN;
1124 			goto out;
1125 		}
1126 		vm_addr = &vsk->remote_addr;
1127 	} else {
1128 		vm_addr = &vsk->local_addr;
1129 	}
1130 
1131 	BUILD_BUG_ON(sizeof(*vm_addr) > sizeof(struct sockaddr_storage));
1132 	memcpy(addr, vm_addr, sizeof(*vm_addr));
1133 	err = sizeof(*vm_addr);
1134 
1135 out:
1136 	release_sock(sk);
1137 	return err;
1138 }
1139 
1140 void vsock_linger(struct sock *sk)
1141 {
1142 	DEFINE_WAIT_FUNC(wait, woken_wake_function);
1143 	ssize_t (*unsent)(struct vsock_sock *vsk);
1144 	struct vsock_sock *vsk = vsock_sk(sk);
1145 	long timeout;
1146 
1147 	if (!sock_flag(sk, SOCK_LINGER))
1148 		return;
1149 
1150 	timeout = sk->sk_lingertime;
1151 	if (!timeout)
1152 		return;
1153 
1154 	/* Transports must implement `unsent_bytes` if they want to support
1155 	 * SOCK_LINGER through `vsock_linger()` since we use it to check when
1156 	 * the socket can be closed.
1157 	 */
1158 	unsent = vsk->transport->unsent_bytes;
1159 	if (!unsent)
1160 		return;
1161 
1162 	add_wait_queue(sk_sleep(sk), &wait);
1163 
1164 	do {
1165 		if (sk_wait_event(sk, &timeout, unsent(vsk) == 0, &wait))
1166 			break;
1167 	} while (!signal_pending(current) && timeout);
1168 
1169 	remove_wait_queue(sk_sleep(sk), &wait);
1170 }
1171 EXPORT_SYMBOL_GPL(vsock_linger);
1172 
1173 static int vsock_shutdown(struct socket *sock, int mode)
1174 {
1175 	int err;
1176 	struct sock *sk;
1177 
1178 	/* User level uses SHUT_RD (0) and SHUT_WR (1), but the kernel uses
1179 	 * RCV_SHUTDOWN (1) and SEND_SHUTDOWN (2), so we must increment mode
1180 	 * here like the other address families do.  Note also that the
1181 	 * increment makes SHUT_RDWR (2) into RCV_SHUTDOWN | SEND_SHUTDOWN (3),
1182 	 * which is what we want.
1183 	 */
1184 	mode++;
1185 
1186 	if ((mode & ~SHUTDOWN_MASK) || !mode)
1187 		return -EINVAL;
1188 
1189 	/* If this is a connection oriented socket and it is not connected then
1190 	 * bail out immediately.  If it is a DGRAM socket then we must first
1191 	 * kick the socket so that it wakes up from any sleeping calls, for
1192 	 * example recv(), and then afterwards return the error.
1193 	 */
1194 
1195 	sk = sock->sk;
1196 
1197 	lock_sock(sk);
1198 	if (sock->state == SS_UNCONNECTED) {
1199 		err = -ENOTCONN;
1200 		if (sock_type_connectible(sk->sk_type))
1201 			goto out;
1202 	} else {
1203 		sock->state = SS_DISCONNECTING;
1204 		err = 0;
1205 	}
1206 
1207 	/* Receive and send shutdowns are treated alike. */
1208 	mode = mode & (RCV_SHUTDOWN | SEND_SHUTDOWN);
1209 	if (mode) {
1210 		sk->sk_shutdown |= mode;
1211 		sk->sk_state_change(sk);
1212 
1213 		if (sock_type_connectible(sk->sk_type)) {
1214 			sock_reset_flag(sk, SOCK_DONE);
1215 			vsock_send_shutdown(sk, mode);
1216 		}
1217 	}
1218 
1219 out:
1220 	release_sock(sk);
1221 	return err;
1222 }
1223 
1224 static __poll_t vsock_poll(struct file *file, struct socket *sock,
1225 			       poll_table *wait)
1226 {
1227 	struct sock *sk;
1228 	__poll_t mask;
1229 	struct vsock_sock *vsk;
1230 
1231 	sk = sock->sk;
1232 	vsk = vsock_sk(sk);
1233 
1234 	poll_wait(file, sk_sleep(sk), wait);
1235 	mask = 0;
1236 
1237 	if (sk->sk_err || !skb_queue_empty_lockless(&sk->sk_error_queue))
1238 		/* Signify that there has been an error on this socket. */
1239 		mask |= EPOLLERR;
1240 
1241 	/* INET sockets treat local write shutdown and peer write shutdown as a
1242 	 * case of EPOLLHUP set.
1243 	 */
1244 	if ((sk->sk_shutdown == SHUTDOWN_MASK) ||
1245 	    ((sk->sk_shutdown & SEND_SHUTDOWN) &&
1246 	     (vsk->peer_shutdown & SEND_SHUTDOWN))) {
1247 		mask |= EPOLLHUP;
1248 	}
1249 
1250 	if (sk->sk_shutdown & RCV_SHUTDOWN ||
1251 	    vsk->peer_shutdown & SEND_SHUTDOWN) {
1252 		mask |= EPOLLRDHUP;
1253 	}
1254 
1255 	if (sk_is_readable(sk))
1256 		mask |= EPOLLIN | EPOLLRDNORM;
1257 
1258 	if (sock->type == SOCK_DGRAM) {
1259 		/* For datagram sockets we can read if there is something in
1260 		 * the queue and write as long as the socket isn't shutdown for
1261 		 * sending.
1262 		 */
1263 		if (!skb_queue_empty_lockless(&sk->sk_receive_queue) ||
1264 		    (sk->sk_shutdown & RCV_SHUTDOWN)) {
1265 			mask |= EPOLLIN | EPOLLRDNORM;
1266 		}
1267 
1268 		if (!(sk->sk_shutdown & SEND_SHUTDOWN))
1269 			mask |= EPOLLOUT | EPOLLWRNORM | EPOLLWRBAND;
1270 
1271 	} else if (sock_type_connectible(sk->sk_type)) {
1272 		const struct vsock_transport *transport;
1273 
1274 		lock_sock(sk);
1275 
1276 		transport = vsk->transport;
1277 
1278 		/* Listening sockets that have connections in their accept
1279 		 * queue can be read.
1280 		 */
1281 		if (sk->sk_state == TCP_LISTEN
1282 		    && !vsock_is_accept_queue_empty(sk))
1283 			mask |= EPOLLIN | EPOLLRDNORM;
1284 
1285 		/* If there is something in the queue then we can read. */
1286 		if (transport && transport->stream_is_active(vsk) &&
1287 		    !(sk->sk_shutdown & RCV_SHUTDOWN)) {
1288 			bool data_ready_now = false;
1289 			int target = sock_rcvlowat(sk, 0, INT_MAX);
1290 			int ret = transport->notify_poll_in(
1291 					vsk, target, &data_ready_now);
1292 			if (ret < 0) {
1293 				mask |= EPOLLERR;
1294 			} else {
1295 				if (data_ready_now)
1296 					mask |= EPOLLIN | EPOLLRDNORM;
1297 
1298 			}
1299 		}
1300 
1301 		/* Sockets whose connections have been closed, reset, or
1302 		 * terminated should also be considered read, and we check the
1303 		 * shutdown flag for that.
1304 		 */
1305 		if (sk->sk_shutdown & RCV_SHUTDOWN ||
1306 		    vsk->peer_shutdown & SEND_SHUTDOWN) {
1307 			mask |= EPOLLIN | EPOLLRDNORM;
1308 		}
1309 
1310 		/* Connected sockets that can produce data can be written. */
1311 		if (transport && sk->sk_state == TCP_ESTABLISHED) {
1312 			if (!(sk->sk_shutdown & SEND_SHUTDOWN)) {
1313 				bool space_avail_now = false;
1314 				int ret = transport->notify_poll_out(
1315 						vsk, 1, &space_avail_now);
1316 				if (ret < 0) {
1317 					mask |= EPOLLERR;
1318 				} else {
1319 					if (space_avail_now)
1320 						/* Remove EPOLLWRBAND since INET
1321 						 * sockets are not setting it.
1322 						 */
1323 						mask |= EPOLLOUT | EPOLLWRNORM;
1324 
1325 				}
1326 			}
1327 		}
1328 
1329 		/* Simulate INET socket poll behaviors, which sets
1330 		 * EPOLLOUT|EPOLLWRNORM when peer is closed and nothing to read,
1331 		 * but local send is not shutdown.
1332 		 */
1333 		if (sk->sk_state == TCP_CLOSE || sk->sk_state == TCP_CLOSING) {
1334 			if (!(sk->sk_shutdown & SEND_SHUTDOWN))
1335 				mask |= EPOLLOUT | EPOLLWRNORM;
1336 
1337 		}
1338 
1339 		release_sock(sk);
1340 	}
1341 
1342 	return mask;
1343 }
1344 
1345 static int vsock_read_skb(struct sock *sk, skb_read_actor_t read_actor)
1346 {
1347 	struct vsock_sock *vsk = vsock_sk(sk);
1348 
1349 	if (WARN_ON_ONCE(!vsk->transport))
1350 		return -ENODEV;
1351 
1352 	return vsk->transport->read_skb(vsk, read_actor);
1353 }
1354 
1355 static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
1356 			       size_t len)
1357 {
1358 	int err;
1359 	struct sock *sk;
1360 	struct vsock_sock *vsk;
1361 	struct sockaddr_vm *remote_addr;
1362 	const struct vsock_transport *transport;
1363 
1364 	if (msg->msg_flags & MSG_OOB)
1365 		return -EOPNOTSUPP;
1366 
1367 	/* For now, MSG_DONTWAIT is always assumed... */
1368 	err = 0;
1369 	sk = sock->sk;
1370 	vsk = vsock_sk(sk);
1371 
1372 	lock_sock(sk);
1373 
1374 	transport = vsk->transport;
1375 
1376 	err = vsock_auto_bind(vsk);
1377 	if (err)
1378 		goto out;
1379 
1380 
1381 	/* If the provided message contains an address, use that.  Otherwise
1382 	 * fall back on the socket's remote handle (if it has been connected).
1383 	 */
1384 	if (msg->msg_name &&
1385 	    vsock_addr_cast(msg->msg_name, msg->msg_namelen,
1386 			    &remote_addr) == 0) {
1387 		/* Ensure this address is of the right type and is a valid
1388 		 * destination.
1389 		 */
1390 
1391 		if (remote_addr->svm_cid == VMADDR_CID_ANY)
1392 			remote_addr->svm_cid = transport->get_local_cid();
1393 
1394 		if (!vsock_addr_bound(remote_addr)) {
1395 			err = -EINVAL;
1396 			goto out;
1397 		}
1398 	} else if (sock->state == SS_CONNECTED) {
1399 		remote_addr = &vsk->remote_addr;
1400 
1401 		if (remote_addr->svm_cid == VMADDR_CID_ANY)
1402 			remote_addr->svm_cid = transport->get_local_cid();
1403 
1404 		/* XXX Should connect() or this function ensure remote_addr is
1405 		 * bound?
1406 		 */
1407 		if (!vsock_addr_bound(&vsk->remote_addr)) {
1408 			err = -EINVAL;
1409 			goto out;
1410 		}
1411 	} else {
1412 		err = -EINVAL;
1413 		goto out;
1414 	}
1415 
1416 	if (!transport->dgram_allow(vsk, remote_addr->svm_cid,
1417 				    remote_addr->svm_port)) {
1418 		err = -EINVAL;
1419 		goto out;
1420 	}
1421 
1422 	err = transport->dgram_enqueue(vsk, remote_addr, msg, len);
1423 
1424 out:
1425 	release_sock(sk);
1426 	return err;
1427 }
1428 
1429 static int vsock_dgram_connect(struct socket *sock,
1430 			       struct sockaddr_unsized *addr, int addr_len, int flags)
1431 {
1432 	int err;
1433 	struct sock *sk;
1434 	struct vsock_sock *vsk;
1435 	struct sockaddr_vm *remote_addr;
1436 
1437 	sk = sock->sk;
1438 	vsk = vsock_sk(sk);
1439 
1440 	err = vsock_addr_cast(addr, addr_len, &remote_addr);
1441 	if (err == -EAFNOSUPPORT && remote_addr->svm_family == AF_UNSPEC) {
1442 		lock_sock(sk);
1443 		vsock_addr_init(&vsk->remote_addr, VMADDR_CID_ANY,
1444 				VMADDR_PORT_ANY);
1445 		sock->state = SS_UNCONNECTED;
1446 		release_sock(sk);
1447 		return 0;
1448 	} else if (err != 0)
1449 		return -EINVAL;
1450 
1451 	lock_sock(sk);
1452 
1453 	err = vsock_auto_bind(vsk);
1454 	if (err)
1455 		goto out;
1456 
1457 	if (!vsk->transport->dgram_allow(vsk, remote_addr->svm_cid,
1458 					 remote_addr->svm_port)) {
1459 		err = -EINVAL;
1460 		goto out;
1461 	}
1462 
1463 	memcpy(&vsk->remote_addr, remote_addr, sizeof(vsk->remote_addr));
1464 	sock->state = SS_CONNECTED;
1465 
1466 	/* sock map disallows redirection of non-TCP sockets with sk_state !=
1467 	 * TCP_ESTABLISHED (see sock_map_redirect_allowed()), so we set
1468 	 * TCP_ESTABLISHED here to allow redirection of connected vsock dgrams.
1469 	 *
1470 	 * This doesn't seem to be abnormal state for datagram sockets, as the
1471 	 * same approach can be see in other datagram socket types as well
1472 	 * (such as unix sockets).
1473 	 */
1474 	sk->sk_state = TCP_ESTABLISHED;
1475 
1476 out:
1477 	release_sock(sk);
1478 	return err;
1479 }
1480 
1481 int __vsock_dgram_recvmsg(struct socket *sock, struct msghdr *msg,
1482 			  size_t len, int flags)
1483 {
1484 	struct sock *sk = sock->sk;
1485 	struct vsock_sock *vsk = vsock_sk(sk);
1486 
1487 	return vsk->transport->dgram_dequeue(vsk, msg, len, flags);
1488 }
1489 
1490 int vsock_dgram_recvmsg(struct socket *sock, struct msghdr *msg,
1491 			size_t len, int flags)
1492 {
1493 #ifdef CONFIG_BPF_SYSCALL
1494 	struct sock *sk = sock->sk;
1495 	const struct proto *prot;
1496 
1497 	prot = READ_ONCE(sk->sk_prot);
1498 	if (prot != &vsock_proto)
1499 		return prot->recvmsg(sk, msg, len, flags, NULL);
1500 #endif
1501 
1502 	return __vsock_dgram_recvmsg(sock, msg, len, flags);
1503 }
1504 EXPORT_SYMBOL_GPL(vsock_dgram_recvmsg);
1505 
1506 static int vsock_do_ioctl(struct socket *sock, unsigned int cmd,
1507 			  int __user *arg)
1508 {
1509 	struct sock *sk = sock->sk;
1510 	struct vsock_sock *vsk;
1511 	int ret;
1512 
1513 	vsk = vsock_sk(sk);
1514 
1515 	switch (cmd) {
1516 	case SIOCINQ: {
1517 		ssize_t n_bytes;
1518 
1519 		if (!vsk->transport) {
1520 			ret = -EOPNOTSUPP;
1521 			break;
1522 		}
1523 
1524 		if (sock_type_connectible(sk->sk_type) &&
1525 		    sk->sk_state == TCP_LISTEN) {
1526 			ret = -EINVAL;
1527 			break;
1528 		}
1529 
1530 		n_bytes = vsock_stream_has_data(vsk);
1531 		if (n_bytes < 0) {
1532 			ret = n_bytes;
1533 			break;
1534 		}
1535 		ret = put_user(n_bytes, arg);
1536 		break;
1537 	}
1538 	case SIOCOUTQ: {
1539 		ssize_t n_bytes;
1540 
1541 		if (!vsk->transport || !vsk->transport->unsent_bytes) {
1542 			ret = -EOPNOTSUPP;
1543 			break;
1544 		}
1545 
1546 		if (sock_type_connectible(sk->sk_type) && sk->sk_state == TCP_LISTEN) {
1547 			ret = -EINVAL;
1548 			break;
1549 		}
1550 
1551 		n_bytes = vsk->transport->unsent_bytes(vsk);
1552 		if (n_bytes < 0) {
1553 			ret = n_bytes;
1554 			break;
1555 		}
1556 
1557 		ret = put_user(n_bytes, arg);
1558 		break;
1559 	}
1560 	default:
1561 		ret = -ENOIOCTLCMD;
1562 	}
1563 
1564 	return ret;
1565 }
1566 
1567 static int vsock_ioctl(struct socket *sock, unsigned int cmd,
1568 		       unsigned long arg)
1569 {
1570 	int ret;
1571 
1572 	lock_sock(sock->sk);
1573 	ret = vsock_do_ioctl(sock, cmd, (int __user *)arg);
1574 	release_sock(sock->sk);
1575 
1576 	return ret;
1577 }
1578 
1579 static const struct proto_ops vsock_dgram_ops = {
1580 	.family = PF_VSOCK,
1581 	.owner = THIS_MODULE,
1582 	.release = vsock_release,
1583 	.bind = vsock_bind,
1584 	.connect = vsock_dgram_connect,
1585 	.socketpair = sock_no_socketpair,
1586 	.accept = sock_no_accept,
1587 	.getname = vsock_getname,
1588 	.poll = vsock_poll,
1589 	.ioctl = vsock_ioctl,
1590 	.listen = sock_no_listen,
1591 	.shutdown = vsock_shutdown,
1592 	.sendmsg = vsock_dgram_sendmsg,
1593 	.recvmsg = vsock_dgram_recvmsg,
1594 	.mmap = sock_no_mmap,
1595 	.read_skb = vsock_read_skb,
1596 };
1597 
1598 static int vsock_transport_cancel_pkt(struct vsock_sock *vsk)
1599 {
1600 	const struct vsock_transport *transport = vsk->transport;
1601 
1602 	if (!transport || !transport->cancel_pkt)
1603 		return -EOPNOTSUPP;
1604 
1605 	return transport->cancel_pkt(vsk);
1606 }
1607 
1608 static void vsock_connect_timeout(struct work_struct *work)
1609 {
1610 	struct sock *sk;
1611 	struct vsock_sock *vsk;
1612 
1613 	vsk = container_of(work, struct vsock_sock, connect_work.work);
1614 	sk = sk_vsock(vsk);
1615 
1616 	lock_sock(sk);
1617 	if (sk->sk_state == TCP_SYN_SENT &&
1618 	    (sk->sk_shutdown != SHUTDOWN_MASK)) {
1619 		sk->sk_state = TCP_CLOSE;
1620 		sk->sk_socket->state = SS_UNCONNECTED;
1621 		sk->sk_err = ETIMEDOUT;
1622 		sk_error_report(sk);
1623 		vsock_transport_cancel_pkt(vsk);
1624 	}
1625 	release_sock(sk);
1626 
1627 	sock_put(sk);
1628 }
1629 
1630 static int vsock_connect(struct socket *sock, struct sockaddr_unsized *addr,
1631 			 int addr_len, int flags)
1632 {
1633 	int err;
1634 	struct sock *sk;
1635 	struct vsock_sock *vsk;
1636 	const struct vsock_transport *transport;
1637 	struct sockaddr_vm *remote_addr;
1638 	long timeout;
1639 	DEFINE_WAIT(wait);
1640 
1641 	err = 0;
1642 	sk = sock->sk;
1643 	vsk = vsock_sk(sk);
1644 
1645 	lock_sock(sk);
1646 
1647 	/* XXX AF_UNSPEC should make us disconnect like AF_INET. */
1648 	switch (sock->state) {
1649 	case SS_CONNECTED:
1650 		err = -EISCONN;
1651 		goto out;
1652 	case SS_DISCONNECTING:
1653 		err = -EINVAL;
1654 		goto out;
1655 	case SS_CONNECTING:
1656 		/* This continues on so we can move sock into the SS_CONNECTED
1657 		 * state once the connection has completed (at which point err
1658 		 * will be set to zero also).  Otherwise, we will either wait
1659 		 * for the connection or return -EALREADY should this be a
1660 		 * non-blocking call.
1661 		 */
1662 		err = -EALREADY;
1663 		if (flags & O_NONBLOCK)
1664 			goto out;
1665 		break;
1666 	default:
1667 		if ((sk->sk_state == TCP_LISTEN) ||
1668 		    vsock_addr_cast(addr, addr_len, &remote_addr) != 0) {
1669 			err = -EINVAL;
1670 			goto out;
1671 		}
1672 
1673 		/* Set the remote address that we are connecting to. */
1674 		memcpy(&vsk->remote_addr, remote_addr,
1675 		       sizeof(vsk->remote_addr));
1676 
1677 		err = vsock_assign_transport(vsk, NULL);
1678 		if (err)
1679 			goto out;
1680 
1681 		transport = vsk->transport;
1682 
1683 		/* The hypervisor and well-known contexts do not have socket
1684 		 * endpoints.
1685 		 */
1686 		if (!transport ||
1687 		    !transport->stream_allow(vsk, remote_addr->svm_cid,
1688 					     remote_addr->svm_port)) {
1689 			err = -ENETUNREACH;
1690 			goto out;
1691 		}
1692 
1693 		if (vsock_msgzerocopy_allow(transport)) {
1694 			set_bit(SOCK_SUPPORT_ZC, &sk->sk_socket->flags);
1695 		} else if (sock_flag(sk, SOCK_ZEROCOPY)) {
1696 			/* If this option was set before 'connect()',
1697 			 * when transport was unknown, check that this
1698 			 * feature is supported here.
1699 			 */
1700 			err = -EOPNOTSUPP;
1701 			goto out;
1702 		}
1703 
1704 		err = vsock_auto_bind(vsk);
1705 		if (err)
1706 			goto out;
1707 
1708 		sk->sk_state = TCP_SYN_SENT;
1709 
1710 		err = transport->connect(vsk);
1711 		if (err < 0)
1712 			goto out;
1713 
1714 		/* sk_err might have been set as a result of an earlier
1715 		 * (failed) connect attempt.
1716 		 */
1717 		sk->sk_err = 0;
1718 
1719 		/* Mark sock as connecting and set the error code to in
1720 		 * progress in case this is a non-blocking connect.
1721 		 */
1722 		sock->state = SS_CONNECTING;
1723 		err = -EINPROGRESS;
1724 	}
1725 
1726 	/* The receive path will handle all communication until we are able to
1727 	 * enter the connected state.  Here we wait for the connection to be
1728 	 * completed or a notification of an error.
1729 	 */
1730 	timeout = vsk->connect_timeout;
1731 	prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
1732 
1733 	/* If the socket is already closing or it is in an error state, there
1734 	 * is no point in waiting.
1735 	 */
1736 	while (sk->sk_state != TCP_ESTABLISHED &&
1737 	       sk->sk_state != TCP_CLOSING && sk->sk_err == 0) {
1738 		if (flags & O_NONBLOCK) {
1739 			/* If we're not going to block, we schedule a timeout
1740 			 * function to generate a timeout on the connection
1741 			 * attempt, in case the peer doesn't respond in a
1742 			 * timely manner. We hold on to the socket until the
1743 			 * timeout fires.
1744 			 */
1745 			sock_hold(sk);
1746 
1747 			/* If the timeout function is already scheduled,
1748 			 * reschedule it, then ungrab the socket refcount to
1749 			 * keep it balanced.
1750 			 */
1751 			if (mod_delayed_work(system_percpu_wq, &vsk->connect_work,
1752 					     timeout))
1753 				sock_put(sk);
1754 
1755 			/* Skip ahead to preserve error code set above. */
1756 			goto out_wait;
1757 		}
1758 
1759 		release_sock(sk);
1760 		timeout = schedule_timeout(timeout);
1761 		lock_sock(sk);
1762 
1763 		/* Connection established. Whatever happens to socket once we
1764 		 * release it, that's not connect()'s concern. No need to go
1765 		 * into signal and timeout handling. Call it a day.
1766 		 *
1767 		 * Note that allowing to "reset" an already established socket
1768 		 * here is racy and insecure.
1769 		 */
1770 		if (sk->sk_state == TCP_ESTABLISHED)
1771 			break;
1772 
1773 		/* If connection was _not_ established and a signal/timeout came
1774 		 * to be, we want the socket's state reset. User space may want
1775 		 * to retry.
1776 		 *
1777 		 * sk_state != TCP_ESTABLISHED implies that socket is not on
1778 		 * vsock_connected_table. We keep the binding and the transport
1779 		 * assigned.
1780 		 */
1781 		if (signal_pending(current) || timeout == 0) {
1782 			err = timeout == 0 ? -ETIMEDOUT : sock_intr_errno(timeout);
1783 
1784 			/* Listener might have already responded with
1785 			 * VIRTIO_VSOCK_OP_RESPONSE. Its handling expects our
1786 			 * sk_state == TCP_SYN_SENT, which hereby we break.
1787 			 * In such case VIRTIO_VSOCK_OP_RST will follow.
1788 			 */
1789 			sk->sk_state = TCP_CLOSE;
1790 			sock->state = SS_UNCONNECTED;
1791 
1792 			/* Try to cancel VIRTIO_VSOCK_OP_REQUEST skb sent out by
1793 			 * transport->connect().
1794 			 */
1795 			vsock_transport_cancel_pkt(vsk);
1796 
1797 			goto out_wait;
1798 		}
1799 
1800 		prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
1801 	}
1802 
1803 	if (sk->sk_err) {
1804 		err = -sk->sk_err;
1805 		sk->sk_state = TCP_CLOSE;
1806 		sock->state = SS_UNCONNECTED;
1807 	} else {
1808 		err = 0;
1809 	}
1810 
1811 out_wait:
1812 	finish_wait(sk_sleep(sk), &wait);
1813 out:
1814 	release_sock(sk);
1815 	return err;
1816 }
1817 
1818 static int vsock_accept(struct socket *sock, struct socket *newsock,
1819 			struct proto_accept_arg *arg)
1820 {
1821 	struct sock *listener;
1822 	int err;
1823 	struct sock *connected;
1824 	struct vsock_sock *vconnected;
1825 	long timeout;
1826 	DEFINE_WAIT(wait);
1827 
1828 	err = 0;
1829 	listener = sock->sk;
1830 
1831 	lock_sock(listener);
1832 
1833 	if (!sock_type_connectible(sock->type)) {
1834 		err = -EOPNOTSUPP;
1835 		goto out;
1836 	}
1837 
1838 	if (listener->sk_state != TCP_LISTEN) {
1839 		err = -EINVAL;
1840 		goto out;
1841 	}
1842 
1843 	/* Wait for children sockets to appear; these are the new sockets
1844 	 * created upon connection establishment.
1845 	 */
1846 	timeout = sock_rcvtimeo(listener, arg->flags & O_NONBLOCK);
1847 	prepare_to_wait(sk_sleep(listener), &wait, TASK_INTERRUPTIBLE);
1848 
1849 	while ((connected = vsock_dequeue_accept(listener)) == NULL &&
1850 	       listener->sk_err == 0) {
1851 		release_sock(listener);
1852 		timeout = schedule_timeout(timeout);
1853 		finish_wait(sk_sleep(listener), &wait);
1854 		lock_sock(listener);
1855 
1856 		if (signal_pending(current)) {
1857 			err = sock_intr_errno(timeout);
1858 			goto out;
1859 		} else if (timeout == 0) {
1860 			err = -EAGAIN;
1861 			goto out;
1862 		}
1863 
1864 		prepare_to_wait(sk_sleep(listener), &wait, TASK_INTERRUPTIBLE);
1865 	}
1866 	finish_wait(sk_sleep(listener), &wait);
1867 
1868 	if (listener->sk_err)
1869 		err = -listener->sk_err;
1870 
1871 	if (connected) {
1872 		sk_acceptq_removed(listener);
1873 
1874 		lock_sock_nested(connected, SINGLE_DEPTH_NESTING);
1875 		vconnected = vsock_sk(connected);
1876 
1877 		/* If the listener socket has received an error, then we should
1878 		 * reject this socket and return.  Note that we simply mark the
1879 		 * socket rejected, drop our reference, and let the cleanup
1880 		 * function handle the cleanup; the fact that we found it in
1881 		 * the listener's accept queue guarantees that the cleanup
1882 		 * function hasn't run yet.
1883 		 */
1884 		if (err) {
1885 			vconnected->rejected = true;
1886 		} else {
1887 			newsock->state = SS_CONNECTED;
1888 			sock_graft(connected, newsock);
1889 
1890 			set_bit(SOCK_CUSTOM_SOCKOPT,
1891 				&connected->sk_socket->flags);
1892 
1893 			if (vsock_msgzerocopy_allow(vconnected->transport))
1894 				set_bit(SOCK_SUPPORT_ZC,
1895 					&connected->sk_socket->flags);
1896 		}
1897 
1898 		release_sock(connected);
1899 		sock_put(connected);
1900 	}
1901 
1902 out:
1903 	release_sock(listener);
1904 	return err;
1905 }
1906 
1907 static int vsock_listen(struct socket *sock, int backlog)
1908 {
1909 	int err;
1910 	struct sock *sk;
1911 	struct vsock_sock *vsk;
1912 
1913 	sk = sock->sk;
1914 
1915 	lock_sock(sk);
1916 
1917 	if (!sock_type_connectible(sk->sk_type)) {
1918 		err = -EOPNOTSUPP;
1919 		goto out;
1920 	}
1921 
1922 	if (sock->state != SS_UNCONNECTED) {
1923 		err = -EINVAL;
1924 		goto out;
1925 	}
1926 
1927 	vsk = vsock_sk(sk);
1928 
1929 	if (!vsock_addr_bound(&vsk->local_addr)) {
1930 		err = -EINVAL;
1931 		goto out;
1932 	}
1933 
1934 	sk->sk_max_ack_backlog = backlog;
1935 	sk->sk_state = TCP_LISTEN;
1936 
1937 	err = 0;
1938 
1939 out:
1940 	release_sock(sk);
1941 	return err;
1942 }
1943 
1944 static void vsock_update_buffer_size(struct vsock_sock *vsk,
1945 				     const struct vsock_transport *transport,
1946 				     u64 val)
1947 {
1948 	if (val > vsk->buffer_max_size)
1949 		val = vsk->buffer_max_size;
1950 
1951 	if (val < vsk->buffer_min_size)
1952 		val = vsk->buffer_min_size;
1953 
1954 	if (val != vsk->buffer_size &&
1955 	    transport && transport->notify_buffer_size)
1956 		transport->notify_buffer_size(vsk, &val);
1957 
1958 	vsk->buffer_size = val;
1959 }
1960 
1961 static int vsock_connectible_setsockopt(struct socket *sock,
1962 					int level,
1963 					int optname,
1964 					sockptr_t optval,
1965 					unsigned int optlen)
1966 {
1967 	int err;
1968 	struct sock *sk;
1969 	struct vsock_sock *vsk;
1970 	const struct vsock_transport *transport;
1971 	u64 val;
1972 
1973 	if (level != AF_VSOCK && level != SOL_SOCKET)
1974 		return -ENOPROTOOPT;
1975 
1976 #define COPY_IN(_v)                                       \
1977 	do {						  \
1978 		if (optlen < sizeof(_v)) {		  \
1979 			err = -EINVAL;			  \
1980 			goto exit;			  \
1981 		}					  \
1982 		if (copy_from_sockptr(&_v, optval, sizeof(_v)) != 0) {	\
1983 			err = -EFAULT;					\
1984 			goto exit;					\
1985 		}							\
1986 	} while (0)
1987 
1988 	err = 0;
1989 	sk = sock->sk;
1990 	vsk = vsock_sk(sk);
1991 
1992 	lock_sock(sk);
1993 
1994 	transport = vsk->transport;
1995 
1996 	if (level == SOL_SOCKET) {
1997 		int zerocopy;
1998 
1999 		if (optname != SO_ZEROCOPY) {
2000 			release_sock(sk);
2001 			return sock_setsockopt(sock, level, optname, optval, optlen);
2002 		}
2003 
2004 		/* Use 'int' type here, because variable to
2005 		 * set this option usually has this type.
2006 		 */
2007 		COPY_IN(zerocopy);
2008 
2009 		if (zerocopy < 0 || zerocopy > 1) {
2010 			err = -EINVAL;
2011 			goto exit;
2012 		}
2013 
2014 		if (transport && !vsock_msgzerocopy_allow(transport)) {
2015 			err = -EOPNOTSUPP;
2016 			goto exit;
2017 		}
2018 
2019 		sock_valbool_flag(sk, SOCK_ZEROCOPY, zerocopy);
2020 		goto exit;
2021 	}
2022 
2023 	switch (optname) {
2024 	case SO_VM_SOCKETS_BUFFER_SIZE:
2025 		COPY_IN(val);
2026 		vsock_update_buffer_size(vsk, transport, val);
2027 		break;
2028 
2029 	case SO_VM_SOCKETS_BUFFER_MAX_SIZE:
2030 		COPY_IN(val);
2031 		vsk->buffer_max_size = val;
2032 		vsock_update_buffer_size(vsk, transport, vsk->buffer_size);
2033 		break;
2034 
2035 	case SO_VM_SOCKETS_BUFFER_MIN_SIZE:
2036 		COPY_IN(val);
2037 		vsk->buffer_min_size = val;
2038 		vsock_update_buffer_size(vsk, transport, vsk->buffer_size);
2039 		break;
2040 
2041 	case SO_VM_SOCKETS_CONNECT_TIMEOUT_NEW:
2042 	case SO_VM_SOCKETS_CONNECT_TIMEOUT_OLD: {
2043 		struct __kernel_sock_timeval tv;
2044 
2045 		err = sock_copy_user_timeval(&tv, optval, optlen,
2046 					     optname == SO_VM_SOCKETS_CONNECT_TIMEOUT_OLD);
2047 		if (err)
2048 			break;
2049 		if (tv.tv_sec >= 0 && tv.tv_usec < USEC_PER_SEC &&
2050 		    tv.tv_sec < (MAX_SCHEDULE_TIMEOUT / HZ - 1)) {
2051 			vsk->connect_timeout = tv.tv_sec * HZ +
2052 				DIV_ROUND_UP((unsigned long)tv.tv_usec, (USEC_PER_SEC / HZ));
2053 			if (vsk->connect_timeout == 0)
2054 				vsk->connect_timeout =
2055 				    VSOCK_DEFAULT_CONNECT_TIMEOUT;
2056 
2057 		} else {
2058 			err = -ERANGE;
2059 		}
2060 		break;
2061 	}
2062 
2063 	default:
2064 		err = -ENOPROTOOPT;
2065 		break;
2066 	}
2067 
2068 #undef COPY_IN
2069 
2070 exit:
2071 	release_sock(sk);
2072 	return err;
2073 }
2074 
2075 static int vsock_connectible_getsockopt(struct socket *sock,
2076 					int level, int optname,
2077 					char __user *optval,
2078 					int __user *optlen)
2079 {
2080 	struct sock *sk = sock->sk;
2081 	struct vsock_sock *vsk = vsock_sk(sk);
2082 
2083 	union {
2084 		u64 val64;
2085 		struct old_timeval32 tm32;
2086 		struct __kernel_old_timeval tm;
2087 		struct  __kernel_sock_timeval stm;
2088 	} v;
2089 
2090 	int lv = sizeof(v.val64);
2091 	int len;
2092 
2093 	if (level != AF_VSOCK)
2094 		return -ENOPROTOOPT;
2095 
2096 	if (get_user(len, optlen))
2097 		return -EFAULT;
2098 
2099 	memset(&v, 0, sizeof(v));
2100 
2101 	switch (optname) {
2102 	case SO_VM_SOCKETS_BUFFER_SIZE:
2103 		v.val64 = vsk->buffer_size;
2104 		break;
2105 
2106 	case SO_VM_SOCKETS_BUFFER_MAX_SIZE:
2107 		v.val64 = vsk->buffer_max_size;
2108 		break;
2109 
2110 	case SO_VM_SOCKETS_BUFFER_MIN_SIZE:
2111 		v.val64 = vsk->buffer_min_size;
2112 		break;
2113 
2114 	case SO_VM_SOCKETS_CONNECT_TIMEOUT_NEW:
2115 	case SO_VM_SOCKETS_CONNECT_TIMEOUT_OLD:
2116 		lv = sock_get_timeout(vsk->connect_timeout, &v,
2117 				      optname == SO_VM_SOCKETS_CONNECT_TIMEOUT_OLD);
2118 		break;
2119 
2120 	default:
2121 		return -ENOPROTOOPT;
2122 	}
2123 
2124 	if (len < lv)
2125 		return -EINVAL;
2126 	if (len > lv)
2127 		len = lv;
2128 	if (copy_to_user(optval, &v, len))
2129 		return -EFAULT;
2130 
2131 	if (put_user(len, optlen))
2132 		return -EFAULT;
2133 
2134 	return 0;
2135 }
2136 
2137 static int vsock_connectible_sendmsg(struct socket *sock, struct msghdr *msg,
2138 				     size_t len)
2139 {
2140 	struct sock *sk;
2141 	struct vsock_sock *vsk;
2142 	const struct vsock_transport *transport;
2143 	ssize_t total_written;
2144 	long timeout;
2145 	int err;
2146 	struct vsock_transport_send_notify_data send_data;
2147 	DEFINE_WAIT_FUNC(wait, woken_wake_function);
2148 
2149 	sk = sock->sk;
2150 	vsk = vsock_sk(sk);
2151 	total_written = 0;
2152 	err = 0;
2153 
2154 	if (msg->msg_flags & MSG_OOB)
2155 		return -EOPNOTSUPP;
2156 
2157 	lock_sock(sk);
2158 
2159 	transport = vsk->transport;
2160 
2161 	/* Callers should not provide a destination with connection oriented
2162 	 * sockets.
2163 	 */
2164 	if (msg->msg_namelen) {
2165 		err = sk->sk_state == TCP_ESTABLISHED ? -EISCONN : -EOPNOTSUPP;
2166 		goto out;
2167 	}
2168 
2169 	/* Send data only if both sides are not shutdown in the direction. */
2170 	if (sk->sk_shutdown & SEND_SHUTDOWN ||
2171 	    vsk->peer_shutdown & RCV_SHUTDOWN) {
2172 		err = -EPIPE;
2173 		goto out;
2174 	}
2175 
2176 	if (!transport || sk->sk_state != TCP_ESTABLISHED ||
2177 	    !vsock_addr_bound(&vsk->local_addr)) {
2178 		err = -ENOTCONN;
2179 		goto out;
2180 	}
2181 
2182 	if (!vsock_addr_bound(&vsk->remote_addr)) {
2183 		err = -EDESTADDRREQ;
2184 		goto out;
2185 	}
2186 
2187 	if (msg->msg_flags & MSG_ZEROCOPY &&
2188 	    !vsock_msgzerocopy_allow(transport)) {
2189 		err = -EOPNOTSUPP;
2190 		goto out;
2191 	}
2192 
2193 	/* Wait for room in the produce queue to enqueue our user's data. */
2194 	timeout = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
2195 
2196 	err = transport->notify_send_init(vsk, &send_data);
2197 	if (err < 0)
2198 		goto out;
2199 
2200 	while (total_written < len) {
2201 		ssize_t written;
2202 
2203 		add_wait_queue(sk_sleep(sk), &wait);
2204 		while (vsock_stream_has_space(vsk) == 0 &&
2205 		       sk->sk_err == 0 &&
2206 		       !(sk->sk_shutdown & SEND_SHUTDOWN) &&
2207 		       !(vsk->peer_shutdown & RCV_SHUTDOWN)) {
2208 
2209 			/* Don't wait for non-blocking sockets. */
2210 			if (timeout == 0) {
2211 				err = -EAGAIN;
2212 				remove_wait_queue(sk_sleep(sk), &wait);
2213 				goto out_err;
2214 			}
2215 
2216 			err = transport->notify_send_pre_block(vsk, &send_data);
2217 			if (err < 0) {
2218 				remove_wait_queue(sk_sleep(sk), &wait);
2219 				goto out_err;
2220 			}
2221 
2222 			release_sock(sk);
2223 			timeout = wait_woken(&wait, TASK_INTERRUPTIBLE, timeout);
2224 			lock_sock(sk);
2225 			if (signal_pending(current)) {
2226 				err = sock_intr_errno(timeout);
2227 				remove_wait_queue(sk_sleep(sk), &wait);
2228 				goto out_err;
2229 			} else if (timeout == 0) {
2230 				err = -EAGAIN;
2231 				remove_wait_queue(sk_sleep(sk), &wait);
2232 				goto out_err;
2233 			}
2234 		}
2235 		remove_wait_queue(sk_sleep(sk), &wait);
2236 
2237 		/* These checks occur both as part of and after the loop
2238 		 * conditional since we need to check before and after
2239 		 * sleeping.
2240 		 */
2241 		if (sk->sk_err) {
2242 			err = -sk->sk_err;
2243 			goto out_err;
2244 		} else if ((sk->sk_shutdown & SEND_SHUTDOWN) ||
2245 			   (vsk->peer_shutdown & RCV_SHUTDOWN)) {
2246 			err = -EPIPE;
2247 			goto out_err;
2248 		}
2249 
2250 		err = transport->notify_send_pre_enqueue(vsk, &send_data);
2251 		if (err < 0)
2252 			goto out_err;
2253 
2254 		/* Note that enqueue will only write as many bytes as are free
2255 		 * in the produce queue, so we don't need to ensure len is
2256 		 * smaller than the queue size.  It is the caller's
2257 		 * responsibility to check how many bytes we were able to send.
2258 		 */
2259 
2260 		if (sk->sk_type == SOCK_SEQPACKET) {
2261 			written = transport->seqpacket_enqueue(vsk,
2262 						msg, len - total_written);
2263 		} else {
2264 			written = transport->stream_enqueue(vsk,
2265 					msg, len - total_written);
2266 		}
2267 
2268 		if (written < 0) {
2269 			err = written;
2270 			goto out_err;
2271 		}
2272 
2273 		total_written += written;
2274 
2275 		err = transport->notify_send_post_enqueue(
2276 				vsk, written, &send_data);
2277 		if (err < 0)
2278 			goto out_err;
2279 
2280 	}
2281 
2282 out_err:
2283 	if (total_written > 0) {
2284 		/* Return number of written bytes only if:
2285 		 * 1) SOCK_STREAM socket.
2286 		 * 2) SOCK_SEQPACKET socket when whole buffer is sent.
2287 		 */
2288 		if (sk->sk_type == SOCK_STREAM || total_written == len)
2289 			err = total_written;
2290 	}
2291 out:
2292 	if (sk->sk_type == SOCK_STREAM)
2293 		err = sk_stream_error(sk, msg->msg_flags, err);
2294 
2295 	release_sock(sk);
2296 	return err;
2297 }
2298 
2299 static int vsock_connectible_wait_data(struct sock *sk,
2300 				       struct wait_queue_entry *wait,
2301 				       long timeout,
2302 				       struct vsock_transport_recv_notify_data *recv_data,
2303 				       size_t target)
2304 {
2305 	const struct vsock_transport *transport;
2306 	struct vsock_sock *vsk;
2307 	s64 data;
2308 	int err;
2309 
2310 	vsk = vsock_sk(sk);
2311 	err = 0;
2312 	transport = vsk->transport;
2313 
2314 	while (1) {
2315 		prepare_to_wait(sk_sleep(sk), wait, TASK_INTERRUPTIBLE);
2316 		data = vsock_connectible_has_data(vsk);
2317 		if (data != 0)
2318 			break;
2319 
2320 		if (sk->sk_err != 0 ||
2321 		    (sk->sk_shutdown & RCV_SHUTDOWN) ||
2322 		    (vsk->peer_shutdown & SEND_SHUTDOWN)) {
2323 			break;
2324 		}
2325 
2326 		/* Don't wait for non-blocking sockets. */
2327 		if (timeout == 0) {
2328 			err = -EAGAIN;
2329 			break;
2330 		}
2331 
2332 		if (recv_data) {
2333 			err = transport->notify_recv_pre_block(vsk, target, recv_data);
2334 			if (err < 0)
2335 				break;
2336 		}
2337 
2338 		release_sock(sk);
2339 		timeout = schedule_timeout(timeout);
2340 		lock_sock(sk);
2341 
2342 		if (signal_pending(current)) {
2343 			err = sock_intr_errno(timeout);
2344 			break;
2345 		} else if (timeout == 0) {
2346 			err = -EAGAIN;
2347 			break;
2348 		}
2349 	}
2350 
2351 	finish_wait(sk_sleep(sk), wait);
2352 
2353 	if (err)
2354 		return err;
2355 
2356 	/* Internal transport error when checking for available
2357 	 * data. XXX This should be changed to a connection
2358 	 * reset in a later change.
2359 	 */
2360 	if (data < 0)
2361 		return -ENOMEM;
2362 
2363 	return data;
2364 }
2365 
2366 static int __vsock_stream_recvmsg(struct sock *sk, struct msghdr *msg,
2367 				  size_t len, int flags)
2368 {
2369 	struct vsock_transport_recv_notify_data recv_data;
2370 	const struct vsock_transport *transport;
2371 	struct vsock_sock *vsk;
2372 	ssize_t copied;
2373 	size_t target;
2374 	long timeout;
2375 	int err;
2376 
2377 	DEFINE_WAIT(wait);
2378 
2379 	vsk = vsock_sk(sk);
2380 	transport = vsk->transport;
2381 
2382 	/* We must not copy less than target bytes into the user's buffer
2383 	 * before returning successfully, so we wait for the consume queue to
2384 	 * have that much data to consume before dequeueing.  Note that this
2385 	 * makes it impossible to handle cases where target is greater than the
2386 	 * queue size.
2387 	 */
2388 	target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
2389 	if (target >= transport->stream_rcvhiwat(vsk)) {
2390 		err = -ENOMEM;
2391 		goto out;
2392 	}
2393 	timeout = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
2394 	copied = 0;
2395 
2396 	err = transport->notify_recv_init(vsk, target, &recv_data);
2397 	if (err < 0)
2398 		goto out;
2399 
2400 
2401 	while (1) {
2402 		ssize_t read;
2403 
2404 		err = vsock_connectible_wait_data(sk, &wait, timeout,
2405 						  &recv_data, target);
2406 		if (err <= 0)
2407 			break;
2408 
2409 		err = transport->notify_recv_pre_dequeue(vsk, target,
2410 							 &recv_data);
2411 		if (err < 0)
2412 			break;
2413 
2414 		read = transport->stream_dequeue(vsk, msg, len - copied, flags);
2415 		if (read < 0) {
2416 			err = read;
2417 			break;
2418 		}
2419 
2420 		copied += read;
2421 
2422 		err = transport->notify_recv_post_dequeue(vsk, target, read,
2423 						!(flags & MSG_PEEK), &recv_data);
2424 		if (err < 0)
2425 			goto out;
2426 
2427 		if (read >= target || flags & MSG_PEEK)
2428 			break;
2429 
2430 		target -= read;
2431 	}
2432 
2433 	if (sk->sk_err)
2434 		err = -sk->sk_err;
2435 	else if (sk->sk_shutdown & RCV_SHUTDOWN)
2436 		err = 0;
2437 
2438 	if (copied > 0)
2439 		err = copied;
2440 
2441 out:
2442 	return err;
2443 }
2444 
2445 static int __vsock_seqpacket_recvmsg(struct sock *sk, struct msghdr *msg,
2446 				     size_t len, int flags)
2447 {
2448 	const struct vsock_transport *transport;
2449 	struct vsock_sock *vsk;
2450 	ssize_t msg_len;
2451 	long timeout;
2452 	int err = 0;
2453 	DEFINE_WAIT(wait);
2454 
2455 	vsk = vsock_sk(sk);
2456 	transport = vsk->transport;
2457 
2458 	timeout = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
2459 
2460 	err = vsock_connectible_wait_data(sk, &wait, timeout, NULL, 0);
2461 	if (err <= 0)
2462 		goto out;
2463 
2464 	msg_len = transport->seqpacket_dequeue(vsk, msg, flags);
2465 
2466 	if (msg_len < 0) {
2467 		err = msg_len;
2468 		goto out;
2469 	}
2470 
2471 	if (sk->sk_err) {
2472 		err = -sk->sk_err;
2473 	} else if (sk->sk_shutdown & RCV_SHUTDOWN) {
2474 		err = 0;
2475 	} else {
2476 		/* User sets MSG_TRUNC, so return real length of
2477 		 * packet.
2478 		 */
2479 		if (flags & MSG_TRUNC)
2480 			err = msg_len;
2481 		else
2482 			err = len - msg_data_left(msg);
2483 
2484 		/* Always set MSG_TRUNC if real length of packet is
2485 		 * bigger than user's buffer.
2486 		 */
2487 		if (msg_len > len)
2488 			msg->msg_flags |= MSG_TRUNC;
2489 	}
2490 
2491 out:
2492 	return err;
2493 }
2494 
2495 int
2496 __vsock_connectible_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
2497 			    int flags)
2498 {
2499 	struct sock *sk;
2500 	struct vsock_sock *vsk;
2501 	const struct vsock_transport *transport;
2502 	int err;
2503 
2504 	sk = sock->sk;
2505 
2506 	if (unlikely(flags & MSG_ERRQUEUE))
2507 		return sock_recv_errqueue(sk, msg, len, SOL_VSOCK, VSOCK_RECVERR);
2508 
2509 	vsk = vsock_sk(sk);
2510 	err = 0;
2511 
2512 	lock_sock(sk);
2513 
2514 	transport = vsk->transport;
2515 
2516 	if (!transport || sk->sk_state != TCP_ESTABLISHED) {
2517 		/* Recvmsg is supposed to return 0 if a peer performs an
2518 		 * orderly shutdown. Differentiate between that case and when a
2519 		 * peer has not connected or a local shutdown occurred with the
2520 		 * SOCK_DONE flag.
2521 		 */
2522 		if (sock_flag(sk, SOCK_DONE))
2523 			err = 0;
2524 		else
2525 			err = -ENOTCONN;
2526 
2527 		goto out;
2528 	}
2529 
2530 	if (flags & MSG_OOB) {
2531 		err = -EOPNOTSUPP;
2532 		goto out;
2533 	}
2534 
2535 	/* We don't check peer_shutdown flag here since peer may actually shut
2536 	 * down, but there can be data in the queue that a local socket can
2537 	 * receive.
2538 	 */
2539 	if (sk->sk_shutdown & RCV_SHUTDOWN) {
2540 		err = 0;
2541 		goto out;
2542 	}
2543 
2544 	/* It is valid on Linux to pass in a zero-length receive buffer.  This
2545 	 * is not an error.  We may as well bail out now.
2546 	 */
2547 	if (!len) {
2548 		err = 0;
2549 		goto out;
2550 	}
2551 
2552 	if (sk->sk_type == SOCK_STREAM)
2553 		err = __vsock_stream_recvmsg(sk, msg, len, flags);
2554 	else
2555 		err = __vsock_seqpacket_recvmsg(sk, msg, len, flags);
2556 
2557 out:
2558 	release_sock(sk);
2559 	return err;
2560 }
2561 
2562 int
2563 vsock_connectible_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
2564 			  int flags)
2565 {
2566 #ifdef CONFIG_BPF_SYSCALL
2567 	struct sock *sk = sock->sk;
2568 	const struct proto *prot;
2569 
2570 	prot = READ_ONCE(sk->sk_prot);
2571 	if (prot != &vsock_proto)
2572 		return prot->recvmsg(sk, msg, len, flags, NULL);
2573 #endif
2574 
2575 	return __vsock_connectible_recvmsg(sock, msg, len, flags);
2576 }
2577 EXPORT_SYMBOL_GPL(vsock_connectible_recvmsg);
2578 
2579 static int vsock_set_rcvlowat(struct sock *sk, int val)
2580 {
2581 	const struct vsock_transport *transport;
2582 	struct vsock_sock *vsk;
2583 
2584 	vsk = vsock_sk(sk);
2585 
2586 	if (val > vsk->buffer_size)
2587 		return -EINVAL;
2588 
2589 	transport = vsk->transport;
2590 
2591 	if (transport && transport->notify_set_rcvlowat) {
2592 		int err;
2593 
2594 		err = transport->notify_set_rcvlowat(vsk, val);
2595 		if (err)
2596 			return err;
2597 	}
2598 
2599 	WRITE_ONCE(sk->sk_rcvlowat, val ? : 1);
2600 	return 0;
2601 }
2602 
2603 static const struct proto_ops vsock_stream_ops = {
2604 	.family = PF_VSOCK,
2605 	.owner = THIS_MODULE,
2606 	.release = vsock_release,
2607 	.bind = vsock_bind,
2608 	.connect = vsock_connect,
2609 	.socketpair = sock_no_socketpair,
2610 	.accept = vsock_accept,
2611 	.getname = vsock_getname,
2612 	.poll = vsock_poll,
2613 	.ioctl = vsock_ioctl,
2614 	.listen = vsock_listen,
2615 	.shutdown = vsock_shutdown,
2616 	.setsockopt = vsock_connectible_setsockopt,
2617 	.getsockopt = vsock_connectible_getsockopt,
2618 	.sendmsg = vsock_connectible_sendmsg,
2619 	.recvmsg = vsock_connectible_recvmsg,
2620 	.mmap = sock_no_mmap,
2621 	.set_rcvlowat = vsock_set_rcvlowat,
2622 	.read_skb = vsock_read_skb,
2623 };
2624 
2625 static const struct proto_ops vsock_seqpacket_ops = {
2626 	.family = PF_VSOCK,
2627 	.owner = THIS_MODULE,
2628 	.release = vsock_release,
2629 	.bind = vsock_bind,
2630 	.connect = vsock_connect,
2631 	.socketpair = sock_no_socketpair,
2632 	.accept = vsock_accept,
2633 	.getname = vsock_getname,
2634 	.poll = vsock_poll,
2635 	.ioctl = vsock_ioctl,
2636 	.listen = vsock_listen,
2637 	.shutdown = vsock_shutdown,
2638 	.setsockopt = vsock_connectible_setsockopt,
2639 	.getsockopt = vsock_connectible_getsockopt,
2640 	.sendmsg = vsock_connectible_sendmsg,
2641 	.recvmsg = vsock_connectible_recvmsg,
2642 	.mmap = sock_no_mmap,
2643 	.read_skb = vsock_read_skb,
2644 };
2645 
2646 static int vsock_create(struct net *net, struct socket *sock,
2647 			int protocol, int kern)
2648 {
2649 	struct vsock_sock *vsk;
2650 	struct sock *sk;
2651 	int ret;
2652 
2653 	if (!sock)
2654 		return -EINVAL;
2655 
2656 	if (protocol && protocol != PF_VSOCK)
2657 		return -EPROTONOSUPPORT;
2658 
2659 	switch (sock->type) {
2660 	case SOCK_DGRAM:
2661 		sock->ops = &vsock_dgram_ops;
2662 		break;
2663 	case SOCK_STREAM:
2664 		sock->ops = &vsock_stream_ops;
2665 		break;
2666 	case SOCK_SEQPACKET:
2667 		sock->ops = &vsock_seqpacket_ops;
2668 		break;
2669 	default:
2670 		return -ESOCKTNOSUPPORT;
2671 	}
2672 
2673 	sock->state = SS_UNCONNECTED;
2674 
2675 	sk = __vsock_create(net, sock, NULL, GFP_KERNEL, 0, kern);
2676 	if (!sk)
2677 		return -ENOMEM;
2678 
2679 	vsk = vsock_sk(sk);
2680 
2681 	if (sock->type == SOCK_DGRAM) {
2682 		ret = vsock_assign_transport(vsk, NULL);
2683 		if (ret < 0) {
2684 			sock->sk = NULL;
2685 			sock_put(sk);
2686 			return ret;
2687 		}
2688 	}
2689 
2690 	/* SOCK_DGRAM doesn't have 'setsockopt' callback set in its
2691 	 * proto_ops, so there is no handler for custom logic.
2692 	 */
2693 	if (sock_type_connectible(sock->type))
2694 		set_bit(SOCK_CUSTOM_SOCKOPT, &sk->sk_socket->flags);
2695 
2696 	vsock_insert_unbound(vsk);
2697 
2698 	return 0;
2699 }
2700 
2701 static const struct net_proto_family vsock_family_ops = {
2702 	.family = AF_VSOCK,
2703 	.create = vsock_create,
2704 	.owner = THIS_MODULE,
2705 };
2706 
2707 static long vsock_dev_do_ioctl(struct file *filp,
2708 			       unsigned int cmd, void __user *ptr)
2709 {
2710 	u32 __user *p = ptr;
2711 	int retval = 0;
2712 	u32 cid;
2713 
2714 	switch (cmd) {
2715 	case IOCTL_VM_SOCKETS_GET_LOCAL_CID:
2716 		/* To be compatible with the VMCI behavior, we prioritize the
2717 		 * guest CID instead of well-know host CID (VMADDR_CID_HOST).
2718 		 */
2719 		cid = vsock_registered_transport_cid(&transport_g2h);
2720 		if (cid == VMADDR_CID_ANY)
2721 			cid = vsock_registered_transport_cid(&transport_h2g);
2722 		if (cid == VMADDR_CID_ANY)
2723 			cid = vsock_registered_transport_cid(&transport_local);
2724 
2725 		if (put_user(cid, p) != 0)
2726 			retval = -EFAULT;
2727 		break;
2728 
2729 	default:
2730 		retval = -ENOIOCTLCMD;
2731 	}
2732 
2733 	return retval;
2734 }
2735 
2736 static long vsock_dev_ioctl(struct file *filp,
2737 			    unsigned int cmd, unsigned long arg)
2738 {
2739 	return vsock_dev_do_ioctl(filp, cmd, (void __user *)arg);
2740 }
2741 
2742 #ifdef CONFIG_COMPAT
2743 static long vsock_dev_compat_ioctl(struct file *filp,
2744 				   unsigned int cmd, unsigned long arg)
2745 {
2746 	return vsock_dev_do_ioctl(filp, cmd, compat_ptr(arg));
2747 }
2748 #endif
2749 
2750 static const struct file_operations vsock_device_ops = {
2751 	.owner		= THIS_MODULE,
2752 	.unlocked_ioctl	= vsock_dev_ioctl,
2753 #ifdef CONFIG_COMPAT
2754 	.compat_ioctl	= vsock_dev_compat_ioctl,
2755 #endif
2756 	.open		= nonseekable_open,
2757 };
2758 
2759 static struct miscdevice vsock_device = {
2760 	.name		= "vsock",
2761 	.fops		= &vsock_device_ops,
2762 };
2763 
2764 static int __vsock_net_mode_string(const struct ctl_table *table, int write,
2765 				   void *buffer, size_t *lenp, loff_t *ppos,
2766 				   enum vsock_net_mode mode,
2767 				   enum vsock_net_mode *new_mode)
2768 {
2769 	char data[VSOCK_NET_MODE_STR_MAX] = {0};
2770 	struct ctl_table tmp;
2771 	int ret;
2772 
2773 	if (!table->data || !table->maxlen || !*lenp) {
2774 		*lenp = 0;
2775 		return 0;
2776 	}
2777 
2778 	tmp = *table;
2779 	tmp.data = data;
2780 
2781 	if (!write) {
2782 		const char *p;
2783 
2784 		switch (mode) {
2785 		case VSOCK_NET_MODE_GLOBAL:
2786 			p = VSOCK_NET_MODE_STR_GLOBAL;
2787 			break;
2788 		case VSOCK_NET_MODE_LOCAL:
2789 			p = VSOCK_NET_MODE_STR_LOCAL;
2790 			break;
2791 		default:
2792 			WARN_ONCE(true, "netns has invalid vsock mode");
2793 			*lenp = 0;
2794 			return 0;
2795 		}
2796 
2797 		strscpy(data, p, sizeof(data));
2798 		tmp.maxlen = strlen(p);
2799 	}
2800 
2801 	ret = proc_dostring(&tmp, write, buffer, lenp, ppos);
2802 	if (ret || !write)
2803 		return ret;
2804 
2805 	if (*lenp >= sizeof(data))
2806 		return -EINVAL;
2807 
2808 	if (!strncmp(data, VSOCK_NET_MODE_STR_GLOBAL, sizeof(data)))
2809 		*new_mode = VSOCK_NET_MODE_GLOBAL;
2810 	else if (!strncmp(data, VSOCK_NET_MODE_STR_LOCAL, sizeof(data)))
2811 		*new_mode = VSOCK_NET_MODE_LOCAL;
2812 	else
2813 		return -EINVAL;
2814 
2815 	return 0;
2816 }
2817 
2818 static int vsock_net_mode_string(const struct ctl_table *table, int write,
2819 				 void *buffer, size_t *lenp, loff_t *ppos)
2820 {
2821 	struct net *net;
2822 
2823 	if (write)
2824 		return -EPERM;
2825 
2826 	net = current->nsproxy->net_ns;
2827 
2828 	return __vsock_net_mode_string(table, write, buffer, lenp, ppos,
2829 				       vsock_net_mode(net), NULL);
2830 }
2831 
2832 static int vsock_net_child_mode_string(const struct ctl_table *table, int write,
2833 				       void *buffer, size_t *lenp, loff_t *ppos)
2834 {
2835 	enum vsock_net_mode new_mode;
2836 	struct net *net;
2837 	int ret;
2838 
2839 	net = current->nsproxy->net_ns;
2840 
2841 	ret = __vsock_net_mode_string(table, write, buffer, lenp, ppos,
2842 				      vsock_net_child_mode(net), &new_mode);
2843 	if (ret)
2844 		return ret;
2845 
2846 	if (write)
2847 		vsock_net_set_child_mode(net, new_mode);
2848 
2849 	return 0;
2850 }
2851 
2852 static struct ctl_table vsock_table[] = {
2853 	{
2854 		.procname	= "ns_mode",
2855 		.data		= &init_net.vsock.mode,
2856 		.maxlen		= VSOCK_NET_MODE_STR_MAX,
2857 		.mode		= 0444,
2858 		.proc_handler	= vsock_net_mode_string
2859 	},
2860 	{
2861 		.procname	= "child_ns_mode",
2862 		.data		= &init_net.vsock.child_ns_mode,
2863 		.maxlen		= VSOCK_NET_MODE_STR_MAX,
2864 		.mode		= 0644,
2865 		.proc_handler	= vsock_net_child_mode_string
2866 	},
2867 };
2868 
2869 static int __net_init vsock_sysctl_register(struct net *net)
2870 {
2871 	struct ctl_table *table;
2872 
2873 	if (net_eq(net, &init_net)) {
2874 		table = vsock_table;
2875 	} else {
2876 		table = kmemdup(vsock_table, sizeof(vsock_table), GFP_KERNEL);
2877 		if (!table)
2878 			goto err_alloc;
2879 
2880 		table[0].data = &net->vsock.mode;
2881 		table[1].data = &net->vsock.child_ns_mode;
2882 	}
2883 
2884 	net->vsock.sysctl_hdr = register_net_sysctl_sz(net, "net/vsock", table,
2885 						       ARRAY_SIZE(vsock_table));
2886 	if (!net->vsock.sysctl_hdr)
2887 		goto err_reg;
2888 
2889 	return 0;
2890 
2891 err_reg:
2892 	if (!net_eq(net, &init_net))
2893 		kfree(table);
2894 err_alloc:
2895 	return -ENOMEM;
2896 }
2897 
2898 static void vsock_sysctl_unregister(struct net *net)
2899 {
2900 	const struct ctl_table *table;
2901 
2902 	table = net->vsock.sysctl_hdr->ctl_table_arg;
2903 	unregister_net_sysctl_table(net->vsock.sysctl_hdr);
2904 	if (!net_eq(net, &init_net))
2905 		kfree(table);
2906 }
2907 
2908 static void vsock_net_init(struct net *net)
2909 {
2910 	if (net_eq(net, &init_net))
2911 		net->vsock.mode = VSOCK_NET_MODE_GLOBAL;
2912 	else
2913 		net->vsock.mode = vsock_net_child_mode(current->nsproxy->net_ns);
2914 
2915 	net->vsock.child_ns_mode = VSOCK_NET_MODE_GLOBAL;
2916 }
2917 
2918 static __net_init int vsock_sysctl_init_net(struct net *net)
2919 {
2920 	vsock_net_init(net);
2921 
2922 	if (vsock_sysctl_register(net))
2923 		return -ENOMEM;
2924 
2925 	return 0;
2926 }
2927 
2928 static __net_exit void vsock_sysctl_exit_net(struct net *net)
2929 {
2930 	vsock_sysctl_unregister(net);
2931 }
2932 
2933 static struct pernet_operations vsock_sysctl_ops = {
2934 	.init = vsock_sysctl_init_net,
2935 	.exit = vsock_sysctl_exit_net,
2936 };
2937 
2938 static int __init vsock_init(void)
2939 {
2940 	int err = 0;
2941 
2942 	vsock_init_tables();
2943 
2944 	vsock_proto.owner = THIS_MODULE;
2945 	vsock_device.minor = MISC_DYNAMIC_MINOR;
2946 	err = misc_register(&vsock_device);
2947 	if (err) {
2948 		pr_err("Failed to register misc device\n");
2949 		goto err_reset_transport;
2950 	}
2951 
2952 	err = proto_register(&vsock_proto, 1);	/* we want our slab */
2953 	if (err) {
2954 		pr_err("Cannot register vsock protocol\n");
2955 		goto err_deregister_misc;
2956 	}
2957 
2958 	err = sock_register(&vsock_family_ops);
2959 	if (err) {
2960 		pr_err("could not register af_vsock (%d) address family: %d\n",
2961 		       AF_VSOCK, err);
2962 		goto err_unregister_proto;
2963 	}
2964 
2965 	if (register_pernet_subsys(&vsock_sysctl_ops)) {
2966 		err = -ENOMEM;
2967 		goto err_unregister_sock;
2968 	}
2969 
2970 	vsock_bpf_build_proto();
2971 
2972 	return 0;
2973 
2974 err_unregister_sock:
2975 	sock_unregister(AF_VSOCK);
2976 err_unregister_proto:
2977 	proto_unregister(&vsock_proto);
2978 err_deregister_misc:
2979 	misc_deregister(&vsock_device);
2980 err_reset_transport:
2981 	return err;
2982 }
2983 
2984 static void __exit vsock_exit(void)
2985 {
2986 	misc_deregister(&vsock_device);
2987 	sock_unregister(AF_VSOCK);
2988 	proto_unregister(&vsock_proto);
2989 	unregister_pernet_subsys(&vsock_sysctl_ops);
2990 }
2991 
2992 const struct vsock_transport *vsock_core_get_transport(struct vsock_sock *vsk)
2993 {
2994 	return vsk->transport;
2995 }
2996 EXPORT_SYMBOL_GPL(vsock_core_get_transport);
2997 
2998 int vsock_core_register(const struct vsock_transport *t, int features)
2999 {
3000 	const struct vsock_transport *t_h2g, *t_g2h, *t_dgram, *t_local;
3001 	int err = mutex_lock_interruptible(&vsock_register_mutex);
3002 
3003 	if (err)
3004 		return err;
3005 
3006 	t_h2g = transport_h2g;
3007 	t_g2h = transport_g2h;
3008 	t_dgram = transport_dgram;
3009 	t_local = transport_local;
3010 
3011 	if (features & VSOCK_TRANSPORT_F_H2G) {
3012 		if (t_h2g) {
3013 			err = -EBUSY;
3014 			goto err_busy;
3015 		}
3016 		t_h2g = t;
3017 	}
3018 
3019 	if (features & VSOCK_TRANSPORT_F_G2H) {
3020 		if (t_g2h) {
3021 			err = -EBUSY;
3022 			goto err_busy;
3023 		}
3024 		t_g2h = t;
3025 	}
3026 
3027 	if (features & VSOCK_TRANSPORT_F_DGRAM) {
3028 		if (t_dgram) {
3029 			err = -EBUSY;
3030 			goto err_busy;
3031 		}
3032 		t_dgram = t;
3033 	}
3034 
3035 	if (features & VSOCK_TRANSPORT_F_LOCAL) {
3036 		if (t_local) {
3037 			err = -EBUSY;
3038 			goto err_busy;
3039 		}
3040 		t_local = t;
3041 	}
3042 
3043 	transport_h2g = t_h2g;
3044 	transport_g2h = t_g2h;
3045 	transport_dgram = t_dgram;
3046 	transport_local = t_local;
3047 
3048 err_busy:
3049 	mutex_unlock(&vsock_register_mutex);
3050 	return err;
3051 }
3052 EXPORT_SYMBOL_GPL(vsock_core_register);
3053 
3054 void vsock_core_unregister(const struct vsock_transport *t)
3055 {
3056 	mutex_lock(&vsock_register_mutex);
3057 
3058 	if (transport_h2g == t)
3059 		transport_h2g = NULL;
3060 
3061 	if (transport_g2h == t)
3062 		transport_g2h = NULL;
3063 
3064 	if (transport_dgram == t)
3065 		transport_dgram = NULL;
3066 
3067 	if (transport_local == t)
3068 		transport_local = NULL;
3069 
3070 	mutex_unlock(&vsock_register_mutex);
3071 }
3072 EXPORT_SYMBOL_GPL(vsock_core_unregister);
3073 
3074 module_init(vsock_init);
3075 module_exit(vsock_exit);
3076 
3077 MODULE_AUTHOR("VMware, Inc.");
3078 MODULE_DESCRIPTION("VMware Virtual Socket Family");
3079 MODULE_VERSION("1.0.2.0-k");
3080 MODULE_LICENSE("GPL v2");
3081