1 // SPDX-License-Identifier: GPL-2.0-only 2 /* 3 * Copyright (C) 2021 Oracle Corporation 4 */ 5 #include <linux/slab.h> 6 #include <linux/completion.h> 7 #include <linux/sched/task.h> 8 #include <linux/sched/vhost_task.h> 9 #include <linux/sched/signal.h> 10 11 enum vhost_task_flags { 12 VHOST_TASK_FLAGS_STOP, 13 VHOST_TASK_FLAGS_KILLED, 14 }; 15 16 struct vhost_task { 17 bool (*fn)(void *data); 18 void (*handle_sigkill)(void *data); 19 void *data; 20 struct completion exited; 21 unsigned long flags; 22 struct task_struct *task; 23 /* serialize SIGKILL and vhost_task_stop calls */ 24 struct mutex exit_mutex; 25 }; 26 27 static int vhost_task_fn(void *data) 28 { 29 struct vhost_task *vtsk = data; 30 31 for (;;) { 32 bool did_work; 33 34 if (signal_pending(current)) { 35 struct ksignal ksig; 36 37 if (get_signal(&ksig)) 38 break; 39 } 40 41 /* mb paired w/ vhost_task_stop */ 42 set_current_state(TASK_INTERRUPTIBLE); 43 44 if (test_bit(VHOST_TASK_FLAGS_STOP, &vtsk->flags)) { 45 __set_current_state(TASK_RUNNING); 46 break; 47 } 48 49 did_work = vtsk->fn(vtsk->data); 50 if (!did_work) 51 schedule(); 52 } 53 54 mutex_lock(&vtsk->exit_mutex); 55 /* 56 * If a vhost_task_stop and SIGKILL race, we can ignore the SIGKILL. 57 * When the vhost layer has called vhost_task_stop it's already stopped 58 * new work and flushed. 59 */ 60 if (!test_bit(VHOST_TASK_FLAGS_STOP, &vtsk->flags)) { 61 set_bit(VHOST_TASK_FLAGS_KILLED, &vtsk->flags); 62 vtsk->handle_sigkill(vtsk->data); 63 } 64 mutex_unlock(&vtsk->exit_mutex); 65 complete(&vtsk->exited); 66 67 do_exit(0); 68 } 69 70 /** 71 * vhost_task_wake - wakeup the vhost_task 72 * @vtsk: vhost_task to wake 73 * 74 * wake up the vhost_task worker thread 75 */ 76 void vhost_task_wake(struct vhost_task *vtsk) 77 { 78 wake_up_process(vtsk->task); 79 } 80 EXPORT_SYMBOL_GPL(vhost_task_wake); 81 82 /** 83 * vhost_task_stop - stop a vhost_task 84 * @vtsk: vhost_task to stop 85 * 86 * vhost_task_fn ensures the worker thread exits after 87 * VHOST_TASK_FLAGS_STOP becomes true. 88 */ 89 void vhost_task_stop(struct vhost_task *vtsk) 90 { 91 mutex_lock(&vtsk->exit_mutex); 92 if (!test_bit(VHOST_TASK_FLAGS_KILLED, &vtsk->flags)) { 93 set_bit(VHOST_TASK_FLAGS_STOP, &vtsk->flags); 94 vhost_task_wake(vtsk); 95 } 96 mutex_unlock(&vtsk->exit_mutex); 97 98 /* 99 * Make sure vhost_task_fn is no longer accessing the vhost_task before 100 * freeing it below. 101 */ 102 wait_for_completion(&vtsk->exited); 103 put_task_struct(vtsk->task); 104 kfree(vtsk); 105 } 106 EXPORT_SYMBOL_GPL(vhost_task_stop); 107 108 /** 109 * vhost_task_create - create a copy of a task to be used by the kernel 110 * @fn: vhost worker function 111 * @handle_sigkill: vhost function to handle when we are killed 112 * @arg: data to be passed to fn and handled_kill 113 * @name: the thread's name 114 * 115 * This returns a specialized task for use by the vhost layer or ERR_PTR() on 116 * failure. The returned task is inactive, and the caller must fire it up 117 * through vhost_task_start(). 118 */ 119 struct vhost_task *vhost_task_create(bool (*fn)(void *), 120 void (*handle_sigkill)(void *), void *arg, 121 const char *name) 122 { 123 struct kernel_clone_args args = { 124 .flags = CLONE_FS | CLONE_UNTRACED | CLONE_VM | 125 CLONE_THREAD | CLONE_SIGHAND, 126 .exit_signal = 0, 127 .fn = vhost_task_fn, 128 .name = name, 129 .user_worker = 1, 130 .no_files = 1, 131 }; 132 struct vhost_task *vtsk; 133 struct task_struct *tsk; 134 135 vtsk = kzalloc(sizeof(*vtsk), GFP_KERNEL); 136 if (!vtsk) 137 return ERR_PTR(-ENOMEM); 138 init_completion(&vtsk->exited); 139 mutex_init(&vtsk->exit_mutex); 140 vtsk->data = arg; 141 vtsk->fn = fn; 142 vtsk->handle_sigkill = handle_sigkill; 143 144 args.fn_arg = vtsk; 145 146 tsk = copy_process(NULL, 0, NUMA_NO_NODE, &args); 147 if (IS_ERR(tsk)) { 148 kfree(vtsk); 149 return ERR_CAST(tsk); 150 } 151 152 vtsk->task = get_task_struct(tsk); 153 return vtsk; 154 } 155 EXPORT_SYMBOL_GPL(vhost_task_create); 156 157 /** 158 * vhost_task_start - start a vhost_task created with vhost_task_create 159 * @vtsk: vhost_task to wake up 160 */ 161 void vhost_task_start(struct vhost_task *vtsk) 162 { 163 wake_up_new_task(vtsk->task); 164 } 165 EXPORT_SYMBOL_GPL(vhost_task_start); 166