xref: /linux/net/vmw_vsock/virtio_transport_common.c (revision c4dde411bc366f568dbe33366253bbfea049e8ea)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * common code for virtio vsock
4  *
5  * Copyright (C) 2013-2015 Red Hat, Inc.
6  * Author: Asias He <asias@redhat.com>
7  *         Stefan Hajnoczi <stefanha@redhat.com>
8  */
9 #include <linux/spinlock.h>
10 #include <linux/module.h>
11 #include <linux/sched/signal.h>
12 #include <linux/ctype.h>
13 #include <linux/list.h>
14 #include <linux/virtio_vsock.h>
15 #include <uapi/linux/vsockmon.h>
16 
17 #include <net/sock.h>
18 #include <net/af_vsock.h>
19 
20 #define CREATE_TRACE_POINTS
21 #include <trace/events/vsock_virtio_transport_common.h>
22 
23 /* How long to wait for graceful shutdown of a connection */
24 #define VSOCK_CLOSE_TIMEOUT (8 * HZ)
25 
26 /* Threshold for detecting small packets to copy */
27 #define GOOD_COPY_LEN  128
28 
29 static void virtio_transport_cancel_close_work(struct vsock_sock *vsk,
30 					       bool cancel_timeout);
31 static s64 virtio_transport_has_space(struct virtio_vsock_sock *vvs);
32 
33 static const struct virtio_transport *
34 virtio_transport_get_ops(struct vsock_sock *vsk)
35 {
36 	const struct vsock_transport *t = vsock_core_get_transport(vsk);
37 
38 	if (WARN_ON(!t))
39 		return NULL;
40 
41 	return container_of(t, struct virtio_transport, transport);
42 }
43 
44 static bool virtio_transport_can_zcopy(const struct virtio_transport *t_ops,
45 				       struct virtio_vsock_pkt_info *info,
46 				       size_t pkt_len)
47 {
48 	struct iov_iter *iov_iter;
49 
50 	if (!info->msg)
51 		return false;
52 
53 	iov_iter = &info->msg->msg_iter;
54 
55 	if (iov_iter->iov_offset)
56 		return false;
57 
58 	/* We can't send whole iov. */
59 	if (iov_iter->count > pkt_len)
60 		return false;
61 
62 	/* Check that transport can send data in zerocopy mode. */
63 	if (t_ops->can_msgzerocopy) {
64 		int pages_to_send = iov_iter_npages(iov_iter, MAX_SKB_FRAGS);
65 
66 		/* +1 is for packet header. */
67 		return t_ops->can_msgzerocopy(pages_to_send + 1);
68 	}
69 
70 	return true;
71 }
72 
73 static int virtio_transport_init_zcopy_skb(struct vsock_sock *vsk,
74 					   struct sk_buff *skb,
75 					   struct msghdr *msg,
76 					   bool zerocopy)
77 {
78 	struct ubuf_info *uarg;
79 
80 	if (msg->msg_ubuf) {
81 		uarg = msg->msg_ubuf;
82 		net_zcopy_get(uarg);
83 	} else {
84 		struct iov_iter *iter = &msg->msg_iter;
85 		struct ubuf_info_msgzc *uarg_zc;
86 
87 		uarg = msg_zerocopy_realloc(sk_vsock(vsk),
88 					    iter->count,
89 					    NULL, false);
90 		if (!uarg)
91 			return -1;
92 
93 		uarg_zc = uarg_to_msgzc(uarg);
94 		uarg_zc->zerocopy = zerocopy ? 1 : 0;
95 	}
96 
97 	skb_zcopy_init(skb, uarg);
98 
99 	return 0;
100 }
101 
102 static int virtio_transport_fill_skb(struct sk_buff *skb,
103 				     struct virtio_vsock_pkt_info *info,
104 				     size_t len,
105 				     bool zcopy)
106 {
107 	struct msghdr *msg = info->msg;
108 
109 	if (zcopy)
110 		return __zerocopy_sg_from_iter(msg, NULL, skb,
111 					       &msg->msg_iter, len, NULL);
112 
113 	virtio_vsock_skb_put(skb, len);
114 	return skb_copy_datagram_from_iter_full(skb, 0, &msg->msg_iter, len);
115 }
116 
117 static void virtio_transport_init_hdr(struct sk_buff *skb,
118 				      struct virtio_vsock_pkt_info *info,
119 				      size_t payload_len,
120 				      u32 src_cid,
121 				      u32 src_port,
122 				      u32 dst_cid,
123 				      u32 dst_port)
124 {
125 	struct virtio_vsock_hdr *hdr;
126 
127 	hdr = virtio_vsock_hdr(skb);
128 	hdr->type	= cpu_to_le16(info->type);
129 	hdr->op		= cpu_to_le16(info->op);
130 	hdr->src_cid	= cpu_to_le64(src_cid);
131 	hdr->dst_cid	= cpu_to_le64(dst_cid);
132 	hdr->src_port	= cpu_to_le32(src_port);
133 	hdr->dst_port	= cpu_to_le32(dst_port);
134 	hdr->flags	= cpu_to_le32(info->flags);
135 	hdr->len	= cpu_to_le32(payload_len);
136 	hdr->buf_alloc	= cpu_to_le32(0);
137 	hdr->fwd_cnt	= cpu_to_le32(0);
138 }
139 
140 static void virtio_transport_copy_nonlinear_skb(const struct sk_buff *skb,
141 						void *dst,
142 						size_t len)
143 {
144 	struct iov_iter iov_iter = { 0 };
145 	struct kvec kvec;
146 	size_t to_copy;
147 
148 	kvec.iov_base = dst;
149 	kvec.iov_len = len;
150 
151 	iov_iter.iter_type = ITER_KVEC;
152 	iov_iter.kvec = &kvec;
153 	iov_iter.nr_segs = 1;
154 
155 	to_copy = min_t(size_t, len, skb->len);
156 
157 	skb_copy_datagram_iter(skb, VIRTIO_VSOCK_SKB_CB(skb)->offset,
158 			       &iov_iter, to_copy);
159 }
160 
161 /* Packet capture */
162 static struct sk_buff *virtio_transport_build_skb(void *opaque)
163 {
164 	struct virtio_vsock_hdr *pkt_hdr;
165 	struct sk_buff *pkt = opaque;
166 	struct af_vsockmon_hdr *hdr;
167 	struct sk_buff *skb;
168 	size_t payload_len;
169 
170 	/* A packet could be split to fit the RX buffer, so we can retrieve
171 	 * the payload length from the header and the buffer pointer taking
172 	 * care of the offset in the original packet.
173 	 */
174 	pkt_hdr = virtio_vsock_hdr(pkt);
175 	payload_len = pkt->len;
176 
177 	skb = alloc_skb(sizeof(*hdr) + sizeof(*pkt_hdr) + payload_len,
178 			GFP_ATOMIC);
179 	if (!skb)
180 		return NULL;
181 
182 	hdr = skb_put(skb, sizeof(*hdr));
183 
184 	/* pkt->hdr is little-endian so no need to byteswap here */
185 	hdr->src_cid = pkt_hdr->src_cid;
186 	hdr->src_port = pkt_hdr->src_port;
187 	hdr->dst_cid = pkt_hdr->dst_cid;
188 	hdr->dst_port = pkt_hdr->dst_port;
189 
190 	hdr->transport = cpu_to_le16(AF_VSOCK_TRANSPORT_VIRTIO);
191 	hdr->len = cpu_to_le16(sizeof(*pkt_hdr));
192 	memset(hdr->reserved, 0, sizeof(hdr->reserved));
193 
194 	switch (le16_to_cpu(pkt_hdr->op)) {
195 	case VIRTIO_VSOCK_OP_REQUEST:
196 	case VIRTIO_VSOCK_OP_RESPONSE:
197 		hdr->op = cpu_to_le16(AF_VSOCK_OP_CONNECT);
198 		break;
199 	case VIRTIO_VSOCK_OP_RST:
200 	case VIRTIO_VSOCK_OP_SHUTDOWN:
201 		hdr->op = cpu_to_le16(AF_VSOCK_OP_DISCONNECT);
202 		break;
203 	case VIRTIO_VSOCK_OP_RW:
204 		hdr->op = cpu_to_le16(AF_VSOCK_OP_PAYLOAD);
205 		break;
206 	case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
207 	case VIRTIO_VSOCK_OP_CREDIT_REQUEST:
208 		hdr->op = cpu_to_le16(AF_VSOCK_OP_CONTROL);
209 		break;
210 	default:
211 		hdr->op = cpu_to_le16(AF_VSOCK_OP_UNKNOWN);
212 		break;
213 	}
214 
215 	skb_put_data(skb, pkt_hdr, sizeof(*pkt_hdr));
216 
217 	if (payload_len) {
218 		if (skb_is_nonlinear(pkt)) {
219 			void *data = skb_put(skb, payload_len);
220 
221 			virtio_transport_copy_nonlinear_skb(pkt, data, payload_len);
222 		} else {
223 			skb_put_data(skb, pkt->data, payload_len);
224 		}
225 	}
226 
227 	return skb;
228 }
229 
230 void virtio_transport_deliver_tap_pkt(struct sk_buff *skb)
231 {
232 	if (virtio_vsock_skb_tap_delivered(skb))
233 		return;
234 
235 	vsock_deliver_tap(virtio_transport_build_skb, skb);
236 	virtio_vsock_skb_set_tap_delivered(skb);
237 }
238 EXPORT_SYMBOL_GPL(virtio_transport_deliver_tap_pkt);
239 
240 static u16 virtio_transport_get_type(struct sock *sk)
241 {
242 	if (sk->sk_type == SOCK_STREAM)
243 		return VIRTIO_VSOCK_TYPE_STREAM;
244 	else
245 		return VIRTIO_VSOCK_TYPE_SEQPACKET;
246 }
247 
248 /* Returns new sk_buff on success, otherwise returns NULL. */
249 static struct sk_buff *virtio_transport_alloc_skb(struct virtio_vsock_pkt_info *info,
250 						  size_t payload_len,
251 						  bool zcopy,
252 						  u32 src_cid,
253 						  u32 src_port,
254 						  u32 dst_cid,
255 						  u32 dst_port)
256 {
257 	struct vsock_sock *vsk;
258 	struct sk_buff *skb;
259 	size_t skb_len;
260 
261 	skb_len = VIRTIO_VSOCK_SKB_HEADROOM;
262 
263 	if (!zcopy)
264 		skb_len += payload_len;
265 
266 	skb = virtio_vsock_alloc_skb(skb_len, GFP_KERNEL);
267 	if (!skb)
268 		return NULL;
269 
270 	virtio_transport_init_hdr(skb, info, payload_len, src_cid, src_port,
271 				  dst_cid, dst_port);
272 
273 	vsk = info->vsk;
274 
275 	/* If 'vsk' != NULL then payload is always present, so we
276 	 * will never call '__zerocopy_sg_from_iter()' below without
277 	 * setting skb owner in 'skb_set_owner_w()'. The only case
278 	 * when 'vsk' == NULL is VIRTIO_VSOCK_OP_RST control message
279 	 * without payload.
280 	 */
281 	WARN_ON_ONCE(!(vsk && (info->msg && payload_len)) && zcopy);
282 
283 	/* Set owner here, because '__zerocopy_sg_from_iter()' uses
284 	 * owner of skb without check to update 'sk_wmem_alloc'.
285 	 */
286 	if (vsk)
287 		skb_set_owner_w(skb, sk_vsock(vsk));
288 
289 	if (info->msg && payload_len > 0) {
290 		int err;
291 
292 		err = virtio_transport_fill_skb(skb, info, payload_len, zcopy);
293 		if (err)
294 			goto out;
295 
296 		if (msg_data_left(info->msg) == 0 &&
297 		    info->type == VIRTIO_VSOCK_TYPE_SEQPACKET) {
298 			struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
299 
300 			hdr->flags |= cpu_to_le32(VIRTIO_VSOCK_SEQ_EOM);
301 
302 			if (info->msg->msg_flags & MSG_EOR)
303 				hdr->flags |= cpu_to_le32(VIRTIO_VSOCK_SEQ_EOR);
304 		}
305 	}
306 
307 	if (info->reply)
308 		virtio_vsock_skb_set_reply(skb);
309 
310 	trace_virtio_transport_alloc_pkt(src_cid, src_port,
311 					 dst_cid, dst_port,
312 					 payload_len,
313 					 info->type,
314 					 info->op,
315 					 info->flags,
316 					 zcopy);
317 
318 	return skb;
319 out:
320 	kfree_skb(skb);
321 	return NULL;
322 }
323 
324 /* This function can only be used on connecting/connected sockets,
325  * since a socket assigned to a transport is required.
326  *
327  * Do not use on listener sockets!
328  */
329 static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
330 					  struct virtio_vsock_pkt_info *info)
331 {
332 	u32 max_skb_len = VIRTIO_VSOCK_MAX_PKT_BUF_SIZE;
333 	u32 src_cid, src_port, dst_cid, dst_port;
334 	const struct virtio_transport *t_ops;
335 	struct virtio_vsock_sock *vvs;
336 	u32 pkt_len = info->pkt_len;
337 	bool can_zcopy = false;
338 	u32 rest_len;
339 	int ret;
340 
341 	info->type = virtio_transport_get_type(sk_vsock(vsk));
342 
343 	t_ops = virtio_transport_get_ops(vsk);
344 	if (unlikely(!t_ops))
345 		return -EFAULT;
346 
347 	src_cid = t_ops->transport.get_local_cid();
348 	src_port = vsk->local_addr.svm_port;
349 	if (!info->remote_cid) {
350 		dst_cid	= vsk->remote_addr.svm_cid;
351 		dst_port = vsk->remote_addr.svm_port;
352 	} else {
353 		dst_cid = info->remote_cid;
354 		dst_port = info->remote_port;
355 	}
356 
357 	vvs = vsk->trans;
358 
359 	/* virtio_transport_get_credit might return less than pkt_len credit */
360 	pkt_len = virtio_transport_get_credit(vvs, pkt_len);
361 
362 	/* Do not send zero length OP_RW pkt */
363 	if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW)
364 		return pkt_len;
365 
366 	if (info->msg) {
367 		/* If zerocopy is not enabled by 'setsockopt()', we behave as
368 		 * there is no MSG_ZEROCOPY flag set.
369 		 */
370 		if (!sock_flag(sk_vsock(vsk), SOCK_ZEROCOPY))
371 			info->msg->msg_flags &= ~MSG_ZEROCOPY;
372 
373 		if (info->msg->msg_flags & MSG_ZEROCOPY)
374 			can_zcopy = virtio_transport_can_zcopy(t_ops, info, pkt_len);
375 
376 		if (can_zcopy)
377 			max_skb_len = min_t(u32, VIRTIO_VSOCK_MAX_PKT_BUF_SIZE,
378 					    (MAX_SKB_FRAGS * PAGE_SIZE));
379 	}
380 
381 	rest_len = pkt_len;
382 
383 	do {
384 		struct sk_buff *skb;
385 		size_t skb_len;
386 
387 		skb_len = min(max_skb_len, rest_len);
388 
389 		skb = virtio_transport_alloc_skb(info, skb_len, can_zcopy,
390 						 src_cid, src_port,
391 						 dst_cid, dst_port);
392 		if (!skb) {
393 			ret = -ENOMEM;
394 			break;
395 		}
396 
397 		/* We process buffer part by part, allocating skb on
398 		 * each iteration. If this is last skb for this buffer
399 		 * and MSG_ZEROCOPY mode is in use - we must allocate
400 		 * completion for the current syscall.
401 		 */
402 		if (info->msg && info->msg->msg_flags & MSG_ZEROCOPY &&
403 		    skb_len == rest_len && info->op == VIRTIO_VSOCK_OP_RW) {
404 			if (virtio_transport_init_zcopy_skb(vsk, skb,
405 							    info->msg,
406 							    can_zcopy)) {
407 				kfree_skb(skb);
408 				ret = -ENOMEM;
409 				break;
410 			}
411 		}
412 
413 		virtio_transport_inc_tx_pkt(vvs, skb);
414 
415 		ret = t_ops->send_pkt(skb, info->net);
416 		if (ret < 0)
417 			break;
418 
419 		/* Both virtio and vhost 'send_pkt()' returns 'skb_len',
420 		 * but for reliability use 'ret' instead of 'skb_len'.
421 		 * Also if partial send happens (e.g. 'ret' != 'skb_len')
422 		 * somehow, we break this loop, but account such returned
423 		 * value in 'virtio_transport_put_credit()'.
424 		 */
425 		rest_len -= ret;
426 
427 		if (WARN_ONCE(ret != skb_len,
428 			      "'send_pkt()' returns %i, but %zu expected\n",
429 			      ret, skb_len))
430 			break;
431 	} while (rest_len);
432 
433 	virtio_transport_put_credit(vvs, rest_len);
434 
435 	/* Return number of bytes, if any data has been sent. */
436 	if (rest_len != pkt_len)
437 		ret = pkt_len - rest_len;
438 
439 	return ret;
440 }
441 
442 static bool virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs,
443 					u32 len)
444 {
445 	if (vvs->buf_used + len > vvs->buf_alloc)
446 		return false;
447 
448 	vvs->rx_bytes += len;
449 	vvs->buf_used += len;
450 	return true;
451 }
452 
453 static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs,
454 					u32 bytes_read, u32 bytes_dequeued)
455 {
456 	vvs->rx_bytes -= bytes_read;
457 	vvs->buf_used -= bytes_dequeued;
458 	vvs->fwd_cnt += bytes_dequeued;
459 }
460 
461 void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct sk_buff *skb)
462 {
463 	struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
464 
465 	spin_lock_bh(&vvs->rx_lock);
466 	vvs->last_fwd_cnt = vvs->fwd_cnt;
467 	hdr->fwd_cnt = cpu_to_le32(vvs->fwd_cnt);
468 	hdr->buf_alloc = cpu_to_le32(vvs->buf_alloc);
469 	spin_unlock_bh(&vvs->rx_lock);
470 }
471 EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt);
472 
473 void virtio_transport_consume_skb_sent(struct sk_buff *skb, bool consume)
474 {
475 	struct sock *s = skb->sk;
476 
477 	if (s && skb->len) {
478 		struct vsock_sock *vs = vsock_sk(s);
479 		struct virtio_vsock_sock *vvs;
480 
481 		vvs = vs->trans;
482 
483 		spin_lock_bh(&vvs->tx_lock);
484 		vvs->bytes_unsent -= skb->len;
485 		spin_unlock_bh(&vvs->tx_lock);
486 	}
487 
488 	if (consume)
489 		consume_skb(skb);
490 }
491 EXPORT_SYMBOL_GPL(virtio_transport_consume_skb_sent);
492 
493 u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit)
494 {
495 	u32 ret;
496 
497 	if (!credit)
498 		return 0;
499 
500 	spin_lock_bh(&vvs->tx_lock);
501 	ret = min_t(u32, credit, virtio_transport_has_space(vvs));
502 	vvs->tx_cnt += ret;
503 	vvs->bytes_unsent += ret;
504 	spin_unlock_bh(&vvs->tx_lock);
505 
506 	return ret;
507 }
508 EXPORT_SYMBOL_GPL(virtio_transport_get_credit);
509 
510 void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit)
511 {
512 	if (!credit)
513 		return;
514 
515 	spin_lock_bh(&vvs->tx_lock);
516 	vvs->tx_cnt -= credit;
517 	vvs->bytes_unsent -= credit;
518 	spin_unlock_bh(&vvs->tx_lock);
519 }
520 EXPORT_SYMBOL_GPL(virtio_transport_put_credit);
521 
522 static int virtio_transport_send_credit_update(struct vsock_sock *vsk)
523 {
524 	struct virtio_vsock_pkt_info info = {
525 		.op = VIRTIO_VSOCK_OP_CREDIT_UPDATE,
526 		.vsk = vsk,
527 		.net = sock_net(sk_vsock(vsk)),
528 	};
529 
530 	return virtio_transport_send_pkt_info(vsk, &info);
531 }
532 
533 static ssize_t
534 virtio_transport_stream_do_peek(struct vsock_sock *vsk,
535 				struct msghdr *msg,
536 				size_t len)
537 {
538 	struct virtio_vsock_sock *vvs = vsk->trans;
539 	struct sk_buff *skb;
540 	size_t total = 0;
541 	int err;
542 
543 	spin_lock_bh(&vvs->rx_lock);
544 
545 	skb_queue_walk(&vvs->rx_queue, skb) {
546 		size_t bytes;
547 
548 		bytes = min_t(size_t, len - total,
549 			      skb->len - VIRTIO_VSOCK_SKB_CB(skb)->offset);
550 
551 		spin_unlock_bh(&vvs->rx_lock);
552 
553 		/* sk_lock is held by caller so no one else can dequeue.
554 		 * Unlock rx_lock since skb_copy_datagram_iter() may sleep.
555 		 */
556 		err = skb_copy_datagram_iter(skb, VIRTIO_VSOCK_SKB_CB(skb)->offset,
557 					     &msg->msg_iter, bytes);
558 		if (err)
559 			goto out;
560 
561 		total += bytes;
562 
563 		spin_lock_bh(&vvs->rx_lock);
564 
565 		if (total == len)
566 			break;
567 	}
568 
569 	spin_unlock_bh(&vvs->rx_lock);
570 
571 	return total;
572 
573 out:
574 	if (total)
575 		err = total;
576 	return err;
577 }
578 
579 static ssize_t
580 virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
581 				   struct msghdr *msg,
582 				   size_t len)
583 {
584 	struct virtio_vsock_sock *vvs = vsk->trans;
585 	struct sk_buff *skb;
586 	u32 fwd_cnt_delta;
587 	bool low_rx_bytes;
588 	int err = -EFAULT;
589 	size_t total = 0;
590 	u32 free_space;
591 
592 	spin_lock_bh(&vvs->rx_lock);
593 
594 	if (WARN_ONCE(skb_queue_empty(&vvs->rx_queue) && vvs->rx_bytes,
595 		      "rx_queue is empty, but rx_bytes is non-zero\n")) {
596 		spin_unlock_bh(&vvs->rx_lock);
597 		return err;
598 	}
599 
600 	while (total < len && !skb_queue_empty(&vvs->rx_queue)) {
601 		size_t bytes, dequeued = 0;
602 
603 		skb = skb_peek(&vvs->rx_queue);
604 
605 		bytes = min_t(size_t, len - total,
606 			      skb->len - VIRTIO_VSOCK_SKB_CB(skb)->offset);
607 
608 		/* sk_lock is held by caller so no one else can dequeue.
609 		 * Unlock rx_lock since skb_copy_datagram_iter() may sleep.
610 		 */
611 		spin_unlock_bh(&vvs->rx_lock);
612 
613 		err = skb_copy_datagram_iter(skb,
614 					     VIRTIO_VSOCK_SKB_CB(skb)->offset,
615 					     &msg->msg_iter, bytes);
616 		if (err)
617 			goto out;
618 
619 		spin_lock_bh(&vvs->rx_lock);
620 
621 		total += bytes;
622 
623 		VIRTIO_VSOCK_SKB_CB(skb)->offset += bytes;
624 
625 		if (skb->len == VIRTIO_VSOCK_SKB_CB(skb)->offset) {
626 			dequeued = le32_to_cpu(virtio_vsock_hdr(skb)->len);
627 			__skb_unlink(skb, &vvs->rx_queue);
628 			consume_skb(skb);
629 		}
630 
631 		virtio_transport_dec_rx_pkt(vvs, bytes, dequeued);
632 	}
633 
634 	fwd_cnt_delta = vvs->fwd_cnt - vvs->last_fwd_cnt;
635 	free_space = vvs->buf_alloc - fwd_cnt_delta;
636 	low_rx_bytes = (vvs->rx_bytes <
637 			sock_rcvlowat(sk_vsock(vsk), 0, INT_MAX));
638 
639 	spin_unlock_bh(&vvs->rx_lock);
640 
641 	/* To reduce the number of credit update messages,
642 	 * don't update credits as long as lots of space is available.
643 	 * Note: the limit chosen here is arbitrary. Setting the limit
644 	 * too high causes extra messages. Too low causes transmitter
645 	 * stalls. As stalls are in theory more expensive than extra
646 	 * messages, we set the limit to a high value. TODO: experiment
647 	 * with different values. Also send credit update message when
648 	 * number of bytes in rx queue is not enough to wake up reader.
649 	 */
650 	if (fwd_cnt_delta &&
651 	    (free_space < VIRTIO_VSOCK_MAX_PKT_BUF_SIZE || low_rx_bytes))
652 		virtio_transport_send_credit_update(vsk);
653 
654 	return total;
655 
656 out:
657 	if (total)
658 		err = total;
659 	return err;
660 }
661 
662 static ssize_t
663 virtio_transport_seqpacket_do_peek(struct vsock_sock *vsk,
664 				   struct msghdr *msg)
665 {
666 	struct virtio_vsock_sock *vvs = vsk->trans;
667 	struct sk_buff *skb;
668 	size_t total, len;
669 
670 	spin_lock_bh(&vvs->rx_lock);
671 
672 	if (!vvs->msg_count) {
673 		spin_unlock_bh(&vvs->rx_lock);
674 		return 0;
675 	}
676 
677 	total = 0;
678 	len = msg_data_left(msg);
679 
680 	skb_queue_walk(&vvs->rx_queue, skb) {
681 		struct virtio_vsock_hdr *hdr;
682 
683 		if (total < len) {
684 			size_t bytes;
685 			int err;
686 
687 			bytes = len - total;
688 			if (bytes > skb->len)
689 				bytes = skb->len;
690 
691 			spin_unlock_bh(&vvs->rx_lock);
692 
693 			/* sk_lock is held by caller so no one else can dequeue.
694 			 * Unlock rx_lock since skb_copy_datagram_iter() may sleep.
695 			 */
696 			err = skb_copy_datagram_iter(skb, VIRTIO_VSOCK_SKB_CB(skb)->offset,
697 						     &msg->msg_iter, bytes);
698 			if (err)
699 				return err;
700 
701 			spin_lock_bh(&vvs->rx_lock);
702 		}
703 
704 		total += skb->len;
705 		hdr = virtio_vsock_hdr(skb);
706 
707 		if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOM) {
708 			if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOR)
709 				msg->msg_flags |= MSG_EOR;
710 
711 			break;
712 		}
713 	}
714 
715 	spin_unlock_bh(&vvs->rx_lock);
716 
717 	return total;
718 }
719 
720 static int virtio_transport_seqpacket_do_dequeue(struct vsock_sock *vsk,
721 						 struct msghdr *msg,
722 						 int flags)
723 {
724 	struct virtio_vsock_sock *vvs = vsk->trans;
725 	int dequeued_len = 0;
726 	size_t user_buf_len = msg_data_left(msg);
727 	bool msg_ready = false;
728 	struct sk_buff *skb;
729 
730 	spin_lock_bh(&vvs->rx_lock);
731 
732 	if (vvs->msg_count == 0) {
733 		spin_unlock_bh(&vvs->rx_lock);
734 		return 0;
735 	}
736 
737 	while (!msg_ready) {
738 		struct virtio_vsock_hdr *hdr;
739 		size_t pkt_len;
740 
741 		skb = __skb_dequeue(&vvs->rx_queue);
742 		if (!skb)
743 			break;
744 		hdr = virtio_vsock_hdr(skb);
745 		pkt_len = (size_t)le32_to_cpu(hdr->len);
746 
747 		if (dequeued_len >= 0) {
748 			size_t bytes_to_copy;
749 
750 			bytes_to_copy = min(user_buf_len, pkt_len);
751 
752 			if (bytes_to_copy) {
753 				int err;
754 
755 				/* sk_lock is held by caller so no one else can dequeue.
756 				 * Unlock rx_lock since skb_copy_datagram_iter() may sleep.
757 				 */
758 				spin_unlock_bh(&vvs->rx_lock);
759 
760 				err = skb_copy_datagram_iter(skb, 0,
761 							     &msg->msg_iter,
762 							     bytes_to_copy);
763 				if (err) {
764 					/* Copy of message failed. Rest of
765 					 * fragments will be freed without copy.
766 					 */
767 					dequeued_len = err;
768 				} else {
769 					user_buf_len -= bytes_to_copy;
770 				}
771 
772 				spin_lock_bh(&vvs->rx_lock);
773 			}
774 
775 			if (dequeued_len >= 0)
776 				dequeued_len += pkt_len;
777 		}
778 
779 		if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOM) {
780 			msg_ready = true;
781 			vvs->msg_count--;
782 
783 			if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOR)
784 				msg->msg_flags |= MSG_EOR;
785 		}
786 
787 		virtio_transport_dec_rx_pkt(vvs, pkt_len, pkt_len);
788 		kfree_skb(skb);
789 	}
790 
791 	spin_unlock_bh(&vvs->rx_lock);
792 
793 	virtio_transport_send_credit_update(vsk);
794 
795 	return dequeued_len;
796 }
797 
798 ssize_t
799 virtio_transport_stream_dequeue(struct vsock_sock *vsk,
800 				struct msghdr *msg,
801 				size_t len, int flags)
802 {
803 	if (flags & MSG_PEEK)
804 		return virtio_transport_stream_do_peek(vsk, msg, len);
805 	else
806 		return virtio_transport_stream_do_dequeue(vsk, msg, len);
807 }
808 EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue);
809 
810 ssize_t
811 virtio_transport_seqpacket_dequeue(struct vsock_sock *vsk,
812 				   struct msghdr *msg,
813 				   int flags)
814 {
815 	if (flags & MSG_PEEK)
816 		return virtio_transport_seqpacket_do_peek(vsk, msg);
817 	else
818 		return virtio_transport_seqpacket_do_dequeue(vsk, msg, flags);
819 }
820 EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_dequeue);
821 
822 static u32 virtio_transport_tx_buf_size(struct virtio_vsock_sock *vvs)
823 {
824 	/* The peer advertises its receive buffer via peer_buf_alloc, but we
825 	 * cap it to our local buf_alloc so a remote peer cannot force us to
826 	 * queue more data than our own buffer configuration allows.
827 	 */
828 	return min(vvs->peer_buf_alloc, vvs->buf_alloc);
829 }
830 
831 int
832 virtio_transport_seqpacket_enqueue(struct vsock_sock *vsk,
833 				   struct msghdr *msg,
834 				   size_t len)
835 {
836 	struct virtio_vsock_sock *vvs = vsk->trans;
837 
838 	spin_lock_bh(&vvs->tx_lock);
839 
840 	if (len > virtio_transport_tx_buf_size(vvs)) {
841 		spin_unlock_bh(&vvs->tx_lock);
842 		return -EMSGSIZE;
843 	}
844 
845 	spin_unlock_bh(&vvs->tx_lock);
846 
847 	return virtio_transport_stream_enqueue(vsk, msg, len);
848 }
849 EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_enqueue);
850 
851 int
852 virtio_transport_dgram_dequeue(struct vsock_sock *vsk,
853 			       struct msghdr *msg,
854 			       size_t len, int flags)
855 {
856 	return -EOPNOTSUPP;
857 }
858 EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue);
859 
860 s64 virtio_transport_stream_has_data(struct vsock_sock *vsk)
861 {
862 	struct virtio_vsock_sock *vvs = vsk->trans;
863 	s64 bytes;
864 
865 	spin_lock_bh(&vvs->rx_lock);
866 	bytes = vvs->rx_bytes;
867 	spin_unlock_bh(&vvs->rx_lock);
868 
869 	return bytes;
870 }
871 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data);
872 
873 u32 virtio_transport_seqpacket_has_data(struct vsock_sock *vsk)
874 {
875 	struct virtio_vsock_sock *vvs = vsk->trans;
876 	u32 msg_count;
877 
878 	spin_lock_bh(&vvs->rx_lock);
879 	msg_count = vvs->msg_count;
880 	spin_unlock_bh(&vvs->rx_lock);
881 
882 	return msg_count;
883 }
884 EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_has_data);
885 
886 static s64 virtio_transport_has_space(struct virtio_vsock_sock *vvs)
887 {
888 	s64 bytes;
889 
890 	/* Use s64 arithmetic so if the peer shrinks peer_buf_alloc while
891 	 * we have bytes in flight (tx_cnt - peer_fwd_cnt), the subtraction
892 	 * does not underflow.
893 	 */
894 	bytes = (s64)virtio_transport_tx_buf_size(vvs) -
895 		(vvs->tx_cnt - vvs->peer_fwd_cnt);
896 	if (bytes < 0)
897 		bytes = 0;
898 
899 	return bytes;
900 }
901 
902 s64 virtio_transport_stream_has_space(struct vsock_sock *vsk)
903 {
904 	struct virtio_vsock_sock *vvs = vsk->trans;
905 	s64 bytes;
906 
907 	spin_lock_bh(&vvs->tx_lock);
908 	bytes = virtio_transport_has_space(vvs);
909 	spin_unlock_bh(&vvs->tx_lock);
910 
911 	return bytes;
912 }
913 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space);
914 
915 int virtio_transport_do_socket_init(struct vsock_sock *vsk,
916 				    struct vsock_sock *psk)
917 {
918 	struct virtio_vsock_sock *vvs;
919 
920 	vvs = kzalloc_obj(*vvs);
921 	if (!vvs)
922 		return -ENOMEM;
923 
924 	vsk->trans = vvs;
925 	vvs->vsk = vsk;
926 	if (psk && psk->trans) {
927 		struct virtio_vsock_sock *ptrans = psk->trans;
928 
929 		vvs->peer_buf_alloc = ptrans->peer_buf_alloc;
930 	}
931 
932 	if (vsk->buffer_size > VIRTIO_VSOCK_MAX_BUF_SIZE)
933 		vsk->buffer_size = VIRTIO_VSOCK_MAX_BUF_SIZE;
934 
935 	vvs->buf_alloc = vsk->buffer_size;
936 
937 	spin_lock_init(&vvs->rx_lock);
938 	spin_lock_init(&vvs->tx_lock);
939 	skb_queue_head_init(&vvs->rx_queue);
940 
941 	return 0;
942 }
943 EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init);
944 
945 /* sk_lock held by the caller */
946 void virtio_transport_notify_buffer_size(struct vsock_sock *vsk, u64 *val)
947 {
948 	struct virtio_vsock_sock *vvs = vsk->trans;
949 
950 	if (*val > VIRTIO_VSOCK_MAX_BUF_SIZE)
951 		*val = VIRTIO_VSOCK_MAX_BUF_SIZE;
952 
953 	vvs->buf_alloc = *val;
954 
955 	virtio_transport_send_credit_update(vsk);
956 }
957 EXPORT_SYMBOL_GPL(virtio_transport_notify_buffer_size);
958 
959 int
960 virtio_transport_notify_poll_in(struct vsock_sock *vsk,
961 				size_t target,
962 				bool *data_ready_now)
963 {
964 	*data_ready_now = vsock_stream_has_data(vsk) >= target;
965 
966 	return 0;
967 }
968 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in);
969 
970 int
971 virtio_transport_notify_poll_out(struct vsock_sock *vsk,
972 				 size_t target,
973 				 bool *space_avail_now)
974 {
975 	s64 free_space;
976 
977 	free_space = vsock_stream_has_space(vsk);
978 	if (free_space > 0)
979 		*space_avail_now = true;
980 	else if (free_space == 0)
981 		*space_avail_now = false;
982 
983 	return 0;
984 }
985 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out);
986 
987 int virtio_transport_notify_recv_init(struct vsock_sock *vsk,
988 	size_t target, struct vsock_transport_recv_notify_data *data)
989 {
990 	return 0;
991 }
992 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init);
993 
994 int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk,
995 	size_t target, struct vsock_transport_recv_notify_data *data)
996 {
997 	return 0;
998 }
999 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block);
1000 
1001 int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk,
1002 	size_t target, struct vsock_transport_recv_notify_data *data)
1003 {
1004 	return 0;
1005 }
1006 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue);
1007 
1008 int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk,
1009 	size_t target, ssize_t copied, bool data_read,
1010 	struct vsock_transport_recv_notify_data *data)
1011 {
1012 	return 0;
1013 }
1014 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue);
1015 
1016 int virtio_transport_notify_send_init(struct vsock_sock *vsk,
1017 	struct vsock_transport_send_notify_data *data)
1018 {
1019 	return 0;
1020 }
1021 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init);
1022 
1023 int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk,
1024 	struct vsock_transport_send_notify_data *data)
1025 {
1026 	return 0;
1027 }
1028 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block);
1029 
1030 int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk,
1031 	struct vsock_transport_send_notify_data *data)
1032 {
1033 	return 0;
1034 }
1035 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue);
1036 
1037 int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk,
1038 	ssize_t written, struct vsock_transport_send_notify_data *data)
1039 {
1040 	return 0;
1041 }
1042 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue);
1043 
1044 u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk)
1045 {
1046 	return vsk->buffer_size;
1047 }
1048 EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat);
1049 
1050 bool virtio_transport_stream_is_active(struct vsock_sock *vsk)
1051 {
1052 	return true;
1053 }
1054 EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active);
1055 
1056 int virtio_transport_dgram_bind(struct vsock_sock *vsk,
1057 				struct sockaddr_vm *addr)
1058 {
1059 	return -EOPNOTSUPP;
1060 }
1061 EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind);
1062 
1063 bool virtio_transport_dgram_allow(struct vsock_sock *vsk, u32 cid, u32 port)
1064 {
1065 	return false;
1066 }
1067 EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow);
1068 
1069 int virtio_transport_connect(struct vsock_sock *vsk)
1070 {
1071 	struct virtio_vsock_pkt_info info = {
1072 		.op = VIRTIO_VSOCK_OP_REQUEST,
1073 		.vsk = vsk,
1074 		.net = sock_net(sk_vsock(vsk)),
1075 	};
1076 
1077 	return virtio_transport_send_pkt_info(vsk, &info);
1078 }
1079 EXPORT_SYMBOL_GPL(virtio_transport_connect);
1080 
1081 int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
1082 {
1083 	struct virtio_vsock_pkt_info info = {
1084 		.op = VIRTIO_VSOCK_OP_SHUTDOWN,
1085 		.flags = (mode & RCV_SHUTDOWN ?
1086 			  VIRTIO_VSOCK_SHUTDOWN_RCV : 0) |
1087 			 (mode & SEND_SHUTDOWN ?
1088 			  VIRTIO_VSOCK_SHUTDOWN_SEND : 0),
1089 		.vsk = vsk,
1090 		.net = sock_net(sk_vsock(vsk)),
1091 	};
1092 
1093 	return virtio_transport_send_pkt_info(vsk, &info);
1094 }
1095 EXPORT_SYMBOL_GPL(virtio_transport_shutdown);
1096 
1097 int
1098 virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
1099 			       struct sockaddr_vm *remote_addr,
1100 			       struct msghdr *msg,
1101 			       size_t dgram_len)
1102 {
1103 	return -EOPNOTSUPP;
1104 }
1105 EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue);
1106 
1107 ssize_t
1108 virtio_transport_stream_enqueue(struct vsock_sock *vsk,
1109 				struct msghdr *msg,
1110 				size_t len)
1111 {
1112 	struct virtio_vsock_pkt_info info = {
1113 		.op = VIRTIO_VSOCK_OP_RW,
1114 		.msg = msg,
1115 		.pkt_len = len,
1116 		.vsk = vsk,
1117 		.net = sock_net(sk_vsock(vsk)),
1118 	};
1119 
1120 	return virtio_transport_send_pkt_info(vsk, &info);
1121 }
1122 EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue);
1123 
1124 void virtio_transport_destruct(struct vsock_sock *vsk)
1125 {
1126 	struct virtio_vsock_sock *vvs = vsk->trans;
1127 
1128 	virtio_transport_cancel_close_work(vsk, true);
1129 
1130 	kfree(vvs);
1131 	vsk->trans = NULL;
1132 }
1133 EXPORT_SYMBOL_GPL(virtio_transport_destruct);
1134 
1135 ssize_t virtio_transport_unsent_bytes(struct vsock_sock *vsk)
1136 {
1137 	struct virtio_vsock_sock *vvs = vsk->trans;
1138 	size_t ret;
1139 
1140 	spin_lock_bh(&vvs->tx_lock);
1141 	ret = vvs->bytes_unsent;
1142 	spin_unlock_bh(&vvs->tx_lock);
1143 
1144 	return ret;
1145 }
1146 EXPORT_SYMBOL_GPL(virtio_transport_unsent_bytes);
1147 
1148 static int virtio_transport_reset(struct vsock_sock *vsk,
1149 				  struct sk_buff *skb)
1150 {
1151 	struct virtio_vsock_pkt_info info = {
1152 		.op = VIRTIO_VSOCK_OP_RST,
1153 		.reply = !!skb,
1154 		.vsk = vsk,
1155 		.net = sock_net(sk_vsock(vsk)),
1156 	};
1157 
1158 	/* Send RST only if the original pkt is not a RST pkt */
1159 	if (skb && le16_to_cpu(virtio_vsock_hdr(skb)->op) == VIRTIO_VSOCK_OP_RST)
1160 		return 0;
1161 
1162 	return virtio_transport_send_pkt_info(vsk, &info);
1163 }
1164 
1165 /* Normally packets are associated with a socket.  There may be no socket if an
1166  * attempt was made to connect to a socket that does not exist.
1167  *
1168  * net refers to the namespace of whoever sent the invalid message. For
1169  * loopback, this is the namespace of the socket. For vhost, this is the
1170  * namespace of the VM (i.e., vhost_vsock).
1171  */
1172 static int virtio_transport_reset_no_sock(const struct virtio_transport *t,
1173 					  struct sk_buff *skb, struct net *net)
1174 {
1175 	struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
1176 	struct virtio_vsock_pkt_info info = {
1177 		.op = VIRTIO_VSOCK_OP_RST,
1178 		.type = le16_to_cpu(hdr->type),
1179 		.reply = true,
1180 
1181 		/* Set sk owner to socket we are replying to (may be NULL for
1182 		 * non-loopback). This keeps a reference to the sock and
1183 		 * sock_net(sk) until the reply skb is freed.
1184 		 */
1185 		.vsk = vsock_sk(skb->sk),
1186 
1187 		/* net is not defined here because we pass it directly to
1188 		 * t->send_pkt(), instead of relying on
1189 		 * virtio_transport_send_pkt_info() to pass it. It is not needed
1190 		 * by virtio_transport_alloc_skb().
1191 		 */
1192 	};
1193 	struct sk_buff *reply;
1194 
1195 	/* Send RST only if the original pkt is not a RST pkt */
1196 	if (le16_to_cpu(hdr->op) == VIRTIO_VSOCK_OP_RST)
1197 		return 0;
1198 
1199 	if (!t)
1200 		return -ENOTCONN;
1201 
1202 	reply = virtio_transport_alloc_skb(&info, 0, false,
1203 					   le64_to_cpu(hdr->dst_cid),
1204 					   le32_to_cpu(hdr->dst_port),
1205 					   le64_to_cpu(hdr->src_cid),
1206 					   le32_to_cpu(hdr->src_port));
1207 	if (!reply)
1208 		return -ENOMEM;
1209 
1210 	return t->send_pkt(reply, net);
1211 }
1212 
1213 /* This function should be called with sk_lock held and SOCK_DONE set */
1214 static void virtio_transport_remove_sock(struct vsock_sock *vsk)
1215 {
1216 	struct virtio_vsock_sock *vvs = vsk->trans;
1217 
1218 	/* We don't need to take rx_lock, as the socket is closing and we are
1219 	 * removing it.
1220 	 */
1221 	__skb_queue_purge(&vvs->rx_queue);
1222 	vsock_remove_sock(vsk);
1223 }
1224 
1225 static void virtio_transport_cancel_close_work(struct vsock_sock *vsk,
1226 					       bool cancel_timeout)
1227 {
1228 	struct sock *sk = sk_vsock(vsk);
1229 
1230 	if (vsk->close_work_scheduled &&
1231 	    (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) {
1232 		vsk->close_work_scheduled = false;
1233 
1234 		virtio_transport_remove_sock(vsk);
1235 
1236 		/* Release refcnt obtained when we scheduled the timeout */
1237 		sock_put(sk);
1238 	}
1239 }
1240 
1241 static void virtio_transport_do_close(struct vsock_sock *vsk,
1242 				      bool cancel_timeout)
1243 {
1244 	struct sock *sk = sk_vsock(vsk);
1245 
1246 	sock_set_flag(sk, SOCK_DONE);
1247 	vsk->peer_shutdown = SHUTDOWN_MASK;
1248 	if (vsock_stream_has_data(vsk) <= 0)
1249 		sk->sk_state = TCP_CLOSING;
1250 	sk->sk_state_change(sk);
1251 
1252 	virtio_transport_cancel_close_work(vsk, cancel_timeout);
1253 }
1254 
1255 static void virtio_transport_close_timeout(struct work_struct *work)
1256 {
1257 	struct vsock_sock *vsk =
1258 		container_of(work, struct vsock_sock, close_work.work);
1259 	struct sock *sk = sk_vsock(vsk);
1260 
1261 	sock_hold(sk);
1262 	lock_sock(sk);
1263 
1264 	if (!sock_flag(sk, SOCK_DONE)) {
1265 		(void)virtio_transport_reset(vsk, NULL);
1266 
1267 		virtio_transport_do_close(vsk, false);
1268 	}
1269 
1270 	vsk->close_work_scheduled = false;
1271 
1272 	release_sock(sk);
1273 	sock_put(sk);
1274 }
1275 
1276 /* User context, vsk->sk is locked */
1277 static bool virtio_transport_close(struct vsock_sock *vsk)
1278 {
1279 	struct sock *sk = &vsk->sk;
1280 
1281 	if (!(sk->sk_state == TCP_ESTABLISHED ||
1282 	      sk->sk_state == TCP_CLOSING))
1283 		return true;
1284 
1285 	/* Already received SHUTDOWN from peer, reply with RST */
1286 	if ((vsk->peer_shutdown & SHUTDOWN_MASK) == SHUTDOWN_MASK) {
1287 		(void)virtio_transport_reset(vsk, NULL);
1288 		return true;
1289 	}
1290 
1291 	if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK)
1292 		(void)virtio_transport_shutdown(vsk, SHUTDOWN_MASK);
1293 
1294 	if (!(current->flags & PF_EXITING))
1295 		vsock_linger(sk);
1296 
1297 	if (sock_flag(sk, SOCK_DONE)) {
1298 		return true;
1299 	}
1300 
1301 	sock_hold(sk);
1302 	INIT_DELAYED_WORK(&vsk->close_work,
1303 			  virtio_transport_close_timeout);
1304 	vsk->close_work_scheduled = true;
1305 	schedule_delayed_work(&vsk->close_work, VSOCK_CLOSE_TIMEOUT);
1306 	return false;
1307 }
1308 
1309 void virtio_transport_release(struct vsock_sock *vsk)
1310 {
1311 	struct sock *sk = &vsk->sk;
1312 	bool remove_sock = true;
1313 
1314 	if (sk->sk_type == SOCK_STREAM || sk->sk_type == SOCK_SEQPACKET)
1315 		remove_sock = virtio_transport_close(vsk);
1316 
1317 	if (remove_sock) {
1318 		sock_set_flag(sk, SOCK_DONE);
1319 		virtio_transport_remove_sock(vsk);
1320 	}
1321 }
1322 EXPORT_SYMBOL_GPL(virtio_transport_release);
1323 
1324 static int
1325 virtio_transport_recv_connecting(struct sock *sk,
1326 				 struct sk_buff *skb)
1327 {
1328 	struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
1329 	struct vsock_sock *vsk = vsock_sk(sk);
1330 	int skerr;
1331 	int err;
1332 
1333 	switch (le16_to_cpu(hdr->op)) {
1334 	case VIRTIO_VSOCK_OP_RESPONSE:
1335 		sk->sk_state = TCP_ESTABLISHED;
1336 		sk->sk_socket->state = SS_CONNECTED;
1337 		vsock_insert_connected(vsk);
1338 		sk->sk_state_change(sk);
1339 		break;
1340 	case VIRTIO_VSOCK_OP_INVALID:
1341 		break;
1342 	case VIRTIO_VSOCK_OP_RST:
1343 		skerr = ECONNRESET;
1344 		err = 0;
1345 		goto destroy;
1346 	default:
1347 		skerr = EPROTO;
1348 		err = -EINVAL;
1349 		goto destroy;
1350 	}
1351 	return 0;
1352 
1353 destroy:
1354 	virtio_transport_reset(vsk, skb);
1355 	sk->sk_state = TCP_CLOSE;
1356 	sk->sk_err = skerr;
1357 	sk_error_report(sk);
1358 	return err;
1359 }
1360 
1361 static void
1362 virtio_transport_recv_enqueue(struct vsock_sock *vsk,
1363 			      struct sk_buff *skb)
1364 {
1365 	struct virtio_vsock_sock *vvs = vsk->trans;
1366 	bool can_enqueue, free_pkt = false;
1367 	struct virtio_vsock_hdr *hdr;
1368 	u32 len;
1369 
1370 	hdr = virtio_vsock_hdr(skb);
1371 	len = le32_to_cpu(hdr->len);
1372 
1373 	spin_lock_bh(&vvs->rx_lock);
1374 
1375 	can_enqueue = virtio_transport_inc_rx_pkt(vvs, len);
1376 	if (!can_enqueue) {
1377 		free_pkt = true;
1378 		goto out;
1379 	}
1380 
1381 	if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOM)
1382 		vvs->msg_count++;
1383 
1384 	/* Try to copy small packets into the buffer of last packet queued,
1385 	 * to avoid wasting memory queueing the entire buffer with a small
1386 	 * payload. Skip non-linear (e.g. zerocopy) skbs; these carry payload
1387 	 * in skb_shinfo.
1388 	 */
1389 	if (len <= GOOD_COPY_LEN && !skb_queue_empty(&vvs->rx_queue) &&
1390 	    !skb_is_nonlinear(skb)) {
1391 		struct virtio_vsock_hdr *last_hdr;
1392 		struct sk_buff *last_skb;
1393 
1394 		last_skb = skb_peek_tail(&vvs->rx_queue);
1395 		last_hdr = virtio_vsock_hdr(last_skb);
1396 
1397 		/* If there is space in the last packet queued, we copy the
1398 		 * new packet in its buffer. We avoid this if the last packet
1399 		 * queued has VIRTIO_VSOCK_SEQ_EOM set, because this is
1400 		 * delimiter of SEQPACKET message, so 'pkt' is the first packet
1401 		 * of a new message.
1402 		 */
1403 		if (skb->len < skb_tailroom(last_skb) &&
1404 		    !(le32_to_cpu(last_hdr->flags) & VIRTIO_VSOCK_SEQ_EOM)) {
1405 			memcpy(skb_put(last_skb, skb->len), skb->data, skb->len);
1406 			free_pkt = true;
1407 			last_hdr->flags |= hdr->flags;
1408 			le32_add_cpu(&last_hdr->len, len);
1409 			goto out;
1410 		}
1411 	}
1412 
1413 	__skb_queue_tail(&vvs->rx_queue, skb);
1414 
1415 out:
1416 	spin_unlock_bh(&vvs->rx_lock);
1417 	if (free_pkt)
1418 		kfree_skb(skb);
1419 }
1420 
1421 static int
1422 virtio_transport_recv_connected(struct sock *sk,
1423 				struct sk_buff *skb)
1424 {
1425 	struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
1426 	struct vsock_sock *vsk = vsock_sk(sk);
1427 	int err = 0;
1428 
1429 	switch (le16_to_cpu(hdr->op)) {
1430 	case VIRTIO_VSOCK_OP_RW:
1431 		virtio_transport_recv_enqueue(vsk, skb);
1432 		vsock_data_ready(sk);
1433 		return err;
1434 	case VIRTIO_VSOCK_OP_CREDIT_REQUEST:
1435 		virtio_transport_send_credit_update(vsk);
1436 		break;
1437 	case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
1438 		sk->sk_write_space(sk);
1439 		break;
1440 	case VIRTIO_VSOCK_OP_SHUTDOWN:
1441 		if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SHUTDOWN_RCV)
1442 			vsk->peer_shutdown |= RCV_SHUTDOWN;
1443 		if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SHUTDOWN_SEND)
1444 			vsk->peer_shutdown |= SEND_SHUTDOWN;
1445 		if (vsk->peer_shutdown == SHUTDOWN_MASK) {
1446 			if (vsock_stream_has_data(vsk) <= 0 && !sock_flag(sk, SOCK_DONE)) {
1447 				(void)virtio_transport_reset(vsk, NULL);
1448 				virtio_transport_do_close(vsk, true);
1449 			}
1450 			/* Remove this socket anyway because the remote peer sent
1451 			 * the shutdown. This way a new connection will succeed
1452 			 * if the remote peer uses the same source port,
1453 			 * even if the old socket is still unreleased, but now disconnected.
1454 			 */
1455 			vsock_remove_sock(vsk);
1456 		}
1457 		if (le32_to_cpu(virtio_vsock_hdr(skb)->flags))
1458 			sk->sk_state_change(sk);
1459 		break;
1460 	case VIRTIO_VSOCK_OP_RST:
1461 		virtio_transport_do_close(vsk, true);
1462 		break;
1463 	default:
1464 		err = -EINVAL;
1465 		break;
1466 	}
1467 
1468 	kfree_skb(skb);
1469 	return err;
1470 }
1471 
1472 static void
1473 virtio_transport_recv_disconnecting(struct sock *sk,
1474 				    struct sk_buff *skb)
1475 {
1476 	struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
1477 	struct vsock_sock *vsk = vsock_sk(sk);
1478 
1479 	if (le16_to_cpu(hdr->op) == VIRTIO_VSOCK_OP_RST)
1480 		virtio_transport_do_close(vsk, true);
1481 }
1482 
1483 static int
1484 virtio_transport_send_response(struct vsock_sock *vsk,
1485 			       struct sk_buff *skb)
1486 {
1487 	struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
1488 	struct virtio_vsock_pkt_info info = {
1489 		.op = VIRTIO_VSOCK_OP_RESPONSE,
1490 		.remote_cid = le64_to_cpu(hdr->src_cid),
1491 		.remote_port = le32_to_cpu(hdr->src_port),
1492 		.reply = true,
1493 		.vsk = vsk,
1494 		.net = sock_net(sk_vsock(vsk)),
1495 	};
1496 
1497 	return virtio_transport_send_pkt_info(vsk, &info);
1498 }
1499 
1500 static bool virtio_transport_space_update(struct sock *sk,
1501 					  struct sk_buff *skb)
1502 {
1503 	struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
1504 	struct vsock_sock *vsk = vsock_sk(sk);
1505 	struct virtio_vsock_sock *vvs = vsk->trans;
1506 	bool space_available;
1507 
1508 	/* Listener sockets are not associated with any transport, so we are
1509 	 * not able to take the state to see if there is space available in the
1510 	 * remote peer, but since they are only used to receive requests, we
1511 	 * can assume that there is always space available in the other peer.
1512 	 */
1513 	if (!vvs)
1514 		return true;
1515 
1516 	/* buf_alloc and fwd_cnt is always included in the hdr */
1517 	spin_lock_bh(&vvs->tx_lock);
1518 	vvs->peer_buf_alloc = le32_to_cpu(hdr->buf_alloc);
1519 	vvs->peer_fwd_cnt = le32_to_cpu(hdr->fwd_cnt);
1520 	space_available = virtio_transport_has_space(vvs);
1521 	spin_unlock_bh(&vvs->tx_lock);
1522 	return space_available;
1523 }
1524 
1525 /* Handle server socket */
1526 static int
1527 virtio_transport_recv_listen(struct sock *sk, struct sk_buff *skb,
1528 			     struct virtio_transport *t)
1529 {
1530 	struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
1531 	struct vsock_sock *vsk = vsock_sk(sk);
1532 	struct vsock_sock *vchild;
1533 	struct sock *child;
1534 	int ret;
1535 
1536 	if (le16_to_cpu(hdr->op) != VIRTIO_VSOCK_OP_REQUEST) {
1537 		virtio_transport_reset_no_sock(t, skb, sock_net(sk));
1538 		return -EINVAL;
1539 	}
1540 
1541 	if (sk_acceptq_is_full(sk)) {
1542 		virtio_transport_reset_no_sock(t, skb, sock_net(sk));
1543 		return -ENOMEM;
1544 	}
1545 
1546 	/* __vsock_release() might have already flushed accept_queue.
1547 	 * Subsequent enqueues would lead to a memory leak.
1548 	 */
1549 	if (sk->sk_shutdown == SHUTDOWN_MASK) {
1550 		virtio_transport_reset_no_sock(t, skb, sock_net(sk));
1551 		return -ESHUTDOWN;
1552 	}
1553 
1554 	child = vsock_create_connected(sk);
1555 	if (!child) {
1556 		virtio_transport_reset_no_sock(t, skb, sock_net(sk));
1557 		return -ENOMEM;
1558 	}
1559 
1560 	lock_sock_nested(child, SINGLE_DEPTH_NESTING);
1561 
1562 	child->sk_state = TCP_ESTABLISHED;
1563 
1564 	vchild = vsock_sk(child);
1565 	vsock_addr_init(&vchild->local_addr, le64_to_cpu(hdr->dst_cid),
1566 			le32_to_cpu(hdr->dst_port));
1567 	vsock_addr_init(&vchild->remote_addr, le64_to_cpu(hdr->src_cid),
1568 			le32_to_cpu(hdr->src_port));
1569 
1570 	ret = vsock_assign_transport(vchild, vsk);
1571 	/* Transport assigned (looking at remote_addr) must be the same
1572 	 * where we received the request.
1573 	 */
1574 	if (ret || vchild->transport != &t->transport) {
1575 		release_sock(child);
1576 		virtio_transport_reset_no_sock(t, skb, sock_net(sk));
1577 		sock_put(child);
1578 		return ret;
1579 	}
1580 
1581 	sk_acceptq_added(sk);
1582 	if (virtio_transport_space_update(child, skb))
1583 		child->sk_write_space(child);
1584 
1585 	vsock_insert_connected(vchild);
1586 	vsock_enqueue_accept(sk, child);
1587 	virtio_transport_send_response(vchild, skb);
1588 
1589 	release_sock(child);
1590 
1591 	sk->sk_data_ready(sk);
1592 	return 0;
1593 }
1594 
1595 static bool virtio_transport_valid_type(u16 type)
1596 {
1597 	return (type == VIRTIO_VSOCK_TYPE_STREAM) ||
1598 	       (type == VIRTIO_VSOCK_TYPE_SEQPACKET);
1599 }
1600 
1601 /* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex
1602  * lock.
1603  */
1604 void virtio_transport_recv_pkt(struct virtio_transport *t,
1605 			       struct sk_buff *skb, struct net *net)
1606 {
1607 	struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
1608 	struct sockaddr_vm src, dst;
1609 	struct vsock_sock *vsk;
1610 	struct sock *sk;
1611 	bool space_available;
1612 
1613 	vsock_addr_init(&src, le64_to_cpu(hdr->src_cid),
1614 			le32_to_cpu(hdr->src_port));
1615 	vsock_addr_init(&dst, le64_to_cpu(hdr->dst_cid),
1616 			le32_to_cpu(hdr->dst_port));
1617 
1618 	trace_virtio_transport_recv_pkt(src.svm_cid, src.svm_port,
1619 					dst.svm_cid, dst.svm_port,
1620 					le32_to_cpu(hdr->len),
1621 					le16_to_cpu(hdr->type),
1622 					le16_to_cpu(hdr->op),
1623 					le32_to_cpu(hdr->flags),
1624 					le32_to_cpu(hdr->buf_alloc),
1625 					le32_to_cpu(hdr->fwd_cnt));
1626 
1627 	if (!virtio_transport_valid_type(le16_to_cpu(hdr->type))) {
1628 		(void)virtio_transport_reset_no_sock(t, skb, net);
1629 		goto free_pkt;
1630 	}
1631 
1632 	/* The socket must be in connected or bound table
1633 	 * otherwise send reset back
1634 	 */
1635 	sk = vsock_find_connected_socket_net(&src, &dst, net);
1636 	if (!sk) {
1637 		sk = vsock_find_bound_socket_net(&dst, net);
1638 		if (!sk) {
1639 			(void)virtio_transport_reset_no_sock(t, skb, net);
1640 			goto free_pkt;
1641 		}
1642 	}
1643 
1644 	if (virtio_transport_get_type(sk) != le16_to_cpu(hdr->type)) {
1645 		(void)virtio_transport_reset_no_sock(t, skb, net);
1646 		sock_put(sk);
1647 		goto free_pkt;
1648 	}
1649 
1650 	if (!skb_set_owner_sk_safe(skb, sk)) {
1651 		WARN_ONCE(1, "receiving vsock socket has sk_refcnt == 0\n");
1652 		goto free_pkt;
1653 	}
1654 
1655 	vsk = vsock_sk(sk);
1656 
1657 	lock_sock(sk);
1658 
1659 	/* Check if sk has been closed or assigned to another transport before
1660 	 * lock_sock (note: listener sockets are not assigned to any transport)
1661 	 */
1662 	if (sock_flag(sk, SOCK_DONE) ||
1663 	    (sk->sk_state != TCP_LISTEN && vsk->transport != &t->transport)) {
1664 		(void)virtio_transport_reset_no_sock(t, skb, net);
1665 		release_sock(sk);
1666 		sock_put(sk);
1667 		goto free_pkt;
1668 	}
1669 
1670 	space_available = virtio_transport_space_update(sk, skb);
1671 
1672 	/* Update CID in case it has changed after a transport reset event */
1673 	if (vsk->local_addr.svm_cid != VMADDR_CID_ANY)
1674 		vsk->local_addr.svm_cid = dst.svm_cid;
1675 
1676 	if (space_available)
1677 		sk->sk_write_space(sk);
1678 
1679 	switch (sk->sk_state) {
1680 	case TCP_LISTEN:
1681 		virtio_transport_recv_listen(sk, skb, t);
1682 		kfree_skb(skb);
1683 		break;
1684 	case TCP_SYN_SENT:
1685 		virtio_transport_recv_connecting(sk, skb);
1686 		kfree_skb(skb);
1687 		break;
1688 	case TCP_ESTABLISHED:
1689 		virtio_transport_recv_connected(sk, skb);
1690 		break;
1691 	case TCP_CLOSING:
1692 		virtio_transport_recv_disconnecting(sk, skb);
1693 		kfree_skb(skb);
1694 		break;
1695 	default:
1696 		(void)virtio_transport_reset_no_sock(t, skb, net);
1697 		kfree_skb(skb);
1698 		break;
1699 	}
1700 
1701 	release_sock(sk);
1702 
1703 	/* Release refcnt obtained when we fetched this socket out of the
1704 	 * bound or connected list.
1705 	 */
1706 	sock_put(sk);
1707 	return;
1708 
1709 free_pkt:
1710 	kfree_skb(skb);
1711 }
1712 EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt);
1713 
1714 /* Remove skbs found in a queue that have a vsk that matches.
1715  *
1716  * Each skb is freed.
1717  *
1718  * Returns the count of skbs that were reply packets.
1719  */
1720 int virtio_transport_purge_skbs(void *vsk, struct sk_buff_head *queue)
1721 {
1722 	struct sk_buff_head freeme;
1723 	struct sk_buff *skb, *tmp;
1724 	int cnt = 0;
1725 
1726 	skb_queue_head_init(&freeme);
1727 
1728 	spin_lock_bh(&queue->lock);
1729 	skb_queue_walk_safe(queue, skb, tmp) {
1730 		if (vsock_sk(skb->sk) != vsk)
1731 			continue;
1732 
1733 		__skb_unlink(skb, queue);
1734 		__skb_queue_tail(&freeme, skb);
1735 
1736 		if (virtio_vsock_skb_reply(skb))
1737 			cnt++;
1738 	}
1739 	spin_unlock_bh(&queue->lock);
1740 
1741 	__skb_queue_purge(&freeme);
1742 
1743 	return cnt;
1744 }
1745 EXPORT_SYMBOL_GPL(virtio_transport_purge_skbs);
1746 
1747 int virtio_transport_read_skb(struct vsock_sock *vsk, skb_read_actor_t recv_actor)
1748 {
1749 	struct virtio_vsock_sock *vvs = vsk->trans;
1750 	struct sock *sk = sk_vsock(vsk);
1751 	struct virtio_vsock_hdr *hdr;
1752 	struct sk_buff *skb;
1753 	u32 pkt_len;
1754 	int off = 0;
1755 	int err;
1756 
1757 	spin_lock_bh(&vvs->rx_lock);
1758 	/* Use __skb_recv_datagram() for race-free handling of the receive. It
1759 	 * works for types other than dgrams.
1760 	 */
1761 	skb = __skb_recv_datagram(sk, &vvs->rx_queue, MSG_DONTWAIT, &off, &err);
1762 	if (!skb) {
1763 		spin_unlock_bh(&vvs->rx_lock);
1764 		return err;
1765 	}
1766 
1767 	hdr = virtio_vsock_hdr(skb);
1768 	if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOM)
1769 		vvs->msg_count--;
1770 
1771 	pkt_len = le32_to_cpu(hdr->len);
1772 	virtio_transport_dec_rx_pkt(vvs, pkt_len, pkt_len);
1773 	spin_unlock_bh(&vvs->rx_lock);
1774 
1775 	virtio_transport_send_credit_update(vsk);
1776 
1777 	return recv_actor(sk, skb);
1778 }
1779 EXPORT_SYMBOL_GPL(virtio_transport_read_skb);
1780 
1781 int virtio_transport_notify_set_rcvlowat(struct vsock_sock *vsk, int val)
1782 {
1783 	struct virtio_vsock_sock *vvs = vsk->trans;
1784 	bool send_update;
1785 
1786 	spin_lock_bh(&vvs->rx_lock);
1787 
1788 	/* If number of available bytes is less than new SO_RCVLOWAT value,
1789 	 * kick sender to send more data, because sender may sleep in its
1790 	 * 'send()' syscall waiting for enough space at our side. Also
1791 	 * don't send credit update when peer already knows actual value -
1792 	 * such transmission will be useless.
1793 	 */
1794 	send_update = (vvs->rx_bytes < val) &&
1795 		      (vvs->fwd_cnt != vvs->last_fwd_cnt);
1796 
1797 	spin_unlock_bh(&vvs->rx_lock);
1798 
1799 	if (send_update) {
1800 		int err;
1801 
1802 		err = virtio_transport_send_credit_update(vsk);
1803 		if (err < 0)
1804 			return err;
1805 	}
1806 
1807 	return 0;
1808 }
1809 EXPORT_SYMBOL_GPL(virtio_transport_notify_set_rcvlowat);
1810 
1811 MODULE_LICENSE("GPL v2");
1812 MODULE_AUTHOR("Asias He");
1813 MODULE_DESCRIPTION("common code for virtio vsock");
1814