xref: /linux/lib/test_hmm.c (revision 7203ca412fc8e8a0588e9adc0f777d3163f8dff3)
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * This is a module to test the HMM (Heterogeneous Memory Management)
4  * mirror and zone device private memory migration APIs of the kernel.
5  * Userspace programs can register with the driver to mirror their own address
6  * space and can use the device to read/write any valid virtual address.
7  */
8 #include <linux/init.h>
9 #include <linux/fs.h>
10 #include <linux/mm.h>
11 #include <linux/module.h>
12 #include <linux/kernel.h>
13 #include <linux/cdev.h>
14 #include <linux/device.h>
15 #include <linux/memremap.h>
16 #include <linux/mutex.h>
17 #include <linux/rwsem.h>
18 #include <linux/sched.h>
19 #include <linux/slab.h>
20 #include <linux/highmem.h>
21 #include <linux/delay.h>
22 #include <linux/pagemap.h>
23 #include <linux/hmm.h>
24 #include <linux/vmalloc.h>
25 #include <linux/swap.h>
26 #include <linux/swapops.h>
27 #include <linux/sched/mm.h>
28 #include <linux/platform_device.h>
29 #include <linux/rmap.h>
30 #include <linux/mmu_notifier.h>
31 #include <linux/migrate.h>
32 
33 #include "test_hmm_uapi.h"
34 
35 #define DMIRROR_NDEVICES		4
36 #define DMIRROR_RANGE_FAULT_TIMEOUT	1000
37 #define DEVMEM_CHUNK_SIZE		(256 * 1024 * 1024U)
38 #define DEVMEM_CHUNKS_RESERVE		16
39 
40 /*
41  * For device_private pages, dpage is just a dummy struct page
42  * representing a piece of device memory. dmirror_devmem_alloc_page
43  * allocates a real system memory page as backing storage to fake a
44  * real device. zone_device_data points to that backing page. But
45  * for device_coherent memory, the struct page represents real
46  * physical CPU-accessible memory that we can use directly.
47  */
48 #define BACKING_PAGE(page) (is_device_private_page((page)) ? \
49 			   (page)->zone_device_data : (page))
50 
51 static unsigned long spm_addr_dev0;
52 module_param(spm_addr_dev0, long, 0644);
53 MODULE_PARM_DESC(spm_addr_dev0,
54 		"Specify start address for SPM (special purpose memory) used for device 0. By setting this Coherent device type will be used. Make sure spm_addr_dev1 is set too. Minimum SPM size should be DEVMEM_CHUNK_SIZE.");
55 
56 static unsigned long spm_addr_dev1;
57 module_param(spm_addr_dev1, long, 0644);
58 MODULE_PARM_DESC(spm_addr_dev1,
59 		"Specify start address for SPM (special purpose memory) used for device 1. By setting this Coherent device type will be used. Make sure spm_addr_dev0 is set too. Minimum SPM size should be DEVMEM_CHUNK_SIZE.");
60 
61 static const struct dev_pagemap_ops dmirror_devmem_ops;
62 static const struct mmu_interval_notifier_ops dmirror_min_ops;
63 static dev_t dmirror_dev;
64 
65 struct dmirror_device;
66 
67 struct dmirror_bounce {
68 	void			*ptr;
69 	unsigned long		size;
70 	unsigned long		addr;
71 	unsigned long		cpages;
72 };
73 
74 #define DPT_XA_TAG_ATOMIC 1UL
75 #define DPT_XA_TAG_WRITE 3UL
76 
77 /*
78  * Data structure to track address ranges and register for mmu interval
79  * notifier updates.
80  */
81 struct dmirror_interval {
82 	struct mmu_interval_notifier	notifier;
83 	struct dmirror			*dmirror;
84 };
85 
86 /*
87  * Data attached to the open device file.
88  * Note that it might be shared after a fork().
89  */
90 struct dmirror {
91 	struct dmirror_device		*mdevice;
92 	struct xarray			pt;
93 	struct mmu_interval_notifier	notifier;
94 	struct mutex			mutex;
95 	__u64			flags;
96 };
97 
98 /*
99  * ZONE_DEVICE pages for migration and simulating device memory.
100  */
101 struct dmirror_chunk {
102 	struct dev_pagemap	pagemap;
103 	struct dmirror_device	*mdevice;
104 	bool remove;
105 };
106 
107 /*
108  * Per device data.
109  */
110 struct dmirror_device {
111 	struct cdev		cdevice;
112 	unsigned int            zone_device_type;
113 	struct device		device;
114 
115 	unsigned int		devmem_capacity;
116 	unsigned int		devmem_count;
117 	struct dmirror_chunk	**devmem_chunks;
118 	struct mutex		devmem_lock;	/* protects the above */
119 
120 	unsigned long		calloc;
121 	unsigned long		cfree;
122 	struct page		*free_pages;
123 	struct folio		*free_folios;
124 	spinlock_t		lock;		/* protects the above */
125 };
126 
127 static struct dmirror_device dmirror_devices[DMIRROR_NDEVICES];
128 
dmirror_bounce_init(struct dmirror_bounce * bounce,unsigned long addr,unsigned long size)129 static int dmirror_bounce_init(struct dmirror_bounce *bounce,
130 			       unsigned long addr,
131 			       unsigned long size)
132 {
133 	bounce->addr = addr;
134 	bounce->size = size;
135 	bounce->cpages = 0;
136 	bounce->ptr = vmalloc(size);
137 	if (!bounce->ptr)
138 		return -ENOMEM;
139 	return 0;
140 }
141 
dmirror_is_private_zone(struct dmirror_device * mdevice)142 static bool dmirror_is_private_zone(struct dmirror_device *mdevice)
143 {
144 	return (mdevice->zone_device_type ==
145 		HMM_DMIRROR_MEMORY_DEVICE_PRIVATE);
146 }
147 
148 static enum migrate_vma_direction
dmirror_select_device(struct dmirror * dmirror)149 dmirror_select_device(struct dmirror *dmirror)
150 {
151 	return (dmirror->mdevice->zone_device_type ==
152 		HMM_DMIRROR_MEMORY_DEVICE_PRIVATE) ?
153 		MIGRATE_VMA_SELECT_DEVICE_PRIVATE :
154 		MIGRATE_VMA_SELECT_DEVICE_COHERENT;
155 }
156 
dmirror_bounce_fini(struct dmirror_bounce * bounce)157 static void dmirror_bounce_fini(struct dmirror_bounce *bounce)
158 {
159 	vfree(bounce->ptr);
160 }
161 
dmirror_fops_open(struct inode * inode,struct file * filp)162 static int dmirror_fops_open(struct inode *inode, struct file *filp)
163 {
164 	struct cdev *cdev = inode->i_cdev;
165 	struct dmirror *dmirror;
166 	int ret;
167 
168 	/* Mirror this process address space */
169 	dmirror = kzalloc(sizeof(*dmirror), GFP_KERNEL);
170 	if (dmirror == NULL)
171 		return -ENOMEM;
172 
173 	dmirror->mdevice = container_of(cdev, struct dmirror_device, cdevice);
174 	mutex_init(&dmirror->mutex);
175 	xa_init(&dmirror->pt);
176 
177 	ret = mmu_interval_notifier_insert(&dmirror->notifier, current->mm,
178 				0, ULONG_MAX & PAGE_MASK, &dmirror_min_ops);
179 	if (ret) {
180 		kfree(dmirror);
181 		return ret;
182 	}
183 
184 	filp->private_data = dmirror;
185 	return 0;
186 }
187 
dmirror_fops_release(struct inode * inode,struct file * filp)188 static int dmirror_fops_release(struct inode *inode, struct file *filp)
189 {
190 	struct dmirror *dmirror = filp->private_data;
191 
192 	mmu_interval_notifier_remove(&dmirror->notifier);
193 	xa_destroy(&dmirror->pt);
194 	kfree(dmirror);
195 	return 0;
196 }
197 
dmirror_page_to_chunk(struct page * page)198 static struct dmirror_chunk *dmirror_page_to_chunk(struct page *page)
199 {
200 	return container_of(page_pgmap(page), struct dmirror_chunk,
201 			    pagemap);
202 }
203 
dmirror_page_to_device(struct page * page)204 static struct dmirror_device *dmirror_page_to_device(struct page *page)
205 
206 {
207 	return dmirror_page_to_chunk(page)->mdevice;
208 }
209 
dmirror_do_fault(struct dmirror * dmirror,struct hmm_range * range)210 static int dmirror_do_fault(struct dmirror *dmirror, struct hmm_range *range)
211 {
212 	unsigned long *pfns = range->hmm_pfns;
213 	unsigned long pfn;
214 
215 	for (pfn = (range->start >> PAGE_SHIFT);
216 	     pfn < (range->end >> PAGE_SHIFT);
217 	     pfn++, pfns++) {
218 		struct page *page;
219 		void *entry;
220 
221 		/*
222 		 * Since we asked for hmm_range_fault() to populate pages,
223 		 * it shouldn't return an error entry on success.
224 		 */
225 		WARN_ON(*pfns & HMM_PFN_ERROR);
226 		WARN_ON(!(*pfns & HMM_PFN_VALID));
227 
228 		page = hmm_pfn_to_page(*pfns);
229 		WARN_ON(!page);
230 
231 		entry = page;
232 		if (*pfns & HMM_PFN_WRITE)
233 			entry = xa_tag_pointer(entry, DPT_XA_TAG_WRITE);
234 		else if (WARN_ON(range->default_flags & HMM_PFN_WRITE))
235 			return -EFAULT;
236 		entry = xa_store(&dmirror->pt, pfn, entry, GFP_ATOMIC);
237 		if (xa_is_err(entry))
238 			return xa_err(entry);
239 	}
240 
241 	return 0;
242 }
243 
dmirror_do_update(struct dmirror * dmirror,unsigned long start,unsigned long end)244 static void dmirror_do_update(struct dmirror *dmirror, unsigned long start,
245 			      unsigned long end)
246 {
247 	unsigned long pfn;
248 	void *entry;
249 
250 	/*
251 	 * The XArray doesn't hold references to pages since it relies on
252 	 * the mmu notifier to clear page pointers when they become stale.
253 	 * Therefore, it is OK to just clear the entry.
254 	 */
255 	xa_for_each_range(&dmirror->pt, pfn, entry, start >> PAGE_SHIFT,
256 			  end >> PAGE_SHIFT)
257 		xa_erase(&dmirror->pt, pfn);
258 }
259 
dmirror_interval_invalidate(struct mmu_interval_notifier * mni,const struct mmu_notifier_range * range,unsigned long cur_seq)260 static bool dmirror_interval_invalidate(struct mmu_interval_notifier *mni,
261 				const struct mmu_notifier_range *range,
262 				unsigned long cur_seq)
263 {
264 	struct dmirror *dmirror = container_of(mni, struct dmirror, notifier);
265 
266 	/*
267 	 * Ignore invalidation callbacks for device private pages since
268 	 * the invalidation is handled as part of the migration process.
269 	 */
270 	if (range->event == MMU_NOTIFY_MIGRATE &&
271 	    range->owner == dmirror->mdevice)
272 		return true;
273 
274 	if (mmu_notifier_range_blockable(range))
275 		mutex_lock(&dmirror->mutex);
276 	else if (!mutex_trylock(&dmirror->mutex))
277 		return false;
278 
279 	mmu_interval_set_seq(mni, cur_seq);
280 	dmirror_do_update(dmirror, range->start, range->end);
281 
282 	mutex_unlock(&dmirror->mutex);
283 	return true;
284 }
285 
286 static const struct mmu_interval_notifier_ops dmirror_min_ops = {
287 	.invalidate = dmirror_interval_invalidate,
288 };
289 
dmirror_range_fault(struct dmirror * dmirror,struct hmm_range * range)290 static int dmirror_range_fault(struct dmirror *dmirror,
291 				struct hmm_range *range)
292 {
293 	struct mm_struct *mm = dmirror->notifier.mm;
294 	unsigned long timeout =
295 		jiffies + msecs_to_jiffies(HMM_RANGE_DEFAULT_TIMEOUT);
296 	int ret;
297 
298 	while (true) {
299 		if (time_after(jiffies, timeout)) {
300 			ret = -EBUSY;
301 			goto out;
302 		}
303 
304 		range->notifier_seq = mmu_interval_read_begin(range->notifier);
305 		mmap_read_lock(mm);
306 		ret = hmm_range_fault(range);
307 		mmap_read_unlock(mm);
308 		if (ret) {
309 			if (ret == -EBUSY)
310 				continue;
311 			goto out;
312 		}
313 
314 		mutex_lock(&dmirror->mutex);
315 		if (mmu_interval_read_retry(range->notifier,
316 					    range->notifier_seq)) {
317 			mutex_unlock(&dmirror->mutex);
318 			continue;
319 		}
320 		break;
321 	}
322 
323 	ret = dmirror_do_fault(dmirror, range);
324 
325 	mutex_unlock(&dmirror->mutex);
326 out:
327 	return ret;
328 }
329 
dmirror_fault(struct dmirror * dmirror,unsigned long start,unsigned long end,bool write)330 static int dmirror_fault(struct dmirror *dmirror, unsigned long start,
331 			 unsigned long end, bool write)
332 {
333 	struct mm_struct *mm = dmirror->notifier.mm;
334 	unsigned long addr;
335 	unsigned long pfns[32];
336 	struct hmm_range range = {
337 		.notifier = &dmirror->notifier,
338 		.hmm_pfns = pfns,
339 		.pfn_flags_mask = 0,
340 		.default_flags =
341 			HMM_PFN_REQ_FAULT | (write ? HMM_PFN_REQ_WRITE : 0),
342 		.dev_private_owner = dmirror->mdevice,
343 	};
344 	int ret = 0;
345 
346 	/* Since the mm is for the mirrored process, get a reference first. */
347 	if (!mmget_not_zero(mm))
348 		return 0;
349 
350 	for (addr = start; addr < end; addr = range.end) {
351 		range.start = addr;
352 		range.end = min(addr + (ARRAY_SIZE(pfns) << PAGE_SHIFT), end);
353 
354 		ret = dmirror_range_fault(dmirror, &range);
355 		if (ret)
356 			break;
357 	}
358 
359 	mmput(mm);
360 	return ret;
361 }
362 
dmirror_do_read(struct dmirror * dmirror,unsigned long start,unsigned long end,struct dmirror_bounce * bounce)363 static int dmirror_do_read(struct dmirror *dmirror, unsigned long start,
364 			   unsigned long end, struct dmirror_bounce *bounce)
365 {
366 	unsigned long pfn;
367 	void *ptr;
368 
369 	ptr = bounce->ptr + ((start - bounce->addr) & PAGE_MASK);
370 
371 	for (pfn = start >> PAGE_SHIFT; pfn < (end >> PAGE_SHIFT); pfn++) {
372 		void *entry;
373 		struct page *page;
374 
375 		entry = xa_load(&dmirror->pt, pfn);
376 		page = xa_untag_pointer(entry);
377 		if (!page)
378 			return -ENOENT;
379 
380 		memcpy_from_page(ptr, page, 0, PAGE_SIZE);
381 
382 		ptr += PAGE_SIZE;
383 		bounce->cpages++;
384 	}
385 
386 	return 0;
387 }
388 
dmirror_read(struct dmirror * dmirror,struct hmm_dmirror_cmd * cmd)389 static int dmirror_read(struct dmirror *dmirror, struct hmm_dmirror_cmd *cmd)
390 {
391 	struct dmirror_bounce bounce;
392 	unsigned long start, end;
393 	unsigned long size = cmd->npages << PAGE_SHIFT;
394 	int ret;
395 
396 	start = cmd->addr;
397 	end = start + size;
398 	if (end < start)
399 		return -EINVAL;
400 
401 	ret = dmirror_bounce_init(&bounce, start, size);
402 	if (ret)
403 		return ret;
404 
405 	while (1) {
406 		mutex_lock(&dmirror->mutex);
407 		ret = dmirror_do_read(dmirror, start, end, &bounce);
408 		mutex_unlock(&dmirror->mutex);
409 		if (ret != -ENOENT)
410 			break;
411 
412 		start = cmd->addr + (bounce.cpages << PAGE_SHIFT);
413 		ret = dmirror_fault(dmirror, start, end, false);
414 		if (ret)
415 			break;
416 		cmd->faults++;
417 	}
418 
419 	if (ret == 0) {
420 		if (copy_to_user(u64_to_user_ptr(cmd->ptr), bounce.ptr,
421 				 bounce.size))
422 			ret = -EFAULT;
423 	}
424 	cmd->cpages = bounce.cpages;
425 	dmirror_bounce_fini(&bounce);
426 	return ret;
427 }
428 
dmirror_do_write(struct dmirror * dmirror,unsigned long start,unsigned long end,struct dmirror_bounce * bounce)429 static int dmirror_do_write(struct dmirror *dmirror, unsigned long start,
430 			    unsigned long end, struct dmirror_bounce *bounce)
431 {
432 	unsigned long pfn;
433 	void *ptr;
434 
435 	ptr = bounce->ptr + ((start - bounce->addr) & PAGE_MASK);
436 
437 	for (pfn = start >> PAGE_SHIFT; pfn < (end >> PAGE_SHIFT); pfn++) {
438 		void *entry;
439 		struct page *page;
440 
441 		entry = xa_load(&dmirror->pt, pfn);
442 		page = xa_untag_pointer(entry);
443 		if (!page || xa_pointer_tag(entry) != DPT_XA_TAG_WRITE)
444 			return -ENOENT;
445 
446 		memcpy_to_page(page, 0, ptr, PAGE_SIZE);
447 
448 		ptr += PAGE_SIZE;
449 		bounce->cpages++;
450 	}
451 
452 	return 0;
453 }
454 
dmirror_write(struct dmirror * dmirror,struct hmm_dmirror_cmd * cmd)455 static int dmirror_write(struct dmirror *dmirror, struct hmm_dmirror_cmd *cmd)
456 {
457 	struct dmirror_bounce bounce;
458 	unsigned long start, end;
459 	unsigned long size = cmd->npages << PAGE_SHIFT;
460 	int ret;
461 
462 	start = cmd->addr;
463 	end = start + size;
464 	if (end < start)
465 		return -EINVAL;
466 
467 	ret = dmirror_bounce_init(&bounce, start, size);
468 	if (ret)
469 		return ret;
470 	if (copy_from_user(bounce.ptr, u64_to_user_ptr(cmd->ptr),
471 			   bounce.size)) {
472 		ret = -EFAULT;
473 		goto fini;
474 	}
475 
476 	while (1) {
477 		mutex_lock(&dmirror->mutex);
478 		ret = dmirror_do_write(dmirror, start, end, &bounce);
479 		mutex_unlock(&dmirror->mutex);
480 		if (ret != -ENOENT)
481 			break;
482 
483 		start = cmd->addr + (bounce.cpages << PAGE_SHIFT);
484 		ret = dmirror_fault(dmirror, start, end, true);
485 		if (ret)
486 			break;
487 		cmd->faults++;
488 	}
489 
490 fini:
491 	cmd->cpages = bounce.cpages;
492 	dmirror_bounce_fini(&bounce);
493 	return ret;
494 }
495 
dmirror_allocate_chunk(struct dmirror_device * mdevice,struct page ** ppage,bool is_large)496 static int dmirror_allocate_chunk(struct dmirror_device *mdevice,
497 				  struct page **ppage, bool is_large)
498 {
499 	struct dmirror_chunk *devmem;
500 	struct resource *res = NULL;
501 	unsigned long pfn;
502 	unsigned long pfn_first;
503 	unsigned long pfn_last;
504 	void *ptr;
505 	int ret = -ENOMEM;
506 
507 	devmem = kzalloc(sizeof(*devmem), GFP_KERNEL);
508 	if (!devmem)
509 		return ret;
510 
511 	switch (mdevice->zone_device_type) {
512 	case HMM_DMIRROR_MEMORY_DEVICE_PRIVATE:
513 		res = request_free_mem_region(&iomem_resource, DEVMEM_CHUNK_SIZE,
514 					      "hmm_dmirror");
515 		if (IS_ERR_OR_NULL(res))
516 			goto err_devmem;
517 		devmem->pagemap.range.start = res->start;
518 		devmem->pagemap.range.end = res->end;
519 		devmem->pagemap.type = MEMORY_DEVICE_PRIVATE;
520 		break;
521 	case HMM_DMIRROR_MEMORY_DEVICE_COHERENT:
522 		devmem->pagemap.range.start = (MINOR(mdevice->cdevice.dev) - 2) ?
523 							spm_addr_dev0 :
524 							spm_addr_dev1;
525 		devmem->pagemap.range.end = devmem->pagemap.range.start +
526 					    DEVMEM_CHUNK_SIZE - 1;
527 		devmem->pagemap.type = MEMORY_DEVICE_COHERENT;
528 		break;
529 	default:
530 		ret = -EINVAL;
531 		goto err_devmem;
532 	}
533 
534 	devmem->pagemap.nr_range = 1;
535 	devmem->pagemap.ops = &dmirror_devmem_ops;
536 	devmem->pagemap.owner = mdevice;
537 
538 	mutex_lock(&mdevice->devmem_lock);
539 
540 	if (mdevice->devmem_count == mdevice->devmem_capacity) {
541 		struct dmirror_chunk **new_chunks;
542 		unsigned int new_capacity;
543 
544 		new_capacity = mdevice->devmem_capacity +
545 				DEVMEM_CHUNKS_RESERVE;
546 		new_chunks = krealloc(mdevice->devmem_chunks,
547 				sizeof(new_chunks[0]) * new_capacity,
548 				GFP_KERNEL);
549 		if (!new_chunks)
550 			goto err_release;
551 		mdevice->devmem_capacity = new_capacity;
552 		mdevice->devmem_chunks = new_chunks;
553 	}
554 	ptr = memremap_pages(&devmem->pagemap, numa_node_id());
555 	if (IS_ERR_OR_NULL(ptr)) {
556 		if (ptr)
557 			ret = PTR_ERR(ptr);
558 		else
559 			ret = -EFAULT;
560 		goto err_release;
561 	}
562 
563 	devmem->mdevice = mdevice;
564 	pfn_first = devmem->pagemap.range.start >> PAGE_SHIFT;
565 	pfn_last = pfn_first + (range_len(&devmem->pagemap.range) >> PAGE_SHIFT);
566 	mdevice->devmem_chunks[mdevice->devmem_count++] = devmem;
567 
568 	mutex_unlock(&mdevice->devmem_lock);
569 
570 	pr_info("added new %u MB chunk (total %u chunks, %u MB) PFNs [0x%lx 0x%lx)\n",
571 		DEVMEM_CHUNK_SIZE / (1024 * 1024),
572 		mdevice->devmem_count,
573 		mdevice->devmem_count * (DEVMEM_CHUNK_SIZE / (1024 * 1024)),
574 		pfn_first, pfn_last);
575 
576 	spin_lock(&mdevice->lock);
577 	for (pfn = pfn_first; pfn < pfn_last; ) {
578 		struct page *page = pfn_to_page(pfn);
579 
580 		if (is_large && IS_ALIGNED(pfn, HPAGE_PMD_NR)
581 			&& (pfn + HPAGE_PMD_NR <= pfn_last)) {
582 			page->zone_device_data = mdevice->free_folios;
583 			mdevice->free_folios = page_folio(page);
584 			pfn += HPAGE_PMD_NR;
585 			continue;
586 		}
587 
588 		page->zone_device_data = mdevice->free_pages;
589 		mdevice->free_pages = page;
590 		pfn++;
591 	}
592 
593 	ret = 0;
594 	if (ppage) {
595 		if (is_large) {
596 			if (!mdevice->free_folios) {
597 				ret = -ENOMEM;
598 				goto err_unlock;
599 			}
600 			*ppage = folio_page(mdevice->free_folios, 0);
601 			mdevice->free_folios = (*ppage)->zone_device_data;
602 			mdevice->calloc += HPAGE_PMD_NR;
603 		} else if (mdevice->free_pages) {
604 			*ppage = mdevice->free_pages;
605 			mdevice->free_pages = (*ppage)->zone_device_data;
606 			mdevice->calloc++;
607 		} else {
608 			ret = -ENOMEM;
609 			goto err_unlock;
610 		}
611 	}
612 err_unlock:
613 	spin_unlock(&mdevice->lock);
614 
615 	return ret;
616 
617 err_release:
618 	mutex_unlock(&mdevice->devmem_lock);
619 	if (res && devmem->pagemap.type == MEMORY_DEVICE_PRIVATE)
620 		release_mem_region(devmem->pagemap.range.start,
621 				   range_len(&devmem->pagemap.range));
622 err_devmem:
623 	kfree(devmem);
624 
625 	return ret;
626 }
627 
dmirror_devmem_alloc_page(struct dmirror * dmirror,bool is_large)628 static struct page *dmirror_devmem_alloc_page(struct dmirror *dmirror,
629 					      bool is_large)
630 {
631 	struct page *dpage = NULL;
632 	struct page *rpage = NULL;
633 	unsigned int order = is_large ? HPAGE_PMD_ORDER : 0;
634 	struct dmirror_device *mdevice = dmirror->mdevice;
635 
636 	/*
637 	 * For ZONE_DEVICE private type, this is a fake device so we allocate
638 	 * real system memory to store our device memory.
639 	 * For ZONE_DEVICE coherent type we use the actual dpage to store the
640 	 * data and ignore rpage.
641 	 */
642 	if (dmirror_is_private_zone(mdevice)) {
643 		rpage = folio_page(folio_alloc(GFP_HIGHUSER, order), 0);
644 		if (!rpage)
645 			return NULL;
646 	}
647 	spin_lock(&mdevice->lock);
648 
649 	if (is_large && mdevice->free_folios) {
650 		dpage = folio_page(mdevice->free_folios, 0);
651 		mdevice->free_folios = dpage->zone_device_data;
652 		mdevice->calloc += 1 << order;
653 		spin_unlock(&mdevice->lock);
654 	} else if (!is_large && mdevice->free_pages) {
655 		dpage = mdevice->free_pages;
656 		mdevice->free_pages = dpage->zone_device_data;
657 		mdevice->calloc++;
658 		spin_unlock(&mdevice->lock);
659 	} else {
660 		spin_unlock(&mdevice->lock);
661 		if (dmirror_allocate_chunk(mdevice, &dpage, is_large))
662 			goto error;
663 	}
664 
665 	zone_device_folio_init(page_folio(dpage), order);
666 	dpage->zone_device_data = rpage;
667 	return dpage;
668 
669 error:
670 	if (rpage)
671 		__free_pages(rpage, order);
672 	return NULL;
673 }
674 
dmirror_migrate_alloc_and_copy(struct migrate_vma * args,struct dmirror * dmirror)675 static void dmirror_migrate_alloc_and_copy(struct migrate_vma *args,
676 					   struct dmirror *dmirror)
677 {
678 	const unsigned long *src = args->src;
679 	unsigned long *dst = args->dst;
680 	unsigned long addr;
681 
682 	for (addr = args->start; addr < args->end; ) {
683 		struct page *spage;
684 		struct page *dpage;
685 		struct page *rpage;
686 		bool is_large = *src & MIGRATE_PFN_COMPOUND;
687 		int write = (*src & MIGRATE_PFN_WRITE) ? MIGRATE_PFN_WRITE : 0;
688 		unsigned long nr = 1;
689 
690 		if (!(*src & MIGRATE_PFN_MIGRATE))
691 			goto next;
692 
693 		/*
694 		 * Note that spage might be NULL which is OK since it is an
695 		 * unallocated pte_none() or read-only zero page.
696 		 */
697 		spage = migrate_pfn_to_page(*src);
698 		if (WARN(spage && is_zone_device_page(spage),
699 		     "page already in device spage pfn: 0x%lx\n",
700 		     page_to_pfn(spage)))
701 			goto next;
702 
703 		if (dmirror->flags & HMM_DMIRROR_FLAG_FAIL_ALLOC) {
704 			dmirror->flags &= ~HMM_DMIRROR_FLAG_FAIL_ALLOC;
705 			dpage = NULL;
706 		} else
707 			dpage = dmirror_devmem_alloc_page(dmirror, is_large);
708 
709 		if (!dpage) {
710 			struct folio *folio;
711 			unsigned long i;
712 			unsigned long spfn = *src >> MIGRATE_PFN_SHIFT;
713 			struct page *src_page;
714 
715 			if (!is_large)
716 				goto next;
717 
718 			if (!spage && is_large) {
719 				nr = HPAGE_PMD_NR;
720 			} else {
721 				folio = page_folio(spage);
722 				nr = folio_nr_pages(folio);
723 			}
724 
725 			for (i = 0; i < nr && addr < args->end; i++) {
726 				dpage = dmirror_devmem_alloc_page(dmirror, false);
727 				rpage = BACKING_PAGE(dpage);
728 				rpage->zone_device_data = dmirror;
729 
730 				*dst = migrate_pfn(page_to_pfn(dpage)) | write;
731 				src_page = pfn_to_page(spfn + i);
732 
733 				if (spage)
734 					copy_highpage(rpage, src_page);
735 				else
736 					clear_highpage(rpage);
737 				src++;
738 				dst++;
739 				addr += PAGE_SIZE;
740 			}
741 			continue;
742 		}
743 
744 		rpage = BACKING_PAGE(dpage);
745 
746 		/*
747 		 * Normally, a device would use the page->zone_device_data to
748 		 * point to the mirror but here we use it to hold the page for
749 		 * the simulated device memory and that page holds the pointer
750 		 * to the mirror.
751 		 */
752 		rpage->zone_device_data = dmirror;
753 
754 		pr_debug("migrating from sys to dev pfn src: 0x%lx pfn dst: 0x%lx\n",
755 			 page_to_pfn(spage), page_to_pfn(dpage));
756 
757 		*dst = migrate_pfn(page_to_pfn(dpage)) | write;
758 
759 		if (is_large) {
760 			int i;
761 			struct folio *folio = page_folio(dpage);
762 			*dst |= MIGRATE_PFN_COMPOUND;
763 
764 			if (folio_test_large(folio)) {
765 				for (i = 0; i < folio_nr_pages(folio); i++) {
766 					struct page *dst_page =
767 						pfn_to_page(page_to_pfn(rpage) + i);
768 					struct page *src_page =
769 						pfn_to_page(page_to_pfn(spage) + i);
770 
771 					if (spage)
772 						copy_highpage(dst_page, src_page);
773 					else
774 						clear_highpage(dst_page);
775 					src++;
776 					dst++;
777 					addr += PAGE_SIZE;
778 				}
779 				continue;
780 			}
781 		}
782 
783 		if (spage)
784 			copy_highpage(rpage, spage);
785 		else
786 			clear_highpage(rpage);
787 
788 next:
789 		src++;
790 		dst++;
791 		addr += PAGE_SIZE;
792 	}
793 }
794 
dmirror_check_atomic(struct dmirror * dmirror,unsigned long start,unsigned long end)795 static int dmirror_check_atomic(struct dmirror *dmirror, unsigned long start,
796 			     unsigned long end)
797 {
798 	unsigned long pfn;
799 
800 	for (pfn = start >> PAGE_SHIFT; pfn < (end >> PAGE_SHIFT); pfn++) {
801 		void *entry;
802 
803 		entry = xa_load(&dmirror->pt, pfn);
804 		if (xa_pointer_tag(entry) == DPT_XA_TAG_ATOMIC)
805 			return -EPERM;
806 	}
807 
808 	return 0;
809 }
810 
dmirror_atomic_map(unsigned long addr,struct page * page,struct dmirror * dmirror)811 static int dmirror_atomic_map(unsigned long addr, struct page *page,
812 		struct dmirror *dmirror)
813 {
814 	void *entry;
815 
816 	/* Map the migrated pages into the device's page tables. */
817 	mutex_lock(&dmirror->mutex);
818 
819 	entry = xa_tag_pointer(page, DPT_XA_TAG_ATOMIC);
820 	entry = xa_store(&dmirror->pt, addr >> PAGE_SHIFT, entry, GFP_ATOMIC);
821 	if (xa_is_err(entry)) {
822 		mutex_unlock(&dmirror->mutex);
823 		return xa_err(entry);
824 	}
825 
826 	mutex_unlock(&dmirror->mutex);
827 	return 0;
828 }
829 
dmirror_migrate_finalize_and_map(struct migrate_vma * args,struct dmirror * dmirror)830 static int dmirror_migrate_finalize_and_map(struct migrate_vma *args,
831 					    struct dmirror *dmirror)
832 {
833 	unsigned long start = args->start;
834 	unsigned long end = args->end;
835 	const unsigned long *src = args->src;
836 	const unsigned long *dst = args->dst;
837 	unsigned long pfn;
838 	const unsigned long start_pfn = start >> PAGE_SHIFT;
839 	const unsigned long end_pfn = end >> PAGE_SHIFT;
840 
841 	/* Map the migrated pages into the device's page tables. */
842 	mutex_lock(&dmirror->mutex);
843 
844 	for (pfn = start_pfn; pfn < end_pfn; pfn++, src++, dst++) {
845 		struct page *dpage;
846 		void *entry;
847 		int nr, i;
848 		struct page *rpage;
849 
850 		if (!(*src & MIGRATE_PFN_MIGRATE))
851 			continue;
852 
853 		dpage = migrate_pfn_to_page(*dst);
854 		if (!dpage)
855 			continue;
856 
857 		if (*dst & MIGRATE_PFN_COMPOUND)
858 			nr = folio_nr_pages(page_folio(dpage));
859 		else
860 			nr = 1;
861 
862 		WARN_ON_ONCE(end_pfn < start_pfn + nr);
863 
864 		rpage = BACKING_PAGE(dpage);
865 		VM_WARN_ON(folio_nr_pages(page_folio(rpage)) != nr);
866 
867 		for (i = 0; i < nr; i++) {
868 			entry = folio_page(page_folio(rpage), i);
869 			if (*dst & MIGRATE_PFN_WRITE)
870 				entry = xa_tag_pointer(entry, DPT_XA_TAG_WRITE);
871 			entry = xa_store(&dmirror->pt, pfn + i, entry, GFP_ATOMIC);
872 			if (xa_is_err(entry)) {
873 				mutex_unlock(&dmirror->mutex);
874 				return xa_err(entry);
875 			}
876 		}
877 	}
878 
879 	mutex_unlock(&dmirror->mutex);
880 	return 0;
881 }
882 
dmirror_exclusive(struct dmirror * dmirror,struct hmm_dmirror_cmd * cmd)883 static int dmirror_exclusive(struct dmirror *dmirror,
884 			     struct hmm_dmirror_cmd *cmd)
885 {
886 	unsigned long start, end, addr;
887 	unsigned long size = cmd->npages << PAGE_SHIFT;
888 	struct mm_struct *mm = dmirror->notifier.mm;
889 	struct dmirror_bounce bounce;
890 	int ret = 0;
891 
892 	start = cmd->addr;
893 	end = start + size;
894 	if (end < start)
895 		return -EINVAL;
896 
897 	/* Since the mm is for the mirrored process, get a reference first. */
898 	if (!mmget_not_zero(mm))
899 		return -EINVAL;
900 
901 	mmap_read_lock(mm);
902 	for (addr = start; !ret && addr < end; addr += PAGE_SIZE) {
903 		struct folio *folio;
904 		struct page *page;
905 
906 		page = make_device_exclusive(mm, addr, NULL, &folio);
907 		if (IS_ERR(page)) {
908 			ret = PTR_ERR(page);
909 			break;
910 		}
911 
912 		ret = dmirror_atomic_map(addr, page, dmirror);
913 		folio_unlock(folio);
914 		folio_put(folio);
915 	}
916 	mmap_read_unlock(mm);
917 	mmput(mm);
918 
919 	if (ret)
920 		return ret;
921 
922 	/* Return the migrated data for verification. */
923 	ret = dmirror_bounce_init(&bounce, start, size);
924 	if (ret)
925 		return ret;
926 	mutex_lock(&dmirror->mutex);
927 	ret = dmirror_do_read(dmirror, start, end, &bounce);
928 	mutex_unlock(&dmirror->mutex);
929 	if (ret == 0) {
930 		if (copy_to_user(u64_to_user_ptr(cmd->ptr), bounce.ptr,
931 				 bounce.size))
932 			ret = -EFAULT;
933 	}
934 
935 	cmd->cpages = bounce.cpages;
936 	dmirror_bounce_fini(&bounce);
937 	return ret;
938 }
939 
dmirror_devmem_fault_alloc_and_copy(struct migrate_vma * args,struct dmirror * dmirror)940 static vm_fault_t dmirror_devmem_fault_alloc_and_copy(struct migrate_vma *args,
941 						      struct dmirror *dmirror)
942 {
943 	const unsigned long *src = args->src;
944 	unsigned long *dst = args->dst;
945 	unsigned long start = args->start;
946 	unsigned long end = args->end;
947 	unsigned long addr;
948 	unsigned int order = 0;
949 	int i;
950 
951 	for (addr = start; addr < end; ) {
952 		struct page *dpage, *spage;
953 
954 		spage = migrate_pfn_to_page(*src);
955 		if (!spage || !(*src & MIGRATE_PFN_MIGRATE)) {
956 			addr += PAGE_SIZE;
957 			goto next;
958 		}
959 
960 		if (WARN_ON(!is_device_private_page(spage) &&
961 			    !is_device_coherent_page(spage))) {
962 			addr += PAGE_SIZE;
963 			goto next;
964 		}
965 
966 		spage = BACKING_PAGE(spage);
967 		order = folio_order(page_folio(spage));
968 		if (order)
969 			*dst = MIGRATE_PFN_COMPOUND;
970 		if (*src & MIGRATE_PFN_WRITE)
971 			*dst |= MIGRATE_PFN_WRITE;
972 
973 		if (dmirror->flags & HMM_DMIRROR_FLAG_FAIL_ALLOC) {
974 			dmirror->flags &= ~HMM_DMIRROR_FLAG_FAIL_ALLOC;
975 			*dst &= ~MIGRATE_PFN_COMPOUND;
976 			dpage = NULL;
977 		} else if (order) {
978 			dpage = folio_page(vma_alloc_folio(GFP_HIGHUSER_MOVABLE,
979 						order, args->vma, addr), 0);
980 		} else {
981 			dpage = alloc_page_vma(GFP_HIGHUSER_MOVABLE, args->vma, addr);
982 		}
983 
984 		if (!dpage && !order)
985 			return VM_FAULT_OOM;
986 
987 		pr_debug("migrating from sys to dev pfn src: 0x%lx pfn dst: 0x%lx\n",
988 				page_to_pfn(spage), page_to_pfn(dpage));
989 
990 		if (dpage) {
991 			lock_page(dpage);
992 			*dst |= migrate_pfn(page_to_pfn(dpage));
993 		}
994 
995 		for (i = 0; i < (1 << order); i++) {
996 			struct page *src_page;
997 			struct page *dst_page;
998 
999 			/* Try with smaller pages if large allocation fails */
1000 			if (!dpage && order) {
1001 				dpage = alloc_page_vma(GFP_HIGHUSER_MOVABLE, args->vma, addr);
1002 				lock_page(dpage);
1003 				dst[i] = migrate_pfn(page_to_pfn(dpage));
1004 				dst_page = pfn_to_page(page_to_pfn(dpage));
1005 				dpage = NULL; /* For the next iteration */
1006 			} else {
1007 				dst_page = pfn_to_page(page_to_pfn(dpage) + i);
1008 			}
1009 
1010 			src_page = pfn_to_page(page_to_pfn(spage) + i);
1011 
1012 			xa_erase(&dmirror->pt, addr >> PAGE_SHIFT);
1013 			addr += PAGE_SIZE;
1014 			copy_highpage(dst_page, src_page);
1015 		}
1016 next:
1017 		src += 1 << order;
1018 		dst += 1 << order;
1019 	}
1020 	return 0;
1021 }
1022 
1023 static unsigned long
dmirror_successful_migrated_pages(struct migrate_vma * migrate)1024 dmirror_successful_migrated_pages(struct migrate_vma *migrate)
1025 {
1026 	unsigned long cpages = 0;
1027 	unsigned long i;
1028 
1029 	for (i = 0; i < migrate->npages; i++) {
1030 		if (migrate->src[i] & MIGRATE_PFN_VALID &&
1031 		    migrate->src[i] & MIGRATE_PFN_MIGRATE)
1032 			cpages++;
1033 	}
1034 	return cpages;
1035 }
1036 
dmirror_migrate_to_system(struct dmirror * dmirror,struct hmm_dmirror_cmd * cmd)1037 static int dmirror_migrate_to_system(struct dmirror *dmirror,
1038 				     struct hmm_dmirror_cmd *cmd)
1039 {
1040 	unsigned long start, end, addr;
1041 	unsigned long size = cmd->npages << PAGE_SHIFT;
1042 	struct mm_struct *mm = dmirror->notifier.mm;
1043 	struct vm_area_struct *vma;
1044 	struct migrate_vma args = { 0 };
1045 	unsigned long next;
1046 	int ret;
1047 	unsigned long *src_pfns;
1048 	unsigned long *dst_pfns;
1049 
1050 	src_pfns = kvcalloc(PTRS_PER_PTE, sizeof(*src_pfns), GFP_KERNEL | __GFP_NOFAIL);
1051 	dst_pfns = kvcalloc(PTRS_PER_PTE, sizeof(*dst_pfns), GFP_KERNEL | __GFP_NOFAIL);
1052 
1053 	start = cmd->addr;
1054 	end = start + size;
1055 	if (end < start)
1056 		return -EINVAL;
1057 
1058 	/* Since the mm is for the mirrored process, get a reference first. */
1059 	if (!mmget_not_zero(mm))
1060 		return -EINVAL;
1061 
1062 	cmd->cpages = 0;
1063 	mmap_read_lock(mm);
1064 	for (addr = start; addr < end; addr = next) {
1065 		vma = vma_lookup(mm, addr);
1066 		if (!vma || !(vma->vm_flags & VM_READ)) {
1067 			ret = -EINVAL;
1068 			goto out;
1069 		}
1070 		next = min(end, addr + (PTRS_PER_PTE << PAGE_SHIFT));
1071 		if (next > vma->vm_end)
1072 			next = vma->vm_end;
1073 
1074 		args.vma = vma;
1075 		args.src = src_pfns;
1076 		args.dst = dst_pfns;
1077 		args.start = addr;
1078 		args.end = next;
1079 		args.pgmap_owner = dmirror->mdevice;
1080 		args.flags = dmirror_select_device(dmirror) | MIGRATE_VMA_SELECT_COMPOUND;
1081 
1082 		ret = migrate_vma_setup(&args);
1083 		if (ret)
1084 			goto out;
1085 
1086 		pr_debug("Migrating from device mem to sys mem\n");
1087 		dmirror_devmem_fault_alloc_and_copy(&args, dmirror);
1088 
1089 		migrate_vma_pages(&args);
1090 		cmd->cpages += dmirror_successful_migrated_pages(&args);
1091 		migrate_vma_finalize(&args);
1092 	}
1093 out:
1094 	mmap_read_unlock(mm);
1095 	mmput(mm);
1096 	kvfree(src_pfns);
1097 	kvfree(dst_pfns);
1098 
1099 	return ret;
1100 }
1101 
dmirror_migrate_to_device(struct dmirror * dmirror,struct hmm_dmirror_cmd * cmd)1102 static int dmirror_migrate_to_device(struct dmirror *dmirror,
1103 				struct hmm_dmirror_cmd *cmd)
1104 {
1105 	unsigned long start, end, addr;
1106 	unsigned long size = cmd->npages << PAGE_SHIFT;
1107 	struct mm_struct *mm = dmirror->notifier.mm;
1108 	struct vm_area_struct *vma;
1109 	struct dmirror_bounce bounce;
1110 	struct migrate_vma args = { 0 };
1111 	unsigned long next;
1112 	int ret;
1113 	unsigned long *src_pfns = NULL;
1114 	unsigned long *dst_pfns = NULL;
1115 
1116 	start = cmd->addr;
1117 	end = start + size;
1118 	if (end < start)
1119 		return -EINVAL;
1120 
1121 	/* Since the mm is for the mirrored process, get a reference first. */
1122 	if (!mmget_not_zero(mm))
1123 		return -EINVAL;
1124 
1125 	ret = -ENOMEM;
1126 	src_pfns = kvcalloc(PTRS_PER_PTE, sizeof(*src_pfns),
1127 			  GFP_KERNEL | __GFP_NOFAIL);
1128 	if (!src_pfns)
1129 		goto free_mem;
1130 
1131 	dst_pfns = kvcalloc(PTRS_PER_PTE, sizeof(*dst_pfns),
1132 			  GFP_KERNEL | __GFP_NOFAIL);
1133 	if (!dst_pfns)
1134 		goto free_mem;
1135 
1136 	ret = 0;
1137 	mmap_read_lock(mm);
1138 	for (addr = start; addr < end; addr = next) {
1139 		vma = vma_lookup(mm, addr);
1140 		if (!vma || !(vma->vm_flags & VM_READ)) {
1141 			ret = -EINVAL;
1142 			goto out;
1143 		}
1144 		next = min(end, addr + (PTRS_PER_PTE << PAGE_SHIFT));
1145 		if (next > vma->vm_end)
1146 			next = vma->vm_end;
1147 
1148 		args.vma = vma;
1149 		args.src = src_pfns;
1150 		args.dst = dst_pfns;
1151 		args.start = addr;
1152 		args.end = next;
1153 		args.pgmap_owner = dmirror->mdevice;
1154 		args.flags = MIGRATE_VMA_SELECT_SYSTEM |
1155 				MIGRATE_VMA_SELECT_COMPOUND;
1156 		ret = migrate_vma_setup(&args);
1157 		if (ret)
1158 			goto out;
1159 
1160 		pr_debug("Migrating from sys mem to device mem\n");
1161 		dmirror_migrate_alloc_and_copy(&args, dmirror);
1162 		migrate_vma_pages(&args);
1163 		dmirror_migrate_finalize_and_map(&args, dmirror);
1164 		migrate_vma_finalize(&args);
1165 	}
1166 	mmap_read_unlock(mm);
1167 	mmput(mm);
1168 
1169 	/*
1170 	 * Return the migrated data for verification.
1171 	 * Only for pages in device zone
1172 	 */
1173 	ret = dmirror_bounce_init(&bounce, start, size);
1174 	if (ret)
1175 		goto free_mem;
1176 	mutex_lock(&dmirror->mutex);
1177 	ret = dmirror_do_read(dmirror, start, end, &bounce);
1178 	mutex_unlock(&dmirror->mutex);
1179 	if (ret == 0) {
1180 		if (copy_to_user(u64_to_user_ptr(cmd->ptr), bounce.ptr,
1181 				 bounce.size))
1182 			ret = -EFAULT;
1183 	}
1184 	cmd->cpages = bounce.cpages;
1185 	dmirror_bounce_fini(&bounce);
1186 	goto free_mem;
1187 
1188 out:
1189 	mmap_read_unlock(mm);
1190 	mmput(mm);
1191 free_mem:
1192 	kfree(src_pfns);
1193 	kfree(dst_pfns);
1194 	return ret;
1195 }
1196 
dmirror_mkentry(struct dmirror * dmirror,struct hmm_range * range,unsigned char * perm,unsigned long entry)1197 static void dmirror_mkentry(struct dmirror *dmirror, struct hmm_range *range,
1198 			    unsigned char *perm, unsigned long entry)
1199 {
1200 	struct page *page;
1201 
1202 	if (entry & HMM_PFN_ERROR) {
1203 		*perm = HMM_DMIRROR_PROT_ERROR;
1204 		return;
1205 	}
1206 	if (!(entry & HMM_PFN_VALID)) {
1207 		*perm = HMM_DMIRROR_PROT_NONE;
1208 		return;
1209 	}
1210 
1211 	page = hmm_pfn_to_page(entry);
1212 	if (is_device_private_page(page)) {
1213 		/* Is the page migrated to this device or some other? */
1214 		if (dmirror->mdevice == dmirror_page_to_device(page))
1215 			*perm = HMM_DMIRROR_PROT_DEV_PRIVATE_LOCAL;
1216 		else
1217 			*perm = HMM_DMIRROR_PROT_DEV_PRIVATE_REMOTE;
1218 	} else if (is_device_coherent_page(page)) {
1219 		/* Is the page migrated to this device or some other? */
1220 		if (dmirror->mdevice == dmirror_page_to_device(page))
1221 			*perm = HMM_DMIRROR_PROT_DEV_COHERENT_LOCAL;
1222 		else
1223 			*perm = HMM_DMIRROR_PROT_DEV_COHERENT_REMOTE;
1224 	} else if (is_zero_pfn(page_to_pfn(page)))
1225 		*perm = HMM_DMIRROR_PROT_ZERO;
1226 	else
1227 		*perm = HMM_DMIRROR_PROT_NONE;
1228 	if (entry & HMM_PFN_WRITE)
1229 		*perm |= HMM_DMIRROR_PROT_WRITE;
1230 	else
1231 		*perm |= HMM_DMIRROR_PROT_READ;
1232 	if (hmm_pfn_to_map_order(entry) + PAGE_SHIFT == PMD_SHIFT)
1233 		*perm |= HMM_DMIRROR_PROT_PMD;
1234 	else if (hmm_pfn_to_map_order(entry) + PAGE_SHIFT == PUD_SHIFT)
1235 		*perm |= HMM_DMIRROR_PROT_PUD;
1236 }
1237 
dmirror_snapshot_invalidate(struct mmu_interval_notifier * mni,const struct mmu_notifier_range * range,unsigned long cur_seq)1238 static bool dmirror_snapshot_invalidate(struct mmu_interval_notifier *mni,
1239 				const struct mmu_notifier_range *range,
1240 				unsigned long cur_seq)
1241 {
1242 	struct dmirror_interval *dmi =
1243 		container_of(mni, struct dmirror_interval, notifier);
1244 	struct dmirror *dmirror = dmi->dmirror;
1245 
1246 	if (mmu_notifier_range_blockable(range))
1247 		mutex_lock(&dmirror->mutex);
1248 	else if (!mutex_trylock(&dmirror->mutex))
1249 		return false;
1250 
1251 	/*
1252 	 * Snapshots only need to set the sequence number since any
1253 	 * invalidation in the interval invalidates the whole snapshot.
1254 	 */
1255 	mmu_interval_set_seq(mni, cur_seq);
1256 
1257 	mutex_unlock(&dmirror->mutex);
1258 	return true;
1259 }
1260 
1261 static const struct mmu_interval_notifier_ops dmirror_mrn_ops = {
1262 	.invalidate = dmirror_snapshot_invalidate,
1263 };
1264 
dmirror_range_snapshot(struct dmirror * dmirror,struct hmm_range * range,unsigned char * perm)1265 static int dmirror_range_snapshot(struct dmirror *dmirror,
1266 				  struct hmm_range *range,
1267 				  unsigned char *perm)
1268 {
1269 	struct mm_struct *mm = dmirror->notifier.mm;
1270 	struct dmirror_interval notifier;
1271 	unsigned long timeout =
1272 		jiffies + msecs_to_jiffies(HMM_RANGE_DEFAULT_TIMEOUT);
1273 	unsigned long i;
1274 	unsigned long n;
1275 	int ret = 0;
1276 
1277 	notifier.dmirror = dmirror;
1278 	range->notifier = &notifier.notifier;
1279 
1280 	ret = mmu_interval_notifier_insert(range->notifier, mm,
1281 			range->start, range->end - range->start,
1282 			&dmirror_mrn_ops);
1283 	if (ret)
1284 		return ret;
1285 
1286 	while (true) {
1287 		if (time_after(jiffies, timeout)) {
1288 			ret = -EBUSY;
1289 			goto out;
1290 		}
1291 
1292 		range->notifier_seq = mmu_interval_read_begin(range->notifier);
1293 
1294 		mmap_read_lock(mm);
1295 		ret = hmm_range_fault(range);
1296 		mmap_read_unlock(mm);
1297 		if (ret) {
1298 			if (ret == -EBUSY)
1299 				continue;
1300 			goto out;
1301 		}
1302 
1303 		mutex_lock(&dmirror->mutex);
1304 		if (mmu_interval_read_retry(range->notifier,
1305 					    range->notifier_seq)) {
1306 			mutex_unlock(&dmirror->mutex);
1307 			continue;
1308 		}
1309 		break;
1310 	}
1311 
1312 	n = (range->end - range->start) >> PAGE_SHIFT;
1313 	for (i = 0; i < n; i++)
1314 		dmirror_mkentry(dmirror, range, perm + i, range->hmm_pfns[i]);
1315 
1316 	mutex_unlock(&dmirror->mutex);
1317 out:
1318 	mmu_interval_notifier_remove(range->notifier);
1319 	return ret;
1320 }
1321 
dmirror_snapshot(struct dmirror * dmirror,struct hmm_dmirror_cmd * cmd)1322 static int dmirror_snapshot(struct dmirror *dmirror,
1323 			    struct hmm_dmirror_cmd *cmd)
1324 {
1325 	struct mm_struct *mm = dmirror->notifier.mm;
1326 	unsigned long start, end;
1327 	unsigned long size = cmd->npages << PAGE_SHIFT;
1328 	unsigned long addr;
1329 	unsigned long next;
1330 	unsigned long pfns[32];
1331 	unsigned char perm[32];
1332 	char __user *uptr;
1333 	struct hmm_range range = {
1334 		.hmm_pfns = pfns,
1335 		.dev_private_owner = dmirror->mdevice,
1336 	};
1337 	int ret = 0;
1338 
1339 	start = cmd->addr;
1340 	end = start + size;
1341 	if (end < start)
1342 		return -EINVAL;
1343 
1344 	/* Since the mm is for the mirrored process, get a reference first. */
1345 	if (!mmget_not_zero(mm))
1346 		return -EINVAL;
1347 
1348 	/*
1349 	 * Register a temporary notifier to detect invalidations even if it
1350 	 * overlaps with other mmu_interval_notifiers.
1351 	 */
1352 	uptr = u64_to_user_ptr(cmd->ptr);
1353 	for (addr = start; addr < end; addr = next) {
1354 		unsigned long n;
1355 
1356 		next = min(addr + (ARRAY_SIZE(pfns) << PAGE_SHIFT), end);
1357 		range.start = addr;
1358 		range.end = next;
1359 
1360 		ret = dmirror_range_snapshot(dmirror, &range, perm);
1361 		if (ret)
1362 			break;
1363 
1364 		n = (range.end - range.start) >> PAGE_SHIFT;
1365 		if (copy_to_user(uptr, perm, n)) {
1366 			ret = -EFAULT;
1367 			break;
1368 		}
1369 
1370 		cmd->cpages += n;
1371 		uptr += n;
1372 	}
1373 	mmput(mm);
1374 
1375 	return ret;
1376 }
1377 
dmirror_device_evict_chunk(struct dmirror_chunk * chunk)1378 static void dmirror_device_evict_chunk(struct dmirror_chunk *chunk)
1379 {
1380 	unsigned long start_pfn = chunk->pagemap.range.start >> PAGE_SHIFT;
1381 	unsigned long end_pfn = chunk->pagemap.range.end >> PAGE_SHIFT;
1382 	unsigned long npages = end_pfn - start_pfn + 1;
1383 	unsigned long i;
1384 	unsigned long *src_pfns;
1385 	unsigned long *dst_pfns;
1386 	unsigned int order = 0;
1387 
1388 	src_pfns = kvcalloc(npages, sizeof(*src_pfns), GFP_KERNEL | __GFP_NOFAIL);
1389 	dst_pfns = kvcalloc(npages, sizeof(*dst_pfns), GFP_KERNEL | __GFP_NOFAIL);
1390 
1391 	migrate_device_range(src_pfns, start_pfn, npages);
1392 	for (i = 0; i < npages; i++) {
1393 		struct page *dpage, *spage;
1394 
1395 		spage = migrate_pfn_to_page(src_pfns[i]);
1396 		if (!spage || !(src_pfns[i] & MIGRATE_PFN_MIGRATE))
1397 			continue;
1398 
1399 		if (WARN_ON(!is_device_private_page(spage) &&
1400 			    !is_device_coherent_page(spage)))
1401 			continue;
1402 
1403 		order = folio_order(page_folio(spage));
1404 		spage = BACKING_PAGE(spage);
1405 		if (src_pfns[i] & MIGRATE_PFN_COMPOUND) {
1406 			dpage = folio_page(folio_alloc(GFP_HIGHUSER_MOVABLE,
1407 					      order), 0);
1408 		} else {
1409 			dpage = alloc_page(GFP_HIGHUSER_MOVABLE | __GFP_NOFAIL);
1410 			order = 0;
1411 		}
1412 
1413 		/* TODO Support splitting here */
1414 		lock_page(dpage);
1415 		dst_pfns[i] = migrate_pfn(page_to_pfn(dpage));
1416 		if (src_pfns[i] & MIGRATE_PFN_WRITE)
1417 			dst_pfns[i] |= MIGRATE_PFN_WRITE;
1418 		if (order)
1419 			dst_pfns[i] |= MIGRATE_PFN_COMPOUND;
1420 		folio_copy(page_folio(dpage), page_folio(spage));
1421 	}
1422 	migrate_device_pages(src_pfns, dst_pfns, npages);
1423 	migrate_device_finalize(src_pfns, dst_pfns, npages);
1424 	kvfree(src_pfns);
1425 	kvfree(dst_pfns);
1426 }
1427 
1428 /* Removes free pages from the free list so they can't be re-allocated */
dmirror_remove_free_pages(struct dmirror_chunk * devmem)1429 static void dmirror_remove_free_pages(struct dmirror_chunk *devmem)
1430 {
1431 	struct dmirror_device *mdevice = devmem->mdevice;
1432 	struct page *page;
1433 	struct folio *folio;
1434 
1435 
1436 	for (folio = mdevice->free_folios; folio; folio = folio_zone_device_data(folio))
1437 		if (dmirror_page_to_chunk(folio_page(folio, 0)) == devmem)
1438 			mdevice->free_folios = folio_zone_device_data(folio);
1439 	for (page = mdevice->free_pages; page; page = page->zone_device_data)
1440 		if (dmirror_page_to_chunk(page) == devmem)
1441 			mdevice->free_pages = page->zone_device_data;
1442 }
1443 
dmirror_device_remove_chunks(struct dmirror_device * mdevice)1444 static void dmirror_device_remove_chunks(struct dmirror_device *mdevice)
1445 {
1446 	unsigned int i;
1447 
1448 	mutex_lock(&mdevice->devmem_lock);
1449 	if (mdevice->devmem_chunks) {
1450 		for (i = 0; i < mdevice->devmem_count; i++) {
1451 			struct dmirror_chunk *devmem =
1452 				mdevice->devmem_chunks[i];
1453 
1454 			spin_lock(&mdevice->lock);
1455 			devmem->remove = true;
1456 			dmirror_remove_free_pages(devmem);
1457 			spin_unlock(&mdevice->lock);
1458 
1459 			dmirror_device_evict_chunk(devmem);
1460 			memunmap_pages(&devmem->pagemap);
1461 			if (devmem->pagemap.type == MEMORY_DEVICE_PRIVATE)
1462 				release_mem_region(devmem->pagemap.range.start,
1463 						   range_len(&devmem->pagemap.range));
1464 			kfree(devmem);
1465 		}
1466 		mdevice->devmem_count = 0;
1467 		mdevice->devmem_capacity = 0;
1468 		mdevice->free_pages = NULL;
1469 		mdevice->free_folios = NULL;
1470 		kfree(mdevice->devmem_chunks);
1471 		mdevice->devmem_chunks = NULL;
1472 	}
1473 	mutex_unlock(&mdevice->devmem_lock);
1474 }
1475 
dmirror_fops_unlocked_ioctl(struct file * filp,unsigned int command,unsigned long arg)1476 static long dmirror_fops_unlocked_ioctl(struct file *filp,
1477 					unsigned int command,
1478 					unsigned long arg)
1479 {
1480 	void __user *uarg = (void __user *)arg;
1481 	struct hmm_dmirror_cmd cmd;
1482 	struct dmirror *dmirror;
1483 	int ret;
1484 
1485 	dmirror = filp->private_data;
1486 	if (!dmirror)
1487 		return -EINVAL;
1488 
1489 	if (copy_from_user(&cmd, uarg, sizeof(cmd)))
1490 		return -EFAULT;
1491 
1492 	if (cmd.addr & ~PAGE_MASK)
1493 		return -EINVAL;
1494 	if (cmd.addr >= (cmd.addr + (cmd.npages << PAGE_SHIFT)))
1495 		return -EINVAL;
1496 
1497 	cmd.cpages = 0;
1498 	cmd.faults = 0;
1499 
1500 	switch (command) {
1501 	case HMM_DMIRROR_READ:
1502 		ret = dmirror_read(dmirror, &cmd);
1503 		break;
1504 
1505 	case HMM_DMIRROR_WRITE:
1506 		ret = dmirror_write(dmirror, &cmd);
1507 		break;
1508 
1509 	case HMM_DMIRROR_MIGRATE_TO_DEV:
1510 		ret = dmirror_migrate_to_device(dmirror, &cmd);
1511 		break;
1512 
1513 	case HMM_DMIRROR_MIGRATE_TO_SYS:
1514 		ret = dmirror_migrate_to_system(dmirror, &cmd);
1515 		break;
1516 
1517 	case HMM_DMIRROR_EXCLUSIVE:
1518 		ret = dmirror_exclusive(dmirror, &cmd);
1519 		break;
1520 
1521 	case HMM_DMIRROR_CHECK_EXCLUSIVE:
1522 		ret = dmirror_check_atomic(dmirror, cmd.addr,
1523 					cmd.addr + (cmd.npages << PAGE_SHIFT));
1524 		break;
1525 
1526 	case HMM_DMIRROR_SNAPSHOT:
1527 		ret = dmirror_snapshot(dmirror, &cmd);
1528 		break;
1529 
1530 	case HMM_DMIRROR_RELEASE:
1531 		dmirror_device_remove_chunks(dmirror->mdevice);
1532 		ret = 0;
1533 		break;
1534 	case HMM_DMIRROR_FLAGS:
1535 		dmirror->flags = cmd.npages;
1536 		ret = 0;
1537 		break;
1538 
1539 	default:
1540 		return -EINVAL;
1541 	}
1542 	if (ret)
1543 		return ret;
1544 
1545 	if (copy_to_user(uarg, &cmd, sizeof(cmd)))
1546 		return -EFAULT;
1547 
1548 	return 0;
1549 }
1550 
dmirror_fops_mmap(struct file * file,struct vm_area_struct * vma)1551 static int dmirror_fops_mmap(struct file *file, struct vm_area_struct *vma)
1552 {
1553 	unsigned long addr;
1554 
1555 	for (addr = vma->vm_start; addr < vma->vm_end; addr += PAGE_SIZE) {
1556 		struct page *page;
1557 		int ret;
1558 
1559 		page = alloc_page(GFP_KERNEL | __GFP_ZERO);
1560 		if (!page)
1561 			return -ENOMEM;
1562 
1563 		ret = vm_insert_page(vma, addr, page);
1564 		if (ret) {
1565 			__free_page(page);
1566 			return ret;
1567 		}
1568 		put_page(page);
1569 	}
1570 
1571 	return 0;
1572 }
1573 
1574 static const struct file_operations dmirror_fops = {
1575 	.open		= dmirror_fops_open,
1576 	.release	= dmirror_fops_release,
1577 	.mmap		= dmirror_fops_mmap,
1578 	.unlocked_ioctl = dmirror_fops_unlocked_ioctl,
1579 	.llseek		= default_llseek,
1580 	.owner		= THIS_MODULE,
1581 };
1582 
dmirror_devmem_free(struct folio * folio)1583 static void dmirror_devmem_free(struct folio *folio)
1584 {
1585 	struct page *page = &folio->page;
1586 	struct page *rpage = BACKING_PAGE(page);
1587 	struct dmirror_device *mdevice;
1588 	struct folio *rfolio = page_folio(rpage);
1589 	unsigned int order = folio_order(rfolio);
1590 
1591 	if (rpage != page) {
1592 		if (order)
1593 			__free_pages(rpage, order);
1594 		else
1595 			__free_page(rpage);
1596 		rpage = NULL;
1597 	}
1598 
1599 	mdevice = dmirror_page_to_device(page);
1600 	spin_lock(&mdevice->lock);
1601 
1602 	/* Return page to our allocator if not freeing the chunk */
1603 	if (!dmirror_page_to_chunk(page)->remove) {
1604 		mdevice->cfree += 1 << order;
1605 		if (order) {
1606 			page->zone_device_data = mdevice->free_folios;
1607 			mdevice->free_folios = page_folio(page);
1608 		} else {
1609 			page->zone_device_data = mdevice->free_pages;
1610 			mdevice->free_pages = page;
1611 		}
1612 	}
1613 	spin_unlock(&mdevice->lock);
1614 }
1615 
dmirror_devmem_fault(struct vm_fault * vmf)1616 static vm_fault_t dmirror_devmem_fault(struct vm_fault *vmf)
1617 {
1618 	struct migrate_vma args = { 0 };
1619 	struct page *rpage;
1620 	struct dmirror *dmirror;
1621 	vm_fault_t ret = 0;
1622 	unsigned int order, nr;
1623 
1624 	/*
1625 	 * Normally, a device would use the page->zone_device_data to point to
1626 	 * the mirror but here we use it to hold the page for the simulated
1627 	 * device memory and that page holds the pointer to the mirror.
1628 	 */
1629 	rpage = folio_zone_device_data(page_folio(vmf->page));
1630 	dmirror = rpage->zone_device_data;
1631 
1632 	/* FIXME demonstrate how we can adjust migrate range */
1633 	order = folio_order(page_folio(vmf->page));
1634 	nr = 1 << order;
1635 
1636 	/*
1637 	 * When folios are partially mapped, we can't rely on the folio
1638 	 * order of vmf->page as the folio might not be fully split yet
1639 	 */
1640 	if (vmf->pte) {
1641 		order = 0;
1642 		nr = 1;
1643 	}
1644 
1645 	/*
1646 	 * Consider a per-cpu cache of src and dst pfns, but with
1647 	 * large number of cpus that might not scale well.
1648 	 */
1649 	args.start = ALIGN_DOWN(vmf->address, (PAGE_SIZE << order));
1650 	args.vma = vmf->vma;
1651 	args.end = args.start + (PAGE_SIZE << order);
1652 
1653 	nr = (args.end - args.start) >> PAGE_SHIFT;
1654 	args.src = kcalloc(nr, sizeof(unsigned long), GFP_KERNEL);
1655 	args.dst = kcalloc(nr, sizeof(unsigned long), GFP_KERNEL);
1656 	args.pgmap_owner = dmirror->mdevice;
1657 	args.flags = dmirror_select_device(dmirror);
1658 	args.fault_page = vmf->page;
1659 
1660 	if (!args.src || !args.dst) {
1661 		ret = VM_FAULT_OOM;
1662 		goto err;
1663 	}
1664 
1665 	if (order)
1666 		args.flags |= MIGRATE_VMA_SELECT_COMPOUND;
1667 
1668 	if (migrate_vma_setup(&args))
1669 		return VM_FAULT_SIGBUS;
1670 
1671 	ret = dmirror_devmem_fault_alloc_and_copy(&args, dmirror);
1672 	if (ret)
1673 		goto err;
1674 	migrate_vma_pages(&args);
1675 	/*
1676 	 * No device finalize step is needed since
1677 	 * dmirror_devmem_fault_alloc_and_copy() will have already
1678 	 * invalidated the device page table.
1679 	 */
1680 	migrate_vma_finalize(&args);
1681 err:
1682 	kfree(args.src);
1683 	kfree(args.dst);
1684 	return ret;
1685 }
1686 
dmirror_devmem_folio_split(struct folio * head,struct folio * tail)1687 static void dmirror_devmem_folio_split(struct folio *head, struct folio *tail)
1688 {
1689 	struct page *rpage = BACKING_PAGE(folio_page(head, 0));
1690 	struct page *rpage_tail;
1691 	struct folio *rfolio;
1692 	unsigned long offset = 0;
1693 
1694 	if (!rpage) {
1695 		tail->page.zone_device_data = NULL;
1696 		return;
1697 	}
1698 
1699 	rfolio = page_folio(rpage);
1700 
1701 	if (tail == NULL) {
1702 		folio_reset_order(rfolio);
1703 		rfolio->mapping = NULL;
1704 		folio_set_count(rfolio, 1);
1705 		return;
1706 	}
1707 
1708 	offset = folio_pfn(tail) - folio_pfn(head);
1709 
1710 	rpage_tail = folio_page(rfolio, offset);
1711 	tail->page.zone_device_data = rpage_tail;
1712 	rpage_tail->zone_device_data = rpage->zone_device_data;
1713 	clear_compound_head(rpage_tail);
1714 	rpage_tail->mapping = NULL;
1715 
1716 	folio_page(tail, 0)->mapping = folio_page(head, 0)->mapping;
1717 	tail->pgmap = head->pgmap;
1718 	folio_set_count(page_folio(rpage_tail), 1);
1719 }
1720 
1721 static const struct dev_pagemap_ops dmirror_devmem_ops = {
1722 	.folio_free	= dmirror_devmem_free,
1723 	.migrate_to_ram	= dmirror_devmem_fault,
1724 	.folio_split	= dmirror_devmem_folio_split,
1725 };
1726 
dmirror_device_init(struct dmirror_device * mdevice,int id)1727 static int dmirror_device_init(struct dmirror_device *mdevice, int id)
1728 {
1729 	dev_t dev;
1730 	int ret;
1731 
1732 	dev = MKDEV(MAJOR(dmirror_dev), id);
1733 	mutex_init(&mdevice->devmem_lock);
1734 	spin_lock_init(&mdevice->lock);
1735 
1736 	cdev_init(&mdevice->cdevice, &dmirror_fops);
1737 	mdevice->cdevice.owner = THIS_MODULE;
1738 	device_initialize(&mdevice->device);
1739 	mdevice->device.devt = dev;
1740 
1741 	ret = dev_set_name(&mdevice->device, "hmm_dmirror%u", id);
1742 	if (ret)
1743 		goto put_device;
1744 
1745 	ret = cdev_device_add(&mdevice->cdevice, &mdevice->device);
1746 	if (ret)
1747 		goto put_device;
1748 
1749 	/* Build a list of free ZONE_DEVICE struct pages */
1750 	return dmirror_allocate_chunk(mdevice, NULL, false);
1751 
1752 put_device:
1753 	put_device(&mdevice->device);
1754 	return ret;
1755 }
1756 
dmirror_device_remove(struct dmirror_device * mdevice)1757 static void dmirror_device_remove(struct dmirror_device *mdevice)
1758 {
1759 	dmirror_device_remove_chunks(mdevice);
1760 	cdev_device_del(&mdevice->cdevice, &mdevice->device);
1761 	put_device(&mdevice->device);
1762 }
1763 
hmm_dmirror_init(void)1764 static int __init hmm_dmirror_init(void)
1765 {
1766 	int ret;
1767 	int id = 0;
1768 	int ndevices = 0;
1769 
1770 	ret = alloc_chrdev_region(&dmirror_dev, 0, DMIRROR_NDEVICES,
1771 				  "HMM_DMIRROR");
1772 	if (ret)
1773 		goto err_unreg;
1774 
1775 	memset(dmirror_devices, 0, DMIRROR_NDEVICES * sizeof(dmirror_devices[0]));
1776 	dmirror_devices[ndevices++].zone_device_type =
1777 				HMM_DMIRROR_MEMORY_DEVICE_PRIVATE;
1778 	dmirror_devices[ndevices++].zone_device_type =
1779 				HMM_DMIRROR_MEMORY_DEVICE_PRIVATE;
1780 	if (spm_addr_dev0 && spm_addr_dev1) {
1781 		dmirror_devices[ndevices++].zone_device_type =
1782 					HMM_DMIRROR_MEMORY_DEVICE_COHERENT;
1783 		dmirror_devices[ndevices++].zone_device_type =
1784 					HMM_DMIRROR_MEMORY_DEVICE_COHERENT;
1785 	}
1786 	for (id = 0; id < ndevices; id++) {
1787 		ret = dmirror_device_init(dmirror_devices + id, id);
1788 		if (ret)
1789 			goto err_chrdev;
1790 	}
1791 
1792 	pr_info("HMM test module loaded. This is only for testing HMM.\n");
1793 	return 0;
1794 
1795 err_chrdev:
1796 	while (--id >= 0)
1797 		dmirror_device_remove(dmirror_devices + id);
1798 	unregister_chrdev_region(dmirror_dev, DMIRROR_NDEVICES);
1799 err_unreg:
1800 	return ret;
1801 }
1802 
hmm_dmirror_exit(void)1803 static void __exit hmm_dmirror_exit(void)
1804 {
1805 	int id;
1806 
1807 	for (id = 0; id < DMIRROR_NDEVICES; id++)
1808 		if (dmirror_devices[id].zone_device_type)
1809 			dmirror_device_remove(dmirror_devices + id);
1810 	unregister_chrdev_region(dmirror_dev, DMIRROR_NDEVICES);
1811 }
1812 
1813 module_init(hmm_dmirror_init);
1814 module_exit(hmm_dmirror_exit);
1815 MODULE_DESCRIPTION("HMM (Heterogeneous Memory Management) test module");
1816 MODULE_LICENSE("GPL");
1817