xref: /linux/drivers/dibs/dibs_loopback.c (revision 07fdad3a93756b872da7b53647715c48d0f4a2d0)
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  *  Functions for dibs loopback/loopback-ism device.
4  *
5  *  Copyright (c) 2024, Alibaba Inc.
6  *
7  *  Author: Wen Gu <guwen@linux.alibaba.com>
8  *          Tony Lu <tonylu@linux.alibaba.com>
9  *
10  */
11 
12 #include <linux/bitops.h>
13 #include <linux/device.h>
14 #include <linux/dibs.h>
15 #include <linux/mm.h>
16 #include <linux/slab.h>
17 #include <linux/spinlock.h>
18 #include <linux/types.h>
19 
20 #include "dibs_loopback.h"
21 
22 #define DIBS_LO_SUPPORT_NOCOPY	0x1
23 #define DIBS_DMA_ADDR_INVALID	(~(dma_addr_t)0)
24 
25 static const char dibs_lo_dev_name[] = "lo";
26 /* global loopback device */
27 static struct dibs_lo_dev *lo_dev;
28 
29 static u16 dibs_lo_get_fabric_id(struct dibs_dev *dibs)
30 {
31 	return DIBS_LOOPBACK_FABRIC;
32 }
33 
34 static int dibs_lo_query_rgid(struct dibs_dev *dibs, const uuid_t *rgid,
35 			      u32 vid_valid, u32 vid)
36 {
37 	/* rgid should be the same as lgid */
38 	if (!uuid_equal(rgid, &dibs->gid))
39 		return -ENETUNREACH;
40 	return 0;
41 }
42 
43 static int dibs_lo_max_dmbs(void)
44 {
45 	return DIBS_LO_MAX_DMBS;
46 }
47 
48 static int dibs_lo_register_dmb(struct dibs_dev *dibs, struct dibs_dmb *dmb,
49 				struct dibs_client *client)
50 {
51 	struct dibs_lo_dmb_node *dmb_node, *tmp_node;
52 	struct dibs_lo_dev *ldev;
53 	struct folio *folio;
54 	unsigned long flags;
55 	int sba_idx, rc;
56 
57 	ldev = dibs->drv_priv;
58 	sba_idx = dmb->idx;
59 	/* check space for new dmb */
60 	for_each_clear_bit(sba_idx, ldev->sba_idx_mask, DIBS_LO_MAX_DMBS) {
61 		if (!test_and_set_bit(sba_idx, ldev->sba_idx_mask))
62 			break;
63 	}
64 	if (sba_idx == DIBS_LO_MAX_DMBS)
65 		return -ENOSPC;
66 
67 	dmb_node = kzalloc(sizeof(*dmb_node), GFP_KERNEL);
68 	if (!dmb_node) {
69 		rc = -ENOMEM;
70 		goto err_bit;
71 	}
72 
73 	dmb_node->sba_idx = sba_idx;
74 	dmb_node->len = dmb->dmb_len;
75 
76 	/* not critical; fail under memory pressure and fallback to TCP */
77 	folio = folio_alloc(GFP_KERNEL | __GFP_NOWARN | __GFP_NOMEMALLOC |
78 			    __GFP_NORETRY | __GFP_ZERO,
79 			    get_order(dmb_node->len));
80 	if (!folio) {
81 		rc = -ENOMEM;
82 		goto err_node;
83 	}
84 	dmb_node->cpu_addr = folio_address(folio);
85 	dmb_node->dma_addr = DIBS_DMA_ADDR_INVALID;
86 	refcount_set(&dmb_node->refcnt, 1);
87 
88 again:
89 	/* add new dmb into hash table */
90 	get_random_bytes(&dmb_node->token, sizeof(dmb_node->token));
91 	write_lock_bh(&ldev->dmb_ht_lock);
92 	hash_for_each_possible(ldev->dmb_ht, tmp_node, list, dmb_node->token) {
93 		if (tmp_node->token == dmb_node->token) {
94 			write_unlock_bh(&ldev->dmb_ht_lock);
95 			goto again;
96 		}
97 	}
98 	hash_add(ldev->dmb_ht, &dmb_node->list, dmb_node->token);
99 	write_unlock_bh(&ldev->dmb_ht_lock);
100 	atomic_inc(&ldev->dmb_cnt);
101 
102 	dmb->idx = dmb_node->sba_idx;
103 	dmb->dmb_tok = dmb_node->token;
104 	dmb->cpu_addr = dmb_node->cpu_addr;
105 	dmb->dma_addr = dmb_node->dma_addr;
106 	dmb->dmb_len = dmb_node->len;
107 
108 	spin_lock_irqsave(&dibs->lock, flags);
109 	dibs->dmb_clientid_arr[sba_idx] = client->id;
110 	spin_unlock_irqrestore(&dibs->lock, flags);
111 
112 	return 0;
113 
114 err_node:
115 	kfree(dmb_node);
116 err_bit:
117 	clear_bit(sba_idx, ldev->sba_idx_mask);
118 	return rc;
119 }
120 
121 static void __dibs_lo_unregister_dmb(struct dibs_lo_dev *ldev,
122 				     struct dibs_lo_dmb_node *dmb_node)
123 {
124 	/* remove dmb from hash table */
125 	write_lock_bh(&ldev->dmb_ht_lock);
126 	hash_del(&dmb_node->list);
127 	write_unlock_bh(&ldev->dmb_ht_lock);
128 
129 	clear_bit(dmb_node->sba_idx, ldev->sba_idx_mask);
130 	folio_put(virt_to_folio(dmb_node->cpu_addr));
131 	kfree(dmb_node);
132 
133 	if (atomic_dec_and_test(&ldev->dmb_cnt))
134 		wake_up(&ldev->ldev_release);
135 }
136 
137 static int dibs_lo_unregister_dmb(struct dibs_dev *dibs, struct dibs_dmb *dmb)
138 {
139 	struct dibs_lo_dmb_node *dmb_node = NULL, *tmp_node;
140 	struct dibs_lo_dev *ldev;
141 	unsigned long flags;
142 
143 	ldev = dibs->drv_priv;
144 
145 	/* find dmb from hash table */
146 	read_lock_bh(&ldev->dmb_ht_lock);
147 	hash_for_each_possible(ldev->dmb_ht, tmp_node, list, dmb->dmb_tok) {
148 		if (tmp_node->token == dmb->dmb_tok) {
149 			dmb_node = tmp_node;
150 			break;
151 		}
152 	}
153 	read_unlock_bh(&ldev->dmb_ht_lock);
154 	if (!dmb_node)
155 		return -EINVAL;
156 
157 	if (refcount_dec_and_test(&dmb_node->refcnt)) {
158 		spin_lock_irqsave(&dibs->lock, flags);
159 		dibs->dmb_clientid_arr[dmb_node->sba_idx] = NO_DIBS_CLIENT;
160 		spin_unlock_irqrestore(&dibs->lock, flags);
161 
162 		__dibs_lo_unregister_dmb(ldev, dmb_node);
163 	}
164 	return 0;
165 }
166 
167 static int dibs_lo_support_dmb_nocopy(struct dibs_dev *dibs)
168 {
169 	return DIBS_LO_SUPPORT_NOCOPY;
170 }
171 
172 static int dibs_lo_attach_dmb(struct dibs_dev *dibs, struct dibs_dmb *dmb)
173 {
174 	struct dibs_lo_dmb_node *dmb_node = NULL, *tmp_node;
175 	struct dibs_lo_dev *ldev;
176 
177 	ldev = dibs->drv_priv;
178 
179 	/* find dmb_node according to dmb->dmb_tok */
180 	read_lock_bh(&ldev->dmb_ht_lock);
181 	hash_for_each_possible(ldev->dmb_ht, tmp_node, list, dmb->dmb_tok) {
182 		if (tmp_node->token == dmb->dmb_tok) {
183 			dmb_node = tmp_node;
184 			break;
185 		}
186 	}
187 	if (!dmb_node) {
188 		read_unlock_bh(&ldev->dmb_ht_lock);
189 		return -EINVAL;
190 	}
191 	read_unlock_bh(&ldev->dmb_ht_lock);
192 
193 	if (!refcount_inc_not_zero(&dmb_node->refcnt))
194 		/* the dmb is being unregistered, but has
195 		 * not been removed from the hash table.
196 		 */
197 		return -EINVAL;
198 
199 	/* provide dmb information */
200 	dmb->idx = dmb_node->sba_idx;
201 	dmb->dmb_tok = dmb_node->token;
202 	dmb->cpu_addr = dmb_node->cpu_addr;
203 	dmb->dma_addr = dmb_node->dma_addr;
204 	dmb->dmb_len = dmb_node->len;
205 	return 0;
206 }
207 
208 static int dibs_lo_detach_dmb(struct dibs_dev *dibs, u64 token)
209 {
210 	struct dibs_lo_dmb_node *dmb_node = NULL, *tmp_node;
211 	struct dibs_lo_dev *ldev;
212 
213 	ldev = dibs->drv_priv;
214 
215 	/* find dmb_node according to dmb->dmb_tok */
216 	read_lock_bh(&ldev->dmb_ht_lock);
217 	hash_for_each_possible(ldev->dmb_ht, tmp_node, list, token) {
218 		if (tmp_node->token == token) {
219 			dmb_node = tmp_node;
220 			break;
221 		}
222 	}
223 	if (!dmb_node) {
224 		read_unlock_bh(&ldev->dmb_ht_lock);
225 		return -EINVAL;
226 	}
227 	read_unlock_bh(&ldev->dmb_ht_lock);
228 
229 	if (refcount_dec_and_test(&dmb_node->refcnt))
230 		__dibs_lo_unregister_dmb(ldev, dmb_node);
231 	return 0;
232 }
233 
234 static int dibs_lo_move_data(struct dibs_dev *dibs, u64 dmb_tok,
235 			     unsigned int idx, bool sf, unsigned int offset,
236 			     void *data, unsigned int size)
237 {
238 	struct dibs_lo_dmb_node *rmb_node = NULL, *tmp_node;
239 	struct dibs_lo_dev *ldev;
240 	u16 s_mask;
241 	u8 client_id;
242 	u32 sba_idx;
243 
244 	ldev = dibs->drv_priv;
245 
246 	read_lock_bh(&ldev->dmb_ht_lock);
247 	hash_for_each_possible(ldev->dmb_ht, tmp_node, list, dmb_tok) {
248 		if (tmp_node->token == dmb_tok) {
249 			rmb_node = tmp_node;
250 			break;
251 		}
252 	}
253 	if (!rmb_node) {
254 		read_unlock_bh(&ldev->dmb_ht_lock);
255 		return -EINVAL;
256 	}
257 	memcpy((char *)rmb_node->cpu_addr + offset, data, size);
258 	sba_idx = rmb_node->sba_idx;
259 	read_unlock_bh(&ldev->dmb_ht_lock);
260 
261 	if (!sf)
262 		return 0;
263 
264 	spin_lock(&dibs->lock);
265 	client_id = dibs->dmb_clientid_arr[sba_idx];
266 	s_mask = ror16(0x1000, idx);
267 	if (likely(client_id != NO_DIBS_CLIENT && dibs->subs[client_id]))
268 		dibs->subs[client_id]->ops->handle_irq(dibs, sba_idx, s_mask);
269 	spin_unlock(&dibs->lock);
270 
271 	return 0;
272 }
273 
274 static const struct dibs_dev_ops dibs_lo_ops = {
275 	.get_fabric_id = dibs_lo_get_fabric_id,
276 	.query_remote_gid = dibs_lo_query_rgid,
277 	.max_dmbs = dibs_lo_max_dmbs,
278 	.register_dmb = dibs_lo_register_dmb,
279 	.unregister_dmb = dibs_lo_unregister_dmb,
280 	.move_data = dibs_lo_move_data,
281 	.support_mmapped_rdmb = dibs_lo_support_dmb_nocopy,
282 	.attach_dmb = dibs_lo_attach_dmb,
283 	.detach_dmb = dibs_lo_detach_dmb,
284 };
285 
286 static void dibs_lo_dev_init(struct dibs_lo_dev *ldev)
287 {
288 	rwlock_init(&ldev->dmb_ht_lock);
289 	hash_init(ldev->dmb_ht);
290 	atomic_set(&ldev->dmb_cnt, 0);
291 	init_waitqueue_head(&ldev->ldev_release);
292 }
293 
294 static void dibs_lo_dev_exit(struct dibs_lo_dev *ldev)
295 {
296 	if (atomic_read(&ldev->dmb_cnt))
297 		wait_event(ldev->ldev_release, !atomic_read(&ldev->dmb_cnt));
298 }
299 
300 static int dibs_lo_dev_probe(void)
301 {
302 	struct dibs_lo_dev *ldev;
303 	struct dibs_dev *dibs;
304 	int ret;
305 
306 	ldev = kzalloc(sizeof(*ldev), GFP_KERNEL);
307 	if (!ldev)
308 		return -ENOMEM;
309 
310 	dibs = dibs_dev_alloc();
311 	if (!dibs) {
312 		kfree(ldev);
313 		return -ENOMEM;
314 	}
315 
316 	ldev->dibs = dibs;
317 	dibs->drv_priv = ldev;
318 	dibs_lo_dev_init(ldev);
319 	uuid_gen(&dibs->gid);
320 	dibs->ops = &dibs_lo_ops;
321 
322 	dibs->dev.parent = NULL;
323 	dev_set_name(&dibs->dev, "%s", dibs_lo_dev_name);
324 
325 	ret = dibs_dev_add(dibs);
326 	if (ret)
327 		goto err_reg;
328 	lo_dev = ldev;
329 	return 0;
330 
331 err_reg:
332 	kfree(dibs->dmb_clientid_arr);
333 	/* pairs with dibs_dev_alloc() */
334 	put_device(&dibs->dev);
335 	kfree(ldev);
336 
337 	return ret;
338 }
339 
340 static void dibs_lo_dev_remove(void)
341 {
342 	if (!lo_dev)
343 		return;
344 
345 	dibs_dev_del(lo_dev->dibs);
346 	dibs_lo_dev_exit(lo_dev);
347 	/* pairs with dibs_dev_alloc() */
348 	put_device(&lo_dev->dibs->dev);
349 	kfree(lo_dev);
350 	lo_dev = NULL;
351 }
352 
353 int dibs_loopback_init(void)
354 {
355 	return dibs_lo_dev_probe();
356 }
357 
358 void dibs_loopback_exit(void)
359 {
360 	dibs_lo_dev_remove();
361 }
362