xref: /linux/lib/test_hmm.c (revision af69016dab967346f759016ca503ebc61dd048b5)
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * This is a module to test the HMM (Heterogeneous Memory Management)
4  * mirror and zone device private memory migration APIs of the kernel.
5  * Userspace programs can register with the driver to mirror their own address
6  * space and can use the device to read/write any valid virtual address.
7  */
8 #include <linux/init.h>
9 #include <linux/fs.h>
10 #include <linux/mm.h>
11 #include <linux/module.h>
12 #include <linux/kernel.h>
13 #include <linux/cdev.h>
14 #include <linux/device.h>
15 #include <linux/memremap.h>
16 #include <linux/mutex.h>
17 #include <linux/rwsem.h>
18 #include <linux/sched.h>
19 #include <linux/slab.h>
20 #include <linux/highmem.h>
21 #include <linux/delay.h>
22 #include <linux/pagemap.h>
23 #include <linux/hmm.h>
24 #include <linux/vmalloc.h>
25 #include <linux/swap.h>
26 #include <linux/swapops.h>
27 #include <linux/sched/mm.h>
28 #include <linux/platform_device.h>
29 #include <linux/rmap.h>
30 #include <linux/mmu_notifier.h>
31 #include <linux/migrate.h>
32 
33 #include "test_hmm_uapi.h"
34 
35 #define DMIRROR_NDEVICES		4
36 #define DMIRROR_RANGE_FAULT_TIMEOUT	1000
37 #define DEVMEM_CHUNK_SIZE		(256 * 1024 * 1024U)
38 #define DEVMEM_CHUNKS_RESERVE		16
39 
40 /*
41  * For device_private pages, dpage is just a dummy struct page
42  * representing a piece of device memory. dmirror_devmem_alloc_page
43  * allocates a real system memory page as backing storage to fake a
44  * real device. zone_device_data points to that backing page. But
45  * for device_coherent memory, the struct page represents real
46  * physical CPU-accessible memory that we can use directly.
47  */
48 #define BACKING_PAGE(page) (is_device_private_page((page)) ? \
49 			   (page)->zone_device_data : (page))
50 
51 static unsigned long spm_addr_dev0;
52 module_param(spm_addr_dev0, long, 0644);
53 MODULE_PARM_DESC(spm_addr_dev0,
54 		"Specify start address for SPM (special purpose memory) used for device 0. By setting this Coherent device type will be used. Make sure spm_addr_dev1 is set too. Minimum SPM size should be DEVMEM_CHUNK_SIZE.");
55 
56 static unsigned long spm_addr_dev1;
57 module_param(spm_addr_dev1, long, 0644);
58 MODULE_PARM_DESC(spm_addr_dev1,
59 		"Specify start address for SPM (special purpose memory) used for device 1. By setting this Coherent device type will be used. Make sure spm_addr_dev0 is set too. Minimum SPM size should be DEVMEM_CHUNK_SIZE.");
60 
61 static const struct dev_pagemap_ops dmirror_devmem_ops;
62 static const struct mmu_interval_notifier_ops dmirror_min_ops;
63 static dev_t dmirror_dev;
64 
65 struct dmirror_device;
66 
67 struct dmirror_bounce {
68 	void			*ptr;
69 	unsigned long		size;
70 	unsigned long		addr;
71 	unsigned long		cpages;
72 };
73 
74 #define DPT_XA_TAG_ATOMIC 1UL
75 #define DPT_XA_TAG_WRITE 3UL
76 
77 /*
78  * Data structure to track address ranges and register for mmu interval
79  * notifier updates.
80  */
81 struct dmirror_interval {
82 	struct mmu_interval_notifier	notifier;
83 	struct dmirror			*dmirror;
84 };
85 
86 /*
87  * Data attached to the open device file.
88  * Note that it might be shared after a fork().
89  */
90 struct dmirror {
91 	struct dmirror_device		*mdevice;
92 	struct xarray			pt;
93 	struct mmu_interval_notifier	notifier;
94 	struct mutex			mutex;
95 	__u64			flags;
96 };
97 
98 /*
99  * ZONE_DEVICE pages for migration and simulating device memory.
100  */
101 struct dmirror_chunk {
102 	struct dev_pagemap	pagemap;
103 	struct dmirror_device	*mdevice;
104 	bool remove;
105 };
106 
107 /*
108  * Per device data.
109  */
110 struct dmirror_device {
111 	struct cdev		cdevice;
112 	unsigned int            zone_device_type;
113 	struct device		device;
114 
115 	unsigned int		devmem_capacity;
116 	unsigned int		devmem_count;
117 	struct dmirror_chunk	**devmem_chunks;
118 	struct mutex		devmem_lock;	/* protects the above */
119 
120 	unsigned long		calloc;
121 	unsigned long		cfree;
122 	struct page		*free_pages;
123 	struct folio		*free_folios;
124 	spinlock_t		lock;		/* protects the above */
125 };
126 
127 static struct dmirror_device dmirror_devices[DMIRROR_NDEVICES];
128 
129 static int dmirror_bounce_init(struct dmirror_bounce *bounce,
130 			       unsigned long addr,
131 			       unsigned long size)
132 {
133 	bounce->addr = addr;
134 	bounce->size = size;
135 	bounce->cpages = 0;
136 	bounce->ptr = vmalloc(size);
137 	if (!bounce->ptr)
138 		return -ENOMEM;
139 	return 0;
140 }
141 
142 static bool dmirror_is_private_zone(struct dmirror_device *mdevice)
143 {
144 	return (mdevice->zone_device_type ==
145 		HMM_DMIRROR_MEMORY_DEVICE_PRIVATE);
146 }
147 
148 static enum migrate_vma_direction
149 dmirror_select_device(struct dmirror *dmirror)
150 {
151 	return (dmirror->mdevice->zone_device_type ==
152 		HMM_DMIRROR_MEMORY_DEVICE_PRIVATE) ?
153 		MIGRATE_VMA_SELECT_DEVICE_PRIVATE :
154 		MIGRATE_VMA_SELECT_DEVICE_COHERENT;
155 }
156 
157 static void dmirror_bounce_fini(struct dmirror_bounce *bounce)
158 {
159 	vfree(bounce->ptr);
160 }
161 
162 static int dmirror_fops_open(struct inode *inode, struct file *filp)
163 {
164 	struct cdev *cdev = inode->i_cdev;
165 	struct dmirror *dmirror;
166 	int ret;
167 
168 	/* Mirror this process address space */
169 	dmirror = kzalloc_obj(*dmirror);
170 	if (dmirror == NULL)
171 		return -ENOMEM;
172 
173 	dmirror->mdevice = container_of(cdev, struct dmirror_device, cdevice);
174 	mutex_init(&dmirror->mutex);
175 	xa_init(&dmirror->pt);
176 
177 	ret = mmu_interval_notifier_insert(&dmirror->notifier, current->mm,
178 				0, ULONG_MAX & PAGE_MASK, &dmirror_min_ops);
179 	if (ret) {
180 		kfree(dmirror);
181 		return ret;
182 	}
183 
184 	filp->private_data = dmirror;
185 	return 0;
186 }
187 
188 static void dmirror_device_evict_chunk(struct dmirror_chunk *chunk)
189 {
190 	unsigned long start_pfn = chunk->pagemap.range.start >> PAGE_SHIFT;
191 	unsigned long end_pfn = chunk->pagemap.range.end >> PAGE_SHIFT;
192 	unsigned long npages = end_pfn - start_pfn + 1;
193 	unsigned long i;
194 	unsigned long *src_pfns;
195 	unsigned long *dst_pfns;
196 	unsigned int order = 0;
197 
198 	src_pfns = kvcalloc(npages, sizeof(*src_pfns), GFP_KERNEL | __GFP_NOFAIL);
199 	dst_pfns = kvcalloc(npages, sizeof(*dst_pfns), GFP_KERNEL | __GFP_NOFAIL);
200 
201 	migrate_device_range(src_pfns, start_pfn, npages);
202 	for (i = 0; i < npages; i++) {
203 		struct page *dpage, *spage;
204 
205 		spage = migrate_pfn_to_page(src_pfns[i]);
206 		if (!spage || !(src_pfns[i] & MIGRATE_PFN_MIGRATE))
207 			continue;
208 
209 		if (WARN_ON(!is_device_private_page(spage) &&
210 			    !is_device_coherent_page(spage)))
211 			continue;
212 
213 		order = folio_order(page_folio(spage));
214 		spage = BACKING_PAGE(spage);
215 		if (src_pfns[i] & MIGRATE_PFN_COMPOUND) {
216 			dpage = folio_page(folio_alloc(GFP_HIGHUSER_MOVABLE,
217 					      order), 0);
218 		} else {
219 			dpage = alloc_page(GFP_HIGHUSER_MOVABLE | __GFP_NOFAIL);
220 			order = 0;
221 		}
222 
223 		/* TODO Support splitting here */
224 		lock_page(dpage);
225 		dst_pfns[i] = migrate_pfn(page_to_pfn(dpage));
226 		if (src_pfns[i] & MIGRATE_PFN_WRITE)
227 			dst_pfns[i] |= MIGRATE_PFN_WRITE;
228 		if (order)
229 			dst_pfns[i] |= MIGRATE_PFN_COMPOUND;
230 		folio_copy(page_folio(dpage), page_folio(spage));
231 	}
232 	migrate_device_pages(src_pfns, dst_pfns, npages);
233 	migrate_device_finalize(src_pfns, dst_pfns, npages);
234 	kvfree(src_pfns);
235 	kvfree(dst_pfns);
236 }
237 
238 static int dmirror_fops_release(struct inode *inode, struct file *filp)
239 {
240 	struct dmirror *dmirror = filp->private_data;
241 	struct dmirror_device *mdevice = dmirror->mdevice;
242 	int i;
243 
244 	mmu_interval_notifier_remove(&dmirror->notifier);
245 
246 	if (mdevice->devmem_chunks) {
247 		for (i = 0; i < mdevice->devmem_count; i++) {
248 			struct dmirror_chunk *devmem =
249 				mdevice->devmem_chunks[i];
250 
251 			dmirror_device_evict_chunk(devmem);
252 		}
253 	}
254 
255 	xa_destroy(&dmirror->pt);
256 	kfree(dmirror);
257 	return 0;
258 }
259 
260 static struct dmirror_chunk *dmirror_page_to_chunk(struct page *page)
261 {
262 	return container_of(page_pgmap(page), struct dmirror_chunk,
263 			    pagemap);
264 }
265 
266 static struct dmirror_device *dmirror_page_to_device(struct page *page)
267 
268 {
269 	return dmirror_page_to_chunk(page)->mdevice;
270 }
271 
272 static int dmirror_do_fault(struct dmirror *dmirror, struct hmm_range *range)
273 {
274 	unsigned long *pfns = range->hmm_pfns;
275 	unsigned long pfn;
276 
277 	for (pfn = (range->start >> PAGE_SHIFT);
278 	     pfn < (range->end >> PAGE_SHIFT);
279 	     pfn++, pfns++) {
280 		struct page *page;
281 		void *entry;
282 
283 		/*
284 		 * Since we asked for hmm_range_fault() to populate pages,
285 		 * it shouldn't return an error entry on success.
286 		 */
287 		WARN_ON(*pfns & HMM_PFN_ERROR);
288 		WARN_ON(!(*pfns & HMM_PFN_VALID));
289 
290 		page = hmm_pfn_to_page(*pfns);
291 		WARN_ON(!page);
292 
293 		entry = page;
294 		if (*pfns & HMM_PFN_WRITE)
295 			entry = xa_tag_pointer(entry, DPT_XA_TAG_WRITE);
296 		else if (WARN_ON(range->default_flags & HMM_PFN_WRITE))
297 			return -EFAULT;
298 		entry = xa_store(&dmirror->pt, pfn, entry, GFP_ATOMIC);
299 		if (xa_is_err(entry))
300 			return xa_err(entry);
301 	}
302 
303 	return 0;
304 }
305 
306 static void dmirror_do_update(struct dmirror *dmirror, unsigned long start,
307 			      unsigned long end)
308 {
309 	unsigned long pfn;
310 	void *entry;
311 
312 	/*
313 	 * The XArray doesn't hold references to pages since it relies on
314 	 * the mmu notifier to clear page pointers when they become stale.
315 	 * Therefore, it is OK to just clear the entry.
316 	 */
317 	xa_for_each_range(&dmirror->pt, pfn, entry, start >> PAGE_SHIFT,
318 			  end >> PAGE_SHIFT)
319 		xa_erase(&dmirror->pt, pfn);
320 }
321 
322 static bool dmirror_interval_invalidate(struct mmu_interval_notifier *mni,
323 				const struct mmu_notifier_range *range,
324 				unsigned long cur_seq)
325 {
326 	struct dmirror *dmirror = container_of(mni, struct dmirror, notifier);
327 
328 	/*
329 	 * Ignore invalidation callbacks for device private pages since
330 	 * the invalidation is handled as part of the migration process.
331 	 */
332 	if (range->event == MMU_NOTIFY_MIGRATE &&
333 	    range->owner == dmirror->mdevice)
334 		return true;
335 
336 	if (mmu_notifier_range_blockable(range))
337 		mutex_lock(&dmirror->mutex);
338 	else if (!mutex_trylock(&dmirror->mutex))
339 		return false;
340 
341 	mmu_interval_set_seq(mni, cur_seq);
342 	dmirror_do_update(dmirror, range->start, range->end);
343 
344 	mutex_unlock(&dmirror->mutex);
345 	return true;
346 }
347 
348 static const struct mmu_interval_notifier_ops dmirror_min_ops = {
349 	.invalidate = dmirror_interval_invalidate,
350 };
351 
352 static int dmirror_range_fault(struct dmirror *dmirror,
353 				struct hmm_range *range)
354 {
355 	struct mm_struct *mm = dmirror->notifier.mm;
356 	unsigned long timeout =
357 		jiffies + msecs_to_jiffies(HMM_RANGE_DEFAULT_TIMEOUT);
358 	int ret;
359 
360 	while (true) {
361 		if (time_after(jiffies, timeout)) {
362 			ret = -EBUSY;
363 			goto out;
364 		}
365 
366 		range->notifier_seq = mmu_interval_read_begin(range->notifier);
367 		mmap_read_lock(mm);
368 		ret = hmm_range_fault(range);
369 		mmap_read_unlock(mm);
370 		if (ret) {
371 			if (ret == -EBUSY)
372 				continue;
373 			goto out;
374 		}
375 
376 		mutex_lock(&dmirror->mutex);
377 		if (mmu_interval_read_retry(range->notifier,
378 					    range->notifier_seq)) {
379 			mutex_unlock(&dmirror->mutex);
380 			continue;
381 		}
382 		break;
383 	}
384 
385 	ret = dmirror_do_fault(dmirror, range);
386 
387 	mutex_unlock(&dmirror->mutex);
388 out:
389 	return ret;
390 }
391 
392 static int dmirror_fault(struct dmirror *dmirror, unsigned long start,
393 			 unsigned long end, bool write)
394 {
395 	struct mm_struct *mm = dmirror->notifier.mm;
396 	unsigned long addr;
397 	unsigned long pfns[32];
398 	struct hmm_range range = {
399 		.notifier = &dmirror->notifier,
400 		.hmm_pfns = pfns,
401 		.pfn_flags_mask = 0,
402 		.default_flags =
403 			HMM_PFN_REQ_FAULT | (write ? HMM_PFN_REQ_WRITE : 0),
404 		.dev_private_owner = dmirror->mdevice,
405 	};
406 	int ret = 0;
407 
408 	/* Since the mm is for the mirrored process, get a reference first. */
409 	if (!mmget_not_zero(mm))
410 		return 0;
411 
412 	for (addr = start; addr < end; addr = range.end) {
413 		range.start = addr;
414 		range.end = min(addr + (ARRAY_SIZE(pfns) << PAGE_SHIFT), end);
415 
416 		ret = dmirror_range_fault(dmirror, &range);
417 		if (ret)
418 			break;
419 	}
420 
421 	mmput(mm);
422 	return ret;
423 }
424 
425 static int dmirror_do_read(struct dmirror *dmirror, unsigned long start,
426 			   unsigned long end, struct dmirror_bounce *bounce)
427 {
428 	unsigned long pfn;
429 	void *ptr;
430 
431 	ptr = bounce->ptr + ((start - bounce->addr) & PAGE_MASK);
432 
433 	for (pfn = start >> PAGE_SHIFT; pfn < (end >> PAGE_SHIFT); pfn++) {
434 		void *entry;
435 		struct page *page;
436 
437 		entry = xa_load(&dmirror->pt, pfn);
438 		page = xa_untag_pointer(entry);
439 		if (!page)
440 			return -ENOENT;
441 
442 		memcpy_from_page(ptr, page, 0, PAGE_SIZE);
443 
444 		ptr += PAGE_SIZE;
445 		bounce->cpages++;
446 	}
447 
448 	return 0;
449 }
450 
451 static int dmirror_read(struct dmirror *dmirror, struct hmm_dmirror_cmd *cmd)
452 {
453 	struct dmirror_bounce bounce;
454 	unsigned long start, end;
455 	unsigned long size = cmd->npages << PAGE_SHIFT;
456 	int ret;
457 
458 	start = cmd->addr;
459 	end = start + size;
460 	if (end < start)
461 		return -EINVAL;
462 
463 	ret = dmirror_bounce_init(&bounce, start, size);
464 	if (ret)
465 		return ret;
466 
467 	while (1) {
468 		mutex_lock(&dmirror->mutex);
469 		ret = dmirror_do_read(dmirror, start, end, &bounce);
470 		mutex_unlock(&dmirror->mutex);
471 		if (ret != -ENOENT)
472 			break;
473 
474 		start = cmd->addr + (bounce.cpages << PAGE_SHIFT);
475 		ret = dmirror_fault(dmirror, start, end, false);
476 		if (ret)
477 			break;
478 		cmd->faults++;
479 	}
480 
481 	if (ret == 0) {
482 		if (copy_to_user(u64_to_user_ptr(cmd->ptr), bounce.ptr,
483 				 bounce.size))
484 			ret = -EFAULT;
485 	}
486 	cmd->cpages = bounce.cpages;
487 	dmirror_bounce_fini(&bounce);
488 	return ret;
489 }
490 
491 static int dmirror_do_write(struct dmirror *dmirror, unsigned long start,
492 			    unsigned long end, struct dmirror_bounce *bounce)
493 {
494 	unsigned long pfn;
495 	void *ptr;
496 
497 	ptr = bounce->ptr + ((start - bounce->addr) & PAGE_MASK);
498 
499 	for (pfn = start >> PAGE_SHIFT; pfn < (end >> PAGE_SHIFT); pfn++) {
500 		void *entry;
501 		struct page *page;
502 
503 		entry = xa_load(&dmirror->pt, pfn);
504 		page = xa_untag_pointer(entry);
505 		if (!page || xa_pointer_tag(entry) != DPT_XA_TAG_WRITE)
506 			return -ENOENT;
507 
508 		memcpy_to_page(page, 0, ptr, PAGE_SIZE);
509 
510 		ptr += PAGE_SIZE;
511 		bounce->cpages++;
512 	}
513 
514 	return 0;
515 }
516 
517 static int dmirror_write(struct dmirror *dmirror, struct hmm_dmirror_cmd *cmd)
518 {
519 	struct dmirror_bounce bounce;
520 	unsigned long start, end;
521 	unsigned long size = cmd->npages << PAGE_SHIFT;
522 	int ret;
523 
524 	start = cmd->addr;
525 	end = start + size;
526 	if (end < start)
527 		return -EINVAL;
528 
529 	ret = dmirror_bounce_init(&bounce, start, size);
530 	if (ret)
531 		return ret;
532 	if (copy_from_user(bounce.ptr, u64_to_user_ptr(cmd->ptr),
533 			   bounce.size)) {
534 		ret = -EFAULT;
535 		goto fini;
536 	}
537 
538 	while (1) {
539 		mutex_lock(&dmirror->mutex);
540 		ret = dmirror_do_write(dmirror, start, end, &bounce);
541 		mutex_unlock(&dmirror->mutex);
542 		if (ret != -ENOENT)
543 			break;
544 
545 		start = cmd->addr + (bounce.cpages << PAGE_SHIFT);
546 		ret = dmirror_fault(dmirror, start, end, true);
547 		if (ret)
548 			break;
549 		cmd->faults++;
550 	}
551 
552 fini:
553 	cmd->cpages = bounce.cpages;
554 	dmirror_bounce_fini(&bounce);
555 	return ret;
556 }
557 
558 static int dmirror_allocate_chunk(struct dmirror_device *mdevice,
559 				  struct page **ppage, bool is_large)
560 {
561 	struct dmirror_chunk *devmem;
562 	struct resource *res = NULL;
563 	unsigned long pfn;
564 	unsigned long pfn_first;
565 	unsigned long pfn_last;
566 	void *ptr;
567 	int ret = -ENOMEM;
568 
569 	devmem = kzalloc_obj(*devmem);
570 	if (!devmem)
571 		return ret;
572 
573 	switch (mdevice->zone_device_type) {
574 	case HMM_DMIRROR_MEMORY_DEVICE_PRIVATE:
575 		res = request_free_mem_region(&iomem_resource, DEVMEM_CHUNK_SIZE,
576 					      "hmm_dmirror");
577 		if (IS_ERR_OR_NULL(res))
578 			goto err_devmem;
579 		devmem->pagemap.range.start = res->start;
580 		devmem->pagemap.range.end = res->end;
581 		devmem->pagemap.type = MEMORY_DEVICE_PRIVATE;
582 		break;
583 	case HMM_DMIRROR_MEMORY_DEVICE_COHERENT:
584 		devmem->pagemap.range.start = (MINOR(mdevice->cdevice.dev) - 2) ?
585 							spm_addr_dev0 :
586 							spm_addr_dev1;
587 		devmem->pagemap.range.end = devmem->pagemap.range.start +
588 					    DEVMEM_CHUNK_SIZE - 1;
589 		devmem->pagemap.type = MEMORY_DEVICE_COHERENT;
590 		break;
591 	default:
592 		ret = -EINVAL;
593 		goto err_devmem;
594 	}
595 
596 	devmem->pagemap.nr_range = 1;
597 	devmem->pagemap.ops = &dmirror_devmem_ops;
598 	devmem->pagemap.owner = mdevice;
599 
600 	mutex_lock(&mdevice->devmem_lock);
601 
602 	if (mdevice->devmem_count == mdevice->devmem_capacity) {
603 		struct dmirror_chunk **new_chunks;
604 		unsigned int new_capacity;
605 
606 		new_capacity = mdevice->devmem_capacity +
607 				DEVMEM_CHUNKS_RESERVE;
608 		new_chunks = krealloc(mdevice->devmem_chunks,
609 				sizeof(new_chunks[0]) * new_capacity,
610 				GFP_KERNEL);
611 		if (!new_chunks)
612 			goto err_release;
613 		mdevice->devmem_capacity = new_capacity;
614 		mdevice->devmem_chunks = new_chunks;
615 	}
616 	ptr = memremap_pages(&devmem->pagemap, numa_node_id());
617 	if (IS_ERR_OR_NULL(ptr)) {
618 		if (ptr)
619 			ret = PTR_ERR(ptr);
620 		else
621 			ret = -EFAULT;
622 		goto err_release;
623 	}
624 
625 	devmem->mdevice = mdevice;
626 	pfn_first = devmem->pagemap.range.start >> PAGE_SHIFT;
627 	pfn_last = pfn_first + (range_len(&devmem->pagemap.range) >> PAGE_SHIFT);
628 	mdevice->devmem_chunks[mdevice->devmem_count++] = devmem;
629 
630 	mutex_unlock(&mdevice->devmem_lock);
631 
632 	pr_info("added new %u MB chunk (total %u chunks, %u MB) PFNs [0x%lx 0x%lx)\n",
633 		DEVMEM_CHUNK_SIZE / (1024 * 1024),
634 		mdevice->devmem_count,
635 		mdevice->devmem_count * (DEVMEM_CHUNK_SIZE / (1024 * 1024)),
636 		pfn_first, pfn_last);
637 
638 	spin_lock(&mdevice->lock);
639 	for (pfn = pfn_first; pfn < pfn_last; ) {
640 		struct page *page = pfn_to_page(pfn);
641 
642 		if (is_large && IS_ALIGNED(pfn, HPAGE_PMD_NR)
643 			&& (pfn + HPAGE_PMD_NR <= pfn_last)) {
644 			page->zone_device_data = mdevice->free_folios;
645 			mdevice->free_folios = page_folio(page);
646 			pfn += HPAGE_PMD_NR;
647 			continue;
648 		}
649 
650 		page->zone_device_data = mdevice->free_pages;
651 		mdevice->free_pages = page;
652 		pfn++;
653 	}
654 
655 	ret = 0;
656 	if (ppage) {
657 		if (is_large) {
658 			if (!mdevice->free_folios) {
659 				ret = -ENOMEM;
660 				goto err_unlock;
661 			}
662 			*ppage = folio_page(mdevice->free_folios, 0);
663 			mdevice->free_folios = (*ppage)->zone_device_data;
664 			mdevice->calloc += HPAGE_PMD_NR;
665 		} else if (mdevice->free_pages) {
666 			*ppage = mdevice->free_pages;
667 			mdevice->free_pages = (*ppage)->zone_device_data;
668 			mdevice->calloc++;
669 		} else {
670 			ret = -ENOMEM;
671 			goto err_unlock;
672 		}
673 	}
674 err_unlock:
675 	spin_unlock(&mdevice->lock);
676 
677 	return ret;
678 
679 err_release:
680 	mutex_unlock(&mdevice->devmem_lock);
681 	if (res && devmem->pagemap.type == MEMORY_DEVICE_PRIVATE)
682 		release_mem_region(devmem->pagemap.range.start,
683 				   range_len(&devmem->pagemap.range));
684 err_devmem:
685 	kfree(devmem);
686 
687 	return ret;
688 }
689 
690 static struct page *dmirror_devmem_alloc_page(struct dmirror *dmirror,
691 					      bool is_large)
692 {
693 	struct page *dpage = NULL;
694 	struct page *rpage = NULL;
695 	unsigned int order = is_large ? HPAGE_PMD_ORDER : 0;
696 	struct dmirror_device *mdevice = dmirror->mdevice;
697 
698 	/*
699 	 * For ZONE_DEVICE private type, this is a fake device so we allocate
700 	 * real system memory to store our device memory.
701 	 * For ZONE_DEVICE coherent type we use the actual dpage to store the
702 	 * data and ignore rpage.
703 	 */
704 	if (dmirror_is_private_zone(mdevice)) {
705 		rpage = folio_page(folio_alloc(GFP_HIGHUSER, order), 0);
706 		if (!rpage)
707 			return NULL;
708 	}
709 	spin_lock(&mdevice->lock);
710 
711 	if (is_large && mdevice->free_folios) {
712 		dpage = folio_page(mdevice->free_folios, 0);
713 		mdevice->free_folios = dpage->zone_device_data;
714 		mdevice->calloc += 1 << order;
715 		spin_unlock(&mdevice->lock);
716 	} else if (!is_large && mdevice->free_pages) {
717 		dpage = mdevice->free_pages;
718 		mdevice->free_pages = dpage->zone_device_data;
719 		mdevice->calloc++;
720 		spin_unlock(&mdevice->lock);
721 	} else {
722 		spin_unlock(&mdevice->lock);
723 		if (dmirror_allocate_chunk(mdevice, &dpage, is_large))
724 			goto error;
725 	}
726 
727 	zone_device_folio_init(page_folio(dpage),
728 			       page_pgmap(folio_page(page_folio(dpage), 0)),
729 			       order);
730 	dpage->zone_device_data = rpage;
731 	return dpage;
732 
733 error:
734 	if (rpage)
735 		__free_pages(rpage, order);
736 	return NULL;
737 }
738 
739 static void dmirror_migrate_alloc_and_copy(struct migrate_vma *args,
740 					   struct dmirror *dmirror)
741 {
742 	const unsigned long *src = args->src;
743 	unsigned long *dst = args->dst;
744 	unsigned long addr;
745 
746 	for (addr = args->start; addr < args->end; ) {
747 		struct page *spage;
748 		struct page *dpage;
749 		struct page *rpage;
750 		bool is_large = *src & MIGRATE_PFN_COMPOUND;
751 		int write = (*src & MIGRATE_PFN_WRITE) ? MIGRATE_PFN_WRITE : 0;
752 		unsigned long nr = 1;
753 
754 		if (!(*src & MIGRATE_PFN_MIGRATE))
755 			goto next;
756 
757 		/*
758 		 * Note that spage might be NULL which is OK since it is an
759 		 * unallocated pte_none() or read-only zero page.
760 		 */
761 		spage = migrate_pfn_to_page(*src);
762 		if (WARN(spage && is_zone_device_page(spage),
763 		     "page already in device spage pfn: 0x%lx\n",
764 		     page_to_pfn(spage)))
765 			goto next;
766 
767 		if (dmirror->flags & HMM_DMIRROR_FLAG_FAIL_ALLOC) {
768 			dmirror->flags &= ~HMM_DMIRROR_FLAG_FAIL_ALLOC;
769 			dpage = NULL;
770 		} else
771 			dpage = dmirror_devmem_alloc_page(dmirror, is_large);
772 
773 		if (!dpage) {
774 			struct folio *folio;
775 			unsigned long i;
776 			unsigned long spfn = *src >> MIGRATE_PFN_SHIFT;
777 			struct page *src_page;
778 
779 			if (!is_large)
780 				goto next;
781 
782 			if (!spage && is_large) {
783 				nr = HPAGE_PMD_NR;
784 			} else {
785 				folio = page_folio(spage);
786 				nr = folio_nr_pages(folio);
787 			}
788 
789 			for (i = 0; i < nr && addr < args->end; i++) {
790 				dpage = dmirror_devmem_alloc_page(dmirror, false);
791 				rpage = BACKING_PAGE(dpage);
792 				rpage->zone_device_data = dmirror;
793 
794 				*dst = migrate_pfn(page_to_pfn(dpage)) | write;
795 				src_page = pfn_to_page(spfn + i);
796 
797 				if (spage)
798 					copy_highpage(rpage, src_page);
799 				else
800 					clear_highpage(rpage);
801 				src++;
802 				dst++;
803 				addr += PAGE_SIZE;
804 			}
805 			continue;
806 		}
807 
808 		rpage = BACKING_PAGE(dpage);
809 
810 		/*
811 		 * Normally, a device would use the page->zone_device_data to
812 		 * point to the mirror but here we use it to hold the page for
813 		 * the simulated device memory and that page holds the pointer
814 		 * to the mirror.
815 		 */
816 		rpage->zone_device_data = dmirror;
817 
818 		pr_debug("migrating from sys to dev pfn src: 0x%lx pfn dst: 0x%lx\n",
819 			 page_to_pfn(spage), page_to_pfn(dpage));
820 
821 		*dst = migrate_pfn(page_to_pfn(dpage)) | write;
822 
823 		if (is_large) {
824 			int i;
825 			struct folio *folio = page_folio(dpage);
826 			*dst |= MIGRATE_PFN_COMPOUND;
827 
828 			if (folio_test_large(folio)) {
829 				for (i = 0; i < folio_nr_pages(folio); i++) {
830 					struct page *dst_page =
831 						pfn_to_page(page_to_pfn(rpage) + i);
832 					struct page *src_page =
833 						pfn_to_page(page_to_pfn(spage) + i);
834 
835 					if (spage)
836 						copy_highpage(dst_page, src_page);
837 					else
838 						clear_highpage(dst_page);
839 					src++;
840 					dst++;
841 					addr += PAGE_SIZE;
842 				}
843 				continue;
844 			}
845 		}
846 
847 		if (spage)
848 			copy_highpage(rpage, spage);
849 		else
850 			clear_highpage(rpage);
851 
852 next:
853 		src++;
854 		dst++;
855 		addr += PAGE_SIZE;
856 	}
857 }
858 
859 static int dmirror_check_atomic(struct dmirror *dmirror, unsigned long start,
860 			     unsigned long end)
861 {
862 	unsigned long pfn;
863 
864 	for (pfn = start >> PAGE_SHIFT; pfn < (end >> PAGE_SHIFT); pfn++) {
865 		void *entry;
866 
867 		entry = xa_load(&dmirror->pt, pfn);
868 		if (xa_pointer_tag(entry) == DPT_XA_TAG_ATOMIC)
869 			return -EPERM;
870 	}
871 
872 	return 0;
873 }
874 
875 static int dmirror_atomic_map(unsigned long addr, struct page *page,
876 		struct dmirror *dmirror)
877 {
878 	void *entry;
879 
880 	/* Map the migrated pages into the device's page tables. */
881 	mutex_lock(&dmirror->mutex);
882 
883 	entry = xa_tag_pointer(page, DPT_XA_TAG_ATOMIC);
884 	entry = xa_store(&dmirror->pt, addr >> PAGE_SHIFT, entry, GFP_ATOMIC);
885 	if (xa_is_err(entry)) {
886 		mutex_unlock(&dmirror->mutex);
887 		return xa_err(entry);
888 	}
889 
890 	mutex_unlock(&dmirror->mutex);
891 	return 0;
892 }
893 
894 static int dmirror_migrate_finalize_and_map(struct migrate_vma *args,
895 					    struct dmirror *dmirror)
896 {
897 	unsigned long start = args->start;
898 	unsigned long end = args->end;
899 	const unsigned long *src = args->src;
900 	const unsigned long *dst = args->dst;
901 	unsigned long pfn;
902 	const unsigned long start_pfn = start >> PAGE_SHIFT;
903 	const unsigned long end_pfn = end >> PAGE_SHIFT;
904 
905 	/* Map the migrated pages into the device's page tables. */
906 	mutex_lock(&dmirror->mutex);
907 
908 	for (pfn = start_pfn; pfn < end_pfn; pfn++, src++, dst++) {
909 		struct page *dpage;
910 		void *entry;
911 		int nr, i;
912 		struct page *rpage;
913 
914 		if (!(*src & MIGRATE_PFN_MIGRATE))
915 			continue;
916 
917 		dpage = migrate_pfn_to_page(*dst);
918 		if (!dpage)
919 			continue;
920 
921 		if (*dst & MIGRATE_PFN_COMPOUND)
922 			nr = folio_nr_pages(page_folio(dpage));
923 		else
924 			nr = 1;
925 
926 		WARN_ON_ONCE(end_pfn < start_pfn + nr);
927 
928 		rpage = BACKING_PAGE(dpage);
929 		VM_WARN_ON(folio_nr_pages(page_folio(rpage)) != nr);
930 
931 		for (i = 0; i < nr; i++) {
932 			entry = folio_page(page_folio(rpage), i);
933 			if (*dst & MIGRATE_PFN_WRITE)
934 				entry = xa_tag_pointer(entry, DPT_XA_TAG_WRITE);
935 			entry = xa_store(&dmirror->pt, pfn + i, entry, GFP_ATOMIC);
936 			if (xa_is_err(entry)) {
937 				mutex_unlock(&dmirror->mutex);
938 				return xa_err(entry);
939 			}
940 		}
941 	}
942 
943 	mutex_unlock(&dmirror->mutex);
944 	return 0;
945 }
946 
947 static int dmirror_exclusive(struct dmirror *dmirror,
948 			     struct hmm_dmirror_cmd *cmd)
949 {
950 	unsigned long start, end, addr;
951 	unsigned long size = cmd->npages << PAGE_SHIFT;
952 	struct mm_struct *mm = dmirror->notifier.mm;
953 	struct dmirror_bounce bounce;
954 	int ret = 0;
955 
956 	start = cmd->addr;
957 	end = start + size;
958 	if (end < start)
959 		return -EINVAL;
960 
961 	/* Since the mm is for the mirrored process, get a reference first. */
962 	if (!mmget_not_zero(mm))
963 		return -EINVAL;
964 
965 	mmap_read_lock(mm);
966 	for (addr = start; !ret && addr < end; addr += PAGE_SIZE) {
967 		struct folio *folio;
968 		struct page *page;
969 
970 		page = make_device_exclusive(mm, addr, NULL, &folio);
971 		if (IS_ERR(page)) {
972 			ret = PTR_ERR(page);
973 			break;
974 		}
975 
976 		ret = dmirror_atomic_map(addr, page, dmirror);
977 		folio_unlock(folio);
978 		folio_put(folio);
979 	}
980 	mmap_read_unlock(mm);
981 	mmput(mm);
982 
983 	if (ret)
984 		return ret;
985 
986 	/* Return the migrated data for verification. */
987 	ret = dmirror_bounce_init(&bounce, start, size);
988 	if (ret)
989 		return ret;
990 	mutex_lock(&dmirror->mutex);
991 	ret = dmirror_do_read(dmirror, start, end, &bounce);
992 	mutex_unlock(&dmirror->mutex);
993 	if (ret == 0) {
994 		if (copy_to_user(u64_to_user_ptr(cmd->ptr), bounce.ptr,
995 				 bounce.size))
996 			ret = -EFAULT;
997 	}
998 
999 	cmd->cpages = bounce.cpages;
1000 	dmirror_bounce_fini(&bounce);
1001 	return ret;
1002 }
1003 
1004 static vm_fault_t dmirror_devmem_fault_alloc_and_copy(struct migrate_vma *args,
1005 						      struct dmirror *dmirror)
1006 {
1007 	const unsigned long *src = args->src;
1008 	unsigned long *dst = args->dst;
1009 	unsigned long start = args->start;
1010 	unsigned long end = args->end;
1011 	unsigned long addr;
1012 	unsigned int order = 0;
1013 	int i;
1014 
1015 	for (addr = start; addr < end; ) {
1016 		struct page *dpage, *spage;
1017 
1018 		spage = migrate_pfn_to_page(*src);
1019 		if (!spage || !(*src & MIGRATE_PFN_MIGRATE)) {
1020 			addr += PAGE_SIZE;
1021 			goto next;
1022 		}
1023 
1024 		if (WARN_ON(!is_device_private_page(spage) &&
1025 			    !is_device_coherent_page(spage))) {
1026 			addr += PAGE_SIZE;
1027 			goto next;
1028 		}
1029 
1030 		spage = BACKING_PAGE(spage);
1031 		order = folio_order(page_folio(spage));
1032 		if (order)
1033 			*dst = MIGRATE_PFN_COMPOUND;
1034 		if (*src & MIGRATE_PFN_WRITE)
1035 			*dst |= MIGRATE_PFN_WRITE;
1036 
1037 		if (dmirror->flags & HMM_DMIRROR_FLAG_FAIL_ALLOC) {
1038 			dmirror->flags &= ~HMM_DMIRROR_FLAG_FAIL_ALLOC;
1039 			*dst &= ~MIGRATE_PFN_COMPOUND;
1040 			dpage = NULL;
1041 		} else if (order) {
1042 			dpage = folio_page(vma_alloc_folio(GFP_HIGHUSER_MOVABLE,
1043 						order, args->vma, addr), 0);
1044 		} else {
1045 			dpage = alloc_page_vma(GFP_HIGHUSER_MOVABLE, args->vma, addr);
1046 		}
1047 
1048 		if (!dpage && !order)
1049 			return VM_FAULT_OOM;
1050 
1051 		pr_debug("migrating from sys to dev pfn src: 0x%lx pfn dst: 0x%lx\n",
1052 				page_to_pfn(spage), page_to_pfn(dpage));
1053 
1054 		if (dpage) {
1055 			lock_page(dpage);
1056 			*dst |= migrate_pfn(page_to_pfn(dpage));
1057 		}
1058 
1059 		for (i = 0; i < (1 << order); i++) {
1060 			struct page *src_page;
1061 			struct page *dst_page;
1062 
1063 			/* Try with smaller pages if large allocation fails */
1064 			if (!dpage && order) {
1065 				dpage = alloc_page_vma(GFP_HIGHUSER_MOVABLE, args->vma, addr);
1066 				lock_page(dpage);
1067 				dst[i] = migrate_pfn(page_to_pfn(dpage));
1068 				dst_page = pfn_to_page(page_to_pfn(dpage));
1069 				dpage = NULL; /* For the next iteration */
1070 			} else {
1071 				dst_page = pfn_to_page(page_to_pfn(dpage) + i);
1072 			}
1073 
1074 			src_page = pfn_to_page(page_to_pfn(spage) + i);
1075 
1076 			xa_erase(&dmirror->pt, addr >> PAGE_SHIFT);
1077 			addr += PAGE_SIZE;
1078 			copy_highpage(dst_page, src_page);
1079 		}
1080 next:
1081 		src += 1 << order;
1082 		dst += 1 << order;
1083 	}
1084 	return 0;
1085 }
1086 
1087 static unsigned long
1088 dmirror_successful_migrated_pages(struct migrate_vma *migrate)
1089 {
1090 	unsigned long cpages = 0;
1091 	unsigned long i;
1092 
1093 	for (i = 0; i < migrate->npages; i++) {
1094 		if (migrate->src[i] & MIGRATE_PFN_VALID &&
1095 		    migrate->src[i] & MIGRATE_PFN_MIGRATE)
1096 			cpages++;
1097 	}
1098 	return cpages;
1099 }
1100 
1101 static int dmirror_migrate_to_system(struct dmirror *dmirror,
1102 				     struct hmm_dmirror_cmd *cmd)
1103 {
1104 	unsigned long start, end, addr;
1105 	unsigned long size = cmd->npages << PAGE_SHIFT;
1106 	struct mm_struct *mm = dmirror->notifier.mm;
1107 	struct vm_area_struct *vma;
1108 	struct migrate_vma args = { 0 };
1109 	unsigned long next;
1110 	int ret;
1111 	unsigned long *src_pfns;
1112 	unsigned long *dst_pfns;
1113 
1114 	src_pfns = kvcalloc(PTRS_PER_PTE, sizeof(*src_pfns), GFP_KERNEL | __GFP_NOFAIL);
1115 	dst_pfns = kvcalloc(PTRS_PER_PTE, sizeof(*dst_pfns), GFP_KERNEL | __GFP_NOFAIL);
1116 
1117 	start = cmd->addr;
1118 	end = start + size;
1119 	if (end < start)
1120 		return -EINVAL;
1121 
1122 	/* Since the mm is for the mirrored process, get a reference first. */
1123 	if (!mmget_not_zero(mm))
1124 		return -EINVAL;
1125 
1126 	cmd->cpages = 0;
1127 	mmap_read_lock(mm);
1128 	for (addr = start; addr < end; addr = next) {
1129 		vma = vma_lookup(mm, addr);
1130 		if (!vma || !(vma->vm_flags & VM_READ)) {
1131 			ret = -EINVAL;
1132 			goto out;
1133 		}
1134 		next = min(end, addr + (PTRS_PER_PTE << PAGE_SHIFT));
1135 		if (next > vma->vm_end)
1136 			next = vma->vm_end;
1137 
1138 		args.vma = vma;
1139 		args.src = src_pfns;
1140 		args.dst = dst_pfns;
1141 		args.start = addr;
1142 		args.end = next;
1143 		args.pgmap_owner = dmirror->mdevice;
1144 		args.flags = dmirror_select_device(dmirror) | MIGRATE_VMA_SELECT_COMPOUND;
1145 
1146 		ret = migrate_vma_setup(&args);
1147 		if (ret)
1148 			goto out;
1149 
1150 		pr_debug("Migrating from device mem to sys mem\n");
1151 		dmirror_devmem_fault_alloc_and_copy(&args, dmirror);
1152 
1153 		migrate_vma_pages(&args);
1154 		cmd->cpages += dmirror_successful_migrated_pages(&args);
1155 		migrate_vma_finalize(&args);
1156 	}
1157 out:
1158 	mmap_read_unlock(mm);
1159 	mmput(mm);
1160 	kvfree(src_pfns);
1161 	kvfree(dst_pfns);
1162 
1163 	return ret;
1164 }
1165 
1166 static int dmirror_migrate_to_device(struct dmirror *dmirror,
1167 				struct hmm_dmirror_cmd *cmd)
1168 {
1169 	unsigned long start, end, addr;
1170 	unsigned long size = cmd->npages << PAGE_SHIFT;
1171 	struct mm_struct *mm = dmirror->notifier.mm;
1172 	struct vm_area_struct *vma;
1173 	struct dmirror_bounce bounce;
1174 	struct migrate_vma args = { 0 };
1175 	unsigned long next;
1176 	int ret;
1177 	unsigned long *src_pfns = NULL;
1178 	unsigned long *dst_pfns = NULL;
1179 
1180 	start = cmd->addr;
1181 	end = start + size;
1182 	if (end < start)
1183 		return -EINVAL;
1184 
1185 	/* Since the mm is for the mirrored process, get a reference first. */
1186 	if (!mmget_not_zero(mm))
1187 		return -EINVAL;
1188 
1189 	ret = -ENOMEM;
1190 	src_pfns = kvcalloc(PTRS_PER_PTE, sizeof(*src_pfns),
1191 			  GFP_KERNEL | __GFP_NOFAIL);
1192 	if (!src_pfns)
1193 		goto free_mem;
1194 
1195 	dst_pfns = kvcalloc(PTRS_PER_PTE, sizeof(*dst_pfns),
1196 			  GFP_KERNEL | __GFP_NOFAIL);
1197 	if (!dst_pfns)
1198 		goto free_mem;
1199 
1200 	ret = 0;
1201 	mmap_read_lock(mm);
1202 	for (addr = start; addr < end; addr = next) {
1203 		vma = vma_lookup(mm, addr);
1204 		if (!vma || !(vma->vm_flags & VM_READ)) {
1205 			ret = -EINVAL;
1206 			goto out;
1207 		}
1208 		next = min(end, addr + (PTRS_PER_PTE << PAGE_SHIFT));
1209 		if (next > vma->vm_end)
1210 			next = vma->vm_end;
1211 
1212 		args.vma = vma;
1213 		args.src = src_pfns;
1214 		args.dst = dst_pfns;
1215 		args.start = addr;
1216 		args.end = next;
1217 		args.pgmap_owner = dmirror->mdevice;
1218 		args.flags = MIGRATE_VMA_SELECT_SYSTEM |
1219 				MIGRATE_VMA_SELECT_COMPOUND;
1220 		ret = migrate_vma_setup(&args);
1221 		if (ret)
1222 			goto out;
1223 
1224 		pr_debug("Migrating from sys mem to device mem\n");
1225 		dmirror_migrate_alloc_and_copy(&args, dmirror);
1226 		migrate_vma_pages(&args);
1227 		dmirror_migrate_finalize_and_map(&args, dmirror);
1228 		migrate_vma_finalize(&args);
1229 	}
1230 	mmap_read_unlock(mm);
1231 	mmput(mm);
1232 
1233 	/*
1234 	 * Return the migrated data for verification.
1235 	 * Only for pages in device zone
1236 	 */
1237 	ret = dmirror_bounce_init(&bounce, start, size);
1238 	if (ret)
1239 		goto free_mem;
1240 	mutex_lock(&dmirror->mutex);
1241 	ret = dmirror_do_read(dmirror, start, end, &bounce);
1242 	mutex_unlock(&dmirror->mutex);
1243 	if (ret == 0) {
1244 		if (copy_to_user(u64_to_user_ptr(cmd->ptr), bounce.ptr,
1245 				 bounce.size))
1246 			ret = -EFAULT;
1247 	}
1248 	cmd->cpages = bounce.cpages;
1249 	dmirror_bounce_fini(&bounce);
1250 	goto free_mem;
1251 
1252 out:
1253 	mmap_read_unlock(mm);
1254 	mmput(mm);
1255 free_mem:
1256 	kfree(src_pfns);
1257 	kfree(dst_pfns);
1258 	return ret;
1259 }
1260 
1261 static void dmirror_mkentry(struct dmirror *dmirror, struct hmm_range *range,
1262 			    unsigned char *perm, unsigned long entry)
1263 {
1264 	struct page *page;
1265 
1266 	if (entry & HMM_PFN_ERROR) {
1267 		*perm = HMM_DMIRROR_PROT_ERROR;
1268 		return;
1269 	}
1270 	if (!(entry & HMM_PFN_VALID)) {
1271 		*perm = HMM_DMIRROR_PROT_NONE;
1272 		return;
1273 	}
1274 
1275 	page = hmm_pfn_to_page(entry);
1276 	if (is_device_private_page(page)) {
1277 		/* Is the page migrated to this device or some other? */
1278 		if (dmirror->mdevice == dmirror_page_to_device(page))
1279 			*perm = HMM_DMIRROR_PROT_DEV_PRIVATE_LOCAL;
1280 		else
1281 			*perm = HMM_DMIRROR_PROT_DEV_PRIVATE_REMOTE;
1282 	} else if (is_device_coherent_page(page)) {
1283 		/* Is the page migrated to this device or some other? */
1284 		if (dmirror->mdevice == dmirror_page_to_device(page))
1285 			*perm = HMM_DMIRROR_PROT_DEV_COHERENT_LOCAL;
1286 		else
1287 			*perm = HMM_DMIRROR_PROT_DEV_COHERENT_REMOTE;
1288 	} else if (is_zero_pfn(page_to_pfn(page)))
1289 		*perm = HMM_DMIRROR_PROT_ZERO;
1290 	else
1291 		*perm = HMM_DMIRROR_PROT_NONE;
1292 	if (entry & HMM_PFN_WRITE)
1293 		*perm |= HMM_DMIRROR_PROT_WRITE;
1294 	else
1295 		*perm |= HMM_DMIRROR_PROT_READ;
1296 	if (hmm_pfn_to_map_order(entry) + PAGE_SHIFT == PMD_SHIFT)
1297 		*perm |= HMM_DMIRROR_PROT_PMD;
1298 	else if (hmm_pfn_to_map_order(entry) + PAGE_SHIFT == PUD_SHIFT)
1299 		*perm |= HMM_DMIRROR_PROT_PUD;
1300 }
1301 
1302 static bool dmirror_snapshot_invalidate(struct mmu_interval_notifier *mni,
1303 				const struct mmu_notifier_range *range,
1304 				unsigned long cur_seq)
1305 {
1306 	struct dmirror_interval *dmi =
1307 		container_of(mni, struct dmirror_interval, notifier);
1308 	struct dmirror *dmirror = dmi->dmirror;
1309 
1310 	if (mmu_notifier_range_blockable(range))
1311 		mutex_lock(&dmirror->mutex);
1312 	else if (!mutex_trylock(&dmirror->mutex))
1313 		return false;
1314 
1315 	/*
1316 	 * Snapshots only need to set the sequence number since any
1317 	 * invalidation in the interval invalidates the whole snapshot.
1318 	 */
1319 	mmu_interval_set_seq(mni, cur_seq);
1320 
1321 	mutex_unlock(&dmirror->mutex);
1322 	return true;
1323 }
1324 
1325 static const struct mmu_interval_notifier_ops dmirror_mrn_ops = {
1326 	.invalidate = dmirror_snapshot_invalidate,
1327 };
1328 
1329 static int dmirror_range_snapshot(struct dmirror *dmirror,
1330 				  struct hmm_range *range,
1331 				  unsigned char *perm)
1332 {
1333 	struct mm_struct *mm = dmirror->notifier.mm;
1334 	struct dmirror_interval notifier;
1335 	unsigned long timeout =
1336 		jiffies + msecs_to_jiffies(HMM_RANGE_DEFAULT_TIMEOUT);
1337 	unsigned long i;
1338 	unsigned long n;
1339 	int ret = 0;
1340 
1341 	notifier.dmirror = dmirror;
1342 	range->notifier = &notifier.notifier;
1343 
1344 	ret = mmu_interval_notifier_insert(range->notifier, mm,
1345 			range->start, range->end - range->start,
1346 			&dmirror_mrn_ops);
1347 	if (ret)
1348 		return ret;
1349 
1350 	while (true) {
1351 		if (time_after(jiffies, timeout)) {
1352 			ret = -EBUSY;
1353 			goto out;
1354 		}
1355 
1356 		range->notifier_seq = mmu_interval_read_begin(range->notifier);
1357 
1358 		mmap_read_lock(mm);
1359 		ret = hmm_range_fault(range);
1360 		mmap_read_unlock(mm);
1361 		if (ret) {
1362 			if (ret == -EBUSY)
1363 				continue;
1364 			goto out;
1365 		}
1366 
1367 		mutex_lock(&dmirror->mutex);
1368 		if (mmu_interval_read_retry(range->notifier,
1369 					    range->notifier_seq)) {
1370 			mutex_unlock(&dmirror->mutex);
1371 			continue;
1372 		}
1373 		break;
1374 	}
1375 
1376 	n = (range->end - range->start) >> PAGE_SHIFT;
1377 	for (i = 0; i < n; i++)
1378 		dmirror_mkentry(dmirror, range, perm + i, range->hmm_pfns[i]);
1379 
1380 	mutex_unlock(&dmirror->mutex);
1381 out:
1382 	mmu_interval_notifier_remove(range->notifier);
1383 	return ret;
1384 }
1385 
1386 static int dmirror_snapshot(struct dmirror *dmirror,
1387 			    struct hmm_dmirror_cmd *cmd)
1388 {
1389 	struct mm_struct *mm = dmirror->notifier.mm;
1390 	unsigned long start, end;
1391 	unsigned long size = cmd->npages << PAGE_SHIFT;
1392 	unsigned long addr;
1393 	unsigned long next;
1394 	unsigned long pfns[32];
1395 	unsigned char perm[32];
1396 	char __user *uptr;
1397 	struct hmm_range range = {
1398 		.hmm_pfns = pfns,
1399 		.dev_private_owner = dmirror->mdevice,
1400 	};
1401 	int ret = 0;
1402 
1403 	start = cmd->addr;
1404 	end = start + size;
1405 	if (end < start)
1406 		return -EINVAL;
1407 
1408 	/* Since the mm is for the mirrored process, get a reference first. */
1409 	if (!mmget_not_zero(mm))
1410 		return -EINVAL;
1411 
1412 	/*
1413 	 * Register a temporary notifier to detect invalidations even if it
1414 	 * overlaps with other mmu_interval_notifiers.
1415 	 */
1416 	uptr = u64_to_user_ptr(cmd->ptr);
1417 	for (addr = start; addr < end; addr = next) {
1418 		unsigned long n;
1419 
1420 		next = min(addr + (ARRAY_SIZE(pfns) << PAGE_SHIFT), end);
1421 		range.start = addr;
1422 		range.end = next;
1423 
1424 		ret = dmirror_range_snapshot(dmirror, &range, perm);
1425 		if (ret)
1426 			break;
1427 
1428 		n = (range.end - range.start) >> PAGE_SHIFT;
1429 		if (copy_to_user(uptr, perm, n)) {
1430 			ret = -EFAULT;
1431 			break;
1432 		}
1433 
1434 		cmd->cpages += n;
1435 		uptr += n;
1436 	}
1437 	mmput(mm);
1438 
1439 	return ret;
1440 }
1441 
1442 /* Removes free pages from the free list so they can't be re-allocated */
1443 static void dmirror_remove_free_pages(struct dmirror_chunk *devmem)
1444 {
1445 	struct dmirror_device *mdevice = devmem->mdevice;
1446 	struct page *page;
1447 	struct folio *folio;
1448 
1449 
1450 	for (folio = mdevice->free_folios; folio; folio = folio_zone_device_data(folio))
1451 		if (dmirror_page_to_chunk(folio_page(folio, 0)) == devmem)
1452 			mdevice->free_folios = folio_zone_device_data(folio);
1453 	for (page = mdevice->free_pages; page; page = page->zone_device_data)
1454 		if (dmirror_page_to_chunk(page) == devmem)
1455 			mdevice->free_pages = page->zone_device_data;
1456 }
1457 
1458 static void dmirror_device_remove_chunks(struct dmirror_device *mdevice)
1459 {
1460 	unsigned int i;
1461 
1462 	mutex_lock(&mdevice->devmem_lock);
1463 	if (mdevice->devmem_chunks) {
1464 		for (i = 0; i < mdevice->devmem_count; i++) {
1465 			struct dmirror_chunk *devmem =
1466 				mdevice->devmem_chunks[i];
1467 
1468 			spin_lock(&mdevice->lock);
1469 			devmem->remove = true;
1470 			dmirror_remove_free_pages(devmem);
1471 			spin_unlock(&mdevice->lock);
1472 
1473 			dmirror_device_evict_chunk(devmem);
1474 			memunmap_pages(&devmem->pagemap);
1475 			if (devmem->pagemap.type == MEMORY_DEVICE_PRIVATE)
1476 				release_mem_region(devmem->pagemap.range.start,
1477 						   range_len(&devmem->pagemap.range));
1478 			kfree(devmem);
1479 		}
1480 		mdevice->devmem_count = 0;
1481 		mdevice->devmem_capacity = 0;
1482 		mdevice->free_pages = NULL;
1483 		mdevice->free_folios = NULL;
1484 		kfree(mdevice->devmem_chunks);
1485 		mdevice->devmem_chunks = NULL;
1486 	}
1487 	mutex_unlock(&mdevice->devmem_lock);
1488 }
1489 
1490 static long dmirror_fops_unlocked_ioctl(struct file *filp,
1491 					unsigned int command,
1492 					unsigned long arg)
1493 {
1494 	void __user *uarg = (void __user *)arg;
1495 	struct hmm_dmirror_cmd cmd;
1496 	struct dmirror *dmirror;
1497 	int ret;
1498 
1499 	dmirror = filp->private_data;
1500 	if (!dmirror)
1501 		return -EINVAL;
1502 
1503 	if (copy_from_user(&cmd, uarg, sizeof(cmd)))
1504 		return -EFAULT;
1505 
1506 	if (cmd.addr & ~PAGE_MASK)
1507 		return -EINVAL;
1508 	if (cmd.addr >= (cmd.addr + (cmd.npages << PAGE_SHIFT)))
1509 		return -EINVAL;
1510 
1511 	cmd.cpages = 0;
1512 	cmd.faults = 0;
1513 
1514 	switch (command) {
1515 	case HMM_DMIRROR_READ:
1516 		ret = dmirror_read(dmirror, &cmd);
1517 		break;
1518 
1519 	case HMM_DMIRROR_WRITE:
1520 		ret = dmirror_write(dmirror, &cmd);
1521 		break;
1522 
1523 	case HMM_DMIRROR_MIGRATE_TO_DEV:
1524 		ret = dmirror_migrate_to_device(dmirror, &cmd);
1525 		break;
1526 
1527 	case HMM_DMIRROR_MIGRATE_TO_SYS:
1528 		ret = dmirror_migrate_to_system(dmirror, &cmd);
1529 		break;
1530 
1531 	case HMM_DMIRROR_EXCLUSIVE:
1532 		ret = dmirror_exclusive(dmirror, &cmd);
1533 		break;
1534 
1535 	case HMM_DMIRROR_CHECK_EXCLUSIVE:
1536 		ret = dmirror_check_atomic(dmirror, cmd.addr,
1537 					cmd.addr + (cmd.npages << PAGE_SHIFT));
1538 		break;
1539 
1540 	case HMM_DMIRROR_SNAPSHOT:
1541 		ret = dmirror_snapshot(dmirror, &cmd);
1542 		break;
1543 
1544 	case HMM_DMIRROR_RELEASE:
1545 		dmirror_device_remove_chunks(dmirror->mdevice);
1546 		ret = 0;
1547 		break;
1548 	case HMM_DMIRROR_FLAGS:
1549 		dmirror->flags = cmd.npages;
1550 		ret = 0;
1551 		break;
1552 
1553 	default:
1554 		return -EINVAL;
1555 	}
1556 	if (ret)
1557 		return ret;
1558 
1559 	if (copy_to_user(uarg, &cmd, sizeof(cmd)))
1560 		return -EFAULT;
1561 
1562 	return 0;
1563 }
1564 
1565 static int dmirror_fops_mmap(struct file *file, struct vm_area_struct *vma)
1566 {
1567 	unsigned long addr;
1568 
1569 	for (addr = vma->vm_start; addr < vma->vm_end; addr += PAGE_SIZE) {
1570 		struct page *page;
1571 		int ret;
1572 
1573 		page = alloc_page(GFP_KERNEL | __GFP_ZERO);
1574 		if (!page)
1575 			return -ENOMEM;
1576 
1577 		ret = vm_insert_page(vma, addr, page);
1578 		if (ret) {
1579 			__free_page(page);
1580 			return ret;
1581 		}
1582 		put_page(page);
1583 	}
1584 
1585 	return 0;
1586 }
1587 
1588 static const struct file_operations dmirror_fops = {
1589 	.open		= dmirror_fops_open,
1590 	.release	= dmirror_fops_release,
1591 	.mmap		= dmirror_fops_mmap,
1592 	.unlocked_ioctl = dmirror_fops_unlocked_ioctl,
1593 	.llseek		= default_llseek,
1594 	.owner		= THIS_MODULE,
1595 };
1596 
1597 static void dmirror_devmem_free(struct folio *folio)
1598 {
1599 	struct page *page = &folio->page;
1600 	struct page *rpage = BACKING_PAGE(page);
1601 	struct dmirror_device *mdevice;
1602 	struct folio *rfolio = page_folio(rpage);
1603 	unsigned int order = folio_order(rfolio);
1604 
1605 	if (rpage != page) {
1606 		if (order)
1607 			__free_pages(rpage, order);
1608 		else
1609 			__free_page(rpage);
1610 		rpage = NULL;
1611 	}
1612 
1613 	mdevice = dmirror_page_to_device(page);
1614 	spin_lock(&mdevice->lock);
1615 
1616 	/* Return page to our allocator if not freeing the chunk */
1617 	if (!dmirror_page_to_chunk(page)->remove) {
1618 		mdevice->cfree += 1 << order;
1619 		if (order) {
1620 			page->zone_device_data = mdevice->free_folios;
1621 			mdevice->free_folios = page_folio(page);
1622 		} else {
1623 			page->zone_device_data = mdevice->free_pages;
1624 			mdevice->free_pages = page;
1625 		}
1626 	}
1627 	spin_unlock(&mdevice->lock);
1628 }
1629 
1630 static vm_fault_t dmirror_devmem_fault(struct vm_fault *vmf)
1631 {
1632 	struct migrate_vma args = { 0 };
1633 	struct page *rpage;
1634 	struct dmirror *dmirror;
1635 	vm_fault_t ret = 0;
1636 	unsigned int order, nr;
1637 
1638 	/*
1639 	 * Normally, a device would use the page->zone_device_data to point to
1640 	 * the mirror but here we use it to hold the page for the simulated
1641 	 * device memory and that page holds the pointer to the mirror.
1642 	 */
1643 	rpage = folio_zone_device_data(page_folio(vmf->page));
1644 	dmirror = rpage->zone_device_data;
1645 
1646 	/* FIXME demonstrate how we can adjust migrate range */
1647 	order = folio_order(page_folio(vmf->page));
1648 	nr = 1 << order;
1649 
1650 	/*
1651 	 * When folios are partially mapped, we can't rely on the folio
1652 	 * order of vmf->page as the folio might not be fully split yet
1653 	 */
1654 	if (vmf->pte) {
1655 		order = 0;
1656 		nr = 1;
1657 	}
1658 
1659 	/*
1660 	 * Consider a per-cpu cache of src and dst pfns, but with
1661 	 * large number of cpus that might not scale well.
1662 	 */
1663 	args.start = ALIGN_DOWN(vmf->address, (PAGE_SIZE << order));
1664 	args.vma = vmf->vma;
1665 	args.end = args.start + (PAGE_SIZE << order);
1666 
1667 	nr = (args.end - args.start) >> PAGE_SHIFT;
1668 	args.src = kcalloc(nr, sizeof(unsigned long), GFP_KERNEL);
1669 	args.dst = kcalloc(nr, sizeof(unsigned long), GFP_KERNEL);
1670 	args.pgmap_owner = dmirror->mdevice;
1671 	args.flags = dmirror_select_device(dmirror);
1672 	args.fault_page = vmf->page;
1673 
1674 	if (!args.src || !args.dst) {
1675 		ret = VM_FAULT_OOM;
1676 		goto err;
1677 	}
1678 
1679 	if (order)
1680 		args.flags |= MIGRATE_VMA_SELECT_COMPOUND;
1681 
1682 	if (migrate_vma_setup(&args))
1683 		return VM_FAULT_SIGBUS;
1684 
1685 	ret = dmirror_devmem_fault_alloc_and_copy(&args, dmirror);
1686 	if (ret)
1687 		goto err;
1688 	migrate_vma_pages(&args);
1689 	/*
1690 	 * No device finalize step is needed since
1691 	 * dmirror_devmem_fault_alloc_and_copy() will have already
1692 	 * invalidated the device page table.
1693 	 */
1694 	migrate_vma_finalize(&args);
1695 err:
1696 	kfree(args.src);
1697 	kfree(args.dst);
1698 	return ret;
1699 }
1700 
1701 static void dmirror_devmem_folio_split(struct folio *head, struct folio *tail)
1702 {
1703 	struct page *rpage = BACKING_PAGE(folio_page(head, 0));
1704 	struct page *rpage_tail;
1705 	struct folio *rfolio;
1706 	unsigned long offset = 0;
1707 
1708 	if (!rpage) {
1709 		tail->page.zone_device_data = NULL;
1710 		return;
1711 	}
1712 
1713 	rfolio = page_folio(rpage);
1714 
1715 	if (tail == NULL) {
1716 		folio_reset_order(rfolio);
1717 		rfolio->mapping = NULL;
1718 		folio_set_count(rfolio, 1);
1719 		return;
1720 	}
1721 
1722 	offset = folio_pfn(tail) - folio_pfn(head);
1723 
1724 	rpage_tail = folio_page(rfolio, offset);
1725 	tail->page.zone_device_data = rpage_tail;
1726 	rpage_tail->zone_device_data = rpage->zone_device_data;
1727 	clear_compound_head(rpage_tail);
1728 	rpage_tail->mapping = NULL;
1729 
1730 	folio_page(tail, 0)->mapping = folio_page(head, 0)->mapping;
1731 	tail->pgmap = head->pgmap;
1732 	folio_set_count(page_folio(rpage_tail), 1);
1733 }
1734 
1735 static const struct dev_pagemap_ops dmirror_devmem_ops = {
1736 	.folio_free	= dmirror_devmem_free,
1737 	.migrate_to_ram	= dmirror_devmem_fault,
1738 	.folio_split	= dmirror_devmem_folio_split,
1739 };
1740 
1741 static void dmirror_device_release(struct device *dev)
1742 {
1743 	struct dmirror_device *mdevice = container_of(dev, struct dmirror_device, device);
1744 
1745 	dmirror_device_remove_chunks(mdevice);
1746 }
1747 
1748 static int dmirror_device_init(struct dmirror_device *mdevice, int id)
1749 {
1750 	dev_t dev;
1751 	int ret;
1752 
1753 	dev = MKDEV(MAJOR(dmirror_dev), id);
1754 	mutex_init(&mdevice->devmem_lock);
1755 	spin_lock_init(&mdevice->lock);
1756 
1757 	cdev_init(&mdevice->cdevice, &dmirror_fops);
1758 	mdevice->cdevice.owner = THIS_MODULE;
1759 	mdevice->device.release = dmirror_device_release;
1760 
1761 	device_initialize(&mdevice->device);
1762 	mdevice->device.devt = dev;
1763 
1764 	ret = dev_set_name(&mdevice->device, "hmm_dmirror%u", id);
1765 	if (ret)
1766 		goto put_device;
1767 
1768 	/* Build a list of free ZONE_DEVICE struct pages */
1769 	ret = dmirror_allocate_chunk(mdevice, NULL, false);
1770 	if (ret)
1771 		goto put_device;
1772 
1773 	ret = cdev_device_add(&mdevice->cdevice, &mdevice->device);
1774 	if (ret)
1775 		goto put_device;
1776 
1777 	return 0;
1778 
1779 put_device:
1780 	put_device(&mdevice->device);
1781 	return ret;
1782 }
1783 
1784 static void dmirror_device_remove(struct dmirror_device *mdevice)
1785 {
1786 	cdev_device_del(&mdevice->cdevice, &mdevice->device);
1787 	put_device(&mdevice->device);
1788 }
1789 
1790 static int __init hmm_dmirror_init(void)
1791 {
1792 	int ret;
1793 	int id = 0;
1794 	int ndevices = 0;
1795 
1796 	ret = alloc_chrdev_region(&dmirror_dev, 0, DMIRROR_NDEVICES,
1797 				  "HMM_DMIRROR");
1798 	if (ret)
1799 		goto err_unreg;
1800 
1801 	memset(dmirror_devices, 0, DMIRROR_NDEVICES * sizeof(dmirror_devices[0]));
1802 	dmirror_devices[ndevices++].zone_device_type =
1803 				HMM_DMIRROR_MEMORY_DEVICE_PRIVATE;
1804 	dmirror_devices[ndevices++].zone_device_type =
1805 				HMM_DMIRROR_MEMORY_DEVICE_PRIVATE;
1806 	if (spm_addr_dev0 && spm_addr_dev1) {
1807 		dmirror_devices[ndevices++].zone_device_type =
1808 					HMM_DMIRROR_MEMORY_DEVICE_COHERENT;
1809 		dmirror_devices[ndevices++].zone_device_type =
1810 					HMM_DMIRROR_MEMORY_DEVICE_COHERENT;
1811 	}
1812 	for (id = 0; id < ndevices; id++) {
1813 		ret = dmirror_device_init(dmirror_devices + id, id);
1814 		if (ret)
1815 			goto err_chrdev;
1816 	}
1817 
1818 	pr_info("HMM test module loaded. This is only for testing HMM.\n");
1819 	return 0;
1820 
1821 err_chrdev:
1822 	while (--id >= 0)
1823 		dmirror_device_remove(dmirror_devices + id);
1824 	unregister_chrdev_region(dmirror_dev, DMIRROR_NDEVICES);
1825 err_unreg:
1826 	return ret;
1827 }
1828 
1829 static void __exit hmm_dmirror_exit(void)
1830 {
1831 	int id;
1832 
1833 	for (id = 0; id < DMIRROR_NDEVICES; id++)
1834 		if (dmirror_devices[id].zone_device_type)
1835 			dmirror_device_remove(dmirror_devices + id);
1836 	unregister_chrdev_region(dmirror_dev, DMIRROR_NDEVICES);
1837 }
1838 
1839 module_init(hmm_dmirror_init);
1840 module_exit(hmm_dmirror_exit);
1841 MODULE_DESCRIPTION("HMM (Heterogeneous Memory Management) test module");
1842 MODULE_LICENSE("GPL");
1843