xref: /linux/lib/test_hmm.c (revision bba2c3615bd6cfee7456d1130f2e6b01b3f4e9ba)
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * This is a module to test the HMM (Heterogeneous Memory Management)
4  * mirror and zone device private memory migration APIs of the kernel.
5  * Userspace programs can register with the driver to mirror their own address
6  * space and can use the device to read/write any valid virtual address.
7  */
8 #include <linux/init.h>
9 #include <linux/fs.h>
10 #include <linux/mm.h>
11 #include <linux/module.h>
12 #include <linux/kernel.h>
13 #include <linux/cdev.h>
14 #include <linux/device.h>
15 #include <linux/memremap.h>
16 #include <linux/mutex.h>
17 #include <linux/rwsem.h>
18 #include <linux/sched.h>
19 #include <linux/slab.h>
20 #include <linux/highmem.h>
21 #include <linux/delay.h>
22 #include <linux/pagemap.h>
23 #include <linux/hmm.h>
24 #include <linux/vmalloc.h>
25 #include <linux/swap.h>
26 #include <linux/swapops.h>
27 #include <linux/sched/mm.h>
28 #include <linux/platform_device.h>
29 #include <linux/rmap.h>
30 #include <linux/mmu_notifier.h>
31 #include <linux/migrate.h>
32 
33 #include "test_hmm_uapi.h"
34 
35 #define DMIRROR_NDEVICES		4
36 #define DMIRROR_RANGE_FAULT_TIMEOUT	1000
37 #define DEVMEM_CHUNK_SIZE		(256 * 1024 * 1024U)
38 #define DEVMEM_CHUNKS_RESERVE		16
39 
40 /*
41  * For device_private pages, dpage is just a dummy struct page
42  * representing a piece of device memory. dmirror_devmem_alloc_page
43  * allocates a real system memory page as backing storage to fake a
44  * real device. zone_device_data points to that backing page. But
45  * for device_coherent memory, the struct page represents real
46  * physical CPU-accessible memory that we can use directly.
47  */
48 #define BACKING_PAGE(page) (is_device_private_page((page)) ? \
49 			   (page)->zone_device_data : (page))
50 
51 static unsigned long spm_addr_dev0;
52 module_param(spm_addr_dev0, long, 0644);
53 MODULE_PARM_DESC(spm_addr_dev0,
54 		"Specify start address for SPM (special purpose memory) used for device 0. By setting this Coherent device type will be used. Make sure spm_addr_dev1 is set too. Minimum SPM size should be DEVMEM_CHUNK_SIZE.");
55 
56 static unsigned long spm_addr_dev1;
57 module_param(spm_addr_dev1, long, 0644);
58 MODULE_PARM_DESC(spm_addr_dev1,
59 		"Specify start address for SPM (special purpose memory) used for device 1. By setting this Coherent device type will be used. Make sure spm_addr_dev0 is set too. Minimum SPM size should be DEVMEM_CHUNK_SIZE.");
60 
61 static const struct dev_pagemap_ops dmirror_devmem_ops;
62 static const struct mmu_interval_notifier_ops dmirror_min_ops;
63 static dev_t dmirror_dev;
64 
65 struct dmirror_device;
66 
67 struct dmirror_bounce {
68 	void			*ptr;
69 	unsigned long		size;
70 	unsigned long		addr;
71 	unsigned long		cpages;
72 };
73 
74 #define DPT_XA_TAG_ATOMIC 1UL
75 #define DPT_XA_TAG_WRITE 3UL
76 
77 /*
78  * Data structure to track address ranges and register for mmu interval
79  * notifier updates.
80  */
81 struct dmirror_interval {
82 	struct mmu_interval_notifier	notifier;
83 	struct dmirror			*dmirror;
84 };
85 
86 /*
87  * Data attached to the open device file.
88  * Note that it might be shared after a fork().
89  */
90 struct dmirror {
91 	struct dmirror_device		*mdevice;
92 	struct xarray			pt;
93 	struct mmu_interval_notifier	notifier;
94 	struct mutex			mutex;
95 	__u64			flags;
96 };
97 
98 /*
99  * ZONE_DEVICE pages for migration and simulating device memory.
100  */
101 struct dmirror_chunk {
102 	struct dev_pagemap	pagemap;
103 	struct dmirror_device	*mdevice;
104 	bool remove;
105 };
106 
107 /*
108  * Per device data.
109  */
110 struct dmirror_device {
111 	struct cdev		cdevice;
112 	unsigned int            zone_device_type;
113 	struct device		device;
114 
115 	unsigned int		devmem_capacity;
116 	unsigned int		devmem_count;
117 	struct dmirror_chunk	**devmem_chunks;
118 	struct mutex		devmem_lock;	/* protects the above */
119 
120 	unsigned long		calloc;
121 	unsigned long		cfree;
122 	struct page		*free_pages;
123 	struct folio		*free_folios;
124 	spinlock_t		lock;		/* protects the above */
125 };
126 
127 static struct dmirror_device dmirror_devices[DMIRROR_NDEVICES];
128 
129 static int dmirror_bounce_init(struct dmirror_bounce *bounce,
130 			       unsigned long addr,
131 			       unsigned long size)
132 {
133 	bounce->addr = addr;
134 	bounce->size = size;
135 	bounce->cpages = 0;
136 	bounce->ptr = vmalloc(size);
137 	if (!bounce->ptr)
138 		return -ENOMEM;
139 	return 0;
140 }
141 
142 static bool dmirror_is_private_zone(struct dmirror_device *mdevice)
143 {
144 	return (mdevice->zone_device_type ==
145 		HMM_DMIRROR_MEMORY_DEVICE_PRIVATE);
146 }
147 
148 static enum migrate_vma_direction
149 dmirror_select_device(struct dmirror *dmirror)
150 {
151 	return (dmirror->mdevice->zone_device_type ==
152 		HMM_DMIRROR_MEMORY_DEVICE_PRIVATE) ?
153 		MIGRATE_VMA_SELECT_DEVICE_PRIVATE :
154 		MIGRATE_VMA_SELECT_DEVICE_COHERENT;
155 }
156 
157 static void dmirror_bounce_fini(struct dmirror_bounce *bounce)
158 {
159 	vfree(bounce->ptr);
160 }
161 
162 static int dmirror_fops_open(struct inode *inode, struct file *filp)
163 {
164 	struct cdev *cdev = inode->i_cdev;
165 	struct dmirror *dmirror;
166 	int ret;
167 
168 	/* Mirror this process address space */
169 	dmirror = kzalloc_obj(*dmirror);
170 	if (dmirror == NULL)
171 		return -ENOMEM;
172 
173 	dmirror->mdevice = container_of(cdev, struct dmirror_device, cdevice);
174 	mutex_init(&dmirror->mutex);
175 	xa_init(&dmirror->pt);
176 
177 	ret = mmu_interval_notifier_insert(&dmirror->notifier, current->mm,
178 				0, ULONG_MAX & PAGE_MASK, &dmirror_min_ops);
179 	if (ret) {
180 		kfree(dmirror);
181 		return ret;
182 	}
183 
184 	filp->private_data = dmirror;
185 	return 0;
186 }
187 
188 static void dmirror_device_evict_chunk(struct dmirror_chunk *chunk)
189 {
190 	unsigned long start_pfn = chunk->pagemap.range.start >> PAGE_SHIFT;
191 	unsigned long end_pfn = chunk->pagemap.range.end >> PAGE_SHIFT;
192 	unsigned long npages = end_pfn - start_pfn + 1;
193 	unsigned long i;
194 	unsigned long *src_pfns;
195 	unsigned long *dst_pfns;
196 	unsigned int order = 0;
197 
198 	src_pfns = kvcalloc(npages, sizeof(*src_pfns), GFP_KERNEL | __GFP_NOFAIL);
199 	dst_pfns = kvcalloc(npages, sizeof(*dst_pfns), GFP_KERNEL | __GFP_NOFAIL);
200 
201 	migrate_device_range(src_pfns, start_pfn, npages);
202 	for (i = 0; i < npages; i++) {
203 		struct page *dpage, *spage;
204 
205 		spage = migrate_pfn_to_page(src_pfns[i]);
206 		if (!spage || !(src_pfns[i] & MIGRATE_PFN_MIGRATE))
207 			continue;
208 
209 		if (WARN_ON(!is_device_private_page(spage) &&
210 			    !is_device_coherent_page(spage)))
211 			continue;
212 
213 		order = folio_order(page_folio(spage));
214 		spage = BACKING_PAGE(spage);
215 		if (src_pfns[i] & MIGRATE_PFN_COMPOUND) {
216 			dpage = folio_page(folio_alloc(GFP_HIGHUSER_MOVABLE,
217 					      order), 0);
218 		} else {
219 			dpage = alloc_page(GFP_HIGHUSER_MOVABLE | __GFP_NOFAIL);
220 			order = 0;
221 		}
222 
223 		/* TODO Support splitting here */
224 		lock_page(dpage);
225 		dst_pfns[i] = migrate_pfn(page_to_pfn(dpage));
226 		if (src_pfns[i] & MIGRATE_PFN_WRITE)
227 			dst_pfns[i] |= MIGRATE_PFN_WRITE;
228 		if (order)
229 			dst_pfns[i] |= MIGRATE_PFN_COMPOUND;
230 		folio_copy(page_folio(dpage), page_folio(spage));
231 	}
232 	migrate_device_pages(src_pfns, dst_pfns, npages);
233 	migrate_device_finalize(src_pfns, dst_pfns, npages);
234 	kvfree(src_pfns);
235 	kvfree(dst_pfns);
236 }
237 
238 static int dmirror_fops_release(struct inode *inode, struct file *filp)
239 {
240 	struct dmirror *dmirror = filp->private_data;
241 	struct dmirror_device *mdevice = dmirror->mdevice;
242 	int i;
243 
244 	mmu_interval_notifier_remove(&dmirror->notifier);
245 
246 	if (mdevice->devmem_chunks) {
247 		for (i = 0; i < mdevice->devmem_count; i++) {
248 			struct dmirror_chunk *devmem =
249 				mdevice->devmem_chunks[i];
250 
251 			dmirror_device_evict_chunk(devmem);
252 		}
253 	}
254 
255 	xa_destroy(&dmirror->pt);
256 	kfree(dmirror);
257 	return 0;
258 }
259 
260 static struct dmirror_chunk *dmirror_page_to_chunk(struct page *page)
261 {
262 	return container_of(page_pgmap(page), struct dmirror_chunk,
263 			    pagemap);
264 }
265 
266 static struct dmirror_device *dmirror_page_to_device(struct page *page)
267 
268 {
269 	return dmirror_page_to_chunk(page)->mdevice;
270 }
271 
272 static int dmirror_do_fault(struct dmirror *dmirror, struct hmm_range *range)
273 {
274 	unsigned long *pfns = range->hmm_pfns;
275 	unsigned long pfn;
276 
277 	for (pfn = (range->start >> PAGE_SHIFT);
278 	     pfn < (range->end >> PAGE_SHIFT);
279 	     pfn++, pfns++) {
280 		struct page *page;
281 		void *entry;
282 
283 		/*
284 		 * Since we asked for hmm_range_fault() to populate pages,
285 		 * it shouldn't return an error entry on success.
286 		 */
287 		WARN_ON(*pfns & HMM_PFN_ERROR);
288 		WARN_ON(!(*pfns & HMM_PFN_VALID));
289 
290 		page = hmm_pfn_to_page(*pfns);
291 		WARN_ON(!page);
292 
293 		entry = page;
294 		if (*pfns & HMM_PFN_WRITE)
295 			entry = xa_tag_pointer(entry, DPT_XA_TAG_WRITE);
296 		else if (WARN_ON(range->default_flags & HMM_PFN_WRITE))
297 			return -EFAULT;
298 		entry = xa_store(&dmirror->pt, pfn, entry, GFP_ATOMIC);
299 		if (xa_is_err(entry))
300 			return xa_err(entry);
301 	}
302 
303 	return 0;
304 }
305 
306 static void dmirror_do_update(struct dmirror *dmirror, unsigned long start,
307 			      unsigned long end)
308 {
309 	unsigned long pfn;
310 	void *entry;
311 
312 	/*
313 	 * The XArray doesn't hold references to pages since it relies on
314 	 * the mmu notifier to clear page pointers when they become stale.
315 	 * Therefore, it is OK to just clear the entry.
316 	 */
317 	xa_for_each_range(&dmirror->pt, pfn, entry, start >> PAGE_SHIFT,
318 			  end >> PAGE_SHIFT)
319 		xa_erase(&dmirror->pt, pfn);
320 }
321 
322 static bool dmirror_interval_invalidate(struct mmu_interval_notifier *mni,
323 				const struct mmu_notifier_range *range,
324 				unsigned long cur_seq)
325 {
326 	struct dmirror *dmirror = container_of(mni, struct dmirror, notifier);
327 
328 	/*
329 	 * Ignore invalidation callbacks for device private pages since
330 	 * the invalidation is handled as part of the migration process.
331 	 */
332 	if (range->event == MMU_NOTIFY_MIGRATE &&
333 	    range->owner == dmirror->mdevice)
334 		return true;
335 
336 	if (mmu_notifier_range_blockable(range))
337 		mutex_lock(&dmirror->mutex);
338 	else if (!mutex_trylock(&dmirror->mutex))
339 		return false;
340 
341 	mmu_interval_set_seq(mni, cur_seq);
342 	dmirror_do_update(dmirror, range->start, range->end);
343 
344 	mutex_unlock(&dmirror->mutex);
345 	return true;
346 }
347 
348 static const struct mmu_interval_notifier_ops dmirror_min_ops = {
349 	.invalidate = dmirror_interval_invalidate,
350 };
351 
352 static int dmirror_range_fault(struct dmirror *dmirror,
353 				struct hmm_range *range)
354 {
355 	struct mm_struct *mm = dmirror->notifier.mm;
356 	unsigned long timeout =
357 		jiffies + msecs_to_jiffies(HMM_RANGE_DEFAULT_TIMEOUT);
358 	int ret;
359 
360 	while (true) {
361 		if (time_after(jiffies, timeout)) {
362 			ret = -EBUSY;
363 			goto out;
364 		}
365 
366 		range->notifier_seq = mmu_interval_read_begin(range->notifier);
367 		mmap_read_lock(mm);
368 		ret = hmm_range_fault(range);
369 		mmap_read_unlock(mm);
370 		if (ret) {
371 			if (ret == -EBUSY)
372 				continue;
373 			goto out;
374 		}
375 
376 		mutex_lock(&dmirror->mutex);
377 		if (mmu_interval_read_retry(range->notifier,
378 					    range->notifier_seq)) {
379 			mutex_unlock(&dmirror->mutex);
380 			continue;
381 		}
382 		break;
383 	}
384 
385 	ret = dmirror_do_fault(dmirror, range);
386 
387 	mutex_unlock(&dmirror->mutex);
388 out:
389 	return ret;
390 }
391 
392 static int dmirror_fault(struct dmirror *dmirror, unsigned long start,
393 			 unsigned long end, bool write)
394 {
395 	struct mm_struct *mm = dmirror->notifier.mm;
396 	unsigned long addr;
397 	unsigned long pfns[32];
398 	struct hmm_range range = {
399 		.notifier = &dmirror->notifier,
400 		.hmm_pfns = pfns,
401 		.pfn_flags_mask = 0,
402 		.default_flags =
403 			HMM_PFN_REQ_FAULT | (write ? HMM_PFN_REQ_WRITE : 0),
404 		.dev_private_owner = dmirror->mdevice,
405 	};
406 	int ret = 0;
407 
408 	/* Since the mm is for the mirrored process, get a reference first. */
409 	if (!mmget_not_zero(mm))
410 		return 0;
411 
412 	for (addr = start; addr < end; addr = range.end) {
413 		range.start = addr;
414 		range.end = min(addr + (ARRAY_SIZE(pfns) << PAGE_SHIFT), end);
415 
416 		ret = dmirror_range_fault(dmirror, &range);
417 		if (ret)
418 			break;
419 	}
420 
421 	mmput(mm);
422 	return ret;
423 }
424 
425 static int dmirror_do_read(struct dmirror *dmirror, unsigned long start,
426 			   unsigned long end, struct dmirror_bounce *bounce)
427 {
428 	unsigned long pfn;
429 	void *ptr;
430 
431 	ptr = bounce->ptr + ((start - bounce->addr) & PAGE_MASK);
432 
433 	for (pfn = start >> PAGE_SHIFT; pfn < (end >> PAGE_SHIFT); pfn++) {
434 		void *entry;
435 		struct page *page;
436 
437 		entry = xa_load(&dmirror->pt, pfn);
438 		page = xa_untag_pointer(entry);
439 		if (!page)
440 			return -ENOENT;
441 
442 		memcpy_from_page(ptr, page, 0, PAGE_SIZE);
443 
444 		ptr += PAGE_SIZE;
445 		bounce->cpages++;
446 	}
447 
448 	return 0;
449 }
450 
451 static int dmirror_read(struct dmirror *dmirror, struct hmm_dmirror_cmd *cmd)
452 {
453 	struct dmirror_bounce bounce;
454 	unsigned long start, end;
455 	unsigned long size = cmd->npages << PAGE_SHIFT;
456 	int ret;
457 
458 	start = cmd->addr;
459 	end = start + size;
460 	if (end < start)
461 		return -EINVAL;
462 
463 	ret = dmirror_bounce_init(&bounce, start, size);
464 	if (ret)
465 		return ret;
466 
467 	while (1) {
468 		mutex_lock(&dmirror->mutex);
469 		ret = dmirror_do_read(dmirror, start, end, &bounce);
470 		mutex_unlock(&dmirror->mutex);
471 		if (ret != -ENOENT)
472 			break;
473 
474 		start = cmd->addr + (bounce.cpages << PAGE_SHIFT);
475 		ret = dmirror_fault(dmirror, start, end, false);
476 		if (ret)
477 			break;
478 		cmd->faults++;
479 	}
480 
481 	if (ret == 0) {
482 		if (copy_to_user(u64_to_user_ptr(cmd->ptr), bounce.ptr,
483 				 bounce.size))
484 			ret = -EFAULT;
485 	}
486 	cmd->cpages = bounce.cpages;
487 	dmirror_bounce_fini(&bounce);
488 	return ret;
489 }
490 
491 static int dmirror_do_write(struct dmirror *dmirror, unsigned long start,
492 			    unsigned long end, struct dmirror_bounce *bounce)
493 {
494 	unsigned long pfn;
495 	void *ptr;
496 
497 	ptr = bounce->ptr + ((start - bounce->addr) & PAGE_MASK);
498 
499 	for (pfn = start >> PAGE_SHIFT; pfn < (end >> PAGE_SHIFT); pfn++) {
500 		void *entry;
501 		struct page *page;
502 
503 		entry = xa_load(&dmirror->pt, pfn);
504 		page = xa_untag_pointer(entry);
505 		if (!page || xa_pointer_tag(entry) != DPT_XA_TAG_WRITE)
506 			return -ENOENT;
507 
508 		memcpy_to_page(page, 0, ptr, PAGE_SIZE);
509 
510 		ptr += PAGE_SIZE;
511 		bounce->cpages++;
512 	}
513 
514 	return 0;
515 }
516 
517 static int dmirror_write(struct dmirror *dmirror, struct hmm_dmirror_cmd *cmd)
518 {
519 	struct dmirror_bounce bounce;
520 	unsigned long start, end;
521 	unsigned long size = cmd->npages << PAGE_SHIFT;
522 	int ret;
523 
524 	start = cmd->addr;
525 	end = start + size;
526 	if (end < start)
527 		return -EINVAL;
528 
529 	ret = dmirror_bounce_init(&bounce, start, size);
530 	if (ret)
531 		return ret;
532 	if (copy_from_user(bounce.ptr, u64_to_user_ptr(cmd->ptr),
533 			   bounce.size)) {
534 		ret = -EFAULT;
535 		goto fini;
536 	}
537 
538 	while (1) {
539 		mutex_lock(&dmirror->mutex);
540 		ret = dmirror_do_write(dmirror, start, end, &bounce);
541 		mutex_unlock(&dmirror->mutex);
542 		if (ret != -ENOENT)
543 			break;
544 
545 		start = cmd->addr + (bounce.cpages << PAGE_SHIFT);
546 		ret = dmirror_fault(dmirror, start, end, true);
547 		if (ret)
548 			break;
549 		cmd->faults++;
550 	}
551 
552 fini:
553 	cmd->cpages = bounce.cpages;
554 	dmirror_bounce_fini(&bounce);
555 	return ret;
556 }
557 
558 static int dmirror_allocate_chunk(struct dmirror_device *mdevice,
559 				  struct page **ppage, bool is_large)
560 {
561 	struct dmirror_chunk *devmem;
562 	struct resource *res = NULL;
563 	unsigned long pfn;
564 	unsigned long pfn_first;
565 	unsigned long pfn_last;
566 	void *ptr;
567 	int ret = -ENOMEM;
568 
569 	devmem = kzalloc_obj(*devmem);
570 	if (!devmem)
571 		return ret;
572 
573 	switch (mdevice->zone_device_type) {
574 	case HMM_DMIRROR_MEMORY_DEVICE_PRIVATE:
575 		res = request_free_mem_region(&iomem_resource, DEVMEM_CHUNK_SIZE,
576 					      "hmm_dmirror");
577 		if (IS_ERR_OR_NULL(res))
578 			goto err_devmem;
579 		devmem->pagemap.range.start = res->start;
580 		devmem->pagemap.range.end = res->end;
581 		devmem->pagemap.type = MEMORY_DEVICE_PRIVATE;
582 		break;
583 	case HMM_DMIRROR_MEMORY_DEVICE_COHERENT:
584 		devmem->pagemap.range.start = (MINOR(mdevice->cdevice.dev) - 2) ?
585 							spm_addr_dev0 :
586 							spm_addr_dev1;
587 		devmem->pagemap.range.end = devmem->pagemap.range.start +
588 					    DEVMEM_CHUNK_SIZE - 1;
589 		devmem->pagemap.type = MEMORY_DEVICE_COHERENT;
590 		break;
591 	default:
592 		ret = -EINVAL;
593 		goto err_devmem;
594 	}
595 
596 	devmem->pagemap.nr_range = 1;
597 	devmem->pagemap.ops = &dmirror_devmem_ops;
598 	devmem->pagemap.owner = mdevice;
599 
600 	mutex_lock(&mdevice->devmem_lock);
601 
602 	if (mdevice->devmem_count == mdevice->devmem_capacity) {
603 		struct dmirror_chunk **new_chunks;
604 		unsigned int new_capacity;
605 
606 		new_capacity = mdevice->devmem_capacity +
607 				DEVMEM_CHUNKS_RESERVE;
608 		new_chunks = krealloc(mdevice->devmem_chunks,
609 				sizeof(new_chunks[0]) * new_capacity,
610 				GFP_KERNEL);
611 		if (!new_chunks)
612 			goto err_release;
613 		mdevice->devmem_capacity = new_capacity;
614 		mdevice->devmem_chunks = new_chunks;
615 	}
616 	ptr = memremap_pages(&devmem->pagemap, numa_node_id());
617 	if (IS_ERR_OR_NULL(ptr)) {
618 		if (ptr)
619 			ret = PTR_ERR(ptr);
620 		else
621 			ret = -EFAULT;
622 		goto err_release;
623 	}
624 
625 	devmem->mdevice = mdevice;
626 	pfn_first = devmem->pagemap.range.start >> PAGE_SHIFT;
627 	pfn_last = pfn_first + (range_len(&devmem->pagemap.range) >> PAGE_SHIFT);
628 	mdevice->devmem_chunks[mdevice->devmem_count++] = devmem;
629 
630 	mutex_unlock(&mdevice->devmem_lock);
631 
632 	pr_info("added new %u MB chunk (total %u chunks, %u MB) PFNs [0x%lx 0x%lx)\n",
633 		DEVMEM_CHUNK_SIZE / (1024 * 1024),
634 		mdevice->devmem_count,
635 		mdevice->devmem_count * (DEVMEM_CHUNK_SIZE / (1024 * 1024)),
636 		pfn_first, pfn_last);
637 
638 	spin_lock(&mdevice->lock);
639 	for (pfn = pfn_first; pfn < pfn_last; ) {
640 		struct page *page = pfn_to_page(pfn);
641 
642 		if (is_large && IS_ALIGNED(pfn, HPAGE_PMD_NR)
643 			&& (pfn + HPAGE_PMD_NR <= pfn_last)) {
644 			page->zone_device_data = mdevice->free_folios;
645 			mdevice->free_folios = page_folio(page);
646 			pfn += HPAGE_PMD_NR;
647 			continue;
648 		}
649 
650 		page->zone_device_data = mdevice->free_pages;
651 		mdevice->free_pages = page;
652 		pfn++;
653 	}
654 
655 	ret = 0;
656 	if (ppage) {
657 		if (is_large) {
658 			if (!mdevice->free_folios) {
659 				ret = -ENOMEM;
660 				goto err_unlock;
661 			}
662 			*ppage = folio_page(mdevice->free_folios, 0);
663 			mdevice->free_folios = (*ppage)->zone_device_data;
664 			mdevice->calloc += HPAGE_PMD_NR;
665 		} else if (mdevice->free_pages) {
666 			*ppage = mdevice->free_pages;
667 			mdevice->free_pages = (*ppage)->zone_device_data;
668 			mdevice->calloc++;
669 		} else {
670 			ret = -ENOMEM;
671 			goto err_unlock;
672 		}
673 	}
674 err_unlock:
675 	spin_unlock(&mdevice->lock);
676 
677 	return ret;
678 
679 err_release:
680 	mutex_unlock(&mdevice->devmem_lock);
681 	if (res && devmem->pagemap.type == MEMORY_DEVICE_PRIVATE)
682 		release_mem_region(devmem->pagemap.range.start,
683 				   range_len(&devmem->pagemap.range));
684 err_devmem:
685 	kfree(devmem);
686 
687 	return ret;
688 }
689 
690 static struct page *dmirror_devmem_alloc_page(struct dmirror *dmirror,
691 					      bool is_large)
692 {
693 	struct page *dpage = NULL;
694 	struct page *rpage = NULL;
695 	unsigned int order = is_large ? HPAGE_PMD_ORDER : 0;
696 	struct dmirror_device *mdevice = dmirror->mdevice;
697 
698 	/*
699 	 * For ZONE_DEVICE private type, this is a fake device so we allocate
700 	 * real system memory to store our device memory.
701 	 * For ZONE_DEVICE coherent type we use the actual dpage to store the
702 	 * data and ignore rpage.
703 	 */
704 	if (dmirror_is_private_zone(mdevice)) {
705 		rpage = folio_page(folio_alloc(GFP_HIGHUSER, order), 0);
706 		if (!rpage)
707 			return NULL;
708 	}
709 	spin_lock(&mdevice->lock);
710 
711 	if (is_large && mdevice->free_folios) {
712 		dpage = folio_page(mdevice->free_folios, 0);
713 		mdevice->free_folios = dpage->zone_device_data;
714 		mdevice->calloc += 1 << order;
715 		spin_unlock(&mdevice->lock);
716 	} else if (!is_large && mdevice->free_pages) {
717 		dpage = mdevice->free_pages;
718 		mdevice->free_pages = dpage->zone_device_data;
719 		mdevice->calloc++;
720 		spin_unlock(&mdevice->lock);
721 	} else {
722 		spin_unlock(&mdevice->lock);
723 		if (dmirror_allocate_chunk(mdevice, &dpage, is_large))
724 			goto error;
725 	}
726 
727 	zone_device_folio_init(page_folio(dpage),
728 			       page_pgmap(folio_page(page_folio(dpage), 0)),
729 			       order);
730 	dpage->zone_device_data = rpage;
731 	return dpage;
732 
733 error:
734 	if (rpage)
735 		__free_pages(rpage, order);
736 	return NULL;
737 }
738 
739 static void dmirror_migrate_alloc_and_copy(struct migrate_vma *args,
740 					   struct dmirror *dmirror)
741 {
742 	const unsigned long *src = args->src;
743 	unsigned long *dst = args->dst;
744 	unsigned long addr;
745 
746 	for (addr = args->start; addr < args->end; ) {
747 		struct page *spage;
748 		struct page *dpage;
749 		struct page *rpage;
750 		bool is_large = *src & MIGRATE_PFN_COMPOUND;
751 		int write = (*src & MIGRATE_PFN_WRITE) ? MIGRATE_PFN_WRITE : 0;
752 		unsigned long nr = 1;
753 
754 		if (!(*src & MIGRATE_PFN_MIGRATE))
755 			goto next;
756 
757 		/*
758 		 * Note that spage might be NULL which is OK since it is an
759 		 * unallocated pte_none() or read-only zero page.
760 		 */
761 		spage = migrate_pfn_to_page(*src);
762 		if (WARN(spage && is_zone_device_page(spage),
763 		     "page already in device spage pfn: 0x%lx\n",
764 		     page_to_pfn(spage)))
765 			goto next;
766 
767 		if (dmirror->flags & HMM_DMIRROR_FLAG_FAIL_ALLOC) {
768 			dmirror->flags &= ~HMM_DMIRROR_FLAG_FAIL_ALLOC;
769 			dpage = NULL;
770 		} else
771 			dpage = dmirror_devmem_alloc_page(dmirror, is_large);
772 
773 		if (!dpage) {
774 			struct folio *folio;
775 			unsigned long i;
776 			unsigned long spfn = *src >> MIGRATE_PFN_SHIFT;
777 			struct page *src_page;
778 
779 			if (!is_large)
780 				goto next;
781 
782 			if (!spage && is_large) {
783 				nr = HPAGE_PMD_NR;
784 			} else {
785 				folio = page_folio(spage);
786 				nr = folio_nr_pages(folio);
787 			}
788 
789 			for (i = 0; i < nr && addr < args->end; i++) {
790 				dpage = dmirror_devmem_alloc_page(dmirror, false);
791 				rpage = BACKING_PAGE(dpage);
792 				rpage->zone_device_data = dmirror;
793 
794 				*dst = migrate_pfn(page_to_pfn(dpage)) | write;
795 				src_page = pfn_to_page(spfn + i);
796 
797 				if (spage)
798 					copy_highpage(rpage, src_page);
799 				else
800 					clear_highpage(rpage);
801 				src++;
802 				dst++;
803 				addr += PAGE_SIZE;
804 			}
805 			continue;
806 		}
807 
808 		rpage = BACKING_PAGE(dpage);
809 
810 		/*
811 		 * Normally, a device would use the page->zone_device_data to
812 		 * point to the mirror but here we use it to hold the page for
813 		 * the simulated device memory and that page holds the pointer
814 		 * to the mirror.
815 		 */
816 		rpage->zone_device_data = dmirror;
817 
818 		pr_debug("migrating from sys to dev pfn src: 0x%lx pfn dst: 0x%lx\n",
819 			 page_to_pfn(spage), page_to_pfn(dpage));
820 
821 		*dst = migrate_pfn(page_to_pfn(dpage)) | write;
822 
823 		if (is_large) {
824 			int i;
825 			struct folio *folio = page_folio(dpage);
826 			*dst |= MIGRATE_PFN_COMPOUND;
827 
828 			if (folio_test_large(folio)) {
829 				for (i = 0; i < folio_nr_pages(folio); i++) {
830 					struct page *dst_page =
831 						pfn_to_page(page_to_pfn(rpage) + i);
832 					struct page *src_page =
833 						pfn_to_page(page_to_pfn(spage) + i);
834 
835 					if (spage)
836 						copy_highpage(dst_page, src_page);
837 					else
838 						clear_highpage(dst_page);
839 					src++;
840 					dst++;
841 					addr += PAGE_SIZE;
842 				}
843 				continue;
844 			}
845 		}
846 
847 		if (spage)
848 			copy_highpage(rpage, spage);
849 		else
850 			clear_highpage(rpage);
851 
852 next:
853 		src++;
854 		dst++;
855 		addr += PAGE_SIZE;
856 	}
857 }
858 
859 static int dmirror_check_atomic(struct dmirror *dmirror, unsigned long start,
860 			     unsigned long end)
861 {
862 	unsigned long pfn;
863 
864 	for (pfn = start >> PAGE_SHIFT; pfn < (end >> PAGE_SHIFT); pfn++) {
865 		void *entry;
866 
867 		entry = xa_load(&dmirror->pt, pfn);
868 		if (xa_pointer_tag(entry) == DPT_XA_TAG_ATOMIC)
869 			return -EPERM;
870 	}
871 
872 	return 0;
873 }
874 
875 static int dmirror_atomic_map(unsigned long addr, struct page *page,
876 		struct dmirror *dmirror)
877 {
878 	void *entry;
879 
880 	/* Map the migrated pages into the device's page tables. */
881 	mutex_lock(&dmirror->mutex);
882 
883 	entry = xa_tag_pointer(page, DPT_XA_TAG_ATOMIC);
884 	entry = xa_store(&dmirror->pt, addr >> PAGE_SHIFT, entry, GFP_ATOMIC);
885 	if (xa_is_err(entry)) {
886 		mutex_unlock(&dmirror->mutex);
887 		return xa_err(entry);
888 	}
889 
890 	mutex_unlock(&dmirror->mutex);
891 	return 0;
892 }
893 
894 static int dmirror_migrate_finalize_and_map(struct migrate_vma *args,
895 					    struct dmirror *dmirror)
896 {
897 	unsigned long start = args->start;
898 	unsigned long end = args->end;
899 	const unsigned long *src = args->src;
900 	const unsigned long *dst = args->dst;
901 	unsigned long pfn;
902 	const unsigned long start_pfn = start >> PAGE_SHIFT;
903 	const unsigned long end_pfn = end >> PAGE_SHIFT;
904 
905 	/* Map the migrated pages into the device's page tables. */
906 	mutex_lock(&dmirror->mutex);
907 
908 	for (pfn = start_pfn; pfn < end_pfn; pfn++, src++, dst++) {
909 		struct page *dpage;
910 		void *entry;
911 		int nr, i;
912 		struct page *rpage;
913 
914 		if (!(*src & MIGRATE_PFN_MIGRATE))
915 			continue;
916 
917 		dpage = migrate_pfn_to_page(*dst);
918 		if (!dpage)
919 			continue;
920 
921 		if (*dst & MIGRATE_PFN_COMPOUND)
922 			nr = folio_nr_pages(page_folio(dpage));
923 		else
924 			nr = 1;
925 
926 		WARN_ON_ONCE(end_pfn < start_pfn + nr);
927 
928 		rpage = BACKING_PAGE(dpage);
929 		VM_WARN_ON(folio_nr_pages(page_folio(rpage)) != nr);
930 
931 		for (i = 0; i < nr; i++) {
932 			entry = folio_page(page_folio(rpage), i);
933 			if (*dst & MIGRATE_PFN_WRITE)
934 				entry = xa_tag_pointer(entry, DPT_XA_TAG_WRITE);
935 			entry = xa_store(&dmirror->pt, pfn + i, entry, GFP_ATOMIC);
936 			if (xa_is_err(entry)) {
937 				mutex_unlock(&dmirror->mutex);
938 				return xa_err(entry);
939 			}
940 		}
941 	}
942 
943 	mutex_unlock(&dmirror->mutex);
944 	return 0;
945 }
946 
947 static int dmirror_exclusive(struct dmirror *dmirror,
948 			     struct hmm_dmirror_cmd *cmd)
949 {
950 	unsigned long start, end, addr;
951 	unsigned long size = cmd->npages << PAGE_SHIFT;
952 	struct mm_struct *mm = dmirror->notifier.mm;
953 	struct dmirror_bounce bounce;
954 	int ret = 0;
955 
956 	start = cmd->addr;
957 	end = start + size;
958 	if (end < start)
959 		return -EINVAL;
960 
961 	/* Since the mm is for the mirrored process, get a reference first. */
962 	if (!mmget_not_zero(mm))
963 		return -EINVAL;
964 
965 	mmap_read_lock(mm);
966 	for (addr = start; !ret && addr < end; addr += PAGE_SIZE) {
967 		struct folio *folio;
968 		struct page *page;
969 
970 		page = make_device_exclusive(mm, addr, NULL, &folio);
971 		if (IS_ERR(page)) {
972 			ret = PTR_ERR(page);
973 			break;
974 		}
975 
976 		ret = dmirror_atomic_map(addr, page, dmirror);
977 		folio_unlock(folio);
978 		folio_put(folio);
979 	}
980 	mmap_read_unlock(mm);
981 	mmput(mm);
982 
983 	if (ret)
984 		return ret;
985 
986 	/* Return the migrated data for verification. */
987 	ret = dmirror_bounce_init(&bounce, start, size);
988 	if (ret)
989 		return ret;
990 	mutex_lock(&dmirror->mutex);
991 	ret = dmirror_do_read(dmirror, start, end, &bounce);
992 	mutex_unlock(&dmirror->mutex);
993 	if (ret == 0) {
994 		if (copy_to_user(u64_to_user_ptr(cmd->ptr), bounce.ptr,
995 				 bounce.size))
996 			ret = -EFAULT;
997 	}
998 
999 	cmd->cpages = bounce.cpages;
1000 	dmirror_bounce_fini(&bounce);
1001 	return ret;
1002 }
1003 
1004 static vm_fault_t dmirror_devmem_fault_alloc_and_copy(struct migrate_vma *args,
1005 						      struct dmirror *dmirror)
1006 {
1007 	const unsigned long *src = args->src;
1008 	unsigned long *dst = args->dst;
1009 	unsigned long start = args->start;
1010 	unsigned long end = args->end;
1011 	unsigned long addr;
1012 	unsigned int order = 0;
1013 	int i;
1014 
1015 	for (addr = start; addr < end; ) {
1016 		struct page *dpage, *spage;
1017 
1018 		spage = migrate_pfn_to_page(*src);
1019 		if (!spage || !(*src & MIGRATE_PFN_MIGRATE)) {
1020 			addr += PAGE_SIZE;
1021 			goto next;
1022 		}
1023 
1024 		if (WARN_ON(!is_device_private_page(spage) &&
1025 			    !is_device_coherent_page(spage))) {
1026 			addr += PAGE_SIZE;
1027 			goto next;
1028 		}
1029 
1030 		spage = BACKING_PAGE(spage);
1031 		order = folio_order(page_folio(spage));
1032 		if (order)
1033 			*dst = MIGRATE_PFN_COMPOUND;
1034 		if (*src & MIGRATE_PFN_WRITE)
1035 			*dst |= MIGRATE_PFN_WRITE;
1036 
1037 		if (dmirror->flags & HMM_DMIRROR_FLAG_FAIL_ALLOC) {
1038 			dmirror->flags &= ~HMM_DMIRROR_FLAG_FAIL_ALLOC;
1039 			*dst &= ~MIGRATE_PFN_COMPOUND;
1040 			dpage = NULL;
1041 		} else if (order) {
1042 			dpage = folio_page(vma_alloc_folio(GFP_HIGHUSER_MOVABLE,
1043 						order, args->vma, addr), 0);
1044 		} else {
1045 			dpage = alloc_page_vma(GFP_HIGHUSER_MOVABLE, args->vma, addr);
1046 		}
1047 
1048 		if (!dpage && !order)
1049 			return VM_FAULT_OOM;
1050 
1051 		pr_debug("migrating from sys to dev pfn src: 0x%lx pfn dst: 0x%lx\n",
1052 				page_to_pfn(spage), page_to_pfn(dpage));
1053 
1054 		if (dpage) {
1055 			lock_page(dpage);
1056 			*dst |= migrate_pfn(page_to_pfn(dpage));
1057 		}
1058 
1059 		for (i = 0; i < (1 << order); i++) {
1060 			struct page *src_page;
1061 			struct page *dst_page;
1062 
1063 			/* Try with smaller pages if large allocation fails */
1064 			if (!dpage && order) {
1065 				dpage = alloc_page_vma(GFP_HIGHUSER_MOVABLE, args->vma, addr);
1066 				if (!dpage) {
1067 					/* Unlock and free pages already allocated. */
1068 					while (i > 0) {
1069 						struct page *fpage;
1070 
1071 						fpage = migrate_pfn_to_page(dst[--i]);
1072 						unlock_page(fpage);
1073 						__free_page(fpage);
1074 					}
1075 					/* Clear remaining dst entries to avoid
1076 					 * migrate_vma_pages/finalize() using
1077 					 * uninitialized values.
1078 					 */
1079 					while (i < (1 << order)) {
1080 						dst[i] = 0;
1081 						i++;
1082 					}
1083 					return VM_FAULT_OOM;
1084 				}
1085 				lock_page(dpage);
1086 				dst[i] = migrate_pfn(page_to_pfn(dpage));
1087 				dst_page = pfn_to_page(page_to_pfn(dpage));
1088 				dpage = NULL; /* For the next iteration */
1089 			} else {
1090 				dst_page = pfn_to_page(page_to_pfn(dpage) + i);
1091 			}
1092 
1093 			src_page = pfn_to_page(page_to_pfn(spage) + i);
1094 
1095 			xa_erase(&dmirror->pt, addr >> PAGE_SHIFT);
1096 			addr += PAGE_SIZE;
1097 			copy_highpage(dst_page, src_page);
1098 		}
1099 next:
1100 		src += 1 << order;
1101 		dst += 1 << order;
1102 	}
1103 	return 0;
1104 }
1105 
1106 static unsigned long
1107 dmirror_successful_migrated_pages(struct migrate_vma *migrate)
1108 {
1109 	unsigned long cpages = 0;
1110 	unsigned long i;
1111 
1112 	for (i = 0; i < migrate->npages; i++) {
1113 		if (migrate->src[i] & MIGRATE_PFN_VALID &&
1114 		    migrate->src[i] & MIGRATE_PFN_MIGRATE)
1115 			cpages++;
1116 	}
1117 	return cpages;
1118 }
1119 
1120 static int dmirror_migrate_to_system(struct dmirror *dmirror,
1121 				     struct hmm_dmirror_cmd *cmd)
1122 {
1123 	unsigned long start, end, addr;
1124 	unsigned long size = cmd->npages << PAGE_SHIFT;
1125 	struct mm_struct *mm = dmirror->notifier.mm;
1126 	struct vm_area_struct *vma;
1127 	struct migrate_vma args = { 0 };
1128 	unsigned long next;
1129 	int ret;
1130 	unsigned long *src_pfns;
1131 	unsigned long *dst_pfns;
1132 
1133 	start = cmd->addr;
1134 	end = start + size;
1135 	if (end < start)
1136 		return -EINVAL;
1137 
1138 	/* Since the mm is for the mirrored process, get a reference first. */
1139 	if (!mmget_not_zero(mm))
1140 		return -EINVAL;
1141 
1142 	src_pfns = kvcalloc(PTRS_PER_PTE, sizeof(*src_pfns), GFP_KERNEL | __GFP_NOFAIL);
1143 	dst_pfns = kvcalloc(PTRS_PER_PTE, sizeof(*dst_pfns), GFP_KERNEL | __GFP_NOFAIL);
1144 
1145 	cmd->cpages = 0;
1146 	mmap_read_lock(mm);
1147 	for (addr = start; addr < end; addr = next) {
1148 		vma = vma_lookup(mm, addr);
1149 		if (!vma || !(vma->vm_flags & VM_READ)) {
1150 			ret = -EINVAL;
1151 			goto out;
1152 		}
1153 		next = min(end, addr + (PTRS_PER_PTE << PAGE_SHIFT));
1154 		if (next > vma->vm_end)
1155 			next = vma->vm_end;
1156 
1157 		args.vma = vma;
1158 		args.src = src_pfns;
1159 		args.dst = dst_pfns;
1160 		args.start = addr;
1161 		args.end = next;
1162 		args.pgmap_owner = dmirror->mdevice;
1163 		args.flags = dmirror_select_device(dmirror) | MIGRATE_VMA_SELECT_COMPOUND;
1164 
1165 		ret = migrate_vma_setup(&args);
1166 		if (ret)
1167 			goto out;
1168 
1169 		pr_debug("Migrating from device mem to sys mem\n");
1170 		if (dmirror_devmem_fault_alloc_and_copy(&args, dmirror)) {
1171 			migrate_vma_finalize(&args);
1172 			ret = -ENOMEM;
1173 			goto out;
1174 		}
1175 
1176 		migrate_vma_pages(&args);
1177 		cmd->cpages += dmirror_successful_migrated_pages(&args);
1178 		migrate_vma_finalize(&args);
1179 	}
1180 out:
1181 	mmap_read_unlock(mm);
1182 	mmput(mm);
1183 	kvfree(src_pfns);
1184 	kvfree(dst_pfns);
1185 
1186 	return ret;
1187 }
1188 
1189 static int dmirror_migrate_to_device(struct dmirror *dmirror,
1190 				struct hmm_dmirror_cmd *cmd)
1191 {
1192 	unsigned long start, end, addr;
1193 	unsigned long size = cmd->npages << PAGE_SHIFT;
1194 	struct mm_struct *mm = dmirror->notifier.mm;
1195 	struct vm_area_struct *vma;
1196 	struct dmirror_bounce bounce;
1197 	struct migrate_vma args = { 0 };
1198 	unsigned long next;
1199 	int ret;
1200 	unsigned long *src_pfns = NULL;
1201 	unsigned long *dst_pfns = NULL;
1202 
1203 	start = cmd->addr;
1204 	end = start + size;
1205 	if (end < start)
1206 		return -EINVAL;
1207 
1208 	/* Since the mm is for the mirrored process, get a reference first. */
1209 	if (!mmget_not_zero(mm))
1210 		return -EINVAL;
1211 
1212 	ret = -ENOMEM;
1213 	src_pfns = kvcalloc(PTRS_PER_PTE, sizeof(*src_pfns),
1214 			  GFP_KERNEL | __GFP_NOFAIL);
1215 	if (!src_pfns)
1216 		goto free_mem;
1217 
1218 	dst_pfns = kvcalloc(PTRS_PER_PTE, sizeof(*dst_pfns),
1219 			  GFP_KERNEL | __GFP_NOFAIL);
1220 	if (!dst_pfns)
1221 		goto free_mem;
1222 
1223 	ret = 0;
1224 	mmap_read_lock(mm);
1225 	for (addr = start; addr < end; addr = next) {
1226 		vma = vma_lookup(mm, addr);
1227 		if (!vma || !(vma->vm_flags & VM_READ)) {
1228 			ret = -EINVAL;
1229 			goto out;
1230 		}
1231 		next = min(end, addr + (PTRS_PER_PTE << PAGE_SHIFT));
1232 		if (next > vma->vm_end)
1233 			next = vma->vm_end;
1234 
1235 		args.vma = vma;
1236 		args.src = src_pfns;
1237 		args.dst = dst_pfns;
1238 		args.start = addr;
1239 		args.end = next;
1240 		args.pgmap_owner = dmirror->mdevice;
1241 		args.flags = MIGRATE_VMA_SELECT_SYSTEM |
1242 				MIGRATE_VMA_SELECT_COMPOUND;
1243 		ret = migrate_vma_setup(&args);
1244 		if (ret)
1245 			goto out;
1246 
1247 		pr_debug("Migrating from sys mem to device mem\n");
1248 		dmirror_migrate_alloc_and_copy(&args, dmirror);
1249 		migrate_vma_pages(&args);
1250 		dmirror_migrate_finalize_and_map(&args, dmirror);
1251 		migrate_vma_finalize(&args);
1252 	}
1253 	mmap_read_unlock(mm);
1254 	mmput(mm);
1255 
1256 	/*
1257 	 * Return the migrated data for verification.
1258 	 * Only for pages in device zone
1259 	 */
1260 	ret = dmirror_bounce_init(&bounce, start, size);
1261 	if (ret)
1262 		goto free_mem;
1263 	mutex_lock(&dmirror->mutex);
1264 	ret = dmirror_do_read(dmirror, start, end, &bounce);
1265 	mutex_unlock(&dmirror->mutex);
1266 	if (ret == 0) {
1267 		if (copy_to_user(u64_to_user_ptr(cmd->ptr), bounce.ptr,
1268 				 bounce.size))
1269 			ret = -EFAULT;
1270 	}
1271 	cmd->cpages = bounce.cpages;
1272 	dmirror_bounce_fini(&bounce);
1273 	goto free_mem;
1274 
1275 out:
1276 	mmap_read_unlock(mm);
1277 	mmput(mm);
1278 free_mem:
1279 	kvfree(src_pfns);
1280 	kvfree(dst_pfns);
1281 	return ret;
1282 }
1283 
1284 static void dmirror_mkentry(struct dmirror *dmirror, struct hmm_range *range,
1285 			    unsigned char *perm, unsigned long entry)
1286 {
1287 	struct page *page;
1288 
1289 	if (entry & HMM_PFN_ERROR) {
1290 		*perm = HMM_DMIRROR_PROT_ERROR;
1291 		return;
1292 	}
1293 	if (!(entry & HMM_PFN_VALID)) {
1294 		*perm = HMM_DMIRROR_PROT_NONE;
1295 		return;
1296 	}
1297 
1298 	page = hmm_pfn_to_page(entry);
1299 	if (is_device_private_page(page)) {
1300 		/* Is the page migrated to this device or some other? */
1301 		if (dmirror->mdevice == dmirror_page_to_device(page))
1302 			*perm = HMM_DMIRROR_PROT_DEV_PRIVATE_LOCAL;
1303 		else
1304 			*perm = HMM_DMIRROR_PROT_DEV_PRIVATE_REMOTE;
1305 	} else if (is_device_coherent_page(page)) {
1306 		/* Is the page migrated to this device or some other? */
1307 		if (dmirror->mdevice == dmirror_page_to_device(page))
1308 			*perm = HMM_DMIRROR_PROT_DEV_COHERENT_LOCAL;
1309 		else
1310 			*perm = HMM_DMIRROR_PROT_DEV_COHERENT_REMOTE;
1311 	} else if (is_zero_pfn(page_to_pfn(page)))
1312 		*perm = HMM_DMIRROR_PROT_ZERO;
1313 	else
1314 		*perm = HMM_DMIRROR_PROT_NONE;
1315 	if (entry & HMM_PFN_WRITE)
1316 		*perm |= HMM_DMIRROR_PROT_WRITE;
1317 	else
1318 		*perm |= HMM_DMIRROR_PROT_READ;
1319 	if (hmm_pfn_to_map_order(entry) + PAGE_SHIFT == PMD_SHIFT)
1320 		*perm |= HMM_DMIRROR_PROT_PMD;
1321 	else if (hmm_pfn_to_map_order(entry) + PAGE_SHIFT == PUD_SHIFT)
1322 		*perm |= HMM_DMIRROR_PROT_PUD;
1323 }
1324 
1325 static bool dmirror_snapshot_invalidate(struct mmu_interval_notifier *mni,
1326 				const struct mmu_notifier_range *range,
1327 				unsigned long cur_seq)
1328 {
1329 	struct dmirror_interval *dmi =
1330 		container_of(mni, struct dmirror_interval, notifier);
1331 	struct dmirror *dmirror = dmi->dmirror;
1332 
1333 	if (mmu_notifier_range_blockable(range))
1334 		mutex_lock(&dmirror->mutex);
1335 	else if (!mutex_trylock(&dmirror->mutex))
1336 		return false;
1337 
1338 	/*
1339 	 * Snapshots only need to set the sequence number since any
1340 	 * invalidation in the interval invalidates the whole snapshot.
1341 	 */
1342 	mmu_interval_set_seq(mni, cur_seq);
1343 
1344 	mutex_unlock(&dmirror->mutex);
1345 	return true;
1346 }
1347 
1348 static const struct mmu_interval_notifier_ops dmirror_mrn_ops = {
1349 	.invalidate = dmirror_snapshot_invalidate,
1350 };
1351 
1352 static int dmirror_range_snapshot(struct dmirror *dmirror,
1353 				  struct hmm_range *range,
1354 				  unsigned char *perm)
1355 {
1356 	struct mm_struct *mm = dmirror->notifier.mm;
1357 	struct dmirror_interval notifier;
1358 	unsigned long timeout =
1359 		jiffies + msecs_to_jiffies(HMM_RANGE_DEFAULT_TIMEOUT);
1360 	unsigned long i;
1361 	unsigned long n;
1362 	int ret = 0;
1363 
1364 	notifier.dmirror = dmirror;
1365 	range->notifier = &notifier.notifier;
1366 
1367 	ret = mmu_interval_notifier_insert(range->notifier, mm,
1368 			range->start, range->end - range->start,
1369 			&dmirror_mrn_ops);
1370 	if (ret)
1371 		return ret;
1372 
1373 	while (true) {
1374 		if (time_after(jiffies, timeout)) {
1375 			ret = -EBUSY;
1376 			goto out;
1377 		}
1378 
1379 		range->notifier_seq = mmu_interval_read_begin(range->notifier);
1380 
1381 		mmap_read_lock(mm);
1382 		ret = hmm_range_fault(range);
1383 		mmap_read_unlock(mm);
1384 		if (ret) {
1385 			if (ret == -EBUSY)
1386 				continue;
1387 			goto out;
1388 		}
1389 
1390 		mutex_lock(&dmirror->mutex);
1391 		if (mmu_interval_read_retry(range->notifier,
1392 					    range->notifier_seq)) {
1393 			mutex_unlock(&dmirror->mutex);
1394 			continue;
1395 		}
1396 		break;
1397 	}
1398 
1399 	n = (range->end - range->start) >> PAGE_SHIFT;
1400 	for (i = 0; i < n; i++)
1401 		dmirror_mkentry(dmirror, range, perm + i, range->hmm_pfns[i]);
1402 
1403 	mutex_unlock(&dmirror->mutex);
1404 out:
1405 	mmu_interval_notifier_remove(range->notifier);
1406 	return ret;
1407 }
1408 
1409 static int dmirror_snapshot(struct dmirror *dmirror,
1410 			    struct hmm_dmirror_cmd *cmd)
1411 {
1412 	struct mm_struct *mm = dmirror->notifier.mm;
1413 	unsigned long start, end;
1414 	unsigned long size = cmd->npages << PAGE_SHIFT;
1415 	unsigned long addr;
1416 	unsigned long next;
1417 	unsigned long pfns[32];
1418 	unsigned char perm[32];
1419 	char __user *uptr;
1420 	struct hmm_range range = {
1421 		.hmm_pfns = pfns,
1422 		.dev_private_owner = dmirror->mdevice,
1423 	};
1424 	int ret = 0;
1425 
1426 	start = cmd->addr;
1427 	end = start + size;
1428 	if (end < start)
1429 		return -EINVAL;
1430 
1431 	/* Since the mm is for the mirrored process, get a reference first. */
1432 	if (!mmget_not_zero(mm))
1433 		return -EINVAL;
1434 
1435 	/*
1436 	 * Register a temporary notifier to detect invalidations even if it
1437 	 * overlaps with other mmu_interval_notifiers.
1438 	 */
1439 	uptr = u64_to_user_ptr(cmd->ptr);
1440 	for (addr = start; addr < end; addr = next) {
1441 		unsigned long n;
1442 
1443 		next = min(addr + (ARRAY_SIZE(pfns) << PAGE_SHIFT), end);
1444 		range.start = addr;
1445 		range.end = next;
1446 
1447 		ret = dmirror_range_snapshot(dmirror, &range, perm);
1448 		if (ret)
1449 			break;
1450 
1451 		n = (range.end - range.start) >> PAGE_SHIFT;
1452 		if (copy_to_user(uptr, perm, n)) {
1453 			ret = -EFAULT;
1454 			break;
1455 		}
1456 
1457 		cmd->cpages += n;
1458 		uptr += n;
1459 	}
1460 	mmput(mm);
1461 
1462 	return ret;
1463 }
1464 
1465 /* Removes free pages from the free list so they can't be re-allocated */
1466 static void dmirror_remove_free_pages(struct dmirror_chunk *devmem)
1467 {
1468 	struct dmirror_device *mdevice = devmem->mdevice;
1469 	struct page *page;
1470 	struct folio *folio;
1471 
1472 
1473 	for (folio = mdevice->free_folios; folio; folio = folio_zone_device_data(folio))
1474 		if (dmirror_page_to_chunk(folio_page(folio, 0)) == devmem)
1475 			mdevice->free_folios = folio_zone_device_data(folio);
1476 	for (page = mdevice->free_pages; page; page = page->zone_device_data)
1477 		if (dmirror_page_to_chunk(page) == devmem)
1478 			mdevice->free_pages = page->zone_device_data;
1479 }
1480 
1481 static void dmirror_device_remove_chunks(struct dmirror_device *mdevice)
1482 {
1483 	unsigned int i;
1484 
1485 	mutex_lock(&mdevice->devmem_lock);
1486 	if (mdevice->devmem_chunks) {
1487 		for (i = 0; i < mdevice->devmem_count; i++) {
1488 			struct dmirror_chunk *devmem =
1489 				mdevice->devmem_chunks[i];
1490 
1491 			spin_lock(&mdevice->lock);
1492 			devmem->remove = true;
1493 			dmirror_remove_free_pages(devmem);
1494 			spin_unlock(&mdevice->lock);
1495 
1496 			dmirror_device_evict_chunk(devmem);
1497 			memunmap_pages(&devmem->pagemap);
1498 			if (devmem->pagemap.type == MEMORY_DEVICE_PRIVATE)
1499 				release_mem_region(devmem->pagemap.range.start,
1500 						   range_len(&devmem->pagemap.range));
1501 			kfree(devmem);
1502 		}
1503 		mdevice->devmem_count = 0;
1504 		mdevice->devmem_capacity = 0;
1505 		mdevice->free_pages = NULL;
1506 		mdevice->free_folios = NULL;
1507 		kfree(mdevice->devmem_chunks);
1508 		mdevice->devmem_chunks = NULL;
1509 	}
1510 	mutex_unlock(&mdevice->devmem_lock);
1511 }
1512 
1513 static long dmirror_fops_unlocked_ioctl(struct file *filp,
1514 					unsigned int command,
1515 					unsigned long arg)
1516 {
1517 	void __user *uarg = (void __user *)arg;
1518 	struct hmm_dmirror_cmd cmd;
1519 	struct dmirror *dmirror;
1520 	int ret;
1521 
1522 	dmirror = filp->private_data;
1523 	if (!dmirror)
1524 		return -EINVAL;
1525 
1526 	if (copy_from_user(&cmd, uarg, sizeof(cmd)))
1527 		return -EFAULT;
1528 
1529 	if (cmd.addr & ~PAGE_MASK)
1530 		return -EINVAL;
1531 	if (cmd.addr >= (cmd.addr + (cmd.npages << PAGE_SHIFT)))
1532 		return -EINVAL;
1533 
1534 	cmd.cpages = 0;
1535 	cmd.faults = 0;
1536 
1537 	switch (command) {
1538 	case HMM_DMIRROR_READ:
1539 		ret = dmirror_read(dmirror, &cmd);
1540 		break;
1541 
1542 	case HMM_DMIRROR_WRITE:
1543 		ret = dmirror_write(dmirror, &cmd);
1544 		break;
1545 
1546 	case HMM_DMIRROR_MIGRATE_TO_DEV:
1547 		ret = dmirror_migrate_to_device(dmirror, &cmd);
1548 		break;
1549 
1550 	case HMM_DMIRROR_MIGRATE_TO_SYS:
1551 		ret = dmirror_migrate_to_system(dmirror, &cmd);
1552 		break;
1553 
1554 	case HMM_DMIRROR_EXCLUSIVE:
1555 		ret = dmirror_exclusive(dmirror, &cmd);
1556 		break;
1557 
1558 	case HMM_DMIRROR_CHECK_EXCLUSIVE:
1559 		ret = dmirror_check_atomic(dmirror, cmd.addr,
1560 					cmd.addr + (cmd.npages << PAGE_SHIFT));
1561 		break;
1562 
1563 	case HMM_DMIRROR_SNAPSHOT:
1564 		ret = dmirror_snapshot(dmirror, &cmd);
1565 		break;
1566 
1567 	case HMM_DMIRROR_RELEASE:
1568 		dmirror_device_remove_chunks(dmirror->mdevice);
1569 		ret = 0;
1570 		break;
1571 	case HMM_DMIRROR_FLAGS:
1572 		dmirror->flags = cmd.npages;
1573 		ret = 0;
1574 		break;
1575 
1576 	default:
1577 		return -EINVAL;
1578 	}
1579 	if (ret)
1580 		return ret;
1581 
1582 	if (copy_to_user(uarg, &cmd, sizeof(cmd)))
1583 		return -EFAULT;
1584 
1585 	return 0;
1586 }
1587 
1588 static int dmirror_fops_mmap(struct file *file, struct vm_area_struct *vma)
1589 {
1590 	unsigned long addr;
1591 
1592 	for (addr = vma->vm_start; addr < vma->vm_end; addr += PAGE_SIZE) {
1593 		struct page *page;
1594 		int ret;
1595 
1596 		page = alloc_page(GFP_KERNEL | __GFP_ZERO);
1597 		if (!page)
1598 			return -ENOMEM;
1599 
1600 		ret = vm_insert_page(vma, addr, page);
1601 		if (ret) {
1602 			__free_page(page);
1603 			return ret;
1604 		}
1605 		put_page(page);
1606 	}
1607 
1608 	return 0;
1609 }
1610 
1611 static const struct file_operations dmirror_fops = {
1612 	.open		= dmirror_fops_open,
1613 	.release	= dmirror_fops_release,
1614 	.mmap		= dmirror_fops_mmap,
1615 	.unlocked_ioctl = dmirror_fops_unlocked_ioctl,
1616 	.llseek		= default_llseek,
1617 	.owner		= THIS_MODULE,
1618 };
1619 
1620 static void dmirror_devmem_free(struct folio *folio)
1621 {
1622 	struct page *page = &folio->page;
1623 	struct page *rpage = BACKING_PAGE(page);
1624 	struct dmirror_device *mdevice;
1625 	struct folio *rfolio = page_folio(rpage);
1626 	unsigned int order = folio_order(rfolio);
1627 
1628 	if (rpage != page) {
1629 		if (order)
1630 			__free_pages(rpage, order);
1631 		else
1632 			__free_page(rpage);
1633 		rpage = NULL;
1634 	}
1635 
1636 	mdevice = dmirror_page_to_device(page);
1637 	spin_lock(&mdevice->lock);
1638 
1639 	/* Return page to our allocator if not freeing the chunk */
1640 	if (!dmirror_page_to_chunk(page)->remove) {
1641 		mdevice->cfree += 1 << order;
1642 		if (order) {
1643 			page->zone_device_data = mdevice->free_folios;
1644 			mdevice->free_folios = page_folio(page);
1645 		} else {
1646 			page->zone_device_data = mdevice->free_pages;
1647 			mdevice->free_pages = page;
1648 		}
1649 	}
1650 	spin_unlock(&mdevice->lock);
1651 }
1652 
1653 static vm_fault_t dmirror_devmem_fault(struct vm_fault *vmf)
1654 {
1655 	struct migrate_vma args = { 0 };
1656 	struct page *rpage;
1657 	struct dmirror *dmirror;
1658 	vm_fault_t ret = 0;
1659 	unsigned int order, nr;
1660 
1661 	/*
1662 	 * Normally, a device would use the page->zone_device_data to point to
1663 	 * the mirror but here we use it to hold the page for the simulated
1664 	 * device memory and that page holds the pointer to the mirror.
1665 	 */
1666 	rpage = folio_zone_device_data(page_folio(vmf->page));
1667 	dmirror = rpage->zone_device_data;
1668 
1669 	/* FIXME demonstrate how we can adjust migrate range */
1670 	order = folio_order(page_folio(vmf->page));
1671 	nr = 1 << order;
1672 
1673 	/*
1674 	 * When folios are partially mapped, we can't rely on the folio
1675 	 * order of vmf->page as the folio might not be fully split yet
1676 	 */
1677 	if (vmf->pte) {
1678 		order = 0;
1679 		nr = 1;
1680 	}
1681 
1682 	/*
1683 	 * Consider a per-cpu cache of src and dst pfns, but with
1684 	 * large number of cpus that might not scale well.
1685 	 */
1686 	args.start = ALIGN_DOWN(vmf->address, (PAGE_SIZE << order));
1687 	args.vma = vmf->vma;
1688 	args.end = args.start + (PAGE_SIZE << order);
1689 
1690 	nr = (args.end - args.start) >> PAGE_SHIFT;
1691 	args.src = kcalloc(nr, sizeof(unsigned long), GFP_KERNEL);
1692 	args.dst = kcalloc(nr, sizeof(unsigned long), GFP_KERNEL);
1693 	args.pgmap_owner = dmirror->mdevice;
1694 	args.flags = dmirror_select_device(dmirror);
1695 	args.fault_page = vmf->page;
1696 
1697 	if (!args.src || !args.dst) {
1698 		ret = VM_FAULT_OOM;
1699 		goto err;
1700 	}
1701 
1702 	if (order)
1703 		args.flags |= MIGRATE_VMA_SELECT_COMPOUND;
1704 
1705 	/*
1706 	 * In practice migrate_vma_setup() should never fail unless the
1707 	 * test is wrong as it just tests some static VMA properties.
1708 	 */
1709 	if (migrate_vma_setup(&args)) {
1710 		ret = VM_FAULT_SIGBUS;
1711 		goto err;
1712 	}
1713 
1714 	ret = dmirror_devmem_fault_alloc_and_copy(&args, dmirror);
1715 	if (ret) {
1716 		migrate_vma_finalize(&args);
1717 		goto err;
1718 	}
1719 	migrate_vma_pages(&args);
1720 	/*
1721 	 * No device finalize step is needed since
1722 	 * dmirror_devmem_fault_alloc_and_copy() will have already
1723 	 * invalidated the device page table.
1724 	 */
1725 	migrate_vma_finalize(&args);
1726 err:
1727 	kfree(args.src);
1728 	kfree(args.dst);
1729 	return ret;
1730 }
1731 
1732 static void dmirror_devmem_folio_split(struct folio *head, struct folio *tail)
1733 {
1734 	struct page *rpage = BACKING_PAGE(folio_page(head, 0));
1735 	struct page *rpage_tail;
1736 	struct folio *rfolio;
1737 	unsigned long offset = 0;
1738 
1739 	if (!rpage) {
1740 		tail->page.zone_device_data = NULL;
1741 		return;
1742 	}
1743 
1744 	rfolio = page_folio(rpage);
1745 
1746 	if (tail == NULL) {
1747 		folio_reset_order(rfolio);
1748 		rfolio->mapping = NULL;
1749 		folio_set_count(rfolio, 1);
1750 		return;
1751 	}
1752 
1753 	offset = folio_pfn(tail) - folio_pfn(head);
1754 
1755 	rpage_tail = folio_page(rfolio, offset);
1756 	tail->page.zone_device_data = rpage_tail;
1757 	rpage_tail->zone_device_data = rpage->zone_device_data;
1758 	clear_compound_head(rpage_tail);
1759 	rpage_tail->mapping = NULL;
1760 
1761 	folio_page(tail, 0)->mapping = folio_page(head, 0)->mapping;
1762 	tail->pgmap = head->pgmap;
1763 	folio_set_count(page_folio(rpage_tail), 1);
1764 }
1765 
1766 static const struct dev_pagemap_ops dmirror_devmem_ops = {
1767 	.folio_free	= dmirror_devmem_free,
1768 	.migrate_to_ram	= dmirror_devmem_fault,
1769 	.folio_split	= dmirror_devmem_folio_split,
1770 };
1771 
1772 static void dmirror_device_release(struct device *dev)
1773 {
1774 	struct dmirror_device *mdevice = container_of(dev, struct dmirror_device, device);
1775 
1776 	dmirror_device_remove_chunks(mdevice);
1777 }
1778 
1779 static int dmirror_device_init(struct dmirror_device *mdevice, int id)
1780 {
1781 	dev_t dev;
1782 	int ret;
1783 
1784 	dev = MKDEV(MAJOR(dmirror_dev), id);
1785 	mutex_init(&mdevice->devmem_lock);
1786 	spin_lock_init(&mdevice->lock);
1787 
1788 	cdev_init(&mdevice->cdevice, &dmirror_fops);
1789 	mdevice->cdevice.owner = THIS_MODULE;
1790 	mdevice->device.release = dmirror_device_release;
1791 
1792 	device_initialize(&mdevice->device);
1793 	mdevice->device.devt = dev;
1794 
1795 	ret = dev_set_name(&mdevice->device, "hmm_dmirror%u", id);
1796 	if (ret)
1797 		goto put_device;
1798 
1799 	/* Build a list of free ZONE_DEVICE struct pages */
1800 	ret = dmirror_allocate_chunk(mdevice, NULL, false);
1801 	if (ret)
1802 		goto put_device;
1803 
1804 	ret = cdev_device_add(&mdevice->cdevice, &mdevice->device);
1805 	if (ret)
1806 		goto put_device;
1807 
1808 	return 0;
1809 
1810 put_device:
1811 	put_device(&mdevice->device);
1812 	return ret;
1813 }
1814 
1815 static void dmirror_device_remove(struct dmirror_device *mdevice)
1816 {
1817 	cdev_device_del(&mdevice->cdevice, &mdevice->device);
1818 	put_device(&mdevice->device);
1819 }
1820 
1821 static int __init hmm_dmirror_init(void)
1822 {
1823 	int ret;
1824 	int id = 0;
1825 	int ndevices = 0;
1826 
1827 	ret = alloc_chrdev_region(&dmirror_dev, 0, DMIRROR_NDEVICES,
1828 				  "HMM_DMIRROR");
1829 	if (ret)
1830 		goto err_unreg;
1831 
1832 	memset(dmirror_devices, 0, DMIRROR_NDEVICES * sizeof(dmirror_devices[0]));
1833 	dmirror_devices[ndevices++].zone_device_type =
1834 				HMM_DMIRROR_MEMORY_DEVICE_PRIVATE;
1835 	dmirror_devices[ndevices++].zone_device_type =
1836 				HMM_DMIRROR_MEMORY_DEVICE_PRIVATE;
1837 	if (spm_addr_dev0 && spm_addr_dev1) {
1838 		dmirror_devices[ndevices++].zone_device_type =
1839 					HMM_DMIRROR_MEMORY_DEVICE_COHERENT;
1840 		dmirror_devices[ndevices++].zone_device_type =
1841 					HMM_DMIRROR_MEMORY_DEVICE_COHERENT;
1842 	}
1843 	for (id = 0; id < ndevices; id++) {
1844 		ret = dmirror_device_init(dmirror_devices + id, id);
1845 		if (ret)
1846 			goto err_chrdev;
1847 	}
1848 
1849 	pr_info("HMM test module loaded. This is only for testing HMM.\n");
1850 	return 0;
1851 
1852 err_chrdev:
1853 	while (--id >= 0)
1854 		dmirror_device_remove(dmirror_devices + id);
1855 	unregister_chrdev_region(dmirror_dev, DMIRROR_NDEVICES);
1856 err_unreg:
1857 	return ret;
1858 }
1859 
1860 static void __exit hmm_dmirror_exit(void)
1861 {
1862 	int id;
1863 
1864 	for (id = 0; id < DMIRROR_NDEVICES; id++)
1865 		if (dmirror_devices[id].zone_device_type)
1866 			dmirror_device_remove(dmirror_devices + id);
1867 	unregister_chrdev_region(dmirror_dev, DMIRROR_NDEVICES);
1868 }
1869 
1870 module_init(hmm_dmirror_init);
1871 module_exit(hmm_dmirror_exit);
1872 MODULE_DESCRIPTION("HMM (Heterogeneous Memory Management) test module");
1873 MODULE_LICENSE("GPL");
1874