xref: /linux/drivers/fwctl/main.c (revision fb39e9092be5a18eaab05b5a2492741fe6e395fe)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES
4  */
5 #define pr_fmt(fmt) "fwctl: " fmt
6 #include <linux/fwctl.h>
7 
8 #include <linux/container_of.h>
9 #include <linux/fs.h>
10 #include <linux/module.h>
11 #include <linux/slab.h>
12 
13 #include <uapi/fwctl/fwctl.h>
14 
15 enum {
16 	FWCTL_MAX_DEVICES = 4096,
17 };
18 static_assert(FWCTL_MAX_DEVICES < (1U << MINORBITS));
19 
20 static dev_t fwctl_dev;
21 static DEFINE_IDA(fwctl_ida);
22 
23 struct fwctl_ucmd {
24 	struct fwctl_uctx *uctx;
25 	void __user *ubuffer;
26 	void *cmd;
27 	u32 user_size;
28 };
29 
30 static int ucmd_respond(struct fwctl_ucmd *ucmd, size_t cmd_len)
31 {
32 	if (copy_to_user(ucmd->ubuffer, ucmd->cmd,
33 			 min_t(size_t, ucmd->user_size, cmd_len)))
34 		return -EFAULT;
35 	return 0;
36 }
37 
38 static int copy_to_user_zero_pad(void __user *to, const void *from,
39 				 size_t from_len, size_t user_len)
40 {
41 	size_t copy_len;
42 
43 	copy_len = min(from_len, user_len);
44 	if (copy_to_user(to, from, copy_len))
45 		return -EFAULT;
46 	if (copy_len < user_len) {
47 		if (clear_user(to + copy_len, user_len - copy_len))
48 			return -EFAULT;
49 	}
50 	return 0;
51 }
52 
53 static int fwctl_cmd_info(struct fwctl_ucmd *ucmd)
54 {
55 	struct fwctl_device *fwctl = ucmd->uctx->fwctl;
56 	struct fwctl_info *cmd = ucmd->cmd;
57 	size_t driver_info_len = 0;
58 
59 	if (cmd->flags)
60 		return -EOPNOTSUPP;
61 
62 	if (!fwctl->ops->info && cmd->device_data_len) {
63 		if (clear_user(u64_to_user_ptr(cmd->out_device_data),
64 			       cmd->device_data_len))
65 			return -EFAULT;
66 	} else if (cmd->device_data_len) {
67 		void *driver_info __free(kfree) =
68 			fwctl->ops->info(ucmd->uctx, &driver_info_len);
69 		if (IS_ERR(driver_info))
70 			return PTR_ERR(driver_info);
71 
72 		if (copy_to_user_zero_pad(u64_to_user_ptr(cmd->out_device_data),
73 					  driver_info, driver_info_len,
74 					  cmd->device_data_len))
75 			return -EFAULT;
76 	}
77 
78 	cmd->out_device_type = fwctl->ops->device_type;
79 	cmd->device_data_len = driver_info_len;
80 	return ucmd_respond(ucmd, sizeof(*cmd));
81 }
82 
83 /* On stack memory for the ioctl structs */
84 union fwctl_ucmd_buffer {
85 	struct fwctl_info info;
86 };
87 
88 struct fwctl_ioctl_op {
89 	unsigned int size;
90 	unsigned int min_size;
91 	unsigned int ioctl_num;
92 	int (*execute)(struct fwctl_ucmd *ucmd);
93 };
94 
95 #define IOCTL_OP(_ioctl, _fn, _struct, _last)                               \
96 	[_IOC_NR(_ioctl) - FWCTL_CMD_BASE] = {                              \
97 		.size = sizeof(_struct) +                                   \
98 			BUILD_BUG_ON_ZERO(sizeof(union fwctl_ucmd_buffer) < \
99 					  sizeof(_struct)),                 \
100 		.min_size = offsetofend(_struct, _last),                    \
101 		.ioctl_num = _ioctl,                                        \
102 		.execute = _fn,                                             \
103 	}
104 static const struct fwctl_ioctl_op fwctl_ioctl_ops[] = {
105 	IOCTL_OP(FWCTL_INFO, fwctl_cmd_info, struct fwctl_info, out_device_data),
106 };
107 
108 static long fwctl_fops_ioctl(struct file *filp, unsigned int cmd,
109 			       unsigned long arg)
110 {
111 	struct fwctl_uctx *uctx = filp->private_data;
112 	const struct fwctl_ioctl_op *op;
113 	struct fwctl_ucmd ucmd = {};
114 	union fwctl_ucmd_buffer buf;
115 	unsigned int nr;
116 	int ret;
117 
118 	nr = _IOC_NR(cmd);
119 	if ((nr - FWCTL_CMD_BASE) >= ARRAY_SIZE(fwctl_ioctl_ops))
120 		return -ENOIOCTLCMD;
121 
122 	op = &fwctl_ioctl_ops[nr - FWCTL_CMD_BASE];
123 	if (op->ioctl_num != cmd)
124 		return -ENOIOCTLCMD;
125 
126 	ucmd.uctx = uctx;
127 	ucmd.cmd = &buf;
128 	ucmd.ubuffer = (void __user *)arg;
129 	ret = get_user(ucmd.user_size, (u32 __user *)ucmd.ubuffer);
130 	if (ret)
131 		return ret;
132 
133 	if (ucmd.user_size < op->min_size)
134 		return -EINVAL;
135 
136 	ret = copy_struct_from_user(ucmd.cmd, op->size, ucmd.ubuffer,
137 				    ucmd.user_size);
138 	if (ret)
139 		return ret;
140 
141 	guard(rwsem_read)(&uctx->fwctl->registration_lock);
142 	if (!uctx->fwctl->ops)
143 		return -ENODEV;
144 	return op->execute(&ucmd);
145 }
146 
147 static int fwctl_fops_open(struct inode *inode, struct file *filp)
148 {
149 	struct fwctl_device *fwctl =
150 		container_of(inode->i_cdev, struct fwctl_device, cdev);
151 	int ret;
152 
153 	guard(rwsem_read)(&fwctl->registration_lock);
154 	if (!fwctl->ops)
155 		return -ENODEV;
156 
157 	struct fwctl_uctx *uctx __free(kfree) =
158 		kzalloc(fwctl->ops->uctx_size, GFP_KERNEL_ACCOUNT);
159 	if (!uctx)
160 		return -ENOMEM;
161 
162 	uctx->fwctl = fwctl;
163 	ret = fwctl->ops->open_uctx(uctx);
164 	if (ret)
165 		return ret;
166 
167 	scoped_guard(mutex, &fwctl->uctx_list_lock) {
168 		list_add_tail(&uctx->uctx_list_entry, &fwctl->uctx_list);
169 	}
170 
171 	get_device(&fwctl->dev);
172 	filp->private_data = no_free_ptr(uctx);
173 	return 0;
174 }
175 
176 static void fwctl_destroy_uctx(struct fwctl_uctx *uctx)
177 {
178 	lockdep_assert_held(&uctx->fwctl->uctx_list_lock);
179 	list_del(&uctx->uctx_list_entry);
180 	uctx->fwctl->ops->close_uctx(uctx);
181 }
182 
183 static int fwctl_fops_release(struct inode *inode, struct file *filp)
184 {
185 	struct fwctl_uctx *uctx = filp->private_data;
186 	struct fwctl_device *fwctl = uctx->fwctl;
187 
188 	scoped_guard(rwsem_read, &fwctl->registration_lock) {
189 		/*
190 		 * NULL ops means fwctl_unregister() has already removed the
191 		 * driver and destroyed the uctx.
192 		 */
193 		if (fwctl->ops) {
194 			guard(mutex)(&fwctl->uctx_list_lock);
195 			fwctl_destroy_uctx(uctx);
196 		}
197 	}
198 
199 	kfree(uctx);
200 	fwctl_put(fwctl);
201 	return 0;
202 }
203 
204 static const struct file_operations fwctl_fops = {
205 	.owner = THIS_MODULE,
206 	.open = fwctl_fops_open,
207 	.release = fwctl_fops_release,
208 	.unlocked_ioctl = fwctl_fops_ioctl,
209 };
210 
211 static void fwctl_device_release(struct device *device)
212 {
213 	struct fwctl_device *fwctl =
214 		container_of(device, struct fwctl_device, dev);
215 
216 	ida_free(&fwctl_ida, fwctl->dev.devt - fwctl_dev);
217 	mutex_destroy(&fwctl->uctx_list_lock);
218 	kfree(fwctl);
219 }
220 
221 static char *fwctl_devnode(const struct device *dev, umode_t *mode)
222 {
223 	return kasprintf(GFP_KERNEL, "fwctl/%s", dev_name(dev));
224 }
225 
226 static struct class fwctl_class = {
227 	.name = "fwctl",
228 	.dev_release = fwctl_device_release,
229 	.devnode = fwctl_devnode,
230 };
231 
232 static struct fwctl_device *
233 _alloc_device(struct device *parent, const struct fwctl_ops *ops, size_t size)
234 {
235 	struct fwctl_device *fwctl __free(kfree) = kzalloc(size, GFP_KERNEL);
236 	int devnum;
237 
238 	if (!fwctl)
239 		return NULL;
240 
241 	devnum = ida_alloc_max(&fwctl_ida, FWCTL_MAX_DEVICES - 1, GFP_KERNEL);
242 	if (devnum < 0)
243 		return NULL;
244 
245 	fwctl->dev.devt = fwctl_dev + devnum;
246 	fwctl->dev.class = &fwctl_class;
247 	fwctl->dev.parent = parent;
248 
249 	init_rwsem(&fwctl->registration_lock);
250 	mutex_init(&fwctl->uctx_list_lock);
251 	INIT_LIST_HEAD(&fwctl->uctx_list);
252 
253 	device_initialize(&fwctl->dev);
254 	return_ptr(fwctl);
255 }
256 
257 /* Drivers use the fwctl_alloc_device() wrapper */
258 struct fwctl_device *_fwctl_alloc_device(struct device *parent,
259 					 const struct fwctl_ops *ops,
260 					 size_t size)
261 {
262 	struct fwctl_device *fwctl __free(fwctl) =
263 		_alloc_device(parent, ops, size);
264 
265 	if (!fwctl)
266 		return NULL;
267 
268 	cdev_init(&fwctl->cdev, &fwctl_fops);
269 	/*
270 	 * The driver module is protected by fwctl_register/unregister(),
271 	 * unregister won't complete until we are done with the driver's module.
272 	 */
273 	fwctl->cdev.owner = THIS_MODULE;
274 
275 	if (dev_set_name(&fwctl->dev, "fwctl%d", fwctl->dev.devt - fwctl_dev))
276 		return NULL;
277 
278 	fwctl->ops = ops;
279 	return_ptr(fwctl);
280 }
281 EXPORT_SYMBOL_NS_GPL(_fwctl_alloc_device, "FWCTL");
282 
283 /**
284  * fwctl_register - Register a new device to the subsystem
285  * @fwctl: Previously allocated fwctl_device
286  *
287  * On return the device is visible through sysfs and /dev, driver ops may be
288  * called.
289  */
290 int fwctl_register(struct fwctl_device *fwctl)
291 {
292 	return cdev_device_add(&fwctl->cdev, &fwctl->dev);
293 }
294 EXPORT_SYMBOL_NS_GPL(fwctl_register, "FWCTL");
295 
296 /**
297  * fwctl_unregister - Unregister a device from the subsystem
298  * @fwctl: Previously allocated and registered fwctl_device
299  *
300  * Undoes fwctl_register(). On return no driver ops will be called. The
301  * caller must still call fwctl_put() to free the fwctl.
302  *
303  * Unregister will return even if userspace still has file descriptors open.
304  * This will call ops->close_uctx() on any open FDs and after return no driver
305  * op will be called. The FDs remain open but all fops will return -ENODEV.
306  *
307  * The design of fwctl allows this sort of disassociation of the driver from the
308  * subsystem primarily by keeping memory allocations owned by the core subsytem.
309  * The fwctl_device and fwctl_uctx can both be freed without requiring a driver
310  * callback. This allows the module to remain unlocked while FDs are open.
311  */
312 void fwctl_unregister(struct fwctl_device *fwctl)
313 {
314 	struct fwctl_uctx *uctx;
315 
316 	cdev_device_del(&fwctl->cdev, &fwctl->dev);
317 
318 	/* Disable and free the driver's resources for any still open FDs. */
319 	guard(rwsem_write)(&fwctl->registration_lock);
320 	guard(mutex)(&fwctl->uctx_list_lock);
321 	while ((uctx = list_first_entry_or_null(&fwctl->uctx_list,
322 						struct fwctl_uctx,
323 						uctx_list_entry)))
324 		fwctl_destroy_uctx(uctx);
325 
326 	/*
327 	 * The driver module may unload after this returns, the op pointer will
328 	 * not be valid.
329 	 */
330 	fwctl->ops = NULL;
331 }
332 EXPORT_SYMBOL_NS_GPL(fwctl_unregister, "FWCTL");
333 
334 static int __init fwctl_init(void)
335 {
336 	int ret;
337 
338 	ret = alloc_chrdev_region(&fwctl_dev, 0, FWCTL_MAX_DEVICES, "fwctl");
339 	if (ret)
340 		return ret;
341 
342 	ret = class_register(&fwctl_class);
343 	if (ret)
344 		goto err_chrdev;
345 	return 0;
346 
347 err_chrdev:
348 	unregister_chrdev_region(fwctl_dev, FWCTL_MAX_DEVICES);
349 	return ret;
350 }
351 
352 static void __exit fwctl_exit(void)
353 {
354 	class_unregister(&fwctl_class);
355 	unregister_chrdev_region(fwctl_dev, FWCTL_MAX_DEVICES);
356 }
357 
358 module_init(fwctl_init);
359 module_exit(fwctl_exit);
360 MODULE_DESCRIPTION("fwctl device firmware access framework");
361 MODULE_LICENSE("GPL");
362