xref: /linux/io_uring/zcrx.c (revision a7ddedc84c59a645ef970b992f7cda5bffc70cc0)
1 // SPDX-License-Identifier: GPL-2.0
2 #include <linux/kernel.h>
3 #include <linux/errno.h>
4 #include <linux/dma-map-ops.h>
5 #include <linux/mm.h>
6 #include <linux/nospec.h>
7 #include <linux/io_uring.h>
8 #include <linux/netdevice.h>
9 #include <linux/rtnetlink.h>
10 #include <linux/skbuff_ref.h>
11 
12 #include <net/page_pool/helpers.h>
13 #include <net/page_pool/memory_provider.h>
14 #include <net/netlink.h>
15 #include <net/netdev_queues.h>
16 #include <net/netdev_rx_queue.h>
17 #include <net/tcp.h>
18 #include <net/rps.h>
19 
20 #include <trace/events/page_pool.h>
21 
22 #include <uapi/linux/io_uring.h>
23 
24 #include "io_uring.h"
25 #include "kbuf.h"
26 #include "memmap.h"
27 #include "zcrx.h"
28 #include "rsrc.h"
29 
30 #define IO_DMA_ATTR (DMA_ATTR_SKIP_CPU_SYNC | DMA_ATTR_WEAK_ORDERING)
31 
32 static inline struct io_zcrx_ifq *io_pp_to_ifq(struct page_pool *pp)
33 {
34 	return pp->mp_priv;
35 }
36 
37 static inline struct io_zcrx_area *io_zcrx_iov_to_area(const struct net_iov *niov)
38 {
39 	struct net_iov_area *owner = net_iov_owner(niov);
40 
41 	return container_of(owner, struct io_zcrx_area, nia);
42 }
43 
44 static inline struct page *io_zcrx_iov_page(const struct net_iov *niov)
45 {
46 	struct io_zcrx_area *area = io_zcrx_iov_to_area(niov);
47 
48 	lockdep_assert(!area->mem.is_dmabuf);
49 
50 	return area->mem.pages[net_iov_idx(niov)];
51 }
52 
53 static int io_populate_area_dma(struct io_zcrx_ifq *ifq,
54 				struct io_zcrx_area *area,
55 				struct sg_table *sgt, unsigned long off)
56 {
57 	struct scatterlist *sg;
58 	unsigned i, niov_idx = 0;
59 
60 	for_each_sgtable_dma_sg(sgt, sg, i) {
61 		dma_addr_t dma = sg_dma_address(sg);
62 		unsigned long sg_len = sg_dma_len(sg);
63 		unsigned long sg_off = min(sg_len, off);
64 
65 		off -= sg_off;
66 		sg_len -= sg_off;
67 		dma += sg_off;
68 
69 		while (sg_len && niov_idx < area->nia.num_niovs) {
70 			struct net_iov *niov = &area->nia.niovs[niov_idx];
71 
72 			if (net_mp_niov_set_dma_addr(niov, dma))
73 				return -EFAULT;
74 			sg_len -= PAGE_SIZE;
75 			dma += PAGE_SIZE;
76 			niov_idx++;
77 		}
78 	}
79 	return 0;
80 }
81 
82 static void io_release_dmabuf(struct io_zcrx_mem *mem)
83 {
84 	if (!IS_ENABLED(CONFIG_DMA_SHARED_BUFFER))
85 		return;
86 
87 	if (mem->sgt)
88 		dma_buf_unmap_attachment_unlocked(mem->attach, mem->sgt,
89 						  DMA_FROM_DEVICE);
90 	if (mem->attach)
91 		dma_buf_detach(mem->dmabuf, mem->attach);
92 	if (mem->dmabuf)
93 		dma_buf_put(mem->dmabuf);
94 
95 	mem->sgt = NULL;
96 	mem->attach = NULL;
97 	mem->dmabuf = NULL;
98 }
99 
100 static int io_import_dmabuf(struct io_zcrx_ifq *ifq,
101 			    struct io_zcrx_mem *mem,
102 			    struct io_uring_zcrx_area_reg *area_reg)
103 {
104 	unsigned long off = (unsigned long)area_reg->addr;
105 	unsigned long len = (unsigned long)area_reg->len;
106 	unsigned long total_size = 0;
107 	struct scatterlist *sg;
108 	int dmabuf_fd = area_reg->dmabuf_fd;
109 	int i, ret;
110 
111 	if (off)
112 		return -EINVAL;
113 	if (WARN_ON_ONCE(!ifq->dev))
114 		return -EFAULT;
115 	if (!IS_ENABLED(CONFIG_DMA_SHARED_BUFFER))
116 		return -EINVAL;
117 
118 	mem->is_dmabuf = true;
119 	mem->dmabuf = dma_buf_get(dmabuf_fd);
120 	if (IS_ERR(mem->dmabuf)) {
121 		ret = PTR_ERR(mem->dmabuf);
122 		mem->dmabuf = NULL;
123 		goto err;
124 	}
125 
126 	mem->attach = dma_buf_attach(mem->dmabuf, ifq->dev);
127 	if (IS_ERR(mem->attach)) {
128 		ret = PTR_ERR(mem->attach);
129 		mem->attach = NULL;
130 		goto err;
131 	}
132 
133 	mem->sgt = dma_buf_map_attachment_unlocked(mem->attach, DMA_FROM_DEVICE);
134 	if (IS_ERR(mem->sgt)) {
135 		ret = PTR_ERR(mem->sgt);
136 		mem->sgt = NULL;
137 		goto err;
138 	}
139 
140 	for_each_sgtable_dma_sg(mem->sgt, sg, i)
141 		total_size += sg_dma_len(sg);
142 
143 	if (total_size != len) {
144 		ret = -EINVAL;
145 		goto err;
146 	}
147 
148 	mem->dmabuf_offset = off;
149 	mem->size = len;
150 	return 0;
151 err:
152 	io_release_dmabuf(mem);
153 	return ret;
154 }
155 
156 static int io_zcrx_map_area_dmabuf(struct io_zcrx_ifq *ifq, struct io_zcrx_area *area)
157 {
158 	if (!IS_ENABLED(CONFIG_DMA_SHARED_BUFFER))
159 		return -EINVAL;
160 	return io_populate_area_dma(ifq, area, area->mem.sgt,
161 				    area->mem.dmabuf_offset);
162 }
163 
164 static unsigned long io_count_account_pages(struct page **pages, unsigned nr_pages)
165 {
166 	struct folio *last_folio = NULL;
167 	unsigned long res = 0;
168 	int i;
169 
170 	for (i = 0; i < nr_pages; i++) {
171 		struct folio *folio = page_folio(pages[i]);
172 
173 		if (folio == last_folio)
174 			continue;
175 		last_folio = folio;
176 		res += 1UL << folio_order(folio);
177 	}
178 	return res;
179 }
180 
181 static int io_import_umem(struct io_zcrx_ifq *ifq,
182 			  struct io_zcrx_mem *mem,
183 			  struct io_uring_zcrx_area_reg *area_reg)
184 {
185 	struct page **pages;
186 	int nr_pages, ret;
187 
188 	if (area_reg->dmabuf_fd)
189 		return -EINVAL;
190 	if (!area_reg->addr)
191 		return -EFAULT;
192 	pages = io_pin_pages((unsigned long)area_reg->addr, area_reg->len,
193 				   &nr_pages);
194 	if (IS_ERR(pages))
195 		return PTR_ERR(pages);
196 
197 	ret = sg_alloc_table_from_pages(&mem->page_sg_table, pages, nr_pages,
198 					0, nr_pages << PAGE_SHIFT,
199 					GFP_KERNEL_ACCOUNT);
200 	if (ret) {
201 		unpin_user_pages(pages, nr_pages);
202 		return ret;
203 	}
204 
205 	mem->account_pages = io_count_account_pages(pages, nr_pages);
206 	ret = io_account_mem(ifq->ctx, mem->account_pages);
207 	if (ret < 0)
208 		mem->account_pages = 0;
209 
210 	mem->pages = pages;
211 	mem->nr_folios = nr_pages;
212 	mem->size = area_reg->len;
213 	return ret;
214 }
215 
216 static void io_release_area_mem(struct io_zcrx_mem *mem)
217 {
218 	if (mem->is_dmabuf) {
219 		io_release_dmabuf(mem);
220 		return;
221 	}
222 	if (mem->pages) {
223 		unpin_user_pages(mem->pages, mem->nr_folios);
224 		sg_free_table(&mem->page_sg_table);
225 		kvfree(mem->pages);
226 	}
227 }
228 
229 static int io_import_area(struct io_zcrx_ifq *ifq,
230 			  struct io_zcrx_mem *mem,
231 			  struct io_uring_zcrx_area_reg *area_reg)
232 {
233 	int ret;
234 
235 	ret = io_validate_user_buf_range(area_reg->addr, area_reg->len);
236 	if (ret)
237 		return ret;
238 	if (area_reg->addr & ~PAGE_MASK || area_reg->len & ~PAGE_MASK)
239 		return -EINVAL;
240 
241 	if (area_reg->flags & IORING_ZCRX_AREA_DMABUF)
242 		return io_import_dmabuf(ifq, mem, area_reg);
243 	return io_import_umem(ifq, mem, area_reg);
244 }
245 
246 static void io_zcrx_unmap_area(struct io_zcrx_ifq *ifq,
247 				struct io_zcrx_area *area)
248 {
249 	int i;
250 
251 	guard(mutex)(&ifq->dma_lock);
252 	if (!area->is_mapped)
253 		return;
254 	area->is_mapped = false;
255 
256 	for (i = 0; i < area->nia.num_niovs; i++)
257 		net_mp_niov_set_dma_addr(&area->nia.niovs[i], 0);
258 
259 	if (area->mem.is_dmabuf) {
260 		io_release_dmabuf(&area->mem);
261 	} else {
262 		dma_unmap_sgtable(ifq->dev, &area->mem.page_sg_table,
263 				  DMA_FROM_DEVICE, IO_DMA_ATTR);
264 	}
265 }
266 
267 static unsigned io_zcrx_map_area_umem(struct io_zcrx_ifq *ifq, struct io_zcrx_area *area)
268 {
269 	int ret;
270 
271 	ret = dma_map_sgtable(ifq->dev, &area->mem.page_sg_table,
272 				DMA_FROM_DEVICE, IO_DMA_ATTR);
273 	if (ret < 0)
274 		return ret;
275 	return io_populate_area_dma(ifq, area, &area->mem.page_sg_table, 0);
276 }
277 
278 static int io_zcrx_map_area(struct io_zcrx_ifq *ifq, struct io_zcrx_area *area)
279 {
280 	int ret;
281 
282 	guard(mutex)(&ifq->dma_lock);
283 	if (area->is_mapped)
284 		return 0;
285 
286 	if (area->mem.is_dmabuf)
287 		ret = io_zcrx_map_area_dmabuf(ifq, area);
288 	else
289 		ret = io_zcrx_map_area_umem(ifq, area);
290 
291 	if (ret == 0)
292 		area->is_mapped = true;
293 	return ret;
294 }
295 
296 static void io_zcrx_sync_for_device(const struct page_pool *pool,
297 				    struct net_iov *niov)
298 {
299 #if defined(CONFIG_HAS_DMA) && defined(CONFIG_DMA_NEED_SYNC)
300 	dma_addr_t dma_addr;
301 
302 	if (!dma_dev_need_sync(pool->p.dev))
303 		return;
304 
305 	dma_addr = page_pool_get_dma_addr_netmem(net_iov_to_netmem(niov));
306 	__dma_sync_single_for_device(pool->p.dev, dma_addr + pool->p.offset,
307 				     PAGE_SIZE, pool->p.dma_dir);
308 #endif
309 }
310 
311 #define IO_RQ_MAX_ENTRIES		32768
312 
313 #define IO_SKBS_PER_CALL_LIMIT	20
314 
315 struct io_zcrx_args {
316 	struct io_kiocb		*req;
317 	struct io_zcrx_ifq	*ifq;
318 	struct socket		*sock;
319 	unsigned		nr_skbs;
320 };
321 
322 static const struct memory_provider_ops io_uring_pp_zc_ops;
323 
324 static inline atomic_t *io_get_user_counter(struct net_iov *niov)
325 {
326 	struct io_zcrx_area *area = io_zcrx_iov_to_area(niov);
327 
328 	return &area->user_refs[net_iov_idx(niov)];
329 }
330 
331 static bool io_zcrx_put_niov_uref(struct net_iov *niov)
332 {
333 	atomic_t *uref = io_get_user_counter(niov);
334 
335 	if (unlikely(!atomic_read(uref)))
336 		return false;
337 	atomic_dec(uref);
338 	return true;
339 }
340 
341 static void io_zcrx_get_niov_uref(struct net_iov *niov)
342 {
343 	atomic_inc(io_get_user_counter(niov));
344 }
345 
346 static int io_allocate_rbuf_ring(struct io_zcrx_ifq *ifq,
347 				 struct io_uring_zcrx_ifq_reg *reg,
348 				 struct io_uring_region_desc *rd,
349 				 u32 id)
350 {
351 	u64 mmap_offset;
352 	size_t off, size;
353 	void *ptr;
354 	int ret;
355 
356 	off = sizeof(struct io_uring);
357 	size = off + sizeof(struct io_uring_zcrx_rqe) * reg->rq_entries;
358 	if (size > rd->size)
359 		return -EINVAL;
360 
361 	mmap_offset = IORING_MAP_OFF_ZCRX_REGION;
362 	mmap_offset += id << IORING_OFF_PBUF_SHIFT;
363 
364 	ret = io_create_region(ifq->ctx, &ifq->region, rd, mmap_offset);
365 	if (ret < 0)
366 		return ret;
367 
368 	ptr = io_region_get_ptr(&ifq->region);
369 	ifq->rq_ring = (struct io_uring *)ptr;
370 	ifq->rqes = (struct io_uring_zcrx_rqe *)(ptr + off);
371 	return 0;
372 }
373 
374 static void io_free_rbuf_ring(struct io_zcrx_ifq *ifq)
375 {
376 	io_free_region(ifq->ctx, &ifq->region);
377 	ifq->rq_ring = NULL;
378 	ifq->rqes = NULL;
379 }
380 
381 static void io_zcrx_free_area(struct io_zcrx_area *area)
382 {
383 	io_zcrx_unmap_area(area->ifq, area);
384 	io_release_area_mem(&area->mem);
385 
386 	if (area->mem.account_pages)
387 		io_unaccount_mem(area->ifq->ctx, area->mem.account_pages);
388 
389 	kvfree(area->freelist);
390 	kvfree(area->nia.niovs);
391 	kvfree(area->user_refs);
392 	kfree(area);
393 }
394 
395 #define IO_ZCRX_AREA_SUPPORTED_FLAGS	(IORING_ZCRX_AREA_DMABUF)
396 
397 static int io_zcrx_create_area(struct io_zcrx_ifq *ifq,
398 			       struct io_zcrx_area **res,
399 			       struct io_uring_zcrx_area_reg *area_reg)
400 {
401 	struct io_zcrx_area *area;
402 	unsigned nr_iovs;
403 	int i, ret;
404 
405 	if (area_reg->flags & ~IO_ZCRX_AREA_SUPPORTED_FLAGS)
406 		return -EINVAL;
407 	if (area_reg->rq_area_token)
408 		return -EINVAL;
409 	if (area_reg->__resv2[0] || area_reg->__resv2[1])
410 		return -EINVAL;
411 
412 	ret = -ENOMEM;
413 	area = kzalloc(sizeof(*area), GFP_KERNEL);
414 	if (!area)
415 		goto err;
416 	area->ifq = ifq;
417 
418 	ret = io_import_area(ifq, &area->mem, area_reg);
419 	if (ret)
420 		goto err;
421 
422 	nr_iovs = area->mem.size >> PAGE_SHIFT;
423 	area->nia.num_niovs = nr_iovs;
424 
425 	ret = -ENOMEM;
426 	area->nia.niovs = kvmalloc_array(nr_iovs, sizeof(area->nia.niovs[0]),
427 					 GFP_KERNEL | __GFP_ZERO);
428 	if (!area->nia.niovs)
429 		goto err;
430 
431 	area->freelist = kvmalloc_array(nr_iovs, sizeof(area->freelist[0]),
432 					GFP_KERNEL | __GFP_ZERO);
433 	if (!area->freelist)
434 		goto err;
435 
436 	area->user_refs = kvmalloc_array(nr_iovs, sizeof(area->user_refs[0]),
437 					GFP_KERNEL | __GFP_ZERO);
438 	if (!area->user_refs)
439 		goto err;
440 
441 	for (i = 0; i < nr_iovs; i++) {
442 		struct net_iov *niov = &area->nia.niovs[i];
443 
444 		niov->owner = &area->nia;
445 		area->freelist[i] = i;
446 		atomic_set(&area->user_refs[i], 0);
447 		niov->type = NET_IOV_IOURING;
448 	}
449 
450 	area->free_count = nr_iovs;
451 	/* we're only supporting one area per ifq for now */
452 	area->area_id = 0;
453 	area_reg->rq_area_token = (u64)area->area_id << IORING_ZCRX_AREA_SHIFT;
454 	spin_lock_init(&area->freelist_lock);
455 	*res = area;
456 	return 0;
457 err:
458 	if (area)
459 		io_zcrx_free_area(area);
460 	return ret;
461 }
462 
463 static struct io_zcrx_ifq *io_zcrx_ifq_alloc(struct io_ring_ctx *ctx)
464 {
465 	struct io_zcrx_ifq *ifq;
466 
467 	ifq = kzalloc(sizeof(*ifq), GFP_KERNEL);
468 	if (!ifq)
469 		return NULL;
470 
471 	ifq->if_rxq = -1;
472 	ifq->ctx = ctx;
473 	spin_lock_init(&ifq->lock);
474 	spin_lock_init(&ifq->rq_lock);
475 	mutex_init(&ifq->dma_lock);
476 	return ifq;
477 }
478 
479 static void io_zcrx_drop_netdev(struct io_zcrx_ifq *ifq)
480 {
481 	spin_lock(&ifq->lock);
482 	if (ifq->netdev) {
483 		netdev_put(ifq->netdev, &ifq->netdev_tracker);
484 		ifq->netdev = NULL;
485 	}
486 	spin_unlock(&ifq->lock);
487 }
488 
489 static void io_close_queue(struct io_zcrx_ifq *ifq)
490 {
491 	struct net_device *netdev;
492 	netdevice_tracker netdev_tracker;
493 	struct pp_memory_provider_params p = {
494 		.mp_ops = &io_uring_pp_zc_ops,
495 		.mp_priv = ifq,
496 	};
497 
498 	if (ifq->if_rxq == -1)
499 		return;
500 
501 	spin_lock(&ifq->lock);
502 	netdev = ifq->netdev;
503 	netdev_tracker = ifq->netdev_tracker;
504 	ifq->netdev = NULL;
505 	spin_unlock(&ifq->lock);
506 
507 	if (netdev) {
508 		net_mp_close_rxq(netdev, ifq->if_rxq, &p);
509 		netdev_put(netdev, &netdev_tracker);
510 	}
511 	ifq->if_rxq = -1;
512 }
513 
514 static void io_zcrx_ifq_free(struct io_zcrx_ifq *ifq)
515 {
516 	io_close_queue(ifq);
517 	io_zcrx_drop_netdev(ifq);
518 
519 	if (ifq->area)
520 		io_zcrx_free_area(ifq->area);
521 	if (ifq->dev)
522 		put_device(ifq->dev);
523 
524 	io_free_rbuf_ring(ifq);
525 	mutex_destroy(&ifq->dma_lock);
526 	kfree(ifq);
527 }
528 
529 struct io_mapped_region *io_zcrx_get_region(struct io_ring_ctx *ctx,
530 					    unsigned int id)
531 {
532 	struct io_zcrx_ifq *ifq = xa_load(&ctx->zcrx_ctxs, id);
533 
534 	lockdep_assert_held(&ctx->mmap_lock);
535 
536 	return ifq ? &ifq->region : NULL;
537 }
538 
539 int io_register_zcrx_ifq(struct io_ring_ctx *ctx,
540 			  struct io_uring_zcrx_ifq_reg __user *arg)
541 {
542 	struct pp_memory_provider_params mp_param = {};
543 	struct io_uring_zcrx_area_reg area;
544 	struct io_uring_zcrx_ifq_reg reg;
545 	struct io_uring_region_desc rd;
546 	struct io_zcrx_ifq *ifq;
547 	int ret;
548 	u32 id;
549 
550 	/*
551 	 * 1. Interface queue allocation.
552 	 * 2. It can observe data destined for sockets of other tasks.
553 	 */
554 	if (!capable(CAP_NET_ADMIN))
555 		return -EPERM;
556 
557 	/* mandatory io_uring features for zc rx */
558 	if (!(ctx->flags & IORING_SETUP_DEFER_TASKRUN &&
559 	      ctx->flags & IORING_SETUP_CQE32))
560 		return -EINVAL;
561 	if (copy_from_user(&reg, arg, sizeof(reg)))
562 		return -EFAULT;
563 	if (copy_from_user(&rd, u64_to_user_ptr(reg.region_ptr), sizeof(rd)))
564 		return -EFAULT;
565 	if (memchr_inv(&reg.__resv, 0, sizeof(reg.__resv)) ||
566 	    reg.__resv2 || reg.zcrx_id)
567 		return -EINVAL;
568 	if (reg.if_rxq == -1 || !reg.rq_entries || reg.flags)
569 		return -EINVAL;
570 	if (reg.rq_entries > IO_RQ_MAX_ENTRIES) {
571 		if (!(ctx->flags & IORING_SETUP_CLAMP))
572 			return -EINVAL;
573 		reg.rq_entries = IO_RQ_MAX_ENTRIES;
574 	}
575 	reg.rq_entries = roundup_pow_of_two(reg.rq_entries);
576 
577 	if (copy_from_user(&area, u64_to_user_ptr(reg.area_ptr), sizeof(area)))
578 		return -EFAULT;
579 
580 	ifq = io_zcrx_ifq_alloc(ctx);
581 	if (!ifq)
582 		return -ENOMEM;
583 	ifq->rq_entries = reg.rq_entries;
584 
585 	scoped_guard(mutex, &ctx->mmap_lock) {
586 		/* preallocate id */
587 		ret = xa_alloc(&ctx->zcrx_ctxs, &id, NULL, xa_limit_31b, GFP_KERNEL);
588 		if (ret)
589 			goto ifq_free;
590 	}
591 
592 	ret = io_allocate_rbuf_ring(ifq, &reg, &rd, id);
593 	if (ret)
594 		goto err;
595 
596 	ifq->netdev = netdev_get_by_index(current->nsproxy->net_ns, reg.if_idx,
597 					  &ifq->netdev_tracker, GFP_KERNEL);
598 	if (!ifq->netdev) {
599 		ret = -ENODEV;
600 		goto err;
601 	}
602 
603 	ifq->dev = netdev_queue_get_dma_dev(ifq->netdev, ifq->if_rxq);
604 	if (!ifq->dev) {
605 		ret = -EOPNOTSUPP;
606 		goto err;
607 	}
608 	get_device(ifq->dev);
609 
610 	ret = io_zcrx_create_area(ifq, &ifq->area, &area);
611 	if (ret)
612 		goto err;
613 
614 	mp_param.mp_ops = &io_uring_pp_zc_ops;
615 	mp_param.mp_priv = ifq;
616 	ret = net_mp_open_rxq(ifq->netdev, reg.if_rxq, &mp_param);
617 	if (ret)
618 		goto err;
619 	ifq->if_rxq = reg.if_rxq;
620 
621 	reg.offsets.rqes = sizeof(struct io_uring);
622 	reg.offsets.head = offsetof(struct io_uring, head);
623 	reg.offsets.tail = offsetof(struct io_uring, tail);
624 	reg.zcrx_id = id;
625 
626 	scoped_guard(mutex, &ctx->mmap_lock) {
627 		/* publish ifq */
628 		ret = -ENOMEM;
629 		if (xa_store(&ctx->zcrx_ctxs, id, ifq, GFP_KERNEL))
630 			goto err;
631 	}
632 
633 	if (copy_to_user(arg, &reg, sizeof(reg)) ||
634 	    copy_to_user(u64_to_user_ptr(reg.region_ptr), &rd, sizeof(rd)) ||
635 	    copy_to_user(u64_to_user_ptr(reg.area_ptr), &area, sizeof(area))) {
636 		ret = -EFAULT;
637 		goto err;
638 	}
639 	return 0;
640 err:
641 	scoped_guard(mutex, &ctx->mmap_lock)
642 		xa_erase(&ctx->zcrx_ctxs, id);
643 ifq_free:
644 	io_zcrx_ifq_free(ifq);
645 	return ret;
646 }
647 
648 void io_unregister_zcrx_ifqs(struct io_ring_ctx *ctx)
649 {
650 	struct io_zcrx_ifq *ifq;
651 
652 	lockdep_assert_held(&ctx->uring_lock);
653 
654 	while (1) {
655 		scoped_guard(mutex, &ctx->mmap_lock) {
656 			unsigned long id = 0;
657 
658 			ifq = xa_find(&ctx->zcrx_ctxs, &id, ULONG_MAX, XA_PRESENT);
659 			if (ifq)
660 				xa_erase(&ctx->zcrx_ctxs, id);
661 		}
662 		if (!ifq)
663 			break;
664 		io_zcrx_ifq_free(ifq);
665 	}
666 
667 	xa_destroy(&ctx->zcrx_ctxs);
668 }
669 
670 static struct net_iov *__io_zcrx_get_free_niov(struct io_zcrx_area *area)
671 {
672 	unsigned niov_idx;
673 
674 	lockdep_assert_held(&area->freelist_lock);
675 
676 	niov_idx = area->freelist[--area->free_count];
677 	return &area->nia.niovs[niov_idx];
678 }
679 
680 static void io_zcrx_return_niov_freelist(struct net_iov *niov)
681 {
682 	struct io_zcrx_area *area = io_zcrx_iov_to_area(niov);
683 
684 	spin_lock_bh(&area->freelist_lock);
685 	area->freelist[area->free_count++] = net_iov_idx(niov);
686 	spin_unlock_bh(&area->freelist_lock);
687 }
688 
689 static void io_zcrx_return_niov(struct net_iov *niov)
690 {
691 	netmem_ref netmem = net_iov_to_netmem(niov);
692 
693 	if (!niov->pp) {
694 		/* copy fallback allocated niovs */
695 		io_zcrx_return_niov_freelist(niov);
696 		return;
697 	}
698 	page_pool_put_unrefed_netmem(niov->pp, netmem, -1, false);
699 }
700 
701 static void io_zcrx_scrub(struct io_zcrx_ifq *ifq)
702 {
703 	struct io_zcrx_area *area = ifq->area;
704 	int i;
705 
706 	if (!area)
707 		return;
708 
709 	/* Reclaim back all buffers given to the user space. */
710 	for (i = 0; i < area->nia.num_niovs; i++) {
711 		struct net_iov *niov = &area->nia.niovs[i];
712 		int nr;
713 
714 		if (!atomic_read(io_get_user_counter(niov)))
715 			continue;
716 		nr = atomic_xchg(io_get_user_counter(niov), 0);
717 		if (nr && !page_pool_unref_netmem(net_iov_to_netmem(niov), nr))
718 			io_zcrx_return_niov(niov);
719 	}
720 }
721 
722 void io_shutdown_zcrx_ifqs(struct io_ring_ctx *ctx)
723 {
724 	struct io_zcrx_ifq *ifq;
725 	unsigned long index;
726 
727 	lockdep_assert_held(&ctx->uring_lock);
728 
729 	xa_for_each(&ctx->zcrx_ctxs, index, ifq) {
730 		io_zcrx_scrub(ifq);
731 		io_close_queue(ifq);
732 	}
733 }
734 
735 static inline u32 io_zcrx_rqring_entries(struct io_zcrx_ifq *ifq)
736 {
737 	u32 entries;
738 
739 	entries = smp_load_acquire(&ifq->rq_ring->tail) - ifq->cached_rq_head;
740 	return min(entries, ifq->rq_entries);
741 }
742 
743 static struct io_uring_zcrx_rqe *io_zcrx_get_rqe(struct io_zcrx_ifq *ifq,
744 						 unsigned mask)
745 {
746 	unsigned int idx = ifq->cached_rq_head++ & mask;
747 
748 	return &ifq->rqes[idx];
749 }
750 
751 static void io_zcrx_ring_refill(struct page_pool *pp,
752 				struct io_zcrx_ifq *ifq)
753 {
754 	unsigned int mask = ifq->rq_entries - 1;
755 	unsigned int entries;
756 	netmem_ref netmem;
757 
758 	spin_lock_bh(&ifq->rq_lock);
759 
760 	entries = io_zcrx_rqring_entries(ifq);
761 	entries = min_t(unsigned, entries, PP_ALLOC_CACHE_REFILL - pp->alloc.count);
762 	if (unlikely(!entries)) {
763 		spin_unlock_bh(&ifq->rq_lock);
764 		return;
765 	}
766 
767 	do {
768 		struct io_uring_zcrx_rqe *rqe = io_zcrx_get_rqe(ifq, mask);
769 		struct io_zcrx_area *area;
770 		struct net_iov *niov;
771 		unsigned niov_idx, area_idx;
772 
773 		area_idx = rqe->off >> IORING_ZCRX_AREA_SHIFT;
774 		niov_idx = (rqe->off & ~IORING_ZCRX_AREA_MASK) >> PAGE_SHIFT;
775 
776 		if (unlikely(rqe->__pad || area_idx))
777 			continue;
778 		area = ifq->area;
779 
780 		if (unlikely(niov_idx >= area->nia.num_niovs))
781 			continue;
782 		niov_idx = array_index_nospec(niov_idx, area->nia.num_niovs);
783 
784 		niov = &area->nia.niovs[niov_idx];
785 		if (!io_zcrx_put_niov_uref(niov))
786 			continue;
787 
788 		netmem = net_iov_to_netmem(niov);
789 		if (page_pool_unref_netmem(netmem, 1) != 0)
790 			continue;
791 
792 		if (unlikely(niov->pp != pp)) {
793 			io_zcrx_return_niov(niov);
794 			continue;
795 		}
796 
797 		io_zcrx_sync_for_device(pp, niov);
798 		net_mp_netmem_place_in_cache(pp, netmem);
799 	} while (--entries);
800 
801 	smp_store_release(&ifq->rq_ring->head, ifq->cached_rq_head);
802 	spin_unlock_bh(&ifq->rq_lock);
803 }
804 
805 static void io_zcrx_refill_slow(struct page_pool *pp, struct io_zcrx_ifq *ifq)
806 {
807 	struct io_zcrx_area *area = ifq->area;
808 
809 	spin_lock_bh(&area->freelist_lock);
810 	while (area->free_count && pp->alloc.count < PP_ALLOC_CACHE_REFILL) {
811 		struct net_iov *niov = __io_zcrx_get_free_niov(area);
812 		netmem_ref netmem = net_iov_to_netmem(niov);
813 
814 		net_mp_niov_set_page_pool(pp, niov);
815 		io_zcrx_sync_for_device(pp, niov);
816 		net_mp_netmem_place_in_cache(pp, netmem);
817 	}
818 	spin_unlock_bh(&area->freelist_lock);
819 }
820 
821 static netmem_ref io_pp_zc_alloc_netmems(struct page_pool *pp, gfp_t gfp)
822 {
823 	struct io_zcrx_ifq *ifq = io_pp_to_ifq(pp);
824 
825 	/* pp should already be ensuring that */
826 	if (unlikely(pp->alloc.count))
827 		goto out_return;
828 
829 	io_zcrx_ring_refill(pp, ifq);
830 	if (likely(pp->alloc.count))
831 		goto out_return;
832 
833 	io_zcrx_refill_slow(pp, ifq);
834 	if (!pp->alloc.count)
835 		return 0;
836 out_return:
837 	return pp->alloc.cache[--pp->alloc.count];
838 }
839 
840 static bool io_pp_zc_release_netmem(struct page_pool *pp, netmem_ref netmem)
841 {
842 	struct net_iov *niov;
843 
844 	if (WARN_ON_ONCE(!netmem_is_net_iov(netmem)))
845 		return false;
846 
847 	niov = netmem_to_net_iov(netmem);
848 	net_mp_niov_clear_page_pool(niov);
849 	io_zcrx_return_niov_freelist(niov);
850 	return false;
851 }
852 
853 static int io_pp_zc_init(struct page_pool *pp)
854 {
855 	struct io_zcrx_ifq *ifq = io_pp_to_ifq(pp);
856 	int ret;
857 
858 	if (WARN_ON_ONCE(!ifq))
859 		return -EINVAL;
860 	if (WARN_ON_ONCE(ifq->dev != pp->p.dev))
861 		return -EINVAL;
862 	if (WARN_ON_ONCE(!pp->dma_map))
863 		return -EOPNOTSUPP;
864 	if (pp->p.order != 0)
865 		return -EOPNOTSUPP;
866 	if (pp->p.dma_dir != DMA_FROM_DEVICE)
867 		return -EOPNOTSUPP;
868 
869 	ret = io_zcrx_map_area(ifq, ifq->area);
870 	if (ret)
871 		return ret;
872 
873 	percpu_ref_get(&ifq->ctx->refs);
874 	return 0;
875 }
876 
877 static void io_pp_zc_destroy(struct page_pool *pp)
878 {
879 	struct io_zcrx_ifq *ifq = io_pp_to_ifq(pp);
880 
881 	percpu_ref_put(&ifq->ctx->refs);
882 }
883 
884 static int io_pp_nl_fill(void *mp_priv, struct sk_buff *rsp,
885 			 struct netdev_rx_queue *rxq)
886 {
887 	struct nlattr *nest;
888 	int type;
889 
890 	type = rxq ? NETDEV_A_QUEUE_IO_URING : NETDEV_A_PAGE_POOL_IO_URING;
891 	nest = nla_nest_start(rsp, type);
892 	if (!nest)
893 		return -EMSGSIZE;
894 	nla_nest_end(rsp, nest);
895 
896 	return 0;
897 }
898 
899 static void io_pp_uninstall(void *mp_priv, struct netdev_rx_queue *rxq)
900 {
901 	struct pp_memory_provider_params *p = &rxq->mp_params;
902 	struct io_zcrx_ifq *ifq = mp_priv;
903 
904 	io_zcrx_drop_netdev(ifq);
905 	if (ifq->area)
906 		io_zcrx_unmap_area(ifq, ifq->area);
907 
908 	p->mp_ops = NULL;
909 	p->mp_priv = NULL;
910 }
911 
912 static const struct memory_provider_ops io_uring_pp_zc_ops = {
913 	.alloc_netmems		= io_pp_zc_alloc_netmems,
914 	.release_netmem		= io_pp_zc_release_netmem,
915 	.init			= io_pp_zc_init,
916 	.destroy		= io_pp_zc_destroy,
917 	.nl_fill		= io_pp_nl_fill,
918 	.uninstall		= io_pp_uninstall,
919 };
920 
921 static bool io_zcrx_queue_cqe(struct io_kiocb *req, struct net_iov *niov,
922 			      struct io_zcrx_ifq *ifq, int off, int len)
923 {
924 	struct io_uring_zcrx_cqe *rcqe;
925 	struct io_zcrx_area *area;
926 	struct io_uring_cqe *cqe;
927 	u64 offset;
928 
929 	if (!io_defer_get_uncommited_cqe(req->ctx, &cqe))
930 		return false;
931 
932 	cqe->user_data = req->cqe.user_data;
933 	cqe->res = len;
934 	cqe->flags = IORING_CQE_F_MORE;
935 
936 	area = io_zcrx_iov_to_area(niov);
937 	offset = off + (net_iov_idx(niov) << PAGE_SHIFT);
938 	rcqe = (struct io_uring_zcrx_cqe *)(cqe + 1);
939 	rcqe->off = offset + ((u64)area->area_id << IORING_ZCRX_AREA_SHIFT);
940 	rcqe->__pad = 0;
941 	return true;
942 }
943 
944 static struct net_iov *io_zcrx_alloc_fallback(struct io_zcrx_area *area)
945 {
946 	struct net_iov *niov = NULL;
947 
948 	spin_lock_bh(&area->freelist_lock);
949 	if (area->free_count)
950 		niov = __io_zcrx_get_free_niov(area);
951 	spin_unlock_bh(&area->freelist_lock);
952 
953 	if (niov)
954 		page_pool_fragment_netmem(net_iov_to_netmem(niov), 1);
955 	return niov;
956 }
957 
958 struct io_copy_cache {
959 	struct page		*page;
960 	unsigned long		offset;
961 	size_t			size;
962 };
963 
964 static ssize_t io_copy_page(struct io_copy_cache *cc, struct page *src_page,
965 			    unsigned int src_offset, size_t len)
966 {
967 	size_t copied = 0;
968 
969 	len = min(len, cc->size);
970 
971 	while (len) {
972 		void *src_addr, *dst_addr;
973 		struct page *dst_page = cc->page;
974 		unsigned dst_offset = cc->offset;
975 		size_t n = len;
976 
977 		if (folio_test_partial_kmap(page_folio(dst_page)) ||
978 		    folio_test_partial_kmap(page_folio(src_page))) {
979 			dst_page = nth_page(dst_page, dst_offset / PAGE_SIZE);
980 			dst_offset = offset_in_page(dst_offset);
981 			src_page = nth_page(src_page, src_offset / PAGE_SIZE);
982 			src_offset = offset_in_page(src_offset);
983 			n = min(PAGE_SIZE - src_offset, PAGE_SIZE - dst_offset);
984 			n = min(n, len);
985 		}
986 
987 		dst_addr = kmap_local_page(dst_page) + dst_offset;
988 		src_addr = kmap_local_page(src_page) + src_offset;
989 
990 		memcpy(dst_addr, src_addr, n);
991 
992 		kunmap_local(src_addr);
993 		kunmap_local(dst_addr);
994 
995 		cc->size -= n;
996 		cc->offset += n;
997 		len -= n;
998 		copied += n;
999 	}
1000 	return copied;
1001 }
1002 
1003 static ssize_t io_zcrx_copy_chunk(struct io_kiocb *req, struct io_zcrx_ifq *ifq,
1004 				  struct page *src_page, unsigned int src_offset,
1005 				  size_t len)
1006 {
1007 	struct io_zcrx_area *area = ifq->area;
1008 	size_t copied = 0;
1009 	int ret = 0;
1010 
1011 	if (area->mem.is_dmabuf)
1012 		return -EFAULT;
1013 
1014 	while (len) {
1015 		struct io_copy_cache cc;
1016 		struct net_iov *niov;
1017 		size_t n;
1018 
1019 		niov = io_zcrx_alloc_fallback(area);
1020 		if (!niov) {
1021 			ret = -ENOMEM;
1022 			break;
1023 		}
1024 
1025 		cc.page = io_zcrx_iov_page(niov);
1026 		cc.offset = 0;
1027 		cc.size = PAGE_SIZE;
1028 
1029 		n = io_copy_page(&cc, src_page, src_offset, len);
1030 
1031 		if (!io_zcrx_queue_cqe(req, niov, ifq, 0, n)) {
1032 			io_zcrx_return_niov(niov);
1033 			ret = -ENOSPC;
1034 			break;
1035 		}
1036 
1037 		io_zcrx_get_niov_uref(niov);
1038 		src_offset += n;
1039 		len -= n;
1040 		copied += n;
1041 	}
1042 
1043 	return copied ? copied : ret;
1044 }
1045 
1046 static int io_zcrx_copy_frag(struct io_kiocb *req, struct io_zcrx_ifq *ifq,
1047 			     const skb_frag_t *frag, int off, int len)
1048 {
1049 	struct page *page = skb_frag_page(frag);
1050 
1051 	return io_zcrx_copy_chunk(req, ifq, page, off + skb_frag_off(frag), len);
1052 }
1053 
1054 static int io_zcrx_recv_frag(struct io_kiocb *req, struct io_zcrx_ifq *ifq,
1055 			     const skb_frag_t *frag, int off, int len)
1056 {
1057 	struct net_iov *niov;
1058 
1059 	if (unlikely(!skb_frag_is_net_iov(frag)))
1060 		return io_zcrx_copy_frag(req, ifq, frag, off, len);
1061 
1062 	niov = netmem_to_net_iov(frag->netmem);
1063 	if (!niov->pp || niov->pp->mp_ops != &io_uring_pp_zc_ops ||
1064 	    io_pp_to_ifq(niov->pp) != ifq)
1065 		return -EFAULT;
1066 
1067 	if (!io_zcrx_queue_cqe(req, niov, ifq, off + skb_frag_off(frag), len))
1068 		return -ENOSPC;
1069 
1070 	/*
1071 	 * Prevent it from being recycled while user is accessing it.
1072 	 * It has to be done before grabbing a user reference.
1073 	 */
1074 	page_pool_ref_netmem(net_iov_to_netmem(niov));
1075 	io_zcrx_get_niov_uref(niov);
1076 	return len;
1077 }
1078 
1079 static int
1080 io_zcrx_recv_skb(read_descriptor_t *desc, struct sk_buff *skb,
1081 		 unsigned int offset, size_t len)
1082 {
1083 	struct io_zcrx_args *args = desc->arg.data;
1084 	struct io_zcrx_ifq *ifq = args->ifq;
1085 	struct io_kiocb *req = args->req;
1086 	struct sk_buff *frag_iter;
1087 	unsigned start, start_off = offset;
1088 	int i, copy, end, off;
1089 	int ret = 0;
1090 
1091 	len = min_t(size_t, len, desc->count);
1092 	/*
1093 	 * __tcp_read_sock() always calls io_zcrx_recv_skb one last time, even
1094 	 * if desc->count is already 0. This is caused by the if (offset + 1 !=
1095 	 * skb->len) check. Return early in this case to break out of
1096 	 * __tcp_read_sock().
1097 	 */
1098 	if (!len)
1099 		return 0;
1100 	if (unlikely(args->nr_skbs++ > IO_SKBS_PER_CALL_LIMIT))
1101 		return -EAGAIN;
1102 
1103 	if (unlikely(offset < skb_headlen(skb))) {
1104 		ssize_t copied;
1105 		size_t to_copy;
1106 
1107 		to_copy = min_t(size_t, skb_headlen(skb) - offset, len);
1108 		copied = io_zcrx_copy_chunk(req, ifq, virt_to_page(skb->data),
1109 					    offset_in_page(skb->data) + offset,
1110 					    to_copy);
1111 		if (copied < 0) {
1112 			ret = copied;
1113 			goto out;
1114 		}
1115 		offset += copied;
1116 		len -= copied;
1117 		if (!len)
1118 			goto out;
1119 		if (offset != skb_headlen(skb))
1120 			goto out;
1121 	}
1122 
1123 	start = skb_headlen(skb);
1124 
1125 	for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) {
1126 		const skb_frag_t *frag;
1127 
1128 		if (WARN_ON(start > offset + len))
1129 			return -EFAULT;
1130 
1131 		frag = &skb_shinfo(skb)->frags[i];
1132 		end = start + skb_frag_size(frag);
1133 
1134 		if (offset < end) {
1135 			copy = end - offset;
1136 			if (copy > len)
1137 				copy = len;
1138 
1139 			off = offset - start;
1140 			ret = io_zcrx_recv_frag(req, ifq, frag, off, copy);
1141 			if (ret < 0)
1142 				goto out;
1143 
1144 			offset += ret;
1145 			len -= ret;
1146 			if (len == 0 || ret != copy)
1147 				goto out;
1148 		}
1149 		start = end;
1150 	}
1151 
1152 	skb_walk_frags(skb, frag_iter) {
1153 		if (WARN_ON(start > offset + len))
1154 			return -EFAULT;
1155 
1156 		end = start + frag_iter->len;
1157 		if (offset < end) {
1158 			copy = end - offset;
1159 			if (copy > len)
1160 				copy = len;
1161 
1162 			off = offset - start;
1163 			ret = io_zcrx_recv_skb(desc, frag_iter, off, copy);
1164 			if (ret < 0)
1165 				goto out;
1166 
1167 			offset += ret;
1168 			len -= ret;
1169 			if (len == 0 || ret != copy)
1170 				goto out;
1171 		}
1172 		start = end;
1173 	}
1174 
1175 out:
1176 	if (offset == start_off)
1177 		return ret;
1178 	desc->count -= (offset - start_off);
1179 	return offset - start_off;
1180 }
1181 
1182 static int io_zcrx_tcp_recvmsg(struct io_kiocb *req, struct io_zcrx_ifq *ifq,
1183 				struct sock *sk, int flags,
1184 				unsigned issue_flags, unsigned int *outlen)
1185 {
1186 	unsigned int len = *outlen;
1187 	struct io_zcrx_args args = {
1188 		.req = req,
1189 		.ifq = ifq,
1190 		.sock = sk->sk_socket,
1191 	};
1192 	read_descriptor_t rd_desc = {
1193 		.count = len ? len : UINT_MAX,
1194 		.arg.data = &args,
1195 	};
1196 	int ret;
1197 
1198 	lock_sock(sk);
1199 	ret = tcp_read_sock(sk, &rd_desc, io_zcrx_recv_skb);
1200 	if (len && ret > 0)
1201 		*outlen = len - ret;
1202 	if (ret <= 0) {
1203 		if (ret < 0 || sock_flag(sk, SOCK_DONE))
1204 			goto out;
1205 		if (sk->sk_err)
1206 			ret = sock_error(sk);
1207 		else if (sk->sk_shutdown & RCV_SHUTDOWN)
1208 			goto out;
1209 		else if (sk->sk_state == TCP_CLOSE)
1210 			ret = -ENOTCONN;
1211 		else
1212 			ret = -EAGAIN;
1213 	} else if (unlikely(args.nr_skbs > IO_SKBS_PER_CALL_LIMIT) &&
1214 		   (issue_flags & IO_URING_F_MULTISHOT)) {
1215 		ret = IOU_REQUEUE;
1216 	} else if (sock_flag(sk, SOCK_DONE)) {
1217 		/* Make it to retry until it finally gets 0. */
1218 		if (issue_flags & IO_URING_F_MULTISHOT)
1219 			ret = IOU_REQUEUE;
1220 		else
1221 			ret = -EAGAIN;
1222 	}
1223 out:
1224 	release_sock(sk);
1225 	return ret;
1226 }
1227 
1228 int io_zcrx_recv(struct io_kiocb *req, struct io_zcrx_ifq *ifq,
1229 		 struct socket *sock, unsigned int flags,
1230 		 unsigned issue_flags, unsigned int *len)
1231 {
1232 	struct sock *sk = sock->sk;
1233 	const struct proto *prot = READ_ONCE(sk->sk_prot);
1234 
1235 	if (prot->recvmsg != tcp_recvmsg)
1236 		return -EPROTONOSUPPORT;
1237 
1238 	sock_rps_record_flow(sk);
1239 	return io_zcrx_tcp_recvmsg(req, ifq, sk, flags, issue_flags, len);
1240 }
1241