xref: /linux/drivers/accel/rocket/rocket_drv.c (revision 658ebeac33517bd3169d4b65ed801e9065d0211a)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /* Copyright 2024-2025 Tomeu Vizoso <tomeu@tomeuvizoso.net> */
3 
4 #include <drm/drm_accel.h>
5 #include <drm/drm_drv.h>
6 #include <drm/drm_gem.h>
7 #include <drm/drm_ioctl.h>
8 #include <drm/rocket_accel.h>
9 #include <linux/clk.h>
10 #include <linux/err.h>
11 #include <linux/iommu.h>
12 #include <linux/of.h>
13 #include <linux/platform_device.h>
14 #include <linux/pm_runtime.h>
15 
16 #include "rocket_drv.h"
17 #include "rocket_gem.h"
18 
19 /*
20  * Facade device, used to expose a single DRM device to userspace, that
21  * schedules jobs to any RKNN cores in the system.
22  */
23 static struct platform_device *drm_dev;
24 static struct rocket_device *rdev;
25 
26 static void
27 rocket_iommu_domain_destroy(struct kref *kref)
28 {
29 	struct rocket_iommu_domain *domain = container_of(kref, struct rocket_iommu_domain, kref);
30 
31 	iommu_domain_free(domain->domain);
32 	domain->domain = NULL;
33 	kfree(domain);
34 }
35 
36 static struct rocket_iommu_domain*
37 rocket_iommu_domain_create(struct device *dev)
38 {
39 	struct rocket_iommu_domain *domain = kmalloc(sizeof(*domain), GFP_KERNEL);
40 	void *err;
41 
42 	if (!domain)
43 		return ERR_PTR(-ENOMEM);
44 
45 	domain->domain = iommu_paging_domain_alloc(dev);
46 	if (IS_ERR(domain->domain)) {
47 		err = ERR_CAST(domain->domain);
48 		kfree(domain);
49 		return err;
50 	}
51 	kref_init(&domain->kref);
52 
53 	return domain;
54 }
55 
56 struct rocket_iommu_domain *
57 rocket_iommu_domain_get(struct rocket_file_priv *rocket_priv)
58 {
59 	kref_get(&rocket_priv->domain->kref);
60 	return rocket_priv->domain;
61 }
62 
63 void
64 rocket_iommu_domain_put(struct rocket_iommu_domain *domain)
65 {
66 	kref_put(&domain->kref, rocket_iommu_domain_destroy);
67 }
68 
69 static int
70 rocket_open(struct drm_device *dev, struct drm_file *file)
71 {
72 	struct rocket_device *rdev = to_rocket_device(dev);
73 	struct rocket_file_priv *rocket_priv;
74 	u64 start, end;
75 	int ret;
76 
77 	if (!try_module_get(THIS_MODULE))
78 		return -EINVAL;
79 
80 	rocket_priv = kzalloc(sizeof(*rocket_priv), GFP_KERNEL);
81 	if (!rocket_priv) {
82 		ret = -ENOMEM;
83 		goto err_put_mod;
84 	}
85 
86 	rocket_priv->rdev = rdev;
87 	rocket_priv->domain = rocket_iommu_domain_create(rdev->cores[0].dev);
88 	if (IS_ERR(rocket_priv->domain)) {
89 		ret = PTR_ERR(rocket_priv->domain);
90 		goto err_free;
91 	}
92 
93 	file->driver_priv = rocket_priv;
94 
95 	start = rocket_priv->domain->domain->geometry.aperture_start;
96 	end = rocket_priv->domain->domain->geometry.aperture_end;
97 	drm_mm_init(&rocket_priv->mm, start, end - start + 1);
98 	mutex_init(&rocket_priv->mm_lock);
99 
100 	return 0;
101 
102 err_free:
103 	kfree(rocket_priv);
104 err_put_mod:
105 	module_put(THIS_MODULE);
106 	return ret;
107 }
108 
109 static void
110 rocket_postclose(struct drm_device *dev, struct drm_file *file)
111 {
112 	struct rocket_file_priv *rocket_priv = file->driver_priv;
113 
114 	mutex_destroy(&rocket_priv->mm_lock);
115 	drm_mm_takedown(&rocket_priv->mm);
116 	rocket_iommu_domain_put(rocket_priv->domain);
117 	kfree(rocket_priv);
118 	module_put(THIS_MODULE);
119 }
120 
121 static const struct drm_ioctl_desc rocket_drm_driver_ioctls[] = {
122 #define ROCKET_IOCTL(n, func) \
123 	DRM_IOCTL_DEF_DRV(ROCKET_##n, rocket_ioctl_##func, 0)
124 
125 	ROCKET_IOCTL(CREATE_BO, create_bo),
126 };
127 
128 DEFINE_DRM_ACCEL_FOPS(rocket_accel_driver_fops);
129 
130 /*
131  * Rocket driver version:
132  * - 1.0 - initial interface
133  */
134 static const struct drm_driver rocket_drm_driver = {
135 	.driver_features	= DRIVER_COMPUTE_ACCEL | DRIVER_GEM,
136 	.open			= rocket_open,
137 	.postclose		= rocket_postclose,
138 	.gem_create_object	= rocket_gem_create_object,
139 	.ioctls			= rocket_drm_driver_ioctls,
140 	.num_ioctls		= ARRAY_SIZE(rocket_drm_driver_ioctls),
141 	.fops			= &rocket_accel_driver_fops,
142 	.name			= "rocket",
143 	.desc			= "rocket DRM",
144 };
145 
146 static int rocket_probe(struct platform_device *pdev)
147 {
148 	if (rdev == NULL) {
149 		/* First core probing, initialize DRM device. */
150 		rdev = rocket_device_init(drm_dev, &rocket_drm_driver);
151 		if (IS_ERR(rdev)) {
152 			dev_err(&pdev->dev, "failed to initialize rocket device\n");
153 			return PTR_ERR(rdev);
154 		}
155 	}
156 
157 	unsigned int core = rdev->num_cores;
158 
159 	dev_set_drvdata(&pdev->dev, rdev);
160 
161 	rdev->cores[core].rdev = rdev;
162 	rdev->cores[core].dev = &pdev->dev;
163 	rdev->cores[core].index = core;
164 
165 	rdev->num_cores++;
166 
167 	return rocket_core_init(&rdev->cores[core]);
168 }
169 
170 static void rocket_remove(struct platform_device *pdev)
171 {
172 	struct device *dev = &pdev->dev;
173 
174 	for (unsigned int core = 0; core < rdev->num_cores; core++) {
175 		if (rdev->cores[core].dev == dev) {
176 			rocket_core_fini(&rdev->cores[core]);
177 			rdev->num_cores--;
178 			break;
179 		}
180 	}
181 
182 	if (rdev->num_cores == 0) {
183 		/* Last core removed, deinitialize DRM device. */
184 		rocket_device_fini(rdev);
185 		rdev = NULL;
186 	}
187 }
188 
189 static const struct of_device_id dt_match[] = {
190 	{ .compatible = "rockchip,rk3588-rknn-core" },
191 	{}
192 };
193 MODULE_DEVICE_TABLE(of, dt_match);
194 
195 static int find_core_for_dev(struct device *dev)
196 {
197 	struct rocket_device *rdev = dev_get_drvdata(dev);
198 
199 	for (unsigned int core = 0; core < rdev->num_cores; core++) {
200 		if (dev == rdev->cores[core].dev)
201 			return core;
202 	}
203 
204 	return -1;
205 }
206 
207 static int rocket_device_runtime_resume(struct device *dev)
208 {
209 	struct rocket_device *rdev = dev_get_drvdata(dev);
210 	int core = find_core_for_dev(dev);
211 	int err = 0;
212 
213 	if (core < 0)
214 		return -ENODEV;
215 
216 	err = clk_bulk_prepare_enable(ARRAY_SIZE(rdev->cores[core].clks), rdev->cores[core].clks);
217 	if (err) {
218 		dev_err(dev, "failed to enable (%d) clocks for core %d\n", err, core);
219 		return err;
220 	}
221 
222 	return 0;
223 }
224 
225 static int rocket_device_runtime_suspend(struct device *dev)
226 {
227 	struct rocket_device *rdev = dev_get_drvdata(dev);
228 	int core = find_core_for_dev(dev);
229 
230 	if (core < 0)
231 		return -ENODEV;
232 
233 	clk_bulk_disable_unprepare(ARRAY_SIZE(rdev->cores[core].clks), rdev->cores[core].clks);
234 
235 	return 0;
236 }
237 
238 EXPORT_GPL_DEV_PM_OPS(rocket_pm_ops) = {
239 	RUNTIME_PM_OPS(rocket_device_runtime_suspend, rocket_device_runtime_resume, NULL)
240 	SYSTEM_SLEEP_PM_OPS(pm_runtime_force_suspend, pm_runtime_force_resume)
241 };
242 
243 static struct platform_driver rocket_driver = {
244 	.probe = rocket_probe,
245 	.remove = rocket_remove,
246 	.driver	 = {
247 		.name = "rocket",
248 		.pm = pm_ptr(&rocket_pm_ops),
249 		.of_match_table = dt_match,
250 	},
251 };
252 
253 static int __init rocket_register(void)
254 {
255 	drm_dev = platform_device_register_simple("rknn", -1, NULL, 0);
256 	if (IS_ERR(drm_dev))
257 		return PTR_ERR(drm_dev);
258 
259 	return platform_driver_register(&rocket_driver);
260 }
261 
262 static void __exit rocket_unregister(void)
263 {
264 	platform_driver_unregister(&rocket_driver);
265 
266 	platform_device_unregister(drm_dev);
267 }
268 
269 module_init(rocket_register);
270 module_exit(rocket_unregister);
271 
272 MODULE_LICENSE("GPL");
273 MODULE_DESCRIPTION("DRM driver for the Rockchip NPU IP");
274 MODULE_AUTHOR("Tomeu Vizoso");
275