xref: /linux/io_uring/tctx.c (revision 364eeb79a213fcf9164208b53764223ad522d6b3)
1 // SPDX-License-Identifier: GPL-2.0
2 #include <linux/kernel.h>
3 #include <linux/errno.h>
4 #include <linux/file.h>
5 #include <linux/mm.h>
6 #include <linux/slab.h>
7 #include <linux/nospec.h>
8 #include <linux/io_uring.h>
9 
10 #include <uapi/linux/io_uring.h>
11 
12 #include "io_uring.h"
13 #include "tctx.h"
14 
15 static struct io_wq *io_init_wq_offload(struct io_ring_ctx *ctx,
16 					struct task_struct *task)
17 {
18 	struct io_wq_hash *hash;
19 	struct io_wq_data data;
20 	unsigned int concurrency;
21 
22 	mutex_lock(&ctx->uring_lock);
23 	hash = ctx->hash_map;
24 	if (!hash) {
25 		hash = kzalloc(sizeof(*hash), GFP_KERNEL);
26 		if (!hash) {
27 			mutex_unlock(&ctx->uring_lock);
28 			return ERR_PTR(-ENOMEM);
29 		}
30 		refcount_set(&hash->refs, 1);
31 		init_waitqueue_head(&hash->wait);
32 		ctx->hash_map = hash;
33 	}
34 	mutex_unlock(&ctx->uring_lock);
35 
36 	data.hash = hash;
37 	data.task = task;
38 	data.free_work = io_wq_free_work;
39 	data.do_work = io_wq_submit_work;
40 
41 	/* Do QD, or 4 * CPUS, whatever is smallest */
42 	concurrency = min(ctx->sq_entries, 4 * num_online_cpus());
43 
44 	return io_wq_create(concurrency, &data);
45 }
46 
47 void __io_uring_free(struct task_struct *tsk)
48 {
49 	struct io_uring_task *tctx = tsk->io_uring;
50 
51 	WARN_ON_ONCE(!xa_empty(&tctx->xa));
52 	WARN_ON_ONCE(tctx->io_wq);
53 	WARN_ON_ONCE(tctx->cached_refs);
54 
55 	percpu_counter_destroy(&tctx->inflight);
56 	kfree(tctx);
57 	tsk->io_uring = NULL;
58 }
59 
60 __cold int io_uring_alloc_task_context(struct task_struct *task,
61 				       struct io_ring_ctx *ctx)
62 {
63 	struct io_uring_task *tctx;
64 	int ret;
65 
66 	tctx = kzalloc(sizeof(*tctx), GFP_KERNEL);
67 	if (unlikely(!tctx))
68 		return -ENOMEM;
69 
70 	ret = percpu_counter_init(&tctx->inflight, 0, GFP_KERNEL);
71 	if (unlikely(ret)) {
72 		kfree(tctx);
73 		return ret;
74 	}
75 
76 	tctx->io_wq = io_init_wq_offload(ctx, task);
77 	if (IS_ERR(tctx->io_wq)) {
78 		ret = PTR_ERR(tctx->io_wq);
79 		percpu_counter_destroy(&tctx->inflight);
80 		kfree(tctx);
81 		return ret;
82 	}
83 
84 	tctx->task = task;
85 	xa_init(&tctx->xa);
86 	init_waitqueue_head(&tctx->wait);
87 	atomic_set(&tctx->in_cancel, 0);
88 	atomic_set(&tctx->inflight_tracked, 0);
89 	task->io_uring = tctx;
90 	init_llist_head(&tctx->task_list);
91 	init_task_work(&tctx->task_work, tctx_task_work);
92 	return 0;
93 }
94 
95 int __io_uring_add_tctx_node(struct io_ring_ctx *ctx)
96 {
97 	struct io_uring_task *tctx = current->io_uring;
98 	struct io_tctx_node *node;
99 	int ret;
100 
101 	if (unlikely(!tctx)) {
102 		ret = io_uring_alloc_task_context(current, ctx);
103 		if (unlikely(ret))
104 			return ret;
105 
106 		tctx = current->io_uring;
107 		if (ctx->iowq_limits_set) {
108 			unsigned int limits[2] = { ctx->iowq_limits[0],
109 						   ctx->iowq_limits[1], };
110 
111 			ret = io_wq_max_workers(tctx->io_wq, limits);
112 			if (ret)
113 				return ret;
114 		}
115 	}
116 	if (!xa_load(&tctx->xa, (unsigned long)ctx)) {
117 		node = kmalloc(sizeof(*node), GFP_KERNEL);
118 		if (!node)
119 			return -ENOMEM;
120 		node->ctx = ctx;
121 		node->task = current;
122 
123 		ret = xa_err(xa_store(&tctx->xa, (unsigned long)ctx,
124 					node, GFP_KERNEL));
125 		if (ret) {
126 			kfree(node);
127 			return ret;
128 		}
129 
130 		mutex_lock(&ctx->uring_lock);
131 		list_add(&node->ctx_node, &ctx->tctx_list);
132 		mutex_unlock(&ctx->uring_lock);
133 	}
134 	return 0;
135 }
136 
137 int __io_uring_add_tctx_node_from_submit(struct io_ring_ctx *ctx)
138 {
139 	int ret;
140 
141 	if (ctx->flags & IORING_SETUP_SINGLE_ISSUER
142 	    && ctx->submitter_task != current)
143 		return -EEXIST;
144 
145 	ret = __io_uring_add_tctx_node(ctx);
146 	if (ret)
147 		return ret;
148 
149 	current->io_uring->last = ctx;
150 	return 0;
151 }
152 
153 /*
154  * Remove this io_uring_file -> task mapping.
155  */
156 __cold void io_uring_del_tctx_node(unsigned long index)
157 {
158 	struct io_uring_task *tctx = current->io_uring;
159 	struct io_tctx_node *node;
160 
161 	if (!tctx)
162 		return;
163 	node = xa_erase(&tctx->xa, index);
164 	if (!node)
165 		return;
166 
167 	WARN_ON_ONCE(current != node->task);
168 	WARN_ON_ONCE(list_empty(&node->ctx_node));
169 
170 	mutex_lock(&node->ctx->uring_lock);
171 	list_del(&node->ctx_node);
172 	mutex_unlock(&node->ctx->uring_lock);
173 
174 	if (tctx->last == node->ctx)
175 		tctx->last = NULL;
176 	kfree(node);
177 }
178 
179 __cold void io_uring_clean_tctx(struct io_uring_task *tctx)
180 {
181 	struct io_wq *wq = tctx->io_wq;
182 	struct io_tctx_node *node;
183 	unsigned long index;
184 
185 	xa_for_each(&tctx->xa, index, node) {
186 		io_uring_del_tctx_node(index);
187 		cond_resched();
188 	}
189 	if (wq) {
190 		/*
191 		 * Must be after io_uring_del_tctx_node() (removes nodes under
192 		 * uring_lock) to avoid race with io_uring_try_cancel_iowq().
193 		 */
194 		io_wq_put_and_exit(wq);
195 		tctx->io_wq = NULL;
196 	}
197 }
198 
199 void io_uring_unreg_ringfd(void)
200 {
201 	struct io_uring_task *tctx = current->io_uring;
202 	int i;
203 
204 	for (i = 0; i < IO_RINGFD_REG_MAX; i++) {
205 		if (tctx->registered_rings[i]) {
206 			fput(tctx->registered_rings[i]);
207 			tctx->registered_rings[i] = NULL;
208 		}
209 	}
210 }
211 
212 int io_ring_add_registered_file(struct io_uring_task *tctx, struct file *file,
213 				     int start, int end)
214 {
215 	int offset;
216 	for (offset = start; offset < end; offset++) {
217 		offset = array_index_nospec(offset, IO_RINGFD_REG_MAX);
218 		if (tctx->registered_rings[offset])
219 			continue;
220 
221 		tctx->registered_rings[offset] = file;
222 		return offset;
223 	}
224 	return -EBUSY;
225 }
226 
227 static int io_ring_add_registered_fd(struct io_uring_task *tctx, int fd,
228 				     int start, int end)
229 {
230 	struct file *file;
231 	int offset;
232 
233 	file = fget(fd);
234 	if (!file) {
235 		return -EBADF;
236 	} else if (!io_is_uring_fops(file)) {
237 		fput(file);
238 		return -EOPNOTSUPP;
239 	}
240 	offset = io_ring_add_registered_file(tctx, file, start, end);
241 	if (offset < 0)
242 		fput(file);
243 	return offset;
244 }
245 
246 /*
247  * Register a ring fd to avoid fdget/fdput for each io_uring_enter()
248  * invocation. User passes in an array of struct io_uring_rsrc_update
249  * with ->data set to the ring_fd, and ->offset given for the desired
250  * index. If no index is desired, application may set ->offset == -1U
251  * and we'll find an available index. Returns number of entries
252  * successfully processed, or < 0 on error if none were processed.
253  */
254 int io_ringfd_register(struct io_ring_ctx *ctx, void __user *__arg,
255 		       unsigned nr_args)
256 {
257 	struct io_uring_rsrc_update __user *arg = __arg;
258 	struct io_uring_rsrc_update reg;
259 	struct io_uring_task *tctx;
260 	int ret, i;
261 
262 	if (!nr_args || nr_args > IO_RINGFD_REG_MAX)
263 		return -EINVAL;
264 
265 	mutex_unlock(&ctx->uring_lock);
266 	ret = __io_uring_add_tctx_node(ctx);
267 	mutex_lock(&ctx->uring_lock);
268 	if (ret)
269 		return ret;
270 
271 	tctx = current->io_uring;
272 	for (i = 0; i < nr_args; i++) {
273 		int start, end;
274 
275 		if (copy_from_user(&reg, &arg[i], sizeof(reg))) {
276 			ret = -EFAULT;
277 			break;
278 		}
279 
280 		if (reg.resv) {
281 			ret = -EINVAL;
282 			break;
283 		}
284 
285 		if (reg.offset == -1U) {
286 			start = 0;
287 			end = IO_RINGFD_REG_MAX;
288 		} else {
289 			if (reg.offset >= IO_RINGFD_REG_MAX) {
290 				ret = -EINVAL;
291 				break;
292 			}
293 			start = reg.offset;
294 			end = start + 1;
295 		}
296 
297 		ret = io_ring_add_registered_fd(tctx, reg.data, start, end);
298 		if (ret < 0)
299 			break;
300 
301 		reg.offset = ret;
302 		if (copy_to_user(&arg[i], &reg, sizeof(reg))) {
303 			fput(tctx->registered_rings[reg.offset]);
304 			tctx->registered_rings[reg.offset] = NULL;
305 			ret = -EFAULT;
306 			break;
307 		}
308 	}
309 
310 	return i ? i : ret;
311 }
312 
313 int io_ringfd_unregister(struct io_ring_ctx *ctx, void __user *__arg,
314 			 unsigned nr_args)
315 {
316 	struct io_uring_rsrc_update __user *arg = __arg;
317 	struct io_uring_task *tctx = current->io_uring;
318 	struct io_uring_rsrc_update reg;
319 	int ret = 0, i;
320 
321 	if (!nr_args || nr_args > IO_RINGFD_REG_MAX)
322 		return -EINVAL;
323 	if (!tctx)
324 		return 0;
325 
326 	for (i = 0; i < nr_args; i++) {
327 		if (copy_from_user(&reg, &arg[i], sizeof(reg))) {
328 			ret = -EFAULT;
329 			break;
330 		}
331 		if (reg.resv || reg.data || reg.offset >= IO_RINGFD_REG_MAX) {
332 			ret = -EINVAL;
333 			break;
334 		}
335 
336 		reg.offset = array_index_nospec(reg.offset, IO_RINGFD_REG_MAX);
337 		if (tctx->registered_rings[reg.offset]) {
338 			fput(tctx->registered_rings[reg.offset]);
339 			tctx->registered_rings[reg.offset] = NULL;
340 		}
341 	}
342 
343 	return i ? i : ret;
344 }
345