xref: /linux/drivers/infiniband/sw/rxe/rxe_odp.c (revision 746680ec6696585e30db3e18c93a63df9cbec39c)
1 // SPDX-License-Identifier: GPL-2.0 OR Linux-OpenIB
2 /*
3  * Copyright (c) 2022-2023 Fujitsu Ltd. All rights reserved.
4  */
5 
6 #include <linux/hmm.h>
7 #include <linux/libnvdimm.h>
8 
9 #include <rdma/ib_umem_odp.h>
10 
11 #include "rxe.h"
12 
13 static bool rxe_ib_invalidate_range(struct mmu_interval_notifier *mni,
14 				    const struct mmu_notifier_range *range,
15 				    unsigned long cur_seq)
16 {
17 	struct ib_umem_odp *umem_odp =
18 		container_of(mni, struct ib_umem_odp, notifier);
19 	unsigned long start, end;
20 
21 	if (!mmu_notifier_range_blockable(range))
22 		return false;
23 
24 	mutex_lock(&umem_odp->umem_mutex);
25 	mmu_interval_set_seq(mni, cur_seq);
26 
27 	start = max_t(u64, ib_umem_start(umem_odp), range->start);
28 	end = min_t(u64, ib_umem_end(umem_odp), range->end);
29 
30 	/* update umem_odp->map.pfn_list */
31 	ib_umem_odp_unmap_dma_pages(umem_odp, start, end);
32 
33 	mutex_unlock(&umem_odp->umem_mutex);
34 	return true;
35 }
36 
37 const struct mmu_interval_notifier_ops rxe_mn_ops = {
38 	.invalidate = rxe_ib_invalidate_range,
39 };
40 
41 #define RXE_PAGEFAULT_DEFAULT 0
42 #define RXE_PAGEFAULT_RDONLY BIT(0)
43 #define RXE_PAGEFAULT_SNAPSHOT BIT(1)
44 static int rxe_odp_do_pagefault_and_lock(struct rxe_mr *mr, u64 user_va, int bcnt, u32 flags)
45 {
46 	struct ib_umem_odp *umem_odp = to_ib_umem_odp(mr->umem);
47 	bool fault = !(flags & RXE_PAGEFAULT_SNAPSHOT);
48 	u64 access_mask = 0;
49 	int np;
50 
51 	if (umem_odp->umem.writable && !(flags & RXE_PAGEFAULT_RDONLY))
52 		access_mask |= HMM_PFN_WRITE;
53 
54 	/*
55 	 * ib_umem_odp_map_dma_and_lock() locks umem_mutex on success.
56 	 * Callers must release the lock later to let invalidation handler
57 	 * do its work again.
58 	 */
59 	np = ib_umem_odp_map_dma_and_lock(umem_odp, user_va, bcnt,
60 					  access_mask, fault);
61 	return np;
62 }
63 
64 static int rxe_odp_init_pages(struct rxe_mr *mr)
65 {
66 	struct ib_umem_odp *umem_odp = to_ib_umem_odp(mr->umem);
67 	int ret;
68 
69 	ret = rxe_odp_do_pagefault_and_lock(mr, mr->umem->address,
70 					    mr->umem->length,
71 					    RXE_PAGEFAULT_SNAPSHOT);
72 
73 	if (ret >= 0)
74 		mutex_unlock(&umem_odp->umem_mutex);
75 
76 	return ret >= 0 ? 0 : ret;
77 }
78 
79 int rxe_odp_mr_init_user(struct rxe_dev *rxe, u64 start, u64 length,
80 			 u64 iova, int access_flags, struct rxe_mr *mr)
81 {
82 	struct ib_umem_odp *umem_odp;
83 	int err;
84 
85 	if (!IS_ENABLED(CONFIG_INFINIBAND_ON_DEMAND_PAGING))
86 		return -EOPNOTSUPP;
87 
88 	rxe_mr_init(access_flags, mr);
89 
90 	if (!start && length == U64_MAX) {
91 		if (iova != 0)
92 			return -EINVAL;
93 		if (!(rxe->attr.odp_caps.general_caps & IB_ODP_SUPPORT_IMPLICIT))
94 			return -EINVAL;
95 
96 		/* Never reach here, for implicit ODP is not implemented. */
97 	}
98 
99 	umem_odp = ib_umem_odp_get(&rxe->ib_dev, start, length, access_flags,
100 				   &rxe_mn_ops);
101 	if (IS_ERR(umem_odp)) {
102 		rxe_dbg_mr(mr, "Unable to create umem_odp err = %d\n",
103 			   (int)PTR_ERR(umem_odp));
104 		return PTR_ERR(umem_odp);
105 	}
106 
107 	umem_odp->private = mr;
108 
109 	mr->umem = &umem_odp->umem;
110 	mr->access = access_flags;
111 	mr->ibmr.length = length;
112 	mr->ibmr.iova = iova;
113 	mr->page_offset = ib_umem_offset(&umem_odp->umem);
114 
115 	err = rxe_odp_init_pages(mr);
116 	if (err) {
117 		ib_umem_odp_release(umem_odp);
118 		return err;
119 	}
120 
121 	mr->state = RXE_MR_STATE_VALID;
122 	mr->ibmr.type = IB_MR_TYPE_USER;
123 
124 	return err;
125 }
126 
127 static inline bool rxe_check_pagefault(struct ib_umem_odp *umem_odp, u64 iova,
128 				       int length)
129 {
130 	bool need_fault = false;
131 	u64 addr;
132 	int idx;
133 
134 	addr = iova & (~(BIT(umem_odp->page_shift) - 1));
135 
136 	/* Skim through all pages that are to be accessed. */
137 	while (addr < iova + length) {
138 		idx = (addr - ib_umem_start(umem_odp)) >> umem_odp->page_shift;
139 
140 		if (!(umem_odp->map.pfn_list[idx] & HMM_PFN_VALID)) {
141 			need_fault = true;
142 			break;
143 		}
144 
145 		addr += BIT(umem_odp->page_shift);
146 	}
147 	return need_fault;
148 }
149 
150 static unsigned long rxe_odp_iova_to_index(struct ib_umem_odp *umem_odp, u64 iova)
151 {
152 	return (iova - ib_umem_start(umem_odp)) >> umem_odp->page_shift;
153 }
154 
155 static unsigned long rxe_odp_iova_to_page_offset(struct ib_umem_odp *umem_odp, u64 iova)
156 {
157 	return iova & (BIT(umem_odp->page_shift) - 1);
158 }
159 
160 static int rxe_odp_map_range_and_lock(struct rxe_mr *mr, u64 iova, int length, u32 flags)
161 {
162 	struct ib_umem_odp *umem_odp = to_ib_umem_odp(mr->umem);
163 	bool need_fault;
164 	int err;
165 
166 	if (unlikely(length < 1))
167 		return -EINVAL;
168 
169 	mutex_lock(&umem_odp->umem_mutex);
170 
171 	need_fault = rxe_check_pagefault(umem_odp, iova, length);
172 	if (need_fault) {
173 		mutex_unlock(&umem_odp->umem_mutex);
174 
175 		/* umem_mutex is locked on success. */
176 		err = rxe_odp_do_pagefault_and_lock(mr, iova, length,
177 						    flags);
178 		if (err < 0)
179 			return err;
180 
181 		need_fault = rxe_check_pagefault(umem_odp, iova, length);
182 		if (need_fault)
183 			return -EFAULT;
184 	}
185 
186 	return 0;
187 }
188 
189 static int __rxe_odp_mr_copy(struct rxe_mr *mr, u64 iova, void *addr,
190 			     int length, enum rxe_mr_copy_dir dir)
191 {
192 	struct ib_umem_odp *umem_odp = to_ib_umem_odp(mr->umem);
193 	struct page *page;
194 	int idx, bytes;
195 	size_t offset;
196 	u8 *user_va;
197 
198 	idx = rxe_odp_iova_to_index(umem_odp, iova);
199 	offset = rxe_odp_iova_to_page_offset(umem_odp, iova);
200 
201 	while (length > 0) {
202 		u8 *src, *dest;
203 
204 		page = hmm_pfn_to_page(umem_odp->map.pfn_list[idx]);
205 		user_va = kmap_local_page(page);
206 
207 		src = (dir == RXE_TO_MR_OBJ) ? addr : user_va;
208 		dest = (dir == RXE_TO_MR_OBJ) ? user_va : addr;
209 
210 		bytes = BIT(umem_odp->page_shift) - offset;
211 		if (bytes > length)
212 			bytes = length;
213 
214 		memcpy(dest, src, bytes);
215 		kunmap_local(user_va);
216 
217 		length  -= bytes;
218 		idx++;
219 		offset = 0;
220 	}
221 
222 	return 0;
223 }
224 
225 int rxe_odp_mr_copy(struct rxe_mr *mr, u64 iova, void *addr, int length,
226 		    enum rxe_mr_copy_dir dir)
227 {
228 	struct ib_umem_odp *umem_odp = to_ib_umem_odp(mr->umem);
229 	u32 flags = RXE_PAGEFAULT_DEFAULT;
230 	int err;
231 
232 	if (length == 0)
233 		return 0;
234 
235 	if (unlikely(!mr->umem->is_odp))
236 		return -EOPNOTSUPP;
237 
238 	switch (dir) {
239 	case RXE_TO_MR_OBJ:
240 		break;
241 
242 	case RXE_FROM_MR_OBJ:
243 		flags |= RXE_PAGEFAULT_RDONLY;
244 		break;
245 
246 	default:
247 		return -EINVAL;
248 	}
249 
250 	err = rxe_odp_map_range_and_lock(mr, iova, length, flags);
251 	if (err)
252 		return err;
253 
254 	err =  __rxe_odp_mr_copy(mr, iova, addr, length, dir);
255 
256 	mutex_unlock(&umem_odp->umem_mutex);
257 
258 	return err;
259 }
260 
261 static enum resp_states rxe_odp_do_atomic_op(struct rxe_mr *mr, u64 iova,
262 					     int opcode, u64 compare,
263 					     u64 swap_add, u64 *orig_val)
264 {
265 	struct ib_umem_odp *umem_odp = to_ib_umem_odp(mr->umem);
266 	unsigned int page_offset;
267 	struct page *page;
268 	unsigned int idx;
269 	u64 value;
270 	u64 *va;
271 	int err;
272 
273 	if (unlikely(mr->state != RXE_MR_STATE_VALID)) {
274 		rxe_dbg_mr(mr, "mr not in valid state\n");
275 		return RESPST_ERR_RKEY_VIOLATION;
276 	}
277 
278 	err = mr_check_range(mr, iova, sizeof(value));
279 	if (err) {
280 		rxe_dbg_mr(mr, "iova out of range\n");
281 		return RESPST_ERR_RKEY_VIOLATION;
282 	}
283 
284 	page_offset = rxe_odp_iova_to_page_offset(umem_odp, iova);
285 	if (unlikely(page_offset & 0x7)) {
286 		rxe_dbg_mr(mr, "iova not aligned\n");
287 		return RESPST_ERR_MISALIGNED_ATOMIC;
288 	}
289 
290 	idx = rxe_odp_iova_to_index(umem_odp, iova);
291 	page = hmm_pfn_to_page(umem_odp->map.pfn_list[idx]);
292 
293 	va = kmap_local_page(page);
294 
295 	spin_lock_bh(&atomic_ops_lock);
296 	value = *orig_val = va[page_offset >> 3];
297 
298 	if (opcode == IB_OPCODE_RC_COMPARE_SWAP) {
299 		if (value == compare)
300 			va[page_offset >> 3] = swap_add;
301 	} else {
302 		value += swap_add;
303 		va[page_offset >> 3] = value;
304 	}
305 	spin_unlock_bh(&atomic_ops_lock);
306 
307 	kunmap_local(va);
308 
309 	return RESPST_NONE;
310 }
311 
312 enum resp_states rxe_odp_atomic_op(struct rxe_mr *mr, u64 iova, int opcode,
313 				   u64 compare, u64 swap_add, u64 *orig_val)
314 {
315 	struct ib_umem_odp *umem_odp = to_ib_umem_odp(mr->umem);
316 	int err;
317 
318 	err = rxe_odp_map_range_and_lock(mr, iova, sizeof(char),
319 					 RXE_PAGEFAULT_DEFAULT);
320 	if (err < 0)
321 		return RESPST_ERR_RKEY_VIOLATION;
322 
323 	err = rxe_odp_do_atomic_op(mr, iova, opcode, compare, swap_add,
324 				   orig_val);
325 	mutex_unlock(&umem_odp->umem_mutex);
326 
327 	return err;
328 }
329 
330 int rxe_odp_flush_pmem_iova(struct rxe_mr *mr, u64 iova,
331 			    unsigned int length)
332 {
333 	struct ib_umem_odp *umem_odp = to_ib_umem_odp(mr->umem);
334 	unsigned int page_offset;
335 	unsigned long index;
336 	struct page *page;
337 	unsigned int bytes;
338 	int err;
339 	u8 *va;
340 
341 	err = rxe_odp_map_range_and_lock(mr, iova, length,
342 					 RXE_PAGEFAULT_DEFAULT);
343 	if (err)
344 		return err;
345 
346 	while (length > 0) {
347 		index = rxe_odp_iova_to_index(umem_odp, iova);
348 		page_offset = rxe_odp_iova_to_page_offset(umem_odp, iova);
349 
350 		page = hmm_pfn_to_page(umem_odp->map.pfn_list[index]);
351 
352 		bytes = min_t(unsigned int, length,
353 			      mr_page_size(mr) - page_offset);
354 
355 		va = kmap_local_page(page);
356 		arch_wb_cache_pmem(va + page_offset, bytes);
357 		kunmap_local(va);
358 
359 		length -= bytes;
360 		iova += bytes;
361 		page_offset = 0;
362 	}
363 
364 	mutex_unlock(&umem_odp->umem_mutex);
365 
366 	return 0;
367 }
368 
369 enum resp_states rxe_odp_do_atomic_write(struct rxe_mr *mr, u64 iova, u64 value)
370 {
371 	struct ib_umem_odp *umem_odp = to_ib_umem_odp(mr->umem);
372 	unsigned int page_offset;
373 	unsigned long index;
374 	struct page *page;
375 	int err;
376 	u64 *va;
377 
378 	/* See IBA oA19-28 */
379 	err = mr_check_range(mr, iova, sizeof(value));
380 	if (unlikely(err)) {
381 		rxe_dbg_mr(mr, "iova out of range\n");
382 		return RESPST_ERR_RKEY_VIOLATION;
383 	}
384 
385 	err = rxe_odp_map_range_and_lock(mr, iova, sizeof(value),
386 					 RXE_PAGEFAULT_DEFAULT);
387 	if (err)
388 		return RESPST_ERR_RKEY_VIOLATION;
389 
390 	page_offset = rxe_odp_iova_to_page_offset(umem_odp, iova);
391 	/* See IBA A19.4.2 */
392 	if (unlikely(page_offset & 0x7)) {
393 		mutex_unlock(&umem_odp->umem_mutex);
394 		rxe_dbg_mr(mr, "misaligned address\n");
395 		return RESPST_ERR_MISALIGNED_ATOMIC;
396 	}
397 
398 	index = rxe_odp_iova_to_index(umem_odp, iova);
399 	page = hmm_pfn_to_page(umem_odp->map.pfn_list[index]);
400 
401 	va = kmap_local_page(page);
402 	/* Do atomic write after all prior operations have completed */
403 	smp_store_release(&va[page_offset >> 3], value);
404 	kunmap_local(va);
405 
406 	mutex_unlock(&umem_odp->umem_mutex);
407 
408 	return RESPST_NONE;
409 }
410 
411 struct prefetch_mr_work {
412 	struct work_struct work;
413 	u32 pf_flags;
414 	u32 num_sge;
415 	struct {
416 		u64 io_virt;
417 		struct rxe_mr *mr;
418 		size_t length;
419 	} frags[];
420 };
421 
422 static void rxe_ib_prefetch_mr_work(struct work_struct *w)
423 {
424 	struct prefetch_mr_work *work =
425 		container_of(w, struct prefetch_mr_work, work);
426 	int ret;
427 	u32 i;
428 
429 	/*
430 	 * We rely on IB/core that work is executed
431 	 * if we have num_sge != 0 only.
432 	 */
433 	WARN_ON(!work->num_sge);
434 	for (i = 0; i < work->num_sge; ++i) {
435 		struct ib_umem_odp *umem_odp;
436 
437 		ret = rxe_odp_do_pagefault_and_lock(work->frags[i].mr,
438 						    work->frags[i].io_virt,
439 						    work->frags[i].length,
440 						    work->pf_flags);
441 		if (ret < 0) {
442 			rxe_dbg_mr(work->frags[i].mr,
443 				   "failed to prefetch the mr\n");
444 			goto deref;
445 		}
446 
447 		umem_odp = to_ib_umem_odp(work->frags[i].mr->umem);
448 		mutex_unlock(&umem_odp->umem_mutex);
449 
450 deref:
451 		rxe_put(work->frags[i].mr);
452 	}
453 
454 	kvfree(work);
455 }
456 
457 static int rxe_ib_prefetch_sg_list(struct ib_pd *ibpd,
458 				   enum ib_uverbs_advise_mr_advice advice,
459 				   u32 pf_flags, struct ib_sge *sg_list,
460 				   u32 num_sge)
461 {
462 	struct rxe_pd *pd = container_of(ibpd, struct rxe_pd, ibpd);
463 	int ret = 0;
464 	u32 i;
465 
466 	for (i = 0; i < num_sge; ++i) {
467 		struct rxe_mr *mr;
468 		struct ib_umem_odp *umem_odp;
469 
470 		mr = lookup_mr(pd, IB_ACCESS_LOCAL_WRITE,
471 			       sg_list[i].lkey, RXE_LOOKUP_LOCAL);
472 
473 		if (!mr) {
474 			rxe_dbg_pd(pd, "mr with lkey %x not found\n",
475 				   sg_list[i].lkey);
476 			return -EINVAL;
477 		}
478 
479 		if (advice == IB_UVERBS_ADVISE_MR_ADVICE_PREFETCH_WRITE &&
480 		    !mr->umem->writable) {
481 			rxe_dbg_mr(mr, "missing write permission\n");
482 			rxe_put(mr);
483 			return -EPERM;
484 		}
485 
486 		ret = rxe_odp_do_pagefault_and_lock(
487 			mr, sg_list[i].addr, sg_list[i].length, pf_flags);
488 		if (ret < 0) {
489 			rxe_dbg_mr(mr, "failed to prefetch the mr\n");
490 			rxe_put(mr);
491 			return ret;
492 		}
493 
494 		umem_odp = to_ib_umem_odp(mr->umem);
495 		mutex_unlock(&umem_odp->umem_mutex);
496 
497 		rxe_put(mr);
498 	}
499 
500 	return 0;
501 }
502 
503 static int rxe_ib_advise_mr_prefetch(struct ib_pd *ibpd,
504 				     enum ib_uverbs_advise_mr_advice advice,
505 				     u32 flags, struct ib_sge *sg_list,
506 				     u32 num_sge)
507 {
508 	struct rxe_pd *pd = container_of(ibpd, struct rxe_pd, ibpd);
509 	u32 pf_flags = RXE_PAGEFAULT_DEFAULT;
510 	struct prefetch_mr_work *work;
511 	struct rxe_mr *mr;
512 	u32 i;
513 
514 	if (advice == IB_UVERBS_ADVISE_MR_ADVICE_PREFETCH)
515 		pf_flags |= RXE_PAGEFAULT_RDONLY;
516 
517 	if (advice == IB_UVERBS_ADVISE_MR_ADVICE_PREFETCH_NO_FAULT)
518 		pf_flags |= RXE_PAGEFAULT_SNAPSHOT;
519 
520 	/* Synchronous call */
521 	if (flags & IB_UVERBS_ADVISE_MR_FLAG_FLUSH)
522 		return rxe_ib_prefetch_sg_list(ibpd, advice, pf_flags, sg_list,
523 					       num_sge);
524 
525 	/* Asynchronous call is "best-effort" and allowed to fail */
526 	work = kvzalloc(struct_size(work, frags, num_sge), GFP_KERNEL);
527 	if (!work)
528 		return -ENOMEM;
529 
530 	INIT_WORK(&work->work, rxe_ib_prefetch_mr_work);
531 	work->pf_flags = pf_flags;
532 	work->num_sge = num_sge;
533 
534 	for (i = 0; i < num_sge; ++i) {
535 		/* Takes a reference, which will be released in the queued work */
536 		mr = lookup_mr(pd, IB_ACCESS_LOCAL_WRITE,
537 			       sg_list[i].lkey, RXE_LOOKUP_LOCAL);
538 		if (!mr) {
539 			mr = ERR_PTR(-EINVAL);
540 			goto err;
541 		}
542 
543 		work->frags[i].io_virt = sg_list[i].addr;
544 		work->frags[i].length = sg_list[i].length;
545 		work->frags[i].mr = mr;
546 	}
547 
548 	queue_work(system_unbound_wq, &work->work);
549 
550 	return 0;
551 
552  err:
553 	/* rollback reference counts for the invalid request */
554 	while (i > 0) {
555 		i--;
556 		rxe_put(work->frags[i].mr);
557 	}
558 
559 	kvfree(work);
560 
561 	return PTR_ERR(mr);
562 }
563 
564 int rxe_ib_advise_mr(struct ib_pd *ibpd,
565 		     enum ib_uverbs_advise_mr_advice advice,
566 		     u32 flags,
567 		     struct ib_sge *sg_list,
568 		     u32 num_sge,
569 		     struct uverbs_attr_bundle *attrs)
570 {
571 	if (advice != IB_UVERBS_ADVISE_MR_ADVICE_PREFETCH &&
572 	    advice != IB_UVERBS_ADVISE_MR_ADVICE_PREFETCH_WRITE &&
573 	    advice != IB_UVERBS_ADVISE_MR_ADVICE_PREFETCH_NO_FAULT)
574 		return -EOPNOTSUPP;
575 
576 	return rxe_ib_advise_mr_prefetch(ibpd, advice, flags,
577 					 sg_list, num_sge);
578 }
579