xref: /linux/drivers/misc/ntsync.c (revision 2cc699b3c2fe5ea84081059c88c46373cbebe630)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * ntsync.c - Kernel driver for NT synchronization primitives
4  *
5  * Copyright (C) 2024 Elizabeth Figura <zfigura@codeweavers.com>
6  */
7 
8 #include <linux/anon_inodes.h>
9 #include <linux/atomic.h>
10 #include <linux/file.h>
11 #include <linux/fs.h>
12 #include <linux/hrtimer.h>
13 #include <linux/ktime.h>
14 #include <linux/miscdevice.h>
15 #include <linux/module.h>
16 #include <linux/mutex.h>
17 #include <linux/overflow.h>
18 #include <linux/sched.h>
19 #include <linux/sched/signal.h>
20 #include <linux/slab.h>
21 #include <linux/spinlock.h>
22 #include <uapi/linux/ntsync.h>
23 
24 #define NTSYNC_NAME	"ntsync"
25 
26 enum ntsync_type {
27 	NTSYNC_TYPE_SEM,
28 	NTSYNC_TYPE_MUTEX,
29 	NTSYNC_TYPE_EVENT,
30 };
31 
32 /*
33  * Individual synchronization primitives are represented by
34  * struct ntsync_obj, and each primitive is backed by a file.
35  *
36  * The whole namespace is represented by a struct ntsync_device also
37  * backed by a file.
38  *
39  * Both rely on struct file for reference counting. Individual
40  * ntsync_obj objects take a reference to the device when created.
41  * Wait operations take a reference to each object being waited on for
42  * the duration of the wait.
43  */
44 
45 struct ntsync_obj {
46 	spinlock_t lock;
47 	int dev_locked;
48 
49 	enum ntsync_type type;
50 
51 	struct file *file;
52 	struct ntsync_device *dev;
53 
54 	/* The following fields are protected by the object lock. */
55 	union {
56 		struct {
57 			__u32 count;
58 			__u32 max;
59 		} sem;
60 		struct {
61 			__u32 count;
62 			pid_t owner;
63 			bool ownerdead;
64 		} mutex;
65 		struct {
66 			bool manual;
67 			bool signaled;
68 		} event;
69 	} u;
70 
71 	/*
72 	 * any_waiters is protected by the object lock, but all_waiters is
73 	 * protected by the device wait_all_lock.
74 	 */
75 	struct list_head any_waiters;
76 	struct list_head all_waiters;
77 
78 	/*
79 	 * Hint describing how many tasks are queued on this object in a
80 	 * wait-all operation.
81 	 *
82 	 * Any time we do a wake, we may need to wake "all" waiters as well as
83 	 * "any" waiters. In order to atomically wake "all" waiters, we must
84 	 * lock all of the objects, and that means grabbing the wait_all_lock
85 	 * below (and, due to lock ordering rules, before locking this object).
86 	 * However, wait-all is a rare operation, and grabbing the wait-all
87 	 * lock for every wake would create unnecessary contention.
88 	 * Therefore we first check whether all_hint is zero, and, if it is,
89 	 * we skip trying to wake "all" waiters.
90 	 *
91 	 * Since wait requests must originate from user-space threads, we're
92 	 * limited here by PID_MAX_LIMIT, so there's no risk of overflow.
93 	 */
94 	atomic_t all_hint;
95 };
96 
97 struct ntsync_q_entry {
98 	struct list_head node;
99 	struct ntsync_q *q;
100 	struct ntsync_obj *obj;
101 	__u32 index;
102 };
103 
104 struct ntsync_q {
105 	struct task_struct *task;
106 	__u32 owner;
107 
108 	/*
109 	 * Protected via atomic_try_cmpxchg(). Only the thread that wins the
110 	 * compare-and-swap may actually change object states and wake this
111 	 * task.
112 	 */
113 	atomic_t signaled;
114 
115 	bool all;
116 	bool ownerdead;
117 	__u32 count;
118 	struct ntsync_q_entry entries[];
119 };
120 
121 struct ntsync_device {
122 	/*
123 	 * Wait-all operations must atomically grab all objects, and be totally
124 	 * ordered with respect to each other and wait-any operations.
125 	 * If one thread is trying to acquire several objects, another thread
126 	 * cannot touch the object at the same time.
127 	 *
128 	 * This device-wide lock is used to serialize wait-for-all
129 	 * operations, and operations on an object that is involved in a
130 	 * wait-for-all.
131 	 */
132 	struct mutex wait_all_lock;
133 
134 	struct file *file;
135 };
136 
137 /*
138  * Single objects are locked using obj->lock.
139  *
140  * Multiple objects are 'locked' while holding dev->wait_all_lock.
141  * In this case however, individual objects are not locked by holding
142  * obj->lock, but by setting obj->dev_locked.
143  *
144  * This means that in order to lock a single object, the sequence is slightly
145  * more complicated than usual. Specifically it needs to check obj->dev_locked
146  * after acquiring obj->lock, if set, it needs to drop the lock and acquire
147  * dev->wait_all_lock in order to serialize against the multi-object operation.
148  */
149 
dev_lock_obj(struct ntsync_device * dev,struct ntsync_obj * obj)150 static void dev_lock_obj(struct ntsync_device *dev, struct ntsync_obj *obj)
151 {
152 	lockdep_assert_held(&dev->wait_all_lock);
153 	lockdep_assert(obj->dev == dev);
154 	spin_lock(&obj->lock);
155 	/*
156 	 * By setting obj->dev_locked inside obj->lock, it is ensured that
157 	 * anyone holding obj->lock must see the value.
158 	 */
159 	obj->dev_locked = 1;
160 	spin_unlock(&obj->lock);
161 }
162 
dev_unlock_obj(struct ntsync_device * dev,struct ntsync_obj * obj)163 static void dev_unlock_obj(struct ntsync_device *dev, struct ntsync_obj *obj)
164 {
165 	lockdep_assert_held(&dev->wait_all_lock);
166 	lockdep_assert(obj->dev == dev);
167 	spin_lock(&obj->lock);
168 	obj->dev_locked = 0;
169 	spin_unlock(&obj->lock);
170 }
171 
obj_lock(struct ntsync_obj * obj)172 static void obj_lock(struct ntsync_obj *obj)
173 {
174 	struct ntsync_device *dev = obj->dev;
175 
176 	for (;;) {
177 		spin_lock(&obj->lock);
178 		if (likely(!obj->dev_locked))
179 			break;
180 
181 		spin_unlock(&obj->lock);
182 		mutex_lock(&dev->wait_all_lock);
183 		spin_lock(&obj->lock);
184 		/*
185 		 * obj->dev_locked should be set and released under the same
186 		 * wait_all_lock section, since we now own this lock, it should
187 		 * be clear.
188 		 */
189 		lockdep_assert(!obj->dev_locked);
190 		spin_unlock(&obj->lock);
191 		mutex_unlock(&dev->wait_all_lock);
192 	}
193 }
194 
obj_unlock(struct ntsync_obj * obj)195 static void obj_unlock(struct ntsync_obj *obj)
196 {
197 	spin_unlock(&obj->lock);
198 }
199 
ntsync_lock_obj(struct ntsync_device * dev,struct ntsync_obj * obj)200 static bool ntsync_lock_obj(struct ntsync_device *dev, struct ntsync_obj *obj)
201 {
202 	bool all;
203 
204 	obj_lock(obj);
205 	all = atomic_read(&obj->all_hint);
206 	if (unlikely(all)) {
207 		obj_unlock(obj);
208 		mutex_lock(&dev->wait_all_lock);
209 		dev_lock_obj(dev, obj);
210 	}
211 
212 	return all;
213 }
214 
ntsync_unlock_obj(struct ntsync_device * dev,struct ntsync_obj * obj,bool all)215 static void ntsync_unlock_obj(struct ntsync_device *dev, struct ntsync_obj *obj, bool all)
216 {
217 	if (all) {
218 		dev_unlock_obj(dev, obj);
219 		mutex_unlock(&dev->wait_all_lock);
220 	} else {
221 		obj_unlock(obj);
222 	}
223 }
224 
225 #define ntsync_assert_held(obj) \
226 	lockdep_assert((lockdep_is_held(&(obj)->lock) != LOCK_STATE_NOT_HELD) || \
227 		       ((lockdep_is_held(&(obj)->dev->wait_all_lock) != LOCK_STATE_NOT_HELD) && \
228 			(obj)->dev_locked))
229 
is_signaled(struct ntsync_obj * obj,__u32 owner)230 static bool is_signaled(struct ntsync_obj *obj, __u32 owner)
231 {
232 	ntsync_assert_held(obj);
233 
234 	switch (obj->type) {
235 	case NTSYNC_TYPE_SEM:
236 		return !!obj->u.sem.count;
237 	case NTSYNC_TYPE_MUTEX:
238 		if (obj->u.mutex.owner && obj->u.mutex.owner != owner)
239 			return false;
240 		return obj->u.mutex.count < UINT_MAX;
241 	case NTSYNC_TYPE_EVENT:
242 		return obj->u.event.signaled;
243 	}
244 
245 	WARN(1, "bad object type %#x\n", obj->type);
246 	return false;
247 }
248 
249 /*
250  * "locked_obj" is an optional pointer to an object which is already locked and
251  * should not be locked again. This is necessary so that changing an object's
252  * state and waking it can be a single atomic operation.
253  */
try_wake_all(struct ntsync_device * dev,struct ntsync_q * q,struct ntsync_obj * locked_obj)254 static void try_wake_all(struct ntsync_device *dev, struct ntsync_q *q,
255 			 struct ntsync_obj *locked_obj)
256 {
257 	__u32 count = q->count;
258 	bool can_wake = true;
259 	int signaled = -1;
260 	__u32 i;
261 
262 	lockdep_assert_held(&dev->wait_all_lock);
263 	if (locked_obj)
264 		lockdep_assert(locked_obj->dev_locked);
265 
266 	for (i = 0; i < count; i++) {
267 		if (q->entries[i].obj != locked_obj)
268 			dev_lock_obj(dev, q->entries[i].obj);
269 	}
270 
271 	for (i = 0; i < count; i++) {
272 		if (!is_signaled(q->entries[i].obj, q->owner)) {
273 			can_wake = false;
274 			break;
275 		}
276 	}
277 
278 	if (can_wake && atomic_try_cmpxchg(&q->signaled, &signaled, 0)) {
279 		for (i = 0; i < count; i++) {
280 			struct ntsync_obj *obj = q->entries[i].obj;
281 
282 			switch (obj->type) {
283 			case NTSYNC_TYPE_SEM:
284 				obj->u.sem.count--;
285 				break;
286 			case NTSYNC_TYPE_MUTEX:
287 				if (obj->u.mutex.ownerdead)
288 					q->ownerdead = true;
289 				obj->u.mutex.ownerdead = false;
290 				obj->u.mutex.count++;
291 				obj->u.mutex.owner = q->owner;
292 				break;
293 			case NTSYNC_TYPE_EVENT:
294 				if (!obj->u.event.manual)
295 					obj->u.event.signaled = false;
296 				break;
297 			}
298 		}
299 		wake_up_process(q->task);
300 	}
301 
302 	for (i = 0; i < count; i++) {
303 		if (q->entries[i].obj != locked_obj)
304 			dev_unlock_obj(dev, q->entries[i].obj);
305 	}
306 }
307 
try_wake_all_obj(struct ntsync_device * dev,struct ntsync_obj * obj)308 static void try_wake_all_obj(struct ntsync_device *dev, struct ntsync_obj *obj)
309 {
310 	struct ntsync_q_entry *entry;
311 
312 	lockdep_assert_held(&dev->wait_all_lock);
313 	lockdep_assert(obj->dev_locked);
314 
315 	list_for_each_entry(entry, &obj->all_waiters, node)
316 		try_wake_all(dev, entry->q, obj);
317 }
318 
try_wake_any_sem(struct ntsync_obj * sem)319 static void try_wake_any_sem(struct ntsync_obj *sem)
320 {
321 	struct ntsync_q_entry *entry;
322 
323 	ntsync_assert_held(sem);
324 	lockdep_assert(sem->type == NTSYNC_TYPE_SEM);
325 
326 	list_for_each_entry(entry, &sem->any_waiters, node) {
327 		struct ntsync_q *q = entry->q;
328 		int signaled = -1;
329 
330 		if (!sem->u.sem.count)
331 			break;
332 
333 		if (atomic_try_cmpxchg(&q->signaled, &signaled, entry->index)) {
334 			sem->u.sem.count--;
335 			wake_up_process(q->task);
336 		}
337 	}
338 }
339 
try_wake_any_mutex(struct ntsync_obj * mutex)340 static void try_wake_any_mutex(struct ntsync_obj *mutex)
341 {
342 	struct ntsync_q_entry *entry;
343 
344 	ntsync_assert_held(mutex);
345 	lockdep_assert(mutex->type == NTSYNC_TYPE_MUTEX);
346 
347 	list_for_each_entry(entry, &mutex->any_waiters, node) {
348 		struct ntsync_q *q = entry->q;
349 		int signaled = -1;
350 
351 		if (mutex->u.mutex.count == UINT_MAX)
352 			break;
353 		if (mutex->u.mutex.owner && mutex->u.mutex.owner != q->owner)
354 			continue;
355 
356 		if (atomic_try_cmpxchg(&q->signaled, &signaled, entry->index)) {
357 			if (mutex->u.mutex.ownerdead)
358 				q->ownerdead = true;
359 			mutex->u.mutex.ownerdead = false;
360 			mutex->u.mutex.count++;
361 			mutex->u.mutex.owner = q->owner;
362 			wake_up_process(q->task);
363 		}
364 	}
365 }
366 
try_wake_any_event(struct ntsync_obj * event)367 static void try_wake_any_event(struct ntsync_obj *event)
368 {
369 	struct ntsync_q_entry *entry;
370 
371 	ntsync_assert_held(event);
372 	lockdep_assert(event->type == NTSYNC_TYPE_EVENT);
373 
374 	list_for_each_entry(entry, &event->any_waiters, node) {
375 		struct ntsync_q *q = entry->q;
376 		int signaled = -1;
377 
378 		if (!event->u.event.signaled)
379 			break;
380 
381 		if (atomic_try_cmpxchg(&q->signaled, &signaled, entry->index)) {
382 			if (!event->u.event.manual)
383 				event->u.event.signaled = false;
384 			wake_up_process(q->task);
385 		}
386 	}
387 }
388 
389 /*
390  * Actually change the semaphore state, returning -EOVERFLOW if it is made
391  * invalid.
392  */
release_sem_state(struct ntsync_obj * sem,__u32 count)393 static int release_sem_state(struct ntsync_obj *sem, __u32 count)
394 {
395 	__u32 sum;
396 
397 	ntsync_assert_held(sem);
398 
399 	if (check_add_overflow(sem->u.sem.count, count, &sum) ||
400 	    sum > sem->u.sem.max)
401 		return -EOVERFLOW;
402 
403 	sem->u.sem.count = sum;
404 	return 0;
405 }
406 
ntsync_sem_release(struct ntsync_obj * sem,void __user * argp)407 static int ntsync_sem_release(struct ntsync_obj *sem, void __user *argp)
408 {
409 	struct ntsync_device *dev = sem->dev;
410 	__u32 __user *user_args = argp;
411 	__u32 prev_count;
412 	__u32 args;
413 	bool all;
414 	int ret;
415 
416 	if (copy_from_user(&args, argp, sizeof(args)))
417 		return -EFAULT;
418 
419 	if (sem->type != NTSYNC_TYPE_SEM)
420 		return -EINVAL;
421 
422 	all = ntsync_lock_obj(dev, sem);
423 
424 	prev_count = sem->u.sem.count;
425 	ret = release_sem_state(sem, args);
426 	if (!ret) {
427 		if (all)
428 			try_wake_all_obj(dev, sem);
429 		try_wake_any_sem(sem);
430 	}
431 
432 	ntsync_unlock_obj(dev, sem, all);
433 
434 	if (!ret && put_user(prev_count, user_args))
435 		ret = -EFAULT;
436 
437 	return ret;
438 }
439 
440 /*
441  * Actually change the mutex state, returning -EPERM if not the owner.
442  */
unlock_mutex_state(struct ntsync_obj * mutex,const struct ntsync_mutex_args * args)443 static int unlock_mutex_state(struct ntsync_obj *mutex,
444 			      const struct ntsync_mutex_args *args)
445 {
446 	ntsync_assert_held(mutex);
447 
448 	if (mutex->u.mutex.owner != args->owner)
449 		return -EPERM;
450 
451 	if (!--mutex->u.mutex.count)
452 		mutex->u.mutex.owner = 0;
453 	return 0;
454 }
455 
ntsync_mutex_unlock(struct ntsync_obj * mutex,void __user * argp)456 static int ntsync_mutex_unlock(struct ntsync_obj *mutex, void __user *argp)
457 {
458 	struct ntsync_mutex_args __user *user_args = argp;
459 	struct ntsync_device *dev = mutex->dev;
460 	struct ntsync_mutex_args args;
461 	__u32 prev_count;
462 	bool all;
463 	int ret;
464 
465 	if (copy_from_user(&args, argp, sizeof(args)))
466 		return -EFAULT;
467 	if (!args.owner)
468 		return -EINVAL;
469 
470 	if (mutex->type != NTSYNC_TYPE_MUTEX)
471 		return -EINVAL;
472 
473 	all = ntsync_lock_obj(dev, mutex);
474 
475 	prev_count = mutex->u.mutex.count;
476 	ret = unlock_mutex_state(mutex, &args);
477 	if (!ret) {
478 		if (all)
479 			try_wake_all_obj(dev, mutex);
480 		try_wake_any_mutex(mutex);
481 	}
482 
483 	ntsync_unlock_obj(dev, mutex, all);
484 
485 	if (!ret && put_user(prev_count, &user_args->count))
486 		ret = -EFAULT;
487 
488 	return ret;
489 }
490 
491 /*
492  * Actually change the mutex state to mark its owner as dead,
493  * returning -EPERM if not the owner.
494  */
kill_mutex_state(struct ntsync_obj * mutex,__u32 owner)495 static int kill_mutex_state(struct ntsync_obj *mutex, __u32 owner)
496 {
497 	ntsync_assert_held(mutex);
498 
499 	if (mutex->u.mutex.owner != owner)
500 		return -EPERM;
501 
502 	mutex->u.mutex.ownerdead = true;
503 	mutex->u.mutex.owner = 0;
504 	mutex->u.mutex.count = 0;
505 	return 0;
506 }
507 
ntsync_mutex_kill(struct ntsync_obj * mutex,void __user * argp)508 static int ntsync_mutex_kill(struct ntsync_obj *mutex, void __user *argp)
509 {
510 	struct ntsync_device *dev = mutex->dev;
511 	__u32 owner;
512 	bool all;
513 	int ret;
514 
515 	if (get_user(owner, (__u32 __user *)argp))
516 		return -EFAULT;
517 	if (!owner)
518 		return -EINVAL;
519 
520 	if (mutex->type != NTSYNC_TYPE_MUTEX)
521 		return -EINVAL;
522 
523 	all = ntsync_lock_obj(dev, mutex);
524 
525 	ret = kill_mutex_state(mutex, owner);
526 	if (!ret) {
527 		if (all)
528 			try_wake_all_obj(dev, mutex);
529 		try_wake_any_mutex(mutex);
530 	}
531 
532 	ntsync_unlock_obj(dev, mutex, all);
533 
534 	return ret;
535 }
536 
ntsync_event_set(struct ntsync_obj * event,void __user * argp,bool pulse)537 static int ntsync_event_set(struct ntsync_obj *event, void __user *argp, bool pulse)
538 {
539 	struct ntsync_device *dev = event->dev;
540 	__u32 prev_state;
541 	bool all;
542 
543 	if (event->type != NTSYNC_TYPE_EVENT)
544 		return -EINVAL;
545 
546 	all = ntsync_lock_obj(dev, event);
547 
548 	prev_state = event->u.event.signaled;
549 	event->u.event.signaled = true;
550 	if (all)
551 		try_wake_all_obj(dev, event);
552 	try_wake_any_event(event);
553 	if (pulse)
554 		event->u.event.signaled = false;
555 
556 	ntsync_unlock_obj(dev, event, all);
557 
558 	if (put_user(prev_state, (__u32 __user *)argp))
559 		return -EFAULT;
560 
561 	return 0;
562 }
563 
ntsync_event_reset(struct ntsync_obj * event,void __user * argp)564 static int ntsync_event_reset(struct ntsync_obj *event, void __user *argp)
565 {
566 	struct ntsync_device *dev = event->dev;
567 	__u32 prev_state;
568 	bool all;
569 
570 	if (event->type != NTSYNC_TYPE_EVENT)
571 		return -EINVAL;
572 
573 	all = ntsync_lock_obj(dev, event);
574 
575 	prev_state = event->u.event.signaled;
576 	event->u.event.signaled = false;
577 
578 	ntsync_unlock_obj(dev, event, all);
579 
580 	if (put_user(prev_state, (__u32 __user *)argp))
581 		return -EFAULT;
582 
583 	return 0;
584 }
585 
ntsync_sem_read(struct ntsync_obj * sem,void __user * argp)586 static int ntsync_sem_read(struct ntsync_obj *sem, void __user *argp)
587 {
588 	struct ntsync_sem_args __user *user_args = argp;
589 	struct ntsync_device *dev = sem->dev;
590 	struct ntsync_sem_args args;
591 	bool all;
592 
593 	if (sem->type != NTSYNC_TYPE_SEM)
594 		return -EINVAL;
595 
596 	all = ntsync_lock_obj(dev, sem);
597 
598 	args.count = sem->u.sem.count;
599 	args.max = sem->u.sem.max;
600 
601 	ntsync_unlock_obj(dev, sem, all);
602 
603 	if (copy_to_user(user_args, &args, sizeof(args)))
604 		return -EFAULT;
605 	return 0;
606 }
607 
ntsync_mutex_read(struct ntsync_obj * mutex,void __user * argp)608 static int ntsync_mutex_read(struct ntsync_obj *mutex, void __user *argp)
609 {
610 	struct ntsync_mutex_args __user *user_args = argp;
611 	struct ntsync_device *dev = mutex->dev;
612 	struct ntsync_mutex_args args;
613 	bool all;
614 	int ret;
615 
616 	if (mutex->type != NTSYNC_TYPE_MUTEX)
617 		return -EINVAL;
618 
619 	all = ntsync_lock_obj(dev, mutex);
620 
621 	args.count = mutex->u.mutex.count;
622 	args.owner = mutex->u.mutex.owner;
623 	ret = mutex->u.mutex.ownerdead ? -EOWNERDEAD : 0;
624 
625 	ntsync_unlock_obj(dev, mutex, all);
626 
627 	if (copy_to_user(user_args, &args, sizeof(args)))
628 		return -EFAULT;
629 	return ret;
630 }
631 
ntsync_event_read(struct ntsync_obj * event,void __user * argp)632 static int ntsync_event_read(struct ntsync_obj *event, void __user *argp)
633 {
634 	struct ntsync_event_args __user *user_args = argp;
635 	struct ntsync_device *dev = event->dev;
636 	struct ntsync_event_args args;
637 	bool all;
638 
639 	if (event->type != NTSYNC_TYPE_EVENT)
640 		return -EINVAL;
641 
642 	all = ntsync_lock_obj(dev, event);
643 
644 	args.manual = event->u.event.manual;
645 	args.signaled = event->u.event.signaled;
646 
647 	ntsync_unlock_obj(dev, event, all);
648 
649 	if (copy_to_user(user_args, &args, sizeof(args)))
650 		return -EFAULT;
651 	return 0;
652 }
653 
ntsync_free_obj(struct ntsync_obj * obj)654 static void ntsync_free_obj(struct ntsync_obj *obj)
655 {
656 	fput(obj->dev->file);
657 	kfree(obj);
658 }
659 
ntsync_obj_release(struct inode * inode,struct file * file)660 static int ntsync_obj_release(struct inode *inode, struct file *file)
661 {
662 	ntsync_free_obj(file->private_data);
663 	return 0;
664 }
665 
ntsync_obj_ioctl(struct file * file,unsigned int cmd,unsigned long parm)666 static long ntsync_obj_ioctl(struct file *file, unsigned int cmd,
667 			     unsigned long parm)
668 {
669 	struct ntsync_obj *obj = file->private_data;
670 	void __user *argp = (void __user *)parm;
671 
672 	switch (cmd) {
673 	case NTSYNC_IOC_SEM_RELEASE:
674 		return ntsync_sem_release(obj, argp);
675 	case NTSYNC_IOC_SEM_READ:
676 		return ntsync_sem_read(obj, argp);
677 	case NTSYNC_IOC_MUTEX_UNLOCK:
678 		return ntsync_mutex_unlock(obj, argp);
679 	case NTSYNC_IOC_MUTEX_KILL:
680 		return ntsync_mutex_kill(obj, argp);
681 	case NTSYNC_IOC_MUTEX_READ:
682 		return ntsync_mutex_read(obj, argp);
683 	case NTSYNC_IOC_EVENT_SET:
684 		return ntsync_event_set(obj, argp, false);
685 	case NTSYNC_IOC_EVENT_RESET:
686 		return ntsync_event_reset(obj, argp);
687 	case NTSYNC_IOC_EVENT_PULSE:
688 		return ntsync_event_set(obj, argp, true);
689 	case NTSYNC_IOC_EVENT_READ:
690 		return ntsync_event_read(obj, argp);
691 	default:
692 		return -ENOIOCTLCMD;
693 	}
694 }
695 
696 static const struct file_operations ntsync_obj_fops = {
697 	.owner		= THIS_MODULE,
698 	.release	= ntsync_obj_release,
699 	.unlocked_ioctl	= ntsync_obj_ioctl,
700 	.compat_ioctl	= compat_ptr_ioctl,
701 };
702 
ntsync_alloc_obj(struct ntsync_device * dev,enum ntsync_type type)703 static struct ntsync_obj *ntsync_alloc_obj(struct ntsync_device *dev,
704 					   enum ntsync_type type)
705 {
706 	struct ntsync_obj *obj;
707 
708 	obj = kzalloc(sizeof(*obj), GFP_KERNEL);
709 	if (!obj)
710 		return NULL;
711 	obj->type = type;
712 	obj->dev = dev;
713 	get_file(dev->file);
714 	spin_lock_init(&obj->lock);
715 	INIT_LIST_HEAD(&obj->any_waiters);
716 	INIT_LIST_HEAD(&obj->all_waiters);
717 	atomic_set(&obj->all_hint, 0);
718 
719 	return obj;
720 }
721 
ntsync_obj_get_fd(struct ntsync_obj * obj)722 static int ntsync_obj_get_fd(struct ntsync_obj *obj)
723 {
724 	struct file *file;
725 	int fd;
726 
727 	fd = get_unused_fd_flags(O_CLOEXEC);
728 	if (fd < 0)
729 		return fd;
730 	file = anon_inode_getfile("ntsync", &ntsync_obj_fops, obj, O_RDWR);
731 	if (IS_ERR(file)) {
732 		put_unused_fd(fd);
733 		return PTR_ERR(file);
734 	}
735 	obj->file = file;
736 	fd_install(fd, file);
737 
738 	return fd;
739 }
740 
ntsync_create_sem(struct ntsync_device * dev,void __user * argp)741 static int ntsync_create_sem(struct ntsync_device *dev, void __user *argp)
742 {
743 	struct ntsync_sem_args args;
744 	struct ntsync_obj *sem;
745 	int fd;
746 
747 	if (copy_from_user(&args, argp, sizeof(args)))
748 		return -EFAULT;
749 
750 	if (args.count > args.max)
751 		return -EINVAL;
752 
753 	sem = ntsync_alloc_obj(dev, NTSYNC_TYPE_SEM);
754 	if (!sem)
755 		return -ENOMEM;
756 	sem->u.sem.count = args.count;
757 	sem->u.sem.max = args.max;
758 	fd = ntsync_obj_get_fd(sem);
759 	if (fd < 0)
760 		ntsync_free_obj(sem);
761 
762 	return fd;
763 }
764 
ntsync_create_mutex(struct ntsync_device * dev,void __user * argp)765 static int ntsync_create_mutex(struct ntsync_device *dev, void __user *argp)
766 {
767 	struct ntsync_mutex_args args;
768 	struct ntsync_obj *mutex;
769 	int fd;
770 
771 	if (copy_from_user(&args, argp, sizeof(args)))
772 		return -EFAULT;
773 
774 	if (!args.owner != !args.count)
775 		return -EINVAL;
776 
777 	mutex = ntsync_alloc_obj(dev, NTSYNC_TYPE_MUTEX);
778 	if (!mutex)
779 		return -ENOMEM;
780 	mutex->u.mutex.count = args.count;
781 	mutex->u.mutex.owner = args.owner;
782 	fd = ntsync_obj_get_fd(mutex);
783 	if (fd < 0)
784 		ntsync_free_obj(mutex);
785 
786 	return fd;
787 }
788 
ntsync_create_event(struct ntsync_device * dev,void __user * argp)789 static int ntsync_create_event(struct ntsync_device *dev, void __user *argp)
790 {
791 	struct ntsync_event_args args;
792 	struct ntsync_obj *event;
793 	int fd;
794 
795 	if (copy_from_user(&args, argp, sizeof(args)))
796 		return -EFAULT;
797 
798 	event = ntsync_alloc_obj(dev, NTSYNC_TYPE_EVENT);
799 	if (!event)
800 		return -ENOMEM;
801 	event->u.event.manual = args.manual;
802 	event->u.event.signaled = args.signaled;
803 	fd = ntsync_obj_get_fd(event);
804 	if (fd < 0)
805 		ntsync_free_obj(event);
806 
807 	return fd;
808 }
809 
get_obj(struct ntsync_device * dev,int fd)810 static struct ntsync_obj *get_obj(struct ntsync_device *dev, int fd)
811 {
812 	struct file *file = fget(fd);
813 	struct ntsync_obj *obj;
814 
815 	if (!file)
816 		return NULL;
817 
818 	if (file->f_op != &ntsync_obj_fops) {
819 		fput(file);
820 		return NULL;
821 	}
822 
823 	obj = file->private_data;
824 	if (obj->dev != dev) {
825 		fput(file);
826 		return NULL;
827 	}
828 
829 	return obj;
830 }
831 
put_obj(struct ntsync_obj * obj)832 static void put_obj(struct ntsync_obj *obj)
833 {
834 	fput(obj->file);
835 }
836 
ntsync_schedule(const struct ntsync_q * q,const struct ntsync_wait_args * args)837 static int ntsync_schedule(const struct ntsync_q *q, const struct ntsync_wait_args *args)
838 {
839 	ktime_t timeout = ns_to_ktime(args->timeout);
840 	clockid_t clock = CLOCK_MONOTONIC;
841 	ktime_t *timeout_ptr;
842 	int ret = 0;
843 
844 	timeout_ptr = (args->timeout == U64_MAX ? NULL : &timeout);
845 
846 	if (args->flags & NTSYNC_WAIT_REALTIME)
847 		clock = CLOCK_REALTIME;
848 
849 	do {
850 		if (signal_pending(current)) {
851 			ret = -ERESTARTSYS;
852 			break;
853 		}
854 
855 		set_current_state(TASK_INTERRUPTIBLE);
856 		if (atomic_read(&q->signaled) != -1) {
857 			ret = 0;
858 			break;
859 		}
860 		ret = schedule_hrtimeout_range_clock(timeout_ptr, 0, HRTIMER_MODE_ABS, clock);
861 	} while (ret < 0);
862 	__set_current_state(TASK_RUNNING);
863 
864 	return ret;
865 }
866 
867 /*
868  * Allocate and initialize the ntsync_q structure, but do not queue us yet.
869  */
setup_wait(struct ntsync_device * dev,const struct ntsync_wait_args * args,bool all,struct ntsync_q ** ret_q)870 static int setup_wait(struct ntsync_device *dev,
871 		      const struct ntsync_wait_args *args, bool all,
872 		      struct ntsync_q **ret_q)
873 {
874 	int fds[NTSYNC_MAX_WAIT_COUNT + 1];
875 	const __u32 count = args->count;
876 	size_t size = array_size(count, sizeof(fds[0]));
877 	struct ntsync_q *q;
878 	__u32 total_count;
879 	__u32 i, j;
880 
881 	if (args->pad || (args->flags & ~NTSYNC_WAIT_REALTIME))
882 		return -EINVAL;
883 
884 	if (size >= sizeof(fds))
885 		return -EINVAL;
886 
887 	total_count = count;
888 	if (args->alert)
889 		total_count++;
890 
891 	if (copy_from_user(fds, u64_to_user_ptr(args->objs), size))
892 		return -EFAULT;
893 	if (args->alert)
894 		fds[count] = args->alert;
895 
896 	q = kmalloc(struct_size(q, entries, total_count), GFP_KERNEL);
897 	if (!q)
898 		return -ENOMEM;
899 	q->task = current;
900 	q->owner = args->owner;
901 	atomic_set(&q->signaled, -1);
902 	q->all = all;
903 	q->ownerdead = false;
904 	q->count = count;
905 
906 	for (i = 0; i < total_count; i++) {
907 		struct ntsync_q_entry *entry = &q->entries[i];
908 		struct ntsync_obj *obj = get_obj(dev, fds[i]);
909 
910 		if (!obj)
911 			goto err;
912 
913 		if (all) {
914 			/* Check that the objects are all distinct. */
915 			for (j = 0; j < i; j++) {
916 				if (obj == q->entries[j].obj) {
917 					put_obj(obj);
918 					goto err;
919 				}
920 			}
921 		}
922 
923 		entry->obj = obj;
924 		entry->q = q;
925 		entry->index = i;
926 	}
927 
928 	*ret_q = q;
929 	return 0;
930 
931 err:
932 	for (j = 0; j < i; j++)
933 		put_obj(q->entries[j].obj);
934 	kfree(q);
935 	return -EINVAL;
936 }
937 
try_wake_any_obj(struct ntsync_obj * obj)938 static void try_wake_any_obj(struct ntsync_obj *obj)
939 {
940 	switch (obj->type) {
941 	case NTSYNC_TYPE_SEM:
942 		try_wake_any_sem(obj);
943 		break;
944 	case NTSYNC_TYPE_MUTEX:
945 		try_wake_any_mutex(obj);
946 		break;
947 	case NTSYNC_TYPE_EVENT:
948 		try_wake_any_event(obj);
949 		break;
950 	}
951 }
952 
ntsync_wait_any(struct ntsync_device * dev,void __user * argp)953 static int ntsync_wait_any(struct ntsync_device *dev, void __user *argp)
954 {
955 	struct ntsync_wait_args args;
956 	__u32 i, total_count;
957 	struct ntsync_q *q;
958 	int signaled;
959 	bool all;
960 	int ret;
961 
962 	if (copy_from_user(&args, argp, sizeof(args)))
963 		return -EFAULT;
964 
965 	ret = setup_wait(dev, &args, false, &q);
966 	if (ret < 0)
967 		return ret;
968 
969 	total_count = args.count;
970 	if (args.alert)
971 		total_count++;
972 
973 	/* queue ourselves */
974 
975 	for (i = 0; i < total_count; i++) {
976 		struct ntsync_q_entry *entry = &q->entries[i];
977 		struct ntsync_obj *obj = entry->obj;
978 
979 		all = ntsync_lock_obj(dev, obj);
980 		list_add_tail(&entry->node, &obj->any_waiters);
981 		ntsync_unlock_obj(dev, obj, all);
982 	}
983 
984 	/*
985 	 * Check if we are already signaled.
986 	 *
987 	 * Note that the API requires that normal objects are checked before
988 	 * the alert event. Hence we queue the alert event last, and check
989 	 * objects in order.
990 	 */
991 
992 	for (i = 0; i < total_count; i++) {
993 		struct ntsync_obj *obj = q->entries[i].obj;
994 
995 		if (atomic_read(&q->signaled) != -1)
996 			break;
997 
998 		all = ntsync_lock_obj(dev, obj);
999 		try_wake_any_obj(obj);
1000 		ntsync_unlock_obj(dev, obj, all);
1001 	}
1002 
1003 	/* sleep */
1004 
1005 	ret = ntsync_schedule(q, &args);
1006 
1007 	/* and finally, unqueue */
1008 
1009 	for (i = 0; i < total_count; i++) {
1010 		struct ntsync_q_entry *entry = &q->entries[i];
1011 		struct ntsync_obj *obj = entry->obj;
1012 
1013 		all = ntsync_lock_obj(dev, obj);
1014 		list_del(&entry->node);
1015 		ntsync_unlock_obj(dev, obj, all);
1016 
1017 		put_obj(obj);
1018 	}
1019 
1020 	signaled = atomic_read(&q->signaled);
1021 	if (signaled != -1) {
1022 		struct ntsync_wait_args __user *user_args = argp;
1023 
1024 		/* even if we caught a signal, we need to communicate success */
1025 		ret = q->ownerdead ? -EOWNERDEAD : 0;
1026 
1027 		if (put_user(signaled, &user_args->index))
1028 			ret = -EFAULT;
1029 	} else if (!ret) {
1030 		ret = -ETIMEDOUT;
1031 	}
1032 
1033 	kfree(q);
1034 	return ret;
1035 }
1036 
ntsync_wait_all(struct ntsync_device * dev,void __user * argp)1037 static int ntsync_wait_all(struct ntsync_device *dev, void __user *argp)
1038 {
1039 	struct ntsync_wait_args args;
1040 	struct ntsync_q *q;
1041 	int signaled;
1042 	__u32 i;
1043 	int ret;
1044 
1045 	if (copy_from_user(&args, argp, sizeof(args)))
1046 		return -EFAULT;
1047 
1048 	ret = setup_wait(dev, &args, true, &q);
1049 	if (ret < 0)
1050 		return ret;
1051 
1052 	/* queue ourselves */
1053 
1054 	mutex_lock(&dev->wait_all_lock);
1055 
1056 	for (i = 0; i < args.count; i++) {
1057 		struct ntsync_q_entry *entry = &q->entries[i];
1058 		struct ntsync_obj *obj = entry->obj;
1059 
1060 		atomic_inc(&obj->all_hint);
1061 
1062 		/*
1063 		 * obj->all_waiters is protected by dev->wait_all_lock rather
1064 		 * than obj->lock, so there is no need to acquire obj->lock
1065 		 * here.
1066 		 */
1067 		list_add_tail(&entry->node, &obj->all_waiters);
1068 	}
1069 	if (args.alert) {
1070 		struct ntsync_q_entry *entry = &q->entries[args.count];
1071 		struct ntsync_obj *obj = entry->obj;
1072 
1073 		dev_lock_obj(dev, obj);
1074 		list_add_tail(&entry->node, &obj->any_waiters);
1075 		dev_unlock_obj(dev, obj);
1076 	}
1077 
1078 	/* check if we are already signaled */
1079 
1080 	try_wake_all(dev, q, NULL);
1081 
1082 	mutex_unlock(&dev->wait_all_lock);
1083 
1084 	/*
1085 	 * Check if the alert event is signaled, making sure to do so only
1086 	 * after checking if the other objects are signaled.
1087 	 */
1088 
1089 	if (args.alert) {
1090 		struct ntsync_obj *obj = q->entries[args.count].obj;
1091 
1092 		if (atomic_read(&q->signaled) == -1) {
1093 			bool all = ntsync_lock_obj(dev, obj);
1094 			try_wake_any_obj(obj);
1095 			ntsync_unlock_obj(dev, obj, all);
1096 		}
1097 	}
1098 
1099 	/* sleep */
1100 
1101 	ret = ntsync_schedule(q, &args);
1102 
1103 	/* and finally, unqueue */
1104 
1105 	mutex_lock(&dev->wait_all_lock);
1106 
1107 	for (i = 0; i < args.count; i++) {
1108 		struct ntsync_q_entry *entry = &q->entries[i];
1109 		struct ntsync_obj *obj = entry->obj;
1110 
1111 		/*
1112 		 * obj->all_waiters is protected by dev->wait_all_lock rather
1113 		 * than obj->lock, so there is no need to acquire it here.
1114 		 */
1115 		list_del(&entry->node);
1116 
1117 		atomic_dec(&obj->all_hint);
1118 
1119 		put_obj(obj);
1120 	}
1121 
1122 	mutex_unlock(&dev->wait_all_lock);
1123 
1124 	if (args.alert) {
1125 		struct ntsync_q_entry *entry = &q->entries[args.count];
1126 		struct ntsync_obj *obj = entry->obj;
1127 		bool all;
1128 
1129 		all = ntsync_lock_obj(dev, obj);
1130 		list_del(&entry->node);
1131 		ntsync_unlock_obj(dev, obj, all);
1132 
1133 		put_obj(obj);
1134 	}
1135 
1136 	signaled = atomic_read(&q->signaled);
1137 	if (signaled != -1) {
1138 		struct ntsync_wait_args __user *user_args = argp;
1139 
1140 		/* even if we caught a signal, we need to communicate success */
1141 		ret = q->ownerdead ? -EOWNERDEAD : 0;
1142 
1143 		if (put_user(signaled, &user_args->index))
1144 			ret = -EFAULT;
1145 	} else if (!ret) {
1146 		ret = -ETIMEDOUT;
1147 	}
1148 
1149 	kfree(q);
1150 	return ret;
1151 }
1152 
ntsync_char_open(struct inode * inode,struct file * file)1153 static int ntsync_char_open(struct inode *inode, struct file *file)
1154 {
1155 	struct ntsync_device *dev;
1156 
1157 	dev = kzalloc(sizeof(*dev), GFP_KERNEL);
1158 	if (!dev)
1159 		return -ENOMEM;
1160 
1161 	mutex_init(&dev->wait_all_lock);
1162 
1163 	file->private_data = dev;
1164 	dev->file = file;
1165 	return nonseekable_open(inode, file);
1166 }
1167 
ntsync_char_release(struct inode * inode,struct file * file)1168 static int ntsync_char_release(struct inode *inode, struct file *file)
1169 {
1170 	struct ntsync_device *dev = file->private_data;
1171 
1172 	kfree(dev);
1173 
1174 	return 0;
1175 }
1176 
ntsync_char_ioctl(struct file * file,unsigned int cmd,unsigned long parm)1177 static long ntsync_char_ioctl(struct file *file, unsigned int cmd,
1178 			      unsigned long parm)
1179 {
1180 	struct ntsync_device *dev = file->private_data;
1181 	void __user *argp = (void __user *)parm;
1182 
1183 	switch (cmd) {
1184 	case NTSYNC_IOC_CREATE_EVENT:
1185 		return ntsync_create_event(dev, argp);
1186 	case NTSYNC_IOC_CREATE_MUTEX:
1187 		return ntsync_create_mutex(dev, argp);
1188 	case NTSYNC_IOC_CREATE_SEM:
1189 		return ntsync_create_sem(dev, argp);
1190 	case NTSYNC_IOC_WAIT_ALL:
1191 		return ntsync_wait_all(dev, argp);
1192 	case NTSYNC_IOC_WAIT_ANY:
1193 		return ntsync_wait_any(dev, argp);
1194 	default:
1195 		return -ENOIOCTLCMD;
1196 	}
1197 }
1198 
1199 static const struct file_operations ntsync_fops = {
1200 	.owner		= THIS_MODULE,
1201 	.open		= ntsync_char_open,
1202 	.release	= ntsync_char_release,
1203 	.unlocked_ioctl	= ntsync_char_ioctl,
1204 	.compat_ioctl	= compat_ptr_ioctl,
1205 };
1206 
1207 static struct miscdevice ntsync_misc = {
1208 	.minor		= MISC_DYNAMIC_MINOR,
1209 	.name		= NTSYNC_NAME,
1210 	.fops		= &ntsync_fops,
1211 	.mode		= 0666,
1212 };
1213 
1214 module_misc_device(ntsync_misc);
1215 
1216 MODULE_AUTHOR("Elizabeth Figura <zfigura@codeweavers.com>");
1217 MODULE_DESCRIPTION("Kernel driver for NT synchronization primitives");
1218 MODULE_LICENSE("GPL");
1219