xref: /linux/io_uring/zcrx.c (revision a0285236ab93fdfdd1008afaa04561d142d6c276)
1 // SPDX-License-Identifier: GPL-2.0
2 #include <linux/kernel.h>
3 #include <linux/errno.h>
4 #include <linux/dma-map-ops.h>
5 #include <linux/mm.h>
6 #include <linux/nospec.h>
7 #include <linux/io_uring.h>
8 #include <linux/netdevice.h>
9 #include <linux/rtnetlink.h>
10 #include <linux/skbuff_ref.h>
11 
12 #include <net/page_pool/helpers.h>
13 #include <net/page_pool/memory_provider.h>
14 #include <net/netlink.h>
15 #include <net/netdev_rx_queue.h>
16 #include <net/tcp.h>
17 #include <net/rps.h>
18 
19 #include <trace/events/page_pool.h>
20 
21 #include <uapi/linux/io_uring.h>
22 
23 #include "io_uring.h"
24 #include "kbuf.h"
25 #include "memmap.h"
26 #include "zcrx.h"
27 #include "rsrc.h"
28 
29 #define IO_DMA_ATTR (DMA_ATTR_SKIP_CPU_SYNC | DMA_ATTR_WEAK_ORDERING)
30 
31 static void __io_zcrx_unmap_area(struct io_zcrx_ifq *ifq,
32 				 struct io_zcrx_area *area, int nr_mapped)
33 {
34 	int i;
35 
36 	for (i = 0; i < nr_mapped; i++) {
37 		struct net_iov *niov = &area->nia.niovs[i];
38 		dma_addr_t dma;
39 
40 		dma = page_pool_get_dma_addr_netmem(net_iov_to_netmem(niov));
41 		dma_unmap_page_attrs(ifq->dev, dma, PAGE_SIZE,
42 				     DMA_FROM_DEVICE, IO_DMA_ATTR);
43 		net_mp_niov_set_dma_addr(niov, 0);
44 	}
45 }
46 
47 static void io_zcrx_unmap_area(struct io_zcrx_ifq *ifq, struct io_zcrx_area *area)
48 {
49 	if (area->is_mapped)
50 		__io_zcrx_unmap_area(ifq, area, area->nia.num_niovs);
51 }
52 
53 static int io_zcrx_map_area(struct io_zcrx_ifq *ifq, struct io_zcrx_area *area)
54 {
55 	int i;
56 
57 	for (i = 0; i < area->nia.num_niovs; i++) {
58 		struct net_iov *niov = &area->nia.niovs[i];
59 		dma_addr_t dma;
60 
61 		dma = dma_map_page_attrs(ifq->dev, area->pages[i], 0, PAGE_SIZE,
62 					 DMA_FROM_DEVICE, IO_DMA_ATTR);
63 		if (dma_mapping_error(ifq->dev, dma))
64 			break;
65 		if (net_mp_niov_set_dma_addr(niov, dma)) {
66 			dma_unmap_page_attrs(ifq->dev, dma, PAGE_SIZE,
67 					     DMA_FROM_DEVICE, IO_DMA_ATTR);
68 			break;
69 		}
70 	}
71 
72 	if (i != area->nia.num_niovs) {
73 		__io_zcrx_unmap_area(ifq, area, i);
74 		return -EINVAL;
75 	}
76 
77 	area->is_mapped = true;
78 	return 0;
79 }
80 
81 static void io_zcrx_sync_for_device(const struct page_pool *pool,
82 				    struct net_iov *niov)
83 {
84 #if defined(CONFIG_HAS_DMA) && defined(CONFIG_DMA_NEED_SYNC)
85 	dma_addr_t dma_addr;
86 
87 	if (!dma_dev_need_sync(pool->p.dev))
88 		return;
89 
90 	dma_addr = page_pool_get_dma_addr_netmem(net_iov_to_netmem(niov));
91 	__dma_sync_single_for_device(pool->p.dev, dma_addr + pool->p.offset,
92 				     PAGE_SIZE, pool->p.dma_dir);
93 #endif
94 }
95 
96 #define IO_RQ_MAX_ENTRIES		32768
97 
98 #define IO_SKBS_PER_CALL_LIMIT	20
99 
100 struct io_zcrx_args {
101 	struct io_kiocb		*req;
102 	struct io_zcrx_ifq	*ifq;
103 	struct socket		*sock;
104 	unsigned		nr_skbs;
105 };
106 
107 static const struct memory_provider_ops io_uring_pp_zc_ops;
108 
109 static inline struct io_zcrx_area *io_zcrx_iov_to_area(const struct net_iov *niov)
110 {
111 	struct net_iov_area *owner = net_iov_owner(niov);
112 
113 	return container_of(owner, struct io_zcrx_area, nia);
114 }
115 
116 static inline atomic_t *io_get_user_counter(struct net_iov *niov)
117 {
118 	struct io_zcrx_area *area = io_zcrx_iov_to_area(niov);
119 
120 	return &area->user_refs[net_iov_idx(niov)];
121 }
122 
123 static bool io_zcrx_put_niov_uref(struct net_iov *niov)
124 {
125 	atomic_t *uref = io_get_user_counter(niov);
126 
127 	if (unlikely(!atomic_read(uref)))
128 		return false;
129 	atomic_dec(uref);
130 	return true;
131 }
132 
133 static void io_zcrx_get_niov_uref(struct net_iov *niov)
134 {
135 	atomic_inc(io_get_user_counter(niov));
136 }
137 
138 static inline struct page *io_zcrx_iov_page(const struct net_iov *niov)
139 {
140 	struct io_zcrx_area *area = io_zcrx_iov_to_area(niov);
141 
142 	return area->pages[net_iov_idx(niov)];
143 }
144 
145 static int io_allocate_rbuf_ring(struct io_zcrx_ifq *ifq,
146 				 struct io_uring_zcrx_ifq_reg *reg,
147 				 struct io_uring_region_desc *rd)
148 {
149 	size_t off, size;
150 	void *ptr;
151 	int ret;
152 
153 	off = sizeof(struct io_uring);
154 	size = off + sizeof(struct io_uring_zcrx_rqe) * reg->rq_entries;
155 	if (size > rd->size)
156 		return -EINVAL;
157 
158 	ret = io_create_region_mmap_safe(ifq->ctx, &ifq->ctx->zcrx_region, rd,
159 					 IORING_MAP_OFF_ZCRX_REGION);
160 	if (ret < 0)
161 		return ret;
162 
163 	ptr = io_region_get_ptr(&ifq->ctx->zcrx_region);
164 	ifq->rq_ring = (struct io_uring *)ptr;
165 	ifq->rqes = (struct io_uring_zcrx_rqe *)(ptr + off);
166 	return 0;
167 }
168 
169 static void io_free_rbuf_ring(struct io_zcrx_ifq *ifq)
170 {
171 	io_free_region(ifq->ctx, &ifq->ctx->zcrx_region);
172 	ifq->rq_ring = NULL;
173 	ifq->rqes = NULL;
174 }
175 
176 static void io_zcrx_free_area(struct io_zcrx_area *area)
177 {
178 	io_zcrx_unmap_area(area->ifq, area);
179 
180 	kvfree(area->freelist);
181 	kvfree(area->nia.niovs);
182 	kvfree(area->user_refs);
183 	if (area->pages) {
184 		unpin_user_pages(area->pages, area->nr_folios);
185 		kvfree(area->pages);
186 	}
187 	kfree(area);
188 }
189 
190 static int io_zcrx_create_area(struct io_zcrx_ifq *ifq,
191 			       struct io_zcrx_area **res,
192 			       struct io_uring_zcrx_area_reg *area_reg)
193 {
194 	struct io_zcrx_area *area;
195 	int i, ret, nr_pages, nr_iovs;
196 	struct iovec iov;
197 
198 	if (area_reg->flags || area_reg->rq_area_token)
199 		return -EINVAL;
200 	if (area_reg->__resv1 || area_reg->__resv2[0] || area_reg->__resv2[1])
201 		return -EINVAL;
202 	if (area_reg->addr & ~PAGE_MASK || area_reg->len & ~PAGE_MASK)
203 		return -EINVAL;
204 
205 	iov.iov_base = u64_to_user_ptr(area_reg->addr);
206 	iov.iov_len = area_reg->len;
207 	ret = io_buffer_validate(&iov);
208 	if (ret)
209 		return ret;
210 
211 	ret = -ENOMEM;
212 	area = kzalloc(sizeof(*area), GFP_KERNEL);
213 	if (!area)
214 		goto err;
215 
216 	area->pages = io_pin_pages((unsigned long)area_reg->addr, area_reg->len,
217 				   &nr_pages);
218 	if (IS_ERR(area->pages)) {
219 		ret = PTR_ERR(area->pages);
220 		area->pages = NULL;
221 		goto err;
222 	}
223 	area->nr_folios = nr_iovs = nr_pages;
224 	area->nia.num_niovs = nr_iovs;
225 
226 	area->nia.niovs = kvmalloc_array(nr_iovs, sizeof(area->nia.niovs[0]),
227 					 GFP_KERNEL | __GFP_ZERO);
228 	if (!area->nia.niovs)
229 		goto err;
230 
231 	area->freelist = kvmalloc_array(nr_iovs, sizeof(area->freelist[0]),
232 					GFP_KERNEL | __GFP_ZERO);
233 	if (!area->freelist)
234 		goto err;
235 
236 	for (i = 0; i < nr_iovs; i++)
237 		area->freelist[i] = i;
238 
239 	area->user_refs = kvmalloc_array(nr_iovs, sizeof(area->user_refs[0]),
240 					GFP_KERNEL | __GFP_ZERO);
241 	if (!area->user_refs)
242 		goto err;
243 
244 	for (i = 0; i < nr_iovs; i++) {
245 		struct net_iov *niov = &area->nia.niovs[i];
246 
247 		niov->owner = &area->nia;
248 		area->freelist[i] = i;
249 		atomic_set(&area->user_refs[i], 0);
250 	}
251 
252 	area->free_count = nr_iovs;
253 	area->ifq = ifq;
254 	/* we're only supporting one area per ifq for now */
255 	area->area_id = 0;
256 	area_reg->rq_area_token = (u64)area->area_id << IORING_ZCRX_AREA_SHIFT;
257 	spin_lock_init(&area->freelist_lock);
258 	*res = area;
259 	return 0;
260 err:
261 	if (area)
262 		io_zcrx_free_area(area);
263 	return ret;
264 }
265 
266 static struct io_zcrx_ifq *io_zcrx_ifq_alloc(struct io_ring_ctx *ctx)
267 {
268 	struct io_zcrx_ifq *ifq;
269 
270 	ifq = kzalloc(sizeof(*ifq), GFP_KERNEL);
271 	if (!ifq)
272 		return NULL;
273 
274 	ifq->if_rxq = -1;
275 	ifq->ctx = ctx;
276 	spin_lock_init(&ifq->lock);
277 	spin_lock_init(&ifq->rq_lock);
278 	return ifq;
279 }
280 
281 static void io_zcrx_drop_netdev(struct io_zcrx_ifq *ifq)
282 {
283 	spin_lock(&ifq->lock);
284 	if (ifq->netdev) {
285 		netdev_put(ifq->netdev, &ifq->netdev_tracker);
286 		ifq->netdev = NULL;
287 	}
288 	spin_unlock(&ifq->lock);
289 }
290 
291 static void io_close_queue(struct io_zcrx_ifq *ifq)
292 {
293 	struct net_device *netdev;
294 	netdevice_tracker netdev_tracker;
295 	struct pp_memory_provider_params p = {
296 		.mp_ops = &io_uring_pp_zc_ops,
297 		.mp_priv = ifq,
298 	};
299 
300 	if (ifq->if_rxq == -1)
301 		return;
302 
303 	spin_lock(&ifq->lock);
304 	netdev = ifq->netdev;
305 	netdev_tracker = ifq->netdev_tracker;
306 	ifq->netdev = NULL;
307 	spin_unlock(&ifq->lock);
308 
309 	if (netdev) {
310 		net_mp_close_rxq(netdev, ifq->if_rxq, &p);
311 		netdev_put(netdev, &netdev_tracker);
312 	}
313 	ifq->if_rxq = -1;
314 }
315 
316 static void io_zcrx_ifq_free(struct io_zcrx_ifq *ifq)
317 {
318 	io_close_queue(ifq);
319 	io_zcrx_drop_netdev(ifq);
320 
321 	if (ifq->area)
322 		io_zcrx_free_area(ifq->area);
323 	if (ifq->dev)
324 		put_device(ifq->dev);
325 
326 	io_free_rbuf_ring(ifq);
327 	kfree(ifq);
328 }
329 
330 int io_register_zcrx_ifq(struct io_ring_ctx *ctx,
331 			  struct io_uring_zcrx_ifq_reg __user *arg)
332 {
333 	struct pp_memory_provider_params mp_param = {};
334 	struct io_uring_zcrx_area_reg area;
335 	struct io_uring_zcrx_ifq_reg reg;
336 	struct io_uring_region_desc rd;
337 	struct io_zcrx_ifq *ifq;
338 	int ret;
339 
340 	/*
341 	 * 1. Interface queue allocation.
342 	 * 2. It can observe data destined for sockets of other tasks.
343 	 */
344 	if (!capable(CAP_NET_ADMIN))
345 		return -EPERM;
346 
347 	/* mandatory io_uring features for zc rx */
348 	if (!(ctx->flags & IORING_SETUP_DEFER_TASKRUN &&
349 	      ctx->flags & IORING_SETUP_CQE32))
350 		return -EINVAL;
351 	if (ctx->ifq)
352 		return -EBUSY;
353 	if (copy_from_user(&reg, arg, sizeof(reg)))
354 		return -EFAULT;
355 	if (copy_from_user(&rd, u64_to_user_ptr(reg.region_ptr), sizeof(rd)))
356 		return -EFAULT;
357 	if (memchr_inv(&reg.__resv, 0, sizeof(reg.__resv)))
358 		return -EINVAL;
359 	if (reg.if_rxq == -1 || !reg.rq_entries || reg.flags)
360 		return -EINVAL;
361 	if (reg.rq_entries > IO_RQ_MAX_ENTRIES) {
362 		if (!(ctx->flags & IORING_SETUP_CLAMP))
363 			return -EINVAL;
364 		reg.rq_entries = IO_RQ_MAX_ENTRIES;
365 	}
366 	reg.rq_entries = roundup_pow_of_two(reg.rq_entries);
367 
368 	if (copy_from_user(&area, u64_to_user_ptr(reg.area_ptr), sizeof(area)))
369 		return -EFAULT;
370 
371 	ifq = io_zcrx_ifq_alloc(ctx);
372 	if (!ifq)
373 		return -ENOMEM;
374 
375 	ret = io_allocate_rbuf_ring(ifq, &reg, &rd);
376 	if (ret)
377 		goto err;
378 
379 	ret = io_zcrx_create_area(ifq, &ifq->area, &area);
380 	if (ret)
381 		goto err;
382 
383 	ifq->rq_entries = reg.rq_entries;
384 
385 	ret = -ENODEV;
386 	ifq->netdev = netdev_get_by_index(current->nsproxy->net_ns, reg.if_idx,
387 					  &ifq->netdev_tracker, GFP_KERNEL);
388 	if (!ifq->netdev)
389 		goto err;
390 
391 	ifq->dev = ifq->netdev->dev.parent;
392 	ret = -EOPNOTSUPP;
393 	if (!ifq->dev)
394 		goto err;
395 	get_device(ifq->dev);
396 
397 	ret = io_zcrx_map_area(ifq, ifq->area);
398 	if (ret)
399 		goto err;
400 
401 	mp_param.mp_ops = &io_uring_pp_zc_ops;
402 	mp_param.mp_priv = ifq;
403 	ret = net_mp_open_rxq(ifq->netdev, reg.if_rxq, &mp_param);
404 	if (ret)
405 		goto err;
406 	ifq->if_rxq = reg.if_rxq;
407 
408 	reg.offsets.rqes = sizeof(struct io_uring);
409 	reg.offsets.head = offsetof(struct io_uring, head);
410 	reg.offsets.tail = offsetof(struct io_uring, tail);
411 
412 	if (copy_to_user(arg, &reg, sizeof(reg)) ||
413 	    copy_to_user(u64_to_user_ptr(reg.region_ptr), &rd, sizeof(rd)) ||
414 	    copy_to_user(u64_to_user_ptr(reg.area_ptr), &area, sizeof(area))) {
415 		ret = -EFAULT;
416 		goto err;
417 	}
418 	ctx->ifq = ifq;
419 	return 0;
420 err:
421 	io_zcrx_ifq_free(ifq);
422 	return ret;
423 }
424 
425 void io_unregister_zcrx_ifqs(struct io_ring_ctx *ctx)
426 {
427 	struct io_zcrx_ifq *ifq = ctx->ifq;
428 
429 	lockdep_assert_held(&ctx->uring_lock);
430 
431 	if (!ifq)
432 		return;
433 
434 	ctx->ifq = NULL;
435 	io_zcrx_ifq_free(ifq);
436 }
437 
438 static struct net_iov *__io_zcrx_get_free_niov(struct io_zcrx_area *area)
439 {
440 	unsigned niov_idx;
441 
442 	lockdep_assert_held(&area->freelist_lock);
443 
444 	niov_idx = area->freelist[--area->free_count];
445 	return &area->nia.niovs[niov_idx];
446 }
447 
448 static void io_zcrx_return_niov_freelist(struct net_iov *niov)
449 {
450 	struct io_zcrx_area *area = io_zcrx_iov_to_area(niov);
451 
452 	spin_lock_bh(&area->freelist_lock);
453 	area->freelist[area->free_count++] = net_iov_idx(niov);
454 	spin_unlock_bh(&area->freelist_lock);
455 }
456 
457 static void io_zcrx_return_niov(struct net_iov *niov)
458 {
459 	netmem_ref netmem = net_iov_to_netmem(niov);
460 
461 	if (!niov->pp) {
462 		/* copy fallback allocated niovs */
463 		io_zcrx_return_niov_freelist(niov);
464 		return;
465 	}
466 	page_pool_put_unrefed_netmem(niov->pp, netmem, -1, false);
467 }
468 
469 static void io_zcrx_scrub(struct io_zcrx_ifq *ifq)
470 {
471 	struct io_zcrx_area *area = ifq->area;
472 	int i;
473 
474 	if (!area)
475 		return;
476 
477 	/* Reclaim back all buffers given to the user space. */
478 	for (i = 0; i < area->nia.num_niovs; i++) {
479 		struct net_iov *niov = &area->nia.niovs[i];
480 		int nr;
481 
482 		if (!atomic_read(io_get_user_counter(niov)))
483 			continue;
484 		nr = atomic_xchg(io_get_user_counter(niov), 0);
485 		if (nr && !page_pool_unref_netmem(net_iov_to_netmem(niov), nr))
486 			io_zcrx_return_niov(niov);
487 	}
488 }
489 
490 void io_shutdown_zcrx_ifqs(struct io_ring_ctx *ctx)
491 {
492 	lockdep_assert_held(&ctx->uring_lock);
493 
494 	if (!ctx->ifq)
495 		return;
496 	io_zcrx_scrub(ctx->ifq);
497 	io_close_queue(ctx->ifq);
498 }
499 
500 static inline u32 io_zcrx_rqring_entries(struct io_zcrx_ifq *ifq)
501 {
502 	u32 entries;
503 
504 	entries = smp_load_acquire(&ifq->rq_ring->tail) - ifq->cached_rq_head;
505 	return min(entries, ifq->rq_entries);
506 }
507 
508 static struct io_uring_zcrx_rqe *io_zcrx_get_rqe(struct io_zcrx_ifq *ifq,
509 						 unsigned mask)
510 {
511 	unsigned int idx = ifq->cached_rq_head++ & mask;
512 
513 	return &ifq->rqes[idx];
514 }
515 
516 static void io_zcrx_ring_refill(struct page_pool *pp,
517 				struct io_zcrx_ifq *ifq)
518 {
519 	unsigned int mask = ifq->rq_entries - 1;
520 	unsigned int entries;
521 	netmem_ref netmem;
522 
523 	spin_lock_bh(&ifq->rq_lock);
524 
525 	entries = io_zcrx_rqring_entries(ifq);
526 	entries = min_t(unsigned, entries, PP_ALLOC_CACHE_REFILL - pp->alloc.count);
527 	if (unlikely(!entries)) {
528 		spin_unlock_bh(&ifq->rq_lock);
529 		return;
530 	}
531 
532 	do {
533 		struct io_uring_zcrx_rqe *rqe = io_zcrx_get_rqe(ifq, mask);
534 		struct io_zcrx_area *area;
535 		struct net_iov *niov;
536 		unsigned niov_idx, area_idx;
537 
538 		area_idx = rqe->off >> IORING_ZCRX_AREA_SHIFT;
539 		niov_idx = (rqe->off & ~IORING_ZCRX_AREA_MASK) >> PAGE_SHIFT;
540 
541 		if (unlikely(rqe->__pad || area_idx))
542 			continue;
543 		area = ifq->area;
544 
545 		if (unlikely(niov_idx >= area->nia.num_niovs))
546 			continue;
547 		niov_idx = array_index_nospec(niov_idx, area->nia.num_niovs);
548 
549 		niov = &area->nia.niovs[niov_idx];
550 		if (!io_zcrx_put_niov_uref(niov))
551 			continue;
552 
553 		netmem = net_iov_to_netmem(niov);
554 		if (page_pool_unref_netmem(netmem, 1) != 0)
555 			continue;
556 
557 		if (unlikely(niov->pp != pp)) {
558 			io_zcrx_return_niov(niov);
559 			continue;
560 		}
561 
562 		io_zcrx_sync_for_device(pp, niov);
563 		net_mp_netmem_place_in_cache(pp, netmem);
564 	} while (--entries);
565 
566 	smp_store_release(&ifq->rq_ring->head, ifq->cached_rq_head);
567 	spin_unlock_bh(&ifq->rq_lock);
568 }
569 
570 static void io_zcrx_refill_slow(struct page_pool *pp, struct io_zcrx_ifq *ifq)
571 {
572 	struct io_zcrx_area *area = ifq->area;
573 
574 	spin_lock_bh(&area->freelist_lock);
575 	while (area->free_count && pp->alloc.count < PP_ALLOC_CACHE_REFILL) {
576 		struct net_iov *niov = __io_zcrx_get_free_niov(area);
577 		netmem_ref netmem = net_iov_to_netmem(niov);
578 
579 		net_mp_niov_set_page_pool(pp, niov);
580 		io_zcrx_sync_for_device(pp, niov);
581 		net_mp_netmem_place_in_cache(pp, netmem);
582 	}
583 	spin_unlock_bh(&area->freelist_lock);
584 }
585 
586 static netmem_ref io_pp_zc_alloc_netmems(struct page_pool *pp, gfp_t gfp)
587 {
588 	struct io_zcrx_ifq *ifq = pp->mp_priv;
589 
590 	/* pp should already be ensuring that */
591 	if (unlikely(pp->alloc.count))
592 		goto out_return;
593 
594 	io_zcrx_ring_refill(pp, ifq);
595 	if (likely(pp->alloc.count))
596 		goto out_return;
597 
598 	io_zcrx_refill_slow(pp, ifq);
599 	if (!pp->alloc.count)
600 		return 0;
601 out_return:
602 	return pp->alloc.cache[--pp->alloc.count];
603 }
604 
605 static bool io_pp_zc_release_netmem(struct page_pool *pp, netmem_ref netmem)
606 {
607 	struct net_iov *niov;
608 
609 	if (WARN_ON_ONCE(!netmem_is_net_iov(netmem)))
610 		return false;
611 
612 	niov = netmem_to_net_iov(netmem);
613 	net_mp_niov_clear_page_pool(niov);
614 	io_zcrx_return_niov_freelist(niov);
615 	return false;
616 }
617 
618 static int io_pp_zc_init(struct page_pool *pp)
619 {
620 	struct io_zcrx_ifq *ifq = pp->mp_priv;
621 
622 	if (WARN_ON_ONCE(!ifq))
623 		return -EINVAL;
624 	if (WARN_ON_ONCE(ifq->dev != pp->p.dev))
625 		return -EINVAL;
626 	if (WARN_ON_ONCE(!pp->dma_map))
627 		return -EOPNOTSUPP;
628 	if (pp->p.order != 0)
629 		return -EOPNOTSUPP;
630 	if (pp->p.dma_dir != DMA_FROM_DEVICE)
631 		return -EOPNOTSUPP;
632 
633 	percpu_ref_get(&ifq->ctx->refs);
634 	return 0;
635 }
636 
637 static void io_pp_zc_destroy(struct page_pool *pp)
638 {
639 	struct io_zcrx_ifq *ifq = pp->mp_priv;
640 	struct io_zcrx_area *area = ifq->area;
641 
642 	if (WARN_ON_ONCE(area->free_count != area->nia.num_niovs))
643 		return;
644 	percpu_ref_put(&ifq->ctx->refs);
645 }
646 
647 static int io_pp_nl_fill(void *mp_priv, struct sk_buff *rsp,
648 			 struct netdev_rx_queue *rxq)
649 {
650 	struct nlattr *nest;
651 	int type;
652 
653 	type = rxq ? NETDEV_A_QUEUE_IO_URING : NETDEV_A_PAGE_POOL_IO_URING;
654 	nest = nla_nest_start(rsp, type);
655 	if (!nest)
656 		return -EMSGSIZE;
657 	nla_nest_end(rsp, nest);
658 
659 	return 0;
660 }
661 
662 static void io_pp_uninstall(void *mp_priv, struct netdev_rx_queue *rxq)
663 {
664 	struct pp_memory_provider_params *p = &rxq->mp_params;
665 	struct io_zcrx_ifq *ifq = mp_priv;
666 
667 	io_zcrx_drop_netdev(ifq);
668 	p->mp_ops = NULL;
669 	p->mp_priv = NULL;
670 }
671 
672 static const struct memory_provider_ops io_uring_pp_zc_ops = {
673 	.alloc_netmems		= io_pp_zc_alloc_netmems,
674 	.release_netmem		= io_pp_zc_release_netmem,
675 	.init			= io_pp_zc_init,
676 	.destroy		= io_pp_zc_destroy,
677 	.nl_fill		= io_pp_nl_fill,
678 	.uninstall		= io_pp_uninstall,
679 };
680 
681 static bool io_zcrx_queue_cqe(struct io_kiocb *req, struct net_iov *niov,
682 			      struct io_zcrx_ifq *ifq, int off, int len)
683 {
684 	struct io_uring_zcrx_cqe *rcqe;
685 	struct io_zcrx_area *area;
686 	struct io_uring_cqe *cqe;
687 	u64 offset;
688 
689 	if (!io_defer_get_uncommited_cqe(req->ctx, &cqe))
690 		return false;
691 
692 	cqe->user_data = req->cqe.user_data;
693 	cqe->res = len;
694 	cqe->flags = IORING_CQE_F_MORE;
695 
696 	area = io_zcrx_iov_to_area(niov);
697 	offset = off + (net_iov_idx(niov) << PAGE_SHIFT);
698 	rcqe = (struct io_uring_zcrx_cqe *)(cqe + 1);
699 	rcqe->off = offset + ((u64)area->area_id << IORING_ZCRX_AREA_SHIFT);
700 	rcqe->__pad = 0;
701 	return true;
702 }
703 
704 static struct net_iov *io_zcrx_alloc_fallback(struct io_zcrx_area *area)
705 {
706 	struct net_iov *niov = NULL;
707 
708 	spin_lock_bh(&area->freelist_lock);
709 	if (area->free_count)
710 		niov = __io_zcrx_get_free_niov(area);
711 	spin_unlock_bh(&area->freelist_lock);
712 
713 	if (niov)
714 		page_pool_fragment_netmem(net_iov_to_netmem(niov), 1);
715 	return niov;
716 }
717 
718 static ssize_t io_zcrx_copy_chunk(struct io_kiocb *req, struct io_zcrx_ifq *ifq,
719 				  void *src_base, struct page *src_page,
720 				  unsigned int src_offset, size_t len)
721 {
722 	struct io_zcrx_area *area = ifq->area;
723 	size_t copied = 0;
724 	int ret = 0;
725 
726 	while (len) {
727 		size_t copy_size = min_t(size_t, PAGE_SIZE, len);
728 		const int dst_off = 0;
729 		struct net_iov *niov;
730 		struct page *dst_page;
731 		void *dst_addr;
732 
733 		niov = io_zcrx_alloc_fallback(area);
734 		if (!niov) {
735 			ret = -ENOMEM;
736 			break;
737 		}
738 
739 		dst_page = io_zcrx_iov_page(niov);
740 		dst_addr = kmap_local_page(dst_page);
741 		if (src_page)
742 			src_base = kmap_local_page(src_page);
743 
744 		memcpy(dst_addr, src_base + src_offset, copy_size);
745 
746 		if (src_page)
747 			kunmap_local(src_base);
748 		kunmap_local(dst_addr);
749 
750 		if (!io_zcrx_queue_cqe(req, niov, ifq, dst_off, copy_size)) {
751 			io_zcrx_return_niov(niov);
752 			ret = -ENOSPC;
753 			break;
754 		}
755 
756 		io_zcrx_get_niov_uref(niov);
757 		src_offset += copy_size;
758 		len -= copy_size;
759 		copied += copy_size;
760 	}
761 
762 	return copied ? copied : ret;
763 }
764 
765 static int io_zcrx_copy_frag(struct io_kiocb *req, struct io_zcrx_ifq *ifq,
766 			     const skb_frag_t *frag, int off, int len)
767 {
768 	struct page *page = skb_frag_page(frag);
769 	u32 p_off, p_len, t, copied = 0;
770 	int ret = 0;
771 
772 	off += skb_frag_off(frag);
773 
774 	skb_frag_foreach_page(frag, off, len,
775 			      page, p_off, p_len, t) {
776 		ret = io_zcrx_copy_chunk(req, ifq, NULL, page, p_off, p_len);
777 		if (ret < 0)
778 			return copied ? copied : ret;
779 		copied += ret;
780 	}
781 	return copied;
782 }
783 
784 static int io_zcrx_recv_frag(struct io_kiocb *req, struct io_zcrx_ifq *ifq,
785 			     const skb_frag_t *frag, int off, int len)
786 {
787 	struct net_iov *niov;
788 
789 	if (unlikely(!skb_frag_is_net_iov(frag)))
790 		return io_zcrx_copy_frag(req, ifq, frag, off, len);
791 
792 	niov = netmem_to_net_iov(frag->netmem);
793 	if (niov->pp->mp_ops != &io_uring_pp_zc_ops ||
794 	    niov->pp->mp_priv != ifq)
795 		return -EFAULT;
796 
797 	if (!io_zcrx_queue_cqe(req, niov, ifq, off + skb_frag_off(frag), len))
798 		return -ENOSPC;
799 
800 	/*
801 	 * Prevent it from being recycled while user is accessing it.
802 	 * It has to be done before grabbing a user reference.
803 	 */
804 	page_pool_ref_netmem(net_iov_to_netmem(niov));
805 	io_zcrx_get_niov_uref(niov);
806 	return len;
807 }
808 
809 static int
810 io_zcrx_recv_skb(read_descriptor_t *desc, struct sk_buff *skb,
811 		 unsigned int offset, size_t len)
812 {
813 	struct io_zcrx_args *args = desc->arg.data;
814 	struct io_zcrx_ifq *ifq = args->ifq;
815 	struct io_kiocb *req = args->req;
816 	struct sk_buff *frag_iter;
817 	unsigned start, start_off = offset;
818 	int i, copy, end, off;
819 	int ret = 0;
820 
821 	len = min_t(size_t, len, desc->count);
822 	/*
823 	 * __tcp_read_sock() always calls io_zcrx_recv_skb one last time, even
824 	 * if desc->count is already 0. This is caused by the if (offset + 1 !=
825 	 * skb->len) check. Return early in this case to break out of
826 	 * __tcp_read_sock().
827 	 */
828 	if (!len)
829 		return 0;
830 	if (unlikely(args->nr_skbs++ > IO_SKBS_PER_CALL_LIMIT))
831 		return -EAGAIN;
832 
833 	if (unlikely(offset < skb_headlen(skb))) {
834 		ssize_t copied;
835 		size_t to_copy;
836 
837 		to_copy = min_t(size_t, skb_headlen(skb) - offset, len);
838 		copied = io_zcrx_copy_chunk(req, ifq, skb->data, NULL,
839 					    offset, to_copy);
840 		if (copied < 0) {
841 			ret = copied;
842 			goto out;
843 		}
844 		offset += copied;
845 		len -= copied;
846 		if (!len)
847 			goto out;
848 		if (offset != skb_headlen(skb))
849 			goto out;
850 	}
851 
852 	start = skb_headlen(skb);
853 
854 	for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) {
855 		const skb_frag_t *frag;
856 
857 		if (WARN_ON(start > offset + len))
858 			return -EFAULT;
859 
860 		frag = &skb_shinfo(skb)->frags[i];
861 		end = start + skb_frag_size(frag);
862 
863 		if (offset < end) {
864 			copy = end - offset;
865 			if (copy > len)
866 				copy = len;
867 
868 			off = offset - start;
869 			ret = io_zcrx_recv_frag(req, ifq, frag, off, copy);
870 			if (ret < 0)
871 				goto out;
872 
873 			offset += ret;
874 			len -= ret;
875 			if (len == 0 || ret != copy)
876 				goto out;
877 		}
878 		start = end;
879 	}
880 
881 	skb_walk_frags(skb, frag_iter) {
882 		if (WARN_ON(start > offset + len))
883 			return -EFAULT;
884 
885 		end = start + frag_iter->len;
886 		if (offset < end) {
887 			copy = end - offset;
888 			if (copy > len)
889 				copy = len;
890 
891 			off = offset - start;
892 			ret = io_zcrx_recv_skb(desc, frag_iter, off, copy);
893 			if (ret < 0)
894 				goto out;
895 
896 			offset += ret;
897 			len -= ret;
898 			if (len == 0 || ret != copy)
899 				goto out;
900 		}
901 		start = end;
902 	}
903 
904 out:
905 	if (offset == start_off)
906 		return ret;
907 	desc->count -= (offset - start_off);
908 	return offset - start_off;
909 }
910 
911 static int io_zcrx_tcp_recvmsg(struct io_kiocb *req, struct io_zcrx_ifq *ifq,
912 				struct sock *sk, int flags,
913 				unsigned issue_flags, unsigned int *outlen)
914 {
915 	unsigned int len = *outlen;
916 	struct io_zcrx_args args = {
917 		.req = req,
918 		.ifq = ifq,
919 		.sock = sk->sk_socket,
920 	};
921 	read_descriptor_t rd_desc = {
922 		.count = len ? len : UINT_MAX,
923 		.arg.data = &args,
924 	};
925 	int ret;
926 
927 	lock_sock(sk);
928 	ret = tcp_read_sock(sk, &rd_desc, io_zcrx_recv_skb);
929 	if (len && ret > 0)
930 		*outlen = len - ret;
931 	if (ret <= 0) {
932 		if (ret < 0 || sock_flag(sk, SOCK_DONE))
933 			goto out;
934 		if (sk->sk_err)
935 			ret = sock_error(sk);
936 		else if (sk->sk_shutdown & RCV_SHUTDOWN)
937 			goto out;
938 		else if (sk->sk_state == TCP_CLOSE)
939 			ret = -ENOTCONN;
940 		else
941 			ret = -EAGAIN;
942 	} else if (unlikely(args.nr_skbs > IO_SKBS_PER_CALL_LIMIT) &&
943 		   (issue_flags & IO_URING_F_MULTISHOT)) {
944 		ret = IOU_REQUEUE;
945 	} else if (sock_flag(sk, SOCK_DONE)) {
946 		/* Make it to retry until it finally gets 0. */
947 		if (issue_flags & IO_URING_F_MULTISHOT)
948 			ret = IOU_REQUEUE;
949 		else
950 			ret = -EAGAIN;
951 	}
952 out:
953 	release_sock(sk);
954 	return ret;
955 }
956 
957 int io_zcrx_recv(struct io_kiocb *req, struct io_zcrx_ifq *ifq,
958 		 struct socket *sock, unsigned int flags,
959 		 unsigned issue_flags, unsigned int *len)
960 {
961 	struct sock *sk = sock->sk;
962 	const struct proto *prot = READ_ONCE(sk->sk_prot);
963 
964 	if (prot->recvmsg != tcp_recvmsg)
965 		return -EPROTONOSUPPORT;
966 
967 	sock_rps_record_flow(sk);
968 	return io_zcrx_tcp_recvmsg(req, ifq, sk, flags, issue_flags, len);
969 }
970