xref: /linux/drivers/misc/ntsync.c (revision 180a232ea78003d1dc869b217b4e49106fd58e8f)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * ntsync.c - Kernel driver for NT synchronization primitives
4  *
5  * Copyright (C) 2024 Elizabeth Figura <zfigura@codeweavers.com>
6  */
7 
8 #include <linux/anon_inodes.h>
9 #include <linux/atomic.h>
10 #include <linux/file.h>
11 #include <linux/fs.h>
12 #include <linux/hrtimer.h>
13 #include <linux/ktime.h>
14 #include <linux/miscdevice.h>
15 #include <linux/module.h>
16 #include <linux/mutex.h>
17 #include <linux/overflow.h>
18 #include <linux/sched.h>
19 #include <linux/sched/signal.h>
20 #include <linux/slab.h>
21 #include <linux/spinlock.h>
22 #include <linux/time_namespace.h>
23 #include <uapi/linux/ntsync.h>
24 
25 #define NTSYNC_NAME	"ntsync"
26 
27 enum ntsync_type {
28 	NTSYNC_TYPE_SEM,
29 	NTSYNC_TYPE_MUTEX,
30 	NTSYNC_TYPE_EVENT,
31 };
32 
33 /*
34  * Individual synchronization primitives are represented by
35  * struct ntsync_obj, and each primitive is backed by a file.
36  *
37  * The whole namespace is represented by a struct ntsync_device also
38  * backed by a file.
39  *
40  * Both rely on struct file for reference counting. Individual
41  * ntsync_obj objects take a reference to the device when created.
42  * Wait operations take a reference to each object being waited on for
43  * the duration of the wait.
44  */
45 
46 struct ntsync_obj {
47 	spinlock_t lock;
48 	int dev_locked;
49 
50 	enum ntsync_type type;
51 
52 	struct file *file;
53 	struct ntsync_device *dev;
54 
55 	/* The following fields are protected by the object lock. */
56 	union {
57 		struct {
58 			__u32 count;
59 			__u32 max;
60 		} sem;
61 		struct {
62 			__u32 count;
63 			pid_t owner;
64 			bool ownerdead;
65 		} mutex;
66 		struct {
67 			bool manual;
68 			bool signaled;
69 		} event;
70 	} u;
71 
72 	/*
73 	 * any_waiters is protected by the object lock, but all_waiters is
74 	 * protected by the device wait_all_lock.
75 	 */
76 	struct list_head any_waiters;
77 	struct list_head all_waiters;
78 
79 	/*
80 	 * Hint describing how many tasks are queued on this object in a
81 	 * wait-all operation.
82 	 *
83 	 * Any time we do a wake, we may need to wake "all" waiters as well as
84 	 * "any" waiters. In order to atomically wake "all" waiters, we must
85 	 * lock all of the objects, and that means grabbing the wait_all_lock
86 	 * below (and, due to lock ordering rules, before locking this object).
87 	 * However, wait-all is a rare operation, and grabbing the wait-all
88 	 * lock for every wake would create unnecessary contention.
89 	 * Therefore we first check whether all_hint is zero, and, if it is,
90 	 * we skip trying to wake "all" waiters.
91 	 *
92 	 * Since wait requests must originate from user-space threads, we're
93 	 * limited here by PID_MAX_LIMIT, so there's no risk of overflow.
94 	 */
95 	atomic_t all_hint;
96 };
97 
98 struct ntsync_q_entry {
99 	struct list_head node;
100 	struct ntsync_q *q;
101 	struct ntsync_obj *obj;
102 	__u32 index;
103 };
104 
105 struct ntsync_q {
106 	struct task_struct *task;
107 	__u32 owner;
108 
109 	/*
110 	 * Protected via atomic_try_cmpxchg(). Only the thread that wins the
111 	 * compare-and-swap may actually change object states and wake this
112 	 * task.
113 	 */
114 	atomic_t signaled;
115 
116 	bool all;
117 	bool ownerdead;
118 	__u32 count;
119 	struct ntsync_q_entry entries[];
120 };
121 
122 struct ntsync_device {
123 	/*
124 	 * Wait-all operations must atomically grab all objects, and be totally
125 	 * ordered with respect to each other and wait-any operations.
126 	 * If one thread is trying to acquire several objects, another thread
127 	 * cannot touch the object at the same time.
128 	 *
129 	 * This device-wide lock is used to serialize wait-for-all
130 	 * operations, and operations on an object that is involved in a
131 	 * wait-for-all.
132 	 */
133 	struct mutex wait_all_lock;
134 
135 	struct file *file;
136 };
137 
138 /*
139  * Single objects are locked using obj->lock.
140  *
141  * Multiple objects are 'locked' while holding dev->wait_all_lock.
142  * In this case however, individual objects are not locked by holding
143  * obj->lock, but by setting obj->dev_locked.
144  *
145  * This means that in order to lock a single object, the sequence is slightly
146  * more complicated than usual. Specifically it needs to check obj->dev_locked
147  * after acquiring obj->lock, if set, it needs to drop the lock and acquire
148  * dev->wait_all_lock in order to serialize against the multi-object operation.
149  */
150 
151 static void dev_lock_obj(struct ntsync_device *dev, struct ntsync_obj *obj)
152 {
153 	lockdep_assert_held(&dev->wait_all_lock);
154 	lockdep_assert(obj->dev == dev);
155 	spin_lock(&obj->lock);
156 	/*
157 	 * By setting obj->dev_locked inside obj->lock, it is ensured that
158 	 * anyone holding obj->lock must see the value.
159 	 */
160 	obj->dev_locked = 1;
161 	spin_unlock(&obj->lock);
162 }
163 
164 static void dev_unlock_obj(struct ntsync_device *dev, struct ntsync_obj *obj)
165 {
166 	lockdep_assert_held(&dev->wait_all_lock);
167 	lockdep_assert(obj->dev == dev);
168 	spin_lock(&obj->lock);
169 	obj->dev_locked = 0;
170 	spin_unlock(&obj->lock);
171 }
172 
173 static void obj_lock(struct ntsync_obj *obj)
174 {
175 	struct ntsync_device *dev = obj->dev;
176 
177 	for (;;) {
178 		spin_lock(&obj->lock);
179 		if (likely(!obj->dev_locked))
180 			break;
181 
182 		spin_unlock(&obj->lock);
183 		mutex_lock(&dev->wait_all_lock);
184 		spin_lock(&obj->lock);
185 		/*
186 		 * obj->dev_locked should be set and released under the same
187 		 * wait_all_lock section, since we now own this lock, it should
188 		 * be clear.
189 		 */
190 		lockdep_assert(!obj->dev_locked);
191 		spin_unlock(&obj->lock);
192 		mutex_unlock(&dev->wait_all_lock);
193 	}
194 }
195 
196 static void obj_unlock(struct ntsync_obj *obj)
197 {
198 	spin_unlock(&obj->lock);
199 }
200 
201 static bool ntsync_lock_obj(struct ntsync_device *dev, struct ntsync_obj *obj)
202 {
203 	bool all;
204 
205 	obj_lock(obj);
206 	all = atomic_read(&obj->all_hint);
207 	if (unlikely(all)) {
208 		obj_unlock(obj);
209 		mutex_lock(&dev->wait_all_lock);
210 		dev_lock_obj(dev, obj);
211 	}
212 
213 	return all;
214 }
215 
216 static void ntsync_unlock_obj(struct ntsync_device *dev, struct ntsync_obj *obj, bool all)
217 {
218 	if (all) {
219 		dev_unlock_obj(dev, obj);
220 		mutex_unlock(&dev->wait_all_lock);
221 	} else {
222 		obj_unlock(obj);
223 	}
224 }
225 
226 #define ntsync_assert_held(obj) \
227 	lockdep_assert((lockdep_is_held(&(obj)->lock) != LOCK_STATE_NOT_HELD) || \
228 		       ((lockdep_is_held(&(obj)->dev->wait_all_lock) != LOCK_STATE_NOT_HELD) && \
229 			(obj)->dev_locked))
230 
231 static bool is_signaled(struct ntsync_obj *obj, __u32 owner)
232 {
233 	ntsync_assert_held(obj);
234 
235 	switch (obj->type) {
236 	case NTSYNC_TYPE_SEM:
237 		return !!obj->u.sem.count;
238 	case NTSYNC_TYPE_MUTEX:
239 		if (obj->u.mutex.owner && obj->u.mutex.owner != owner)
240 			return false;
241 		return obj->u.mutex.count < UINT_MAX;
242 	case NTSYNC_TYPE_EVENT:
243 		return obj->u.event.signaled;
244 	}
245 
246 	WARN(1, "bad object type %#x\n", obj->type);
247 	return false;
248 }
249 
250 /*
251  * "locked_obj" is an optional pointer to an object which is already locked and
252  * should not be locked again. This is necessary so that changing an object's
253  * state and waking it can be a single atomic operation.
254  */
255 static void try_wake_all(struct ntsync_device *dev, struct ntsync_q *q,
256 			 struct ntsync_obj *locked_obj)
257 {
258 	__u32 count = q->count;
259 	bool can_wake = true;
260 	int signaled = -1;
261 	__u32 i;
262 
263 	lockdep_assert_held(&dev->wait_all_lock);
264 	if (locked_obj)
265 		lockdep_assert(locked_obj->dev_locked);
266 
267 	for (i = 0; i < count; i++) {
268 		if (q->entries[i].obj != locked_obj)
269 			dev_lock_obj(dev, q->entries[i].obj);
270 	}
271 
272 	for (i = 0; i < count; i++) {
273 		if (!is_signaled(q->entries[i].obj, q->owner)) {
274 			can_wake = false;
275 			break;
276 		}
277 	}
278 
279 	if (can_wake && atomic_try_cmpxchg(&q->signaled, &signaled, 0)) {
280 		for (i = 0; i < count; i++) {
281 			struct ntsync_obj *obj = q->entries[i].obj;
282 
283 			switch (obj->type) {
284 			case NTSYNC_TYPE_SEM:
285 				obj->u.sem.count--;
286 				break;
287 			case NTSYNC_TYPE_MUTEX:
288 				if (obj->u.mutex.ownerdead)
289 					q->ownerdead = true;
290 				obj->u.mutex.ownerdead = false;
291 				obj->u.mutex.count++;
292 				obj->u.mutex.owner = q->owner;
293 				break;
294 			case NTSYNC_TYPE_EVENT:
295 				if (!obj->u.event.manual)
296 					obj->u.event.signaled = false;
297 				break;
298 			}
299 		}
300 		wake_up_process(q->task);
301 	}
302 
303 	for (i = 0; i < count; i++) {
304 		if (q->entries[i].obj != locked_obj)
305 			dev_unlock_obj(dev, q->entries[i].obj);
306 	}
307 }
308 
309 static void try_wake_all_obj(struct ntsync_device *dev, struct ntsync_obj *obj)
310 {
311 	struct ntsync_q_entry *entry;
312 
313 	lockdep_assert_held(&dev->wait_all_lock);
314 	lockdep_assert(obj->dev_locked);
315 
316 	list_for_each_entry(entry, &obj->all_waiters, node)
317 		try_wake_all(dev, entry->q, obj);
318 }
319 
320 static void try_wake_any_sem(struct ntsync_obj *sem)
321 {
322 	struct ntsync_q_entry *entry;
323 
324 	ntsync_assert_held(sem);
325 	lockdep_assert(sem->type == NTSYNC_TYPE_SEM);
326 
327 	list_for_each_entry(entry, &sem->any_waiters, node) {
328 		struct ntsync_q *q = entry->q;
329 		int signaled = -1;
330 
331 		if (!sem->u.sem.count)
332 			break;
333 
334 		if (atomic_try_cmpxchg(&q->signaled, &signaled, entry->index)) {
335 			sem->u.sem.count--;
336 			wake_up_process(q->task);
337 		}
338 	}
339 }
340 
341 static void try_wake_any_mutex(struct ntsync_obj *mutex)
342 {
343 	struct ntsync_q_entry *entry;
344 
345 	ntsync_assert_held(mutex);
346 	lockdep_assert(mutex->type == NTSYNC_TYPE_MUTEX);
347 
348 	list_for_each_entry(entry, &mutex->any_waiters, node) {
349 		struct ntsync_q *q = entry->q;
350 		int signaled = -1;
351 
352 		if (mutex->u.mutex.count == UINT_MAX)
353 			break;
354 		if (mutex->u.mutex.owner && mutex->u.mutex.owner != q->owner)
355 			continue;
356 
357 		if (atomic_try_cmpxchg(&q->signaled, &signaled, entry->index)) {
358 			if (mutex->u.mutex.ownerdead)
359 				q->ownerdead = true;
360 			mutex->u.mutex.ownerdead = false;
361 			mutex->u.mutex.count++;
362 			mutex->u.mutex.owner = q->owner;
363 			wake_up_process(q->task);
364 		}
365 	}
366 }
367 
368 static void try_wake_any_event(struct ntsync_obj *event)
369 {
370 	struct ntsync_q_entry *entry;
371 
372 	ntsync_assert_held(event);
373 	lockdep_assert(event->type == NTSYNC_TYPE_EVENT);
374 
375 	list_for_each_entry(entry, &event->any_waiters, node) {
376 		struct ntsync_q *q = entry->q;
377 		int signaled = -1;
378 
379 		if (!event->u.event.signaled)
380 			break;
381 
382 		if (atomic_try_cmpxchg(&q->signaled, &signaled, entry->index)) {
383 			if (!event->u.event.manual)
384 				event->u.event.signaled = false;
385 			wake_up_process(q->task);
386 		}
387 	}
388 }
389 
390 /*
391  * Actually change the semaphore state, returning -EOVERFLOW if it is made
392  * invalid.
393  */
394 static int release_sem_state(struct ntsync_obj *sem, __u32 count)
395 {
396 	__u32 sum;
397 
398 	ntsync_assert_held(sem);
399 
400 	if (check_add_overflow(sem->u.sem.count, count, &sum) ||
401 	    sum > sem->u.sem.max)
402 		return -EOVERFLOW;
403 
404 	sem->u.sem.count = sum;
405 	return 0;
406 }
407 
408 static int ntsync_sem_release(struct ntsync_obj *sem, void __user *argp)
409 {
410 	struct ntsync_device *dev = sem->dev;
411 	__u32 __user *user_args = argp;
412 	__u32 prev_count;
413 	__u32 args;
414 	bool all;
415 	int ret;
416 
417 	if (copy_from_user(&args, argp, sizeof(args)))
418 		return -EFAULT;
419 
420 	if (sem->type != NTSYNC_TYPE_SEM)
421 		return -EINVAL;
422 
423 	all = ntsync_lock_obj(dev, sem);
424 
425 	prev_count = sem->u.sem.count;
426 	ret = release_sem_state(sem, args);
427 	if (!ret) {
428 		if (all)
429 			try_wake_all_obj(dev, sem);
430 		try_wake_any_sem(sem);
431 	}
432 
433 	ntsync_unlock_obj(dev, sem, all);
434 
435 	if (!ret && put_user(prev_count, user_args))
436 		ret = -EFAULT;
437 
438 	return ret;
439 }
440 
441 /*
442  * Actually change the mutex state, returning -EPERM if not the owner.
443  */
444 static int unlock_mutex_state(struct ntsync_obj *mutex,
445 			      const struct ntsync_mutex_args *args)
446 {
447 	ntsync_assert_held(mutex);
448 
449 	if (mutex->u.mutex.owner != args->owner)
450 		return -EPERM;
451 
452 	if (!--mutex->u.mutex.count)
453 		mutex->u.mutex.owner = 0;
454 	return 0;
455 }
456 
457 static int ntsync_mutex_unlock(struct ntsync_obj *mutex, void __user *argp)
458 {
459 	struct ntsync_mutex_args __user *user_args = argp;
460 	struct ntsync_device *dev = mutex->dev;
461 	struct ntsync_mutex_args args;
462 	__u32 prev_count;
463 	bool all;
464 	int ret;
465 
466 	if (copy_from_user(&args, argp, sizeof(args)))
467 		return -EFAULT;
468 	if (!args.owner)
469 		return -EINVAL;
470 
471 	if (mutex->type != NTSYNC_TYPE_MUTEX)
472 		return -EINVAL;
473 
474 	all = ntsync_lock_obj(dev, mutex);
475 
476 	prev_count = mutex->u.mutex.count;
477 	ret = unlock_mutex_state(mutex, &args);
478 	if (!ret) {
479 		if (all)
480 			try_wake_all_obj(dev, mutex);
481 		try_wake_any_mutex(mutex);
482 	}
483 
484 	ntsync_unlock_obj(dev, mutex, all);
485 
486 	if (!ret && put_user(prev_count, &user_args->count))
487 		ret = -EFAULT;
488 
489 	return ret;
490 }
491 
492 /*
493  * Actually change the mutex state to mark its owner as dead,
494  * returning -EPERM if not the owner.
495  */
496 static int kill_mutex_state(struct ntsync_obj *mutex, __u32 owner)
497 {
498 	ntsync_assert_held(mutex);
499 
500 	if (mutex->u.mutex.owner != owner)
501 		return -EPERM;
502 
503 	mutex->u.mutex.ownerdead = true;
504 	mutex->u.mutex.owner = 0;
505 	mutex->u.mutex.count = 0;
506 	return 0;
507 }
508 
509 static int ntsync_mutex_kill(struct ntsync_obj *mutex, void __user *argp)
510 {
511 	struct ntsync_device *dev = mutex->dev;
512 	__u32 owner;
513 	bool all;
514 	int ret;
515 
516 	if (get_user(owner, (__u32 __user *)argp))
517 		return -EFAULT;
518 	if (!owner)
519 		return -EINVAL;
520 
521 	if (mutex->type != NTSYNC_TYPE_MUTEX)
522 		return -EINVAL;
523 
524 	all = ntsync_lock_obj(dev, mutex);
525 
526 	ret = kill_mutex_state(mutex, owner);
527 	if (!ret) {
528 		if (all)
529 			try_wake_all_obj(dev, mutex);
530 		try_wake_any_mutex(mutex);
531 	}
532 
533 	ntsync_unlock_obj(dev, mutex, all);
534 
535 	return ret;
536 }
537 
538 static int ntsync_event_set(struct ntsync_obj *event, void __user *argp, bool pulse)
539 {
540 	struct ntsync_device *dev = event->dev;
541 	__u32 prev_state;
542 	bool all;
543 
544 	if (event->type != NTSYNC_TYPE_EVENT)
545 		return -EINVAL;
546 
547 	all = ntsync_lock_obj(dev, event);
548 
549 	prev_state = event->u.event.signaled;
550 	event->u.event.signaled = true;
551 	if (all)
552 		try_wake_all_obj(dev, event);
553 	try_wake_any_event(event);
554 	if (pulse)
555 		event->u.event.signaled = false;
556 
557 	ntsync_unlock_obj(dev, event, all);
558 
559 	if (put_user(prev_state, (__u32 __user *)argp))
560 		return -EFAULT;
561 
562 	return 0;
563 }
564 
565 static int ntsync_event_reset(struct ntsync_obj *event, void __user *argp)
566 {
567 	struct ntsync_device *dev = event->dev;
568 	__u32 prev_state;
569 	bool all;
570 
571 	if (event->type != NTSYNC_TYPE_EVENT)
572 		return -EINVAL;
573 
574 	all = ntsync_lock_obj(dev, event);
575 
576 	prev_state = event->u.event.signaled;
577 	event->u.event.signaled = false;
578 
579 	ntsync_unlock_obj(dev, event, all);
580 
581 	if (put_user(prev_state, (__u32 __user *)argp))
582 		return -EFAULT;
583 
584 	return 0;
585 }
586 
587 static int ntsync_sem_read(struct ntsync_obj *sem, void __user *argp)
588 {
589 	struct ntsync_sem_args __user *user_args = argp;
590 	struct ntsync_device *dev = sem->dev;
591 	struct ntsync_sem_args args;
592 	bool all;
593 
594 	if (sem->type != NTSYNC_TYPE_SEM)
595 		return -EINVAL;
596 
597 	all = ntsync_lock_obj(dev, sem);
598 
599 	args.count = sem->u.sem.count;
600 	args.max = sem->u.sem.max;
601 
602 	ntsync_unlock_obj(dev, sem, all);
603 
604 	if (copy_to_user(user_args, &args, sizeof(args)))
605 		return -EFAULT;
606 	return 0;
607 }
608 
609 static int ntsync_mutex_read(struct ntsync_obj *mutex, void __user *argp)
610 {
611 	struct ntsync_mutex_args __user *user_args = argp;
612 	struct ntsync_device *dev = mutex->dev;
613 	struct ntsync_mutex_args args;
614 	bool all;
615 	int ret;
616 
617 	if (mutex->type != NTSYNC_TYPE_MUTEX)
618 		return -EINVAL;
619 
620 	all = ntsync_lock_obj(dev, mutex);
621 
622 	args.count = mutex->u.mutex.count;
623 	args.owner = mutex->u.mutex.owner;
624 	ret = mutex->u.mutex.ownerdead ? -EOWNERDEAD : 0;
625 
626 	ntsync_unlock_obj(dev, mutex, all);
627 
628 	if (copy_to_user(user_args, &args, sizeof(args)))
629 		return -EFAULT;
630 	return ret;
631 }
632 
633 static int ntsync_event_read(struct ntsync_obj *event, void __user *argp)
634 {
635 	struct ntsync_event_args __user *user_args = argp;
636 	struct ntsync_device *dev = event->dev;
637 	struct ntsync_event_args args;
638 	bool all;
639 
640 	if (event->type != NTSYNC_TYPE_EVENT)
641 		return -EINVAL;
642 
643 	all = ntsync_lock_obj(dev, event);
644 
645 	args.manual = event->u.event.manual;
646 	args.signaled = event->u.event.signaled;
647 
648 	ntsync_unlock_obj(dev, event, all);
649 
650 	if (copy_to_user(user_args, &args, sizeof(args)))
651 		return -EFAULT;
652 	return 0;
653 }
654 
655 static void ntsync_free_obj(struct ntsync_obj *obj)
656 {
657 	fput(obj->dev->file);
658 	kfree(obj);
659 }
660 
661 static int ntsync_obj_release(struct inode *inode, struct file *file)
662 {
663 	ntsync_free_obj(file->private_data);
664 	return 0;
665 }
666 
667 static long ntsync_obj_ioctl(struct file *file, unsigned int cmd,
668 			     unsigned long parm)
669 {
670 	struct ntsync_obj *obj = file->private_data;
671 	void __user *argp = (void __user *)parm;
672 
673 	switch (cmd) {
674 	case NTSYNC_IOC_SEM_RELEASE:
675 		return ntsync_sem_release(obj, argp);
676 	case NTSYNC_IOC_SEM_READ:
677 		return ntsync_sem_read(obj, argp);
678 	case NTSYNC_IOC_MUTEX_UNLOCK:
679 		return ntsync_mutex_unlock(obj, argp);
680 	case NTSYNC_IOC_MUTEX_KILL:
681 		return ntsync_mutex_kill(obj, argp);
682 	case NTSYNC_IOC_MUTEX_READ:
683 		return ntsync_mutex_read(obj, argp);
684 	case NTSYNC_IOC_EVENT_SET:
685 		return ntsync_event_set(obj, argp, false);
686 	case NTSYNC_IOC_EVENT_RESET:
687 		return ntsync_event_reset(obj, argp);
688 	case NTSYNC_IOC_EVENT_PULSE:
689 		return ntsync_event_set(obj, argp, true);
690 	case NTSYNC_IOC_EVENT_READ:
691 		return ntsync_event_read(obj, argp);
692 	default:
693 		return -ENOIOCTLCMD;
694 	}
695 }
696 
697 static const struct file_operations ntsync_obj_fops = {
698 	.owner		= THIS_MODULE,
699 	.release	= ntsync_obj_release,
700 	.unlocked_ioctl	= ntsync_obj_ioctl,
701 	.compat_ioctl	= compat_ptr_ioctl,
702 };
703 
704 static struct ntsync_obj *ntsync_alloc_obj(struct ntsync_device *dev,
705 					   enum ntsync_type type)
706 {
707 	struct ntsync_obj *obj;
708 
709 	obj = kzalloc_obj(*obj);
710 	if (!obj)
711 		return NULL;
712 	obj->type = type;
713 	obj->dev = dev;
714 	get_file(dev->file);
715 	spin_lock_init(&obj->lock);
716 	INIT_LIST_HEAD(&obj->any_waiters);
717 	INIT_LIST_HEAD(&obj->all_waiters);
718 	atomic_set(&obj->all_hint, 0);
719 
720 	return obj;
721 }
722 
723 static int ntsync_obj_get_fd(struct ntsync_obj *obj)
724 {
725 	FD_PREPARE(fdf, O_CLOEXEC,
726 		   anon_inode_getfile("ntsync", &ntsync_obj_fops, obj, O_RDWR));
727 	if (fdf.err)
728 		return fdf.err;
729 	obj->file = fd_prepare_file(fdf);
730 	return fd_publish(fdf);
731 }
732 
733 static int ntsync_create_sem(struct ntsync_device *dev, void __user *argp)
734 {
735 	struct ntsync_sem_args args;
736 	struct ntsync_obj *sem;
737 	int fd;
738 
739 	if (copy_from_user(&args, argp, sizeof(args)))
740 		return -EFAULT;
741 
742 	if (args.count > args.max)
743 		return -EINVAL;
744 
745 	sem = ntsync_alloc_obj(dev, NTSYNC_TYPE_SEM);
746 	if (!sem)
747 		return -ENOMEM;
748 	sem->u.sem.count = args.count;
749 	sem->u.sem.max = args.max;
750 	fd = ntsync_obj_get_fd(sem);
751 	if (fd < 0)
752 		ntsync_free_obj(sem);
753 
754 	return fd;
755 }
756 
757 static int ntsync_create_mutex(struct ntsync_device *dev, void __user *argp)
758 {
759 	struct ntsync_mutex_args args;
760 	struct ntsync_obj *mutex;
761 	int fd;
762 
763 	if (copy_from_user(&args, argp, sizeof(args)))
764 		return -EFAULT;
765 
766 	if (!args.owner != !args.count)
767 		return -EINVAL;
768 
769 	mutex = ntsync_alloc_obj(dev, NTSYNC_TYPE_MUTEX);
770 	if (!mutex)
771 		return -ENOMEM;
772 	mutex->u.mutex.count = args.count;
773 	mutex->u.mutex.owner = args.owner;
774 	fd = ntsync_obj_get_fd(mutex);
775 	if (fd < 0)
776 		ntsync_free_obj(mutex);
777 
778 	return fd;
779 }
780 
781 static int ntsync_create_event(struct ntsync_device *dev, void __user *argp)
782 {
783 	struct ntsync_event_args args;
784 	struct ntsync_obj *event;
785 	int fd;
786 
787 	if (copy_from_user(&args, argp, sizeof(args)))
788 		return -EFAULT;
789 
790 	event = ntsync_alloc_obj(dev, NTSYNC_TYPE_EVENT);
791 	if (!event)
792 		return -ENOMEM;
793 	event->u.event.manual = args.manual;
794 	event->u.event.signaled = args.signaled;
795 	fd = ntsync_obj_get_fd(event);
796 	if (fd < 0)
797 		ntsync_free_obj(event);
798 
799 	return fd;
800 }
801 
802 static struct ntsync_obj *get_obj(struct ntsync_device *dev, int fd)
803 {
804 	struct file *file = fget(fd);
805 	struct ntsync_obj *obj;
806 
807 	if (!file)
808 		return NULL;
809 
810 	if (file->f_op != &ntsync_obj_fops) {
811 		fput(file);
812 		return NULL;
813 	}
814 
815 	obj = file->private_data;
816 	if (obj->dev != dev) {
817 		fput(file);
818 		return NULL;
819 	}
820 
821 	return obj;
822 }
823 
824 static void put_obj(struct ntsync_obj *obj)
825 {
826 	fput(obj->file);
827 }
828 
829 static int ntsync_schedule(const struct ntsync_q *q, const struct ntsync_wait_args *args)
830 {
831 	ktime_t timeout = ns_to_ktime(args->timeout);
832 	clockid_t clock = CLOCK_MONOTONIC;
833 	ktime_t *timeout_ptr;
834 	int ret = 0;
835 
836 	timeout_ptr = (args->timeout == U64_MAX ? NULL : &timeout);
837 
838 	if (args->flags & NTSYNC_WAIT_REALTIME)
839 		clock = CLOCK_REALTIME;
840 	else
841 		timeout = timens_ktime_to_host(clock, timeout);
842 
843 	do {
844 		if (signal_pending(current)) {
845 			ret = -ERESTARTSYS;
846 			break;
847 		}
848 
849 		set_current_state(TASK_INTERRUPTIBLE);
850 		if (atomic_read(&q->signaled) != -1) {
851 			ret = 0;
852 			break;
853 		}
854 		ret = schedule_hrtimeout_range_clock(timeout_ptr, 0, HRTIMER_MODE_ABS, clock);
855 	} while (ret < 0);
856 	__set_current_state(TASK_RUNNING);
857 
858 	return ret;
859 }
860 
861 /*
862  * Allocate and initialize the ntsync_q structure, but do not queue us yet.
863  */
864 static int setup_wait(struct ntsync_device *dev,
865 		      const struct ntsync_wait_args *args, bool all,
866 		      struct ntsync_q **ret_q)
867 {
868 	int fds[NTSYNC_MAX_WAIT_COUNT + 1];
869 	const __u32 count = args->count;
870 	size_t size = array_size(count, sizeof(fds[0]));
871 	struct ntsync_q *q;
872 	__u32 total_count;
873 	__u32 i, j;
874 
875 	if (args->pad || (args->flags & ~NTSYNC_WAIT_REALTIME))
876 		return -EINVAL;
877 
878 	if (size >= sizeof(fds))
879 		return -EINVAL;
880 
881 	total_count = count;
882 	if (args->alert)
883 		total_count++;
884 
885 	if (copy_from_user(fds, u64_to_user_ptr(args->objs), size))
886 		return -EFAULT;
887 	if (args->alert)
888 		fds[count] = args->alert;
889 
890 	q = kmalloc_flex(*q, entries, total_count);
891 	if (!q)
892 		return -ENOMEM;
893 	q->task = current;
894 	q->owner = args->owner;
895 	atomic_set(&q->signaled, -1);
896 	q->all = all;
897 	q->ownerdead = false;
898 	q->count = count;
899 
900 	for (i = 0; i < total_count; i++) {
901 		struct ntsync_q_entry *entry = &q->entries[i];
902 		struct ntsync_obj *obj = get_obj(dev, fds[i]);
903 
904 		if (!obj)
905 			goto err;
906 
907 		if (all) {
908 			/* Check that the objects are all distinct. */
909 			for (j = 0; j < i; j++) {
910 				if (obj == q->entries[j].obj) {
911 					put_obj(obj);
912 					goto err;
913 				}
914 			}
915 		}
916 
917 		entry->obj = obj;
918 		entry->q = q;
919 		entry->index = i;
920 	}
921 
922 	*ret_q = q;
923 	return 0;
924 
925 err:
926 	for (j = 0; j < i; j++)
927 		put_obj(q->entries[j].obj);
928 	kfree(q);
929 	return -EINVAL;
930 }
931 
932 static void try_wake_any_obj(struct ntsync_obj *obj)
933 {
934 	switch (obj->type) {
935 	case NTSYNC_TYPE_SEM:
936 		try_wake_any_sem(obj);
937 		break;
938 	case NTSYNC_TYPE_MUTEX:
939 		try_wake_any_mutex(obj);
940 		break;
941 	case NTSYNC_TYPE_EVENT:
942 		try_wake_any_event(obj);
943 		break;
944 	}
945 }
946 
947 static int ntsync_wait_any(struct ntsync_device *dev, void __user *argp)
948 {
949 	struct ntsync_wait_args args;
950 	__u32 i, total_count;
951 	struct ntsync_q *q;
952 	int signaled;
953 	bool all;
954 	int ret;
955 
956 	if (copy_from_user(&args, argp, sizeof(args)))
957 		return -EFAULT;
958 
959 	ret = setup_wait(dev, &args, false, &q);
960 	if (ret < 0)
961 		return ret;
962 
963 	total_count = args.count;
964 	if (args.alert)
965 		total_count++;
966 
967 	/* queue ourselves */
968 
969 	for (i = 0; i < total_count; i++) {
970 		struct ntsync_q_entry *entry = &q->entries[i];
971 		struct ntsync_obj *obj = entry->obj;
972 
973 		all = ntsync_lock_obj(dev, obj);
974 		list_add_tail(&entry->node, &obj->any_waiters);
975 		ntsync_unlock_obj(dev, obj, all);
976 	}
977 
978 	/*
979 	 * Check if we are already signaled.
980 	 *
981 	 * Note that the API requires that normal objects are checked before
982 	 * the alert event. Hence we queue the alert event last, and check
983 	 * objects in order.
984 	 */
985 
986 	for (i = 0; i < total_count; i++) {
987 		struct ntsync_obj *obj = q->entries[i].obj;
988 
989 		if (atomic_read(&q->signaled) != -1)
990 			break;
991 
992 		all = ntsync_lock_obj(dev, obj);
993 		try_wake_any_obj(obj);
994 		ntsync_unlock_obj(dev, obj, all);
995 	}
996 
997 	/* sleep */
998 
999 	ret = ntsync_schedule(q, &args);
1000 
1001 	/* and finally, unqueue */
1002 
1003 	for (i = 0; i < total_count; i++) {
1004 		struct ntsync_q_entry *entry = &q->entries[i];
1005 		struct ntsync_obj *obj = entry->obj;
1006 
1007 		all = ntsync_lock_obj(dev, obj);
1008 		list_del(&entry->node);
1009 		ntsync_unlock_obj(dev, obj, all);
1010 
1011 		put_obj(obj);
1012 	}
1013 
1014 	signaled = atomic_read(&q->signaled);
1015 	if (signaled != -1) {
1016 		struct ntsync_wait_args __user *user_args = argp;
1017 
1018 		/* even if we caught a signal, we need to communicate success */
1019 		ret = q->ownerdead ? -EOWNERDEAD : 0;
1020 
1021 		if (put_user(signaled, &user_args->index))
1022 			ret = -EFAULT;
1023 	} else if (!ret) {
1024 		ret = -ETIMEDOUT;
1025 	}
1026 
1027 	kfree(q);
1028 	return ret;
1029 }
1030 
1031 static int ntsync_wait_all(struct ntsync_device *dev, void __user *argp)
1032 {
1033 	struct ntsync_wait_args args;
1034 	struct ntsync_q *q;
1035 	int signaled;
1036 	__u32 i;
1037 	int ret;
1038 
1039 	if (copy_from_user(&args, argp, sizeof(args)))
1040 		return -EFAULT;
1041 
1042 	ret = setup_wait(dev, &args, true, &q);
1043 	if (ret < 0)
1044 		return ret;
1045 
1046 	/* queue ourselves */
1047 
1048 	mutex_lock(&dev->wait_all_lock);
1049 
1050 	for (i = 0; i < args.count; i++) {
1051 		struct ntsync_q_entry *entry = &q->entries[i];
1052 		struct ntsync_obj *obj = entry->obj;
1053 
1054 		atomic_inc(&obj->all_hint);
1055 
1056 		/*
1057 		 * obj->all_waiters is protected by dev->wait_all_lock rather
1058 		 * than obj->lock, so there is no need to acquire obj->lock
1059 		 * here.
1060 		 */
1061 		list_add_tail(&entry->node, &obj->all_waiters);
1062 	}
1063 	if (args.alert) {
1064 		struct ntsync_q_entry *entry = &q->entries[args.count];
1065 		struct ntsync_obj *obj = entry->obj;
1066 
1067 		dev_lock_obj(dev, obj);
1068 		list_add_tail(&entry->node, &obj->any_waiters);
1069 		dev_unlock_obj(dev, obj);
1070 	}
1071 
1072 	/* check if we are already signaled */
1073 
1074 	try_wake_all(dev, q, NULL);
1075 
1076 	mutex_unlock(&dev->wait_all_lock);
1077 
1078 	/*
1079 	 * Check if the alert event is signaled, making sure to do so only
1080 	 * after checking if the other objects are signaled.
1081 	 */
1082 
1083 	if (args.alert) {
1084 		struct ntsync_obj *obj = q->entries[args.count].obj;
1085 
1086 		if (atomic_read(&q->signaled) == -1) {
1087 			bool all = ntsync_lock_obj(dev, obj);
1088 			try_wake_any_obj(obj);
1089 			ntsync_unlock_obj(dev, obj, all);
1090 		}
1091 	}
1092 
1093 	/* sleep */
1094 
1095 	ret = ntsync_schedule(q, &args);
1096 
1097 	/* and finally, unqueue */
1098 
1099 	mutex_lock(&dev->wait_all_lock);
1100 
1101 	for (i = 0; i < args.count; i++) {
1102 		struct ntsync_q_entry *entry = &q->entries[i];
1103 		struct ntsync_obj *obj = entry->obj;
1104 
1105 		/*
1106 		 * obj->all_waiters is protected by dev->wait_all_lock rather
1107 		 * than obj->lock, so there is no need to acquire it here.
1108 		 */
1109 		list_del(&entry->node);
1110 
1111 		atomic_dec(&obj->all_hint);
1112 
1113 		put_obj(obj);
1114 	}
1115 
1116 	mutex_unlock(&dev->wait_all_lock);
1117 
1118 	if (args.alert) {
1119 		struct ntsync_q_entry *entry = &q->entries[args.count];
1120 		struct ntsync_obj *obj = entry->obj;
1121 		bool all;
1122 
1123 		all = ntsync_lock_obj(dev, obj);
1124 		list_del(&entry->node);
1125 		ntsync_unlock_obj(dev, obj, all);
1126 
1127 		put_obj(obj);
1128 	}
1129 
1130 	signaled = atomic_read(&q->signaled);
1131 	if (signaled != -1) {
1132 		struct ntsync_wait_args __user *user_args = argp;
1133 
1134 		/* even if we caught a signal, we need to communicate success */
1135 		ret = q->ownerdead ? -EOWNERDEAD : 0;
1136 
1137 		if (put_user(signaled, &user_args->index))
1138 			ret = -EFAULT;
1139 	} else if (!ret) {
1140 		ret = -ETIMEDOUT;
1141 	}
1142 
1143 	kfree(q);
1144 	return ret;
1145 }
1146 
1147 static int ntsync_char_open(struct inode *inode, struct file *file)
1148 {
1149 	struct ntsync_device *dev;
1150 
1151 	dev = kzalloc_obj(*dev);
1152 	if (!dev)
1153 		return -ENOMEM;
1154 
1155 	mutex_init(&dev->wait_all_lock);
1156 
1157 	file->private_data = dev;
1158 	dev->file = file;
1159 	return nonseekable_open(inode, file);
1160 }
1161 
1162 static int ntsync_char_release(struct inode *inode, struct file *file)
1163 {
1164 	struct ntsync_device *dev = file->private_data;
1165 
1166 	kfree(dev);
1167 
1168 	return 0;
1169 }
1170 
1171 static long ntsync_char_ioctl(struct file *file, unsigned int cmd,
1172 			      unsigned long parm)
1173 {
1174 	struct ntsync_device *dev = file->private_data;
1175 	void __user *argp = (void __user *)parm;
1176 
1177 	switch (cmd) {
1178 	case NTSYNC_IOC_CREATE_EVENT:
1179 		return ntsync_create_event(dev, argp);
1180 	case NTSYNC_IOC_CREATE_MUTEX:
1181 		return ntsync_create_mutex(dev, argp);
1182 	case NTSYNC_IOC_CREATE_SEM:
1183 		return ntsync_create_sem(dev, argp);
1184 	case NTSYNC_IOC_WAIT_ALL:
1185 		return ntsync_wait_all(dev, argp);
1186 	case NTSYNC_IOC_WAIT_ANY:
1187 		return ntsync_wait_any(dev, argp);
1188 	default:
1189 		return -ENOIOCTLCMD;
1190 	}
1191 }
1192 
1193 static const struct file_operations ntsync_fops = {
1194 	.owner		= THIS_MODULE,
1195 	.open		= ntsync_char_open,
1196 	.release	= ntsync_char_release,
1197 	.unlocked_ioctl	= ntsync_char_ioctl,
1198 	.compat_ioctl	= compat_ptr_ioctl,
1199 };
1200 
1201 static struct miscdevice ntsync_misc = {
1202 	.minor		= MISC_DYNAMIC_MINOR,
1203 	.name		= NTSYNC_NAME,
1204 	.fops		= &ntsync_fops,
1205 	.mode		= 0666,
1206 };
1207 
1208 module_misc_device(ntsync_misc);
1209 
1210 MODULE_AUTHOR("Elizabeth Figura <zfigura@codeweavers.com>");
1211 MODULE_DESCRIPTION("Kernel driver for NT synchronization primitives");
1212 MODULE_LICENSE("GPL");
1213