1 // SPDX-License-Identifier: GPL-2.0 OR Linux-OpenIB
2 /*
3 * Copyright (c) 2022-2023 Fujitsu Ltd. All rights reserved.
4 */
5
6 #include <linux/hmm.h>
7 #include <linux/libnvdimm.h>
8
9 #include <rdma/ib_umem_odp.h>
10
11 #include "rxe.h"
12
rxe_ib_invalidate_range(struct mmu_interval_notifier * mni,const struct mmu_notifier_range * range,unsigned long cur_seq)13 static bool rxe_ib_invalidate_range(struct mmu_interval_notifier *mni,
14 const struct mmu_notifier_range *range,
15 unsigned long cur_seq)
16 {
17 struct ib_umem_odp *umem_odp =
18 container_of(mni, struct ib_umem_odp, notifier);
19 unsigned long start, end;
20
21 if (!mmu_notifier_range_blockable(range))
22 return false;
23
24 mutex_lock(&umem_odp->umem_mutex);
25 mmu_interval_set_seq(mni, cur_seq);
26
27 start = max_t(u64, ib_umem_start(umem_odp), range->start);
28 end = min_t(u64, ib_umem_end(umem_odp), range->end);
29
30 /* update umem_odp->map.pfn_list */
31 ib_umem_odp_unmap_dma_pages(umem_odp, start, end);
32
33 mutex_unlock(&umem_odp->umem_mutex);
34 return true;
35 }
36
37 const struct mmu_interval_notifier_ops rxe_mn_ops = {
38 .invalidate = rxe_ib_invalidate_range,
39 };
40
41 #define RXE_PAGEFAULT_DEFAULT 0
42 #define RXE_PAGEFAULT_RDONLY BIT(0)
43 #define RXE_PAGEFAULT_SNAPSHOT BIT(1)
rxe_odp_do_pagefault_and_lock(struct rxe_mr * mr,u64 user_va,int bcnt,u32 flags)44 static int rxe_odp_do_pagefault_and_lock(struct rxe_mr *mr, u64 user_va, int bcnt, u32 flags)
45 {
46 struct ib_umem_odp *umem_odp = to_ib_umem_odp(mr->umem);
47 bool fault = !(flags & RXE_PAGEFAULT_SNAPSHOT);
48 u64 access_mask = 0;
49 int np;
50
51 if (umem_odp->umem.writable && !(flags & RXE_PAGEFAULT_RDONLY))
52 access_mask |= HMM_PFN_WRITE;
53
54 /*
55 * ib_umem_odp_map_dma_and_lock() locks umem_mutex on success.
56 * Callers must release the lock later to let invalidation handler
57 * do its work again.
58 */
59 np = ib_umem_odp_map_dma_and_lock(umem_odp, user_va, bcnt,
60 access_mask, fault);
61 return np;
62 }
63
rxe_odp_init_pages(struct rxe_mr * mr)64 static int rxe_odp_init_pages(struct rxe_mr *mr)
65 {
66 struct ib_umem_odp *umem_odp = to_ib_umem_odp(mr->umem);
67 int ret;
68
69 ret = rxe_odp_do_pagefault_and_lock(mr, mr->umem->address,
70 mr->umem->length,
71 RXE_PAGEFAULT_SNAPSHOT);
72
73 if (ret >= 0)
74 mutex_unlock(&umem_odp->umem_mutex);
75
76 return ret >= 0 ? 0 : ret;
77 }
78
rxe_odp_mr_init_user(struct rxe_dev * rxe,u64 start,u64 length,u64 iova,int access_flags,struct rxe_mr * mr)79 int rxe_odp_mr_init_user(struct rxe_dev *rxe, u64 start, u64 length,
80 u64 iova, int access_flags, struct rxe_mr *mr)
81 {
82 struct ib_umem_odp *umem_odp;
83 int err;
84
85 if (!IS_ENABLED(CONFIG_INFINIBAND_ON_DEMAND_PAGING))
86 return -EOPNOTSUPP;
87
88 rxe_mr_init(access_flags, mr);
89
90 if (!start && length == U64_MAX) {
91 if (iova != 0)
92 return -EINVAL;
93 if (!(rxe->attr.odp_caps.general_caps & IB_ODP_SUPPORT_IMPLICIT))
94 return -EINVAL;
95
96 /* Never reach here, for implicit ODP is not implemented. */
97 }
98
99 umem_odp = ib_umem_odp_get(&rxe->ib_dev, start, length, access_flags,
100 &rxe_mn_ops);
101 if (IS_ERR(umem_odp)) {
102 rxe_dbg_mr(mr, "Unable to create umem_odp err = %d\n",
103 (int)PTR_ERR(umem_odp));
104 return PTR_ERR(umem_odp);
105 }
106
107 umem_odp->private = mr;
108
109 mr->umem = &umem_odp->umem;
110 mr->access = access_flags;
111 mr->ibmr.length = length;
112 mr->ibmr.iova = iova;
113 mr->page_offset = ib_umem_offset(&umem_odp->umem);
114
115 err = rxe_odp_init_pages(mr);
116 if (err) {
117 ib_umem_odp_release(umem_odp);
118 return err;
119 }
120
121 mr->state = RXE_MR_STATE_VALID;
122 mr->ibmr.type = IB_MR_TYPE_USER;
123
124 return err;
125 }
126
rxe_check_pagefault(struct ib_umem_odp * umem_odp,u64 iova,int length)127 static inline bool rxe_check_pagefault(struct ib_umem_odp *umem_odp, u64 iova,
128 int length)
129 {
130 bool need_fault = false;
131 u64 addr;
132 int idx;
133
134 addr = iova & (~(BIT(umem_odp->page_shift) - 1));
135
136 /* Skim through all pages that are to be accessed. */
137 while (addr < iova + length) {
138 idx = (addr - ib_umem_start(umem_odp)) >> umem_odp->page_shift;
139
140 if (!(umem_odp->map.pfn_list[idx] & HMM_PFN_VALID)) {
141 need_fault = true;
142 break;
143 }
144
145 addr += BIT(umem_odp->page_shift);
146 }
147 return need_fault;
148 }
149
rxe_odp_iova_to_index(struct ib_umem_odp * umem_odp,u64 iova)150 static unsigned long rxe_odp_iova_to_index(struct ib_umem_odp *umem_odp, u64 iova)
151 {
152 return (iova - ib_umem_start(umem_odp)) >> umem_odp->page_shift;
153 }
154
rxe_odp_iova_to_page_offset(struct ib_umem_odp * umem_odp,u64 iova)155 static unsigned long rxe_odp_iova_to_page_offset(struct ib_umem_odp *umem_odp, u64 iova)
156 {
157 return iova & (BIT(umem_odp->page_shift) - 1);
158 }
159
rxe_odp_map_range_and_lock(struct rxe_mr * mr,u64 iova,int length,u32 flags)160 static int rxe_odp_map_range_and_lock(struct rxe_mr *mr, u64 iova, int length, u32 flags)
161 {
162 struct ib_umem_odp *umem_odp = to_ib_umem_odp(mr->umem);
163 bool need_fault;
164 int err;
165
166 if (unlikely(length < 1))
167 return -EINVAL;
168
169 mutex_lock(&umem_odp->umem_mutex);
170
171 need_fault = rxe_check_pagefault(umem_odp, iova, length);
172 if (need_fault) {
173 mutex_unlock(&umem_odp->umem_mutex);
174
175 /* umem_mutex is locked on success. */
176 err = rxe_odp_do_pagefault_and_lock(mr, iova, length,
177 flags);
178 if (err < 0)
179 return err;
180
181 need_fault = rxe_check_pagefault(umem_odp, iova, length);
182 if (need_fault)
183 return -EFAULT;
184 }
185
186 return 0;
187 }
188
__rxe_odp_mr_copy(struct rxe_mr * mr,u64 iova,void * addr,int length,enum rxe_mr_copy_dir dir)189 static int __rxe_odp_mr_copy(struct rxe_mr *mr, u64 iova, void *addr,
190 int length, enum rxe_mr_copy_dir dir)
191 {
192 struct ib_umem_odp *umem_odp = to_ib_umem_odp(mr->umem);
193 struct page *page;
194 int idx, bytes;
195 size_t offset;
196 u8 *user_va;
197
198 idx = rxe_odp_iova_to_index(umem_odp, iova);
199 offset = rxe_odp_iova_to_page_offset(umem_odp, iova);
200
201 while (length > 0) {
202 u8 *src, *dest;
203
204 page = hmm_pfn_to_page(umem_odp->map.pfn_list[idx]);
205 user_va = kmap_local_page(page);
206
207 src = (dir == RXE_TO_MR_OBJ) ? addr : user_va;
208 dest = (dir == RXE_TO_MR_OBJ) ? user_va : addr;
209
210 bytes = BIT(umem_odp->page_shift) - offset;
211 if (bytes > length)
212 bytes = length;
213
214 memcpy(dest, src, bytes);
215 kunmap_local(user_va);
216
217 length -= bytes;
218 idx++;
219 offset = 0;
220 }
221
222 return 0;
223 }
224
rxe_odp_mr_copy(struct rxe_mr * mr,u64 iova,void * addr,int length,enum rxe_mr_copy_dir dir)225 int rxe_odp_mr_copy(struct rxe_mr *mr, u64 iova, void *addr, int length,
226 enum rxe_mr_copy_dir dir)
227 {
228 struct ib_umem_odp *umem_odp = to_ib_umem_odp(mr->umem);
229 u32 flags = RXE_PAGEFAULT_DEFAULT;
230 int err;
231
232 if (length == 0)
233 return 0;
234
235 if (unlikely(!mr->umem->is_odp))
236 return -EOPNOTSUPP;
237
238 switch (dir) {
239 case RXE_TO_MR_OBJ:
240 break;
241
242 case RXE_FROM_MR_OBJ:
243 flags |= RXE_PAGEFAULT_RDONLY;
244 break;
245
246 default:
247 return -EINVAL;
248 }
249
250 err = rxe_odp_map_range_and_lock(mr, iova, length, flags);
251 if (err)
252 return err;
253
254 err = __rxe_odp_mr_copy(mr, iova, addr, length, dir);
255
256 mutex_unlock(&umem_odp->umem_mutex);
257
258 return err;
259 }
260
rxe_odp_do_atomic_op(struct rxe_mr * mr,u64 iova,int opcode,u64 compare,u64 swap_add,u64 * orig_val)261 static enum resp_states rxe_odp_do_atomic_op(struct rxe_mr *mr, u64 iova,
262 int opcode, u64 compare,
263 u64 swap_add, u64 *orig_val)
264 {
265 struct ib_umem_odp *umem_odp = to_ib_umem_odp(mr->umem);
266 unsigned int page_offset;
267 struct page *page;
268 unsigned int idx;
269 u64 value;
270 u64 *va;
271 int err;
272
273 if (unlikely(mr->state != RXE_MR_STATE_VALID)) {
274 rxe_dbg_mr(mr, "mr not in valid state\n");
275 return RESPST_ERR_RKEY_VIOLATION;
276 }
277
278 err = mr_check_range(mr, iova, sizeof(value));
279 if (err) {
280 rxe_dbg_mr(mr, "iova out of range\n");
281 return RESPST_ERR_RKEY_VIOLATION;
282 }
283
284 page_offset = rxe_odp_iova_to_page_offset(umem_odp, iova);
285 if (unlikely(page_offset & 0x7)) {
286 rxe_dbg_mr(mr, "iova not aligned\n");
287 return RESPST_ERR_MISALIGNED_ATOMIC;
288 }
289
290 idx = rxe_odp_iova_to_index(umem_odp, iova);
291 page = hmm_pfn_to_page(umem_odp->map.pfn_list[idx]);
292
293 va = kmap_local_page(page);
294
295 spin_lock_bh(&atomic_ops_lock);
296 value = *orig_val = va[page_offset >> 3];
297
298 if (opcode == IB_OPCODE_RC_COMPARE_SWAP) {
299 if (value == compare)
300 va[page_offset >> 3] = swap_add;
301 } else {
302 value += swap_add;
303 va[page_offset >> 3] = value;
304 }
305 spin_unlock_bh(&atomic_ops_lock);
306
307 kunmap_local(va);
308
309 return RESPST_NONE;
310 }
311
rxe_odp_atomic_op(struct rxe_mr * mr,u64 iova,int opcode,u64 compare,u64 swap_add,u64 * orig_val)312 enum resp_states rxe_odp_atomic_op(struct rxe_mr *mr, u64 iova, int opcode,
313 u64 compare, u64 swap_add, u64 *orig_val)
314 {
315 struct ib_umem_odp *umem_odp = to_ib_umem_odp(mr->umem);
316 int err;
317
318 err = rxe_odp_map_range_and_lock(mr, iova, sizeof(char),
319 RXE_PAGEFAULT_DEFAULT);
320 if (err < 0)
321 return RESPST_ERR_RKEY_VIOLATION;
322
323 err = rxe_odp_do_atomic_op(mr, iova, opcode, compare, swap_add,
324 orig_val);
325 mutex_unlock(&umem_odp->umem_mutex);
326
327 return err;
328 }
329
rxe_odp_flush_pmem_iova(struct rxe_mr * mr,u64 iova,unsigned int length)330 int rxe_odp_flush_pmem_iova(struct rxe_mr *mr, u64 iova,
331 unsigned int length)
332 {
333 struct ib_umem_odp *umem_odp = to_ib_umem_odp(mr->umem);
334 unsigned int page_offset;
335 unsigned long index;
336 struct page *page;
337 unsigned int bytes;
338 int err;
339 u8 *va;
340
341 err = rxe_odp_map_range_and_lock(mr, iova, length,
342 RXE_PAGEFAULT_DEFAULT);
343 if (err)
344 return err;
345
346 while (length > 0) {
347 index = rxe_odp_iova_to_index(umem_odp, iova);
348 page_offset = rxe_odp_iova_to_page_offset(umem_odp, iova);
349
350 page = hmm_pfn_to_page(umem_odp->map.pfn_list[index]);
351
352 bytes = min_t(unsigned int, length,
353 mr_page_size(mr) - page_offset);
354
355 va = kmap_local_page(page);
356 arch_wb_cache_pmem(va + page_offset, bytes);
357 kunmap_local(va);
358
359 length -= bytes;
360 iova += bytes;
361 }
362
363 mutex_unlock(&umem_odp->umem_mutex);
364
365 return 0;
366 }
367
rxe_odp_do_atomic_write(struct rxe_mr * mr,u64 iova,u64 value)368 enum resp_states rxe_odp_do_atomic_write(struct rxe_mr *mr, u64 iova, u64 value)
369 {
370 struct ib_umem_odp *umem_odp = to_ib_umem_odp(mr->umem);
371 unsigned int page_offset;
372 unsigned long index;
373 struct page *page;
374 int err;
375 u64 *va;
376
377 /* See IBA oA19-28 */
378 err = mr_check_range(mr, iova, sizeof(value));
379 if (unlikely(err)) {
380 rxe_dbg_mr(mr, "iova out of range\n");
381 return RESPST_ERR_RKEY_VIOLATION;
382 }
383
384 err = rxe_odp_map_range_and_lock(mr, iova, sizeof(value),
385 RXE_PAGEFAULT_DEFAULT);
386 if (err)
387 return RESPST_ERR_RKEY_VIOLATION;
388
389 page_offset = rxe_odp_iova_to_page_offset(umem_odp, iova);
390 /* See IBA A19.4.2 */
391 if (unlikely(page_offset & 0x7)) {
392 mutex_unlock(&umem_odp->umem_mutex);
393 rxe_dbg_mr(mr, "misaligned address\n");
394 return RESPST_ERR_MISALIGNED_ATOMIC;
395 }
396
397 index = rxe_odp_iova_to_index(umem_odp, iova);
398 page = hmm_pfn_to_page(umem_odp->map.pfn_list[index]);
399
400 va = kmap_local_page(page);
401 /* Do atomic write after all prior operations have completed */
402 smp_store_release(&va[page_offset >> 3], value);
403 kunmap_local(va);
404
405 mutex_unlock(&umem_odp->umem_mutex);
406
407 return RESPST_NONE;
408 }
409
410 struct prefetch_mr_work {
411 struct work_struct work;
412 u32 pf_flags;
413 u32 num_sge;
414 struct {
415 u64 io_virt;
416 struct rxe_mr *mr;
417 size_t length;
418 } frags[];
419 };
420
rxe_ib_prefetch_mr_work(struct work_struct * w)421 static void rxe_ib_prefetch_mr_work(struct work_struct *w)
422 {
423 struct prefetch_mr_work *work =
424 container_of(w, struct prefetch_mr_work, work);
425 int ret;
426 u32 i;
427
428 /*
429 * We rely on IB/core that work is executed
430 * if we have num_sge != 0 only.
431 */
432 WARN_ON(!work->num_sge);
433 for (i = 0; i < work->num_sge; ++i) {
434 struct ib_umem_odp *umem_odp;
435
436 ret = rxe_odp_do_pagefault_and_lock(work->frags[i].mr,
437 work->frags[i].io_virt,
438 work->frags[i].length,
439 work->pf_flags);
440 if (ret < 0) {
441 rxe_dbg_mr(work->frags[i].mr,
442 "failed to prefetch the mr\n");
443 goto deref;
444 }
445
446 umem_odp = to_ib_umem_odp(work->frags[i].mr->umem);
447 mutex_unlock(&umem_odp->umem_mutex);
448
449 deref:
450 rxe_put(work->frags[i].mr);
451 }
452
453 kvfree(work);
454 }
455
rxe_ib_prefetch_sg_list(struct ib_pd * ibpd,enum ib_uverbs_advise_mr_advice advice,u32 pf_flags,struct ib_sge * sg_list,u32 num_sge)456 static int rxe_ib_prefetch_sg_list(struct ib_pd *ibpd,
457 enum ib_uverbs_advise_mr_advice advice,
458 u32 pf_flags, struct ib_sge *sg_list,
459 u32 num_sge)
460 {
461 struct rxe_pd *pd = container_of(ibpd, struct rxe_pd, ibpd);
462 int ret = 0;
463 u32 i;
464
465 for (i = 0; i < num_sge; ++i) {
466 struct rxe_mr *mr;
467 struct ib_umem_odp *umem_odp;
468
469 mr = lookup_mr(pd, IB_ACCESS_LOCAL_WRITE,
470 sg_list[i].lkey, RXE_LOOKUP_LOCAL);
471
472 if (!mr) {
473 rxe_dbg_pd(pd, "mr with lkey %x not found\n",
474 sg_list[i].lkey);
475 return -EINVAL;
476 }
477
478 if (advice == IB_UVERBS_ADVISE_MR_ADVICE_PREFETCH_WRITE &&
479 !mr->umem->writable) {
480 rxe_dbg_mr(mr, "missing write permission\n");
481 rxe_put(mr);
482 return -EPERM;
483 }
484
485 ret = rxe_odp_do_pagefault_and_lock(
486 mr, sg_list[i].addr, sg_list[i].length, pf_flags);
487 if (ret < 0) {
488 rxe_dbg_mr(mr, "failed to prefetch the mr\n");
489 rxe_put(mr);
490 return ret;
491 }
492
493 umem_odp = to_ib_umem_odp(mr->umem);
494 mutex_unlock(&umem_odp->umem_mutex);
495
496 rxe_put(mr);
497 }
498
499 return 0;
500 }
501
rxe_ib_advise_mr_prefetch(struct ib_pd * ibpd,enum ib_uverbs_advise_mr_advice advice,u32 flags,struct ib_sge * sg_list,u32 num_sge)502 static int rxe_ib_advise_mr_prefetch(struct ib_pd *ibpd,
503 enum ib_uverbs_advise_mr_advice advice,
504 u32 flags, struct ib_sge *sg_list,
505 u32 num_sge)
506 {
507 struct rxe_pd *pd = container_of(ibpd, struct rxe_pd, ibpd);
508 u32 pf_flags = RXE_PAGEFAULT_DEFAULT;
509 struct prefetch_mr_work *work;
510 struct rxe_mr *mr;
511 u32 i;
512
513 if (advice == IB_UVERBS_ADVISE_MR_ADVICE_PREFETCH)
514 pf_flags |= RXE_PAGEFAULT_RDONLY;
515
516 if (advice == IB_UVERBS_ADVISE_MR_ADVICE_PREFETCH_NO_FAULT)
517 pf_flags |= RXE_PAGEFAULT_SNAPSHOT;
518
519 /* Synchronous call */
520 if (flags & IB_UVERBS_ADVISE_MR_FLAG_FLUSH)
521 return rxe_ib_prefetch_sg_list(ibpd, advice, pf_flags, sg_list,
522 num_sge);
523
524 /* Asynchronous call is "best-effort" and allowed to fail */
525 work = kvzalloc(struct_size(work, frags, num_sge), GFP_KERNEL);
526 if (!work)
527 return -ENOMEM;
528
529 INIT_WORK(&work->work, rxe_ib_prefetch_mr_work);
530 work->pf_flags = pf_flags;
531 work->num_sge = num_sge;
532
533 for (i = 0; i < num_sge; ++i) {
534 /* Takes a reference, which will be released in the queued work */
535 mr = lookup_mr(pd, IB_ACCESS_LOCAL_WRITE,
536 sg_list[i].lkey, RXE_LOOKUP_LOCAL);
537 if (!mr) {
538 mr = ERR_PTR(-EINVAL);
539 goto err;
540 }
541
542 work->frags[i].io_virt = sg_list[i].addr;
543 work->frags[i].length = sg_list[i].length;
544 work->frags[i].mr = mr;
545 }
546
547 queue_work(system_unbound_wq, &work->work);
548
549 return 0;
550
551 err:
552 /* rollback reference counts for the invalid request */
553 while (i > 0) {
554 i--;
555 rxe_put(work->frags[i].mr);
556 }
557
558 kvfree(work);
559
560 return PTR_ERR(mr);
561 }
562
rxe_ib_advise_mr(struct ib_pd * ibpd,enum ib_uverbs_advise_mr_advice advice,u32 flags,struct ib_sge * sg_list,u32 num_sge,struct uverbs_attr_bundle * attrs)563 int rxe_ib_advise_mr(struct ib_pd *ibpd,
564 enum ib_uverbs_advise_mr_advice advice,
565 u32 flags,
566 struct ib_sge *sg_list,
567 u32 num_sge,
568 struct uverbs_attr_bundle *attrs)
569 {
570 if (advice != IB_UVERBS_ADVISE_MR_ADVICE_PREFETCH &&
571 advice != IB_UVERBS_ADVISE_MR_ADVICE_PREFETCH_WRITE &&
572 advice != IB_UVERBS_ADVISE_MR_ADVICE_PREFETCH_NO_FAULT)
573 return -EOPNOTSUPP;
574
575 return rxe_ib_advise_mr_prefetch(ibpd, advice, flags,
576 sg_list, num_sge);
577 }
578