xref: /linux/lib/test_hmm.c (revision bcb6058a4b4596f12065276faeb9363dc4887ea9)
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * This is a module to test the HMM (Heterogeneous Memory Management)
4  * mirror and zone device private memory migration APIs of the kernel.
5  * Userspace programs can register with the driver to mirror their own address
6  * space and can use the device to read/write any valid virtual address.
7  */
8 #include <linux/init.h>
9 #include <linux/fs.h>
10 #include <linux/mm.h>
11 #include <linux/module.h>
12 #include <linux/kernel.h>
13 #include <linux/cdev.h>
14 #include <linux/device.h>
15 #include <linux/memremap.h>
16 #include <linux/mutex.h>
17 #include <linux/rwsem.h>
18 #include <linux/sched.h>
19 #include <linux/slab.h>
20 #include <linux/highmem.h>
21 #include <linux/delay.h>
22 #include <linux/pagemap.h>
23 #include <linux/hmm.h>
24 #include <linux/vmalloc.h>
25 #include <linux/swap.h>
26 #include <linux/swapops.h>
27 #include <linux/sched/mm.h>
28 #include <linux/platform_device.h>
29 #include <linux/rmap.h>
30 #include <linux/mmu_notifier.h>
31 #include <linux/migrate.h>
32 
33 #include "test_hmm_uapi.h"
34 
35 #define DMIRROR_NDEVICES		4
36 #define DMIRROR_RANGE_FAULT_TIMEOUT	1000
37 #define DEVMEM_CHUNK_SIZE		(256 * 1024 * 1024U)
38 #define DEVMEM_CHUNKS_RESERVE		16
39 
40 /*
41  * For device_private pages, dpage is just a dummy struct page
42  * representing a piece of device memory. dmirror_devmem_alloc_page
43  * allocates a real system memory page as backing storage to fake a
44  * real device. zone_device_data points to that backing page. But
45  * for device_coherent memory, the struct page represents real
46  * physical CPU-accessible memory that we can use directly.
47  */
48 #define BACKING_PAGE(page) (is_device_private_page((page)) ? \
49 			   (page)->zone_device_data : (page))
50 
51 static unsigned long spm_addr_dev0;
52 module_param(spm_addr_dev0, long, 0644);
53 MODULE_PARM_DESC(spm_addr_dev0,
54 		"Specify start address for SPM (special purpose memory) used for device 0. By setting this Coherent device type will be used. Make sure spm_addr_dev1 is set too. Minimum SPM size should be DEVMEM_CHUNK_SIZE.");
55 
56 static unsigned long spm_addr_dev1;
57 module_param(spm_addr_dev1, long, 0644);
58 MODULE_PARM_DESC(spm_addr_dev1,
59 		"Specify start address for SPM (special purpose memory) used for device 1. By setting this Coherent device type will be used. Make sure spm_addr_dev0 is set too. Minimum SPM size should be DEVMEM_CHUNK_SIZE.");
60 
61 static const struct dev_pagemap_ops dmirror_devmem_ops;
62 static const struct mmu_interval_notifier_ops dmirror_min_ops;
63 static dev_t dmirror_dev;
64 
65 struct dmirror_device;
66 
67 struct dmirror_bounce {
68 	void			*ptr;
69 	unsigned long		size;
70 	unsigned long		addr;
71 	unsigned long		cpages;
72 };
73 
74 #define DPT_XA_TAG_ATOMIC 1UL
75 #define DPT_XA_TAG_WRITE 3UL
76 
77 /*
78  * Data structure to track address ranges and register for mmu interval
79  * notifier updates.
80  */
81 struct dmirror_interval {
82 	struct mmu_interval_notifier	notifier;
83 	struct dmirror			*dmirror;
84 };
85 
86 /*
87  * Data attached to the open device file.
88  * Note that it might be shared after a fork().
89  */
90 struct dmirror {
91 	struct dmirror_device		*mdevice;
92 	struct xarray			pt;
93 	struct mmu_interval_notifier	notifier;
94 	struct mutex			mutex;
95 	__u64			flags;
96 };
97 
98 /*
99  * ZONE_DEVICE pages for migration and simulating device memory.
100  */
101 struct dmirror_chunk {
102 	struct dev_pagemap	pagemap;
103 	struct dmirror_device	*mdevice;
104 	bool remove;
105 };
106 
107 /*
108  * Per device data.
109  */
110 struct dmirror_device {
111 	struct cdev		cdevice;
112 	unsigned int            zone_device_type;
113 	struct device		device;
114 
115 	unsigned int		devmem_capacity;
116 	unsigned int		devmem_count;
117 	struct dmirror_chunk	**devmem_chunks;
118 	struct mutex		devmem_lock;	/* protects the above */
119 
120 	unsigned long		calloc;
121 	unsigned long		cfree;
122 	struct page		*free_pages;
123 	struct folio		*free_folios;
124 	spinlock_t		lock;		/* protects the above */
125 };
126 
127 static struct dmirror_device dmirror_devices[DMIRROR_NDEVICES];
128 
dmirror_bounce_init(struct dmirror_bounce * bounce,unsigned long addr,unsigned long size)129 static int dmirror_bounce_init(struct dmirror_bounce *bounce,
130 			       unsigned long addr,
131 			       unsigned long size)
132 {
133 	bounce->addr = addr;
134 	bounce->size = size;
135 	bounce->cpages = 0;
136 	bounce->ptr = vmalloc(size);
137 	if (!bounce->ptr)
138 		return -ENOMEM;
139 	return 0;
140 }
141 
dmirror_is_private_zone(struct dmirror_device * mdevice)142 static bool dmirror_is_private_zone(struct dmirror_device *mdevice)
143 {
144 	return (mdevice->zone_device_type ==
145 		HMM_DMIRROR_MEMORY_DEVICE_PRIVATE);
146 }
147 
148 static enum migrate_vma_direction
dmirror_select_device(struct dmirror * dmirror)149 dmirror_select_device(struct dmirror *dmirror)
150 {
151 	return (dmirror->mdevice->zone_device_type ==
152 		HMM_DMIRROR_MEMORY_DEVICE_PRIVATE) ?
153 		MIGRATE_VMA_SELECT_DEVICE_PRIVATE :
154 		MIGRATE_VMA_SELECT_DEVICE_COHERENT;
155 }
156 
dmirror_bounce_fini(struct dmirror_bounce * bounce)157 static void dmirror_bounce_fini(struct dmirror_bounce *bounce)
158 {
159 	vfree(bounce->ptr);
160 }
161 
dmirror_fops_open(struct inode * inode,struct file * filp)162 static int dmirror_fops_open(struct inode *inode, struct file *filp)
163 {
164 	struct cdev *cdev = inode->i_cdev;
165 	struct dmirror *dmirror;
166 	int ret;
167 
168 	/* Mirror this process address space */
169 	dmirror = kzalloc(sizeof(*dmirror), GFP_KERNEL);
170 	if (dmirror == NULL)
171 		return -ENOMEM;
172 
173 	dmirror->mdevice = container_of(cdev, struct dmirror_device, cdevice);
174 	mutex_init(&dmirror->mutex);
175 	xa_init(&dmirror->pt);
176 
177 	ret = mmu_interval_notifier_insert(&dmirror->notifier, current->mm,
178 				0, ULONG_MAX & PAGE_MASK, &dmirror_min_ops);
179 	if (ret) {
180 		kfree(dmirror);
181 		return ret;
182 	}
183 
184 	filp->private_data = dmirror;
185 	return 0;
186 }
187 
dmirror_fops_release(struct inode * inode,struct file * filp)188 static int dmirror_fops_release(struct inode *inode, struct file *filp)
189 {
190 	struct dmirror *dmirror = filp->private_data;
191 
192 	mmu_interval_notifier_remove(&dmirror->notifier);
193 	xa_destroy(&dmirror->pt);
194 	kfree(dmirror);
195 	return 0;
196 }
197 
dmirror_page_to_chunk(struct page * page)198 static struct dmirror_chunk *dmirror_page_to_chunk(struct page *page)
199 {
200 	return container_of(page_pgmap(page), struct dmirror_chunk,
201 			    pagemap);
202 }
203 
dmirror_page_to_device(struct page * page)204 static struct dmirror_device *dmirror_page_to_device(struct page *page)
205 
206 {
207 	return dmirror_page_to_chunk(page)->mdevice;
208 }
209 
dmirror_do_fault(struct dmirror * dmirror,struct hmm_range * range)210 static int dmirror_do_fault(struct dmirror *dmirror, struct hmm_range *range)
211 {
212 	unsigned long *pfns = range->hmm_pfns;
213 	unsigned long pfn;
214 
215 	for (pfn = (range->start >> PAGE_SHIFT);
216 	     pfn < (range->end >> PAGE_SHIFT);
217 	     pfn++, pfns++) {
218 		struct page *page;
219 		void *entry;
220 
221 		/*
222 		 * Since we asked for hmm_range_fault() to populate pages,
223 		 * it shouldn't return an error entry on success.
224 		 */
225 		WARN_ON(*pfns & HMM_PFN_ERROR);
226 		WARN_ON(!(*pfns & HMM_PFN_VALID));
227 
228 		page = hmm_pfn_to_page(*pfns);
229 		WARN_ON(!page);
230 
231 		entry = page;
232 		if (*pfns & HMM_PFN_WRITE)
233 			entry = xa_tag_pointer(entry, DPT_XA_TAG_WRITE);
234 		else if (WARN_ON(range->default_flags & HMM_PFN_WRITE))
235 			return -EFAULT;
236 		entry = xa_store(&dmirror->pt, pfn, entry, GFP_ATOMIC);
237 		if (xa_is_err(entry))
238 			return xa_err(entry);
239 	}
240 
241 	return 0;
242 }
243 
dmirror_do_update(struct dmirror * dmirror,unsigned long start,unsigned long end)244 static void dmirror_do_update(struct dmirror *dmirror, unsigned long start,
245 			      unsigned long end)
246 {
247 	unsigned long pfn;
248 	void *entry;
249 
250 	/*
251 	 * The XArray doesn't hold references to pages since it relies on
252 	 * the mmu notifier to clear page pointers when they become stale.
253 	 * Therefore, it is OK to just clear the entry.
254 	 */
255 	xa_for_each_range(&dmirror->pt, pfn, entry, start >> PAGE_SHIFT,
256 			  end >> PAGE_SHIFT)
257 		xa_erase(&dmirror->pt, pfn);
258 }
259 
dmirror_interval_invalidate(struct mmu_interval_notifier * mni,const struct mmu_notifier_range * range,unsigned long cur_seq)260 static bool dmirror_interval_invalidate(struct mmu_interval_notifier *mni,
261 				const struct mmu_notifier_range *range,
262 				unsigned long cur_seq)
263 {
264 	struct dmirror *dmirror = container_of(mni, struct dmirror, notifier);
265 
266 	/*
267 	 * Ignore invalidation callbacks for device private pages since
268 	 * the invalidation is handled as part of the migration process.
269 	 */
270 	if (range->event == MMU_NOTIFY_MIGRATE &&
271 	    range->owner == dmirror->mdevice)
272 		return true;
273 
274 	if (mmu_notifier_range_blockable(range))
275 		mutex_lock(&dmirror->mutex);
276 	else if (!mutex_trylock(&dmirror->mutex))
277 		return false;
278 
279 	mmu_interval_set_seq(mni, cur_seq);
280 	dmirror_do_update(dmirror, range->start, range->end);
281 
282 	mutex_unlock(&dmirror->mutex);
283 	return true;
284 }
285 
286 static const struct mmu_interval_notifier_ops dmirror_min_ops = {
287 	.invalidate = dmirror_interval_invalidate,
288 };
289 
dmirror_range_fault(struct dmirror * dmirror,struct hmm_range * range)290 static int dmirror_range_fault(struct dmirror *dmirror,
291 				struct hmm_range *range)
292 {
293 	struct mm_struct *mm = dmirror->notifier.mm;
294 	unsigned long timeout =
295 		jiffies + msecs_to_jiffies(HMM_RANGE_DEFAULT_TIMEOUT);
296 	int ret;
297 
298 	while (true) {
299 		if (time_after(jiffies, timeout)) {
300 			ret = -EBUSY;
301 			goto out;
302 		}
303 
304 		range->notifier_seq = mmu_interval_read_begin(range->notifier);
305 		mmap_read_lock(mm);
306 		ret = hmm_range_fault(range);
307 		mmap_read_unlock(mm);
308 		if (ret) {
309 			if (ret == -EBUSY)
310 				continue;
311 			goto out;
312 		}
313 
314 		mutex_lock(&dmirror->mutex);
315 		if (mmu_interval_read_retry(range->notifier,
316 					    range->notifier_seq)) {
317 			mutex_unlock(&dmirror->mutex);
318 			continue;
319 		}
320 		break;
321 	}
322 
323 	ret = dmirror_do_fault(dmirror, range);
324 
325 	mutex_unlock(&dmirror->mutex);
326 out:
327 	return ret;
328 }
329 
dmirror_fault(struct dmirror * dmirror,unsigned long start,unsigned long end,bool write)330 static int dmirror_fault(struct dmirror *dmirror, unsigned long start,
331 			 unsigned long end, bool write)
332 {
333 	struct mm_struct *mm = dmirror->notifier.mm;
334 	unsigned long addr;
335 	unsigned long pfns[32];
336 	struct hmm_range range = {
337 		.notifier = &dmirror->notifier,
338 		.hmm_pfns = pfns,
339 		.pfn_flags_mask = 0,
340 		.default_flags =
341 			HMM_PFN_REQ_FAULT | (write ? HMM_PFN_REQ_WRITE : 0),
342 		.dev_private_owner = dmirror->mdevice,
343 	};
344 	int ret = 0;
345 
346 	/* Since the mm is for the mirrored process, get a reference first. */
347 	if (!mmget_not_zero(mm))
348 		return 0;
349 
350 	for (addr = start; addr < end; addr = range.end) {
351 		range.start = addr;
352 		range.end = min(addr + (ARRAY_SIZE(pfns) << PAGE_SHIFT), end);
353 
354 		ret = dmirror_range_fault(dmirror, &range);
355 		if (ret)
356 			break;
357 	}
358 
359 	mmput(mm);
360 	return ret;
361 }
362 
dmirror_do_read(struct dmirror * dmirror,unsigned long start,unsigned long end,struct dmirror_bounce * bounce)363 static int dmirror_do_read(struct dmirror *dmirror, unsigned long start,
364 			   unsigned long end, struct dmirror_bounce *bounce)
365 {
366 	unsigned long pfn;
367 	void *ptr;
368 
369 	ptr = bounce->ptr + ((start - bounce->addr) & PAGE_MASK);
370 
371 	for (pfn = start >> PAGE_SHIFT; pfn < (end >> PAGE_SHIFT); pfn++) {
372 		void *entry;
373 		struct page *page;
374 
375 		entry = xa_load(&dmirror->pt, pfn);
376 		page = xa_untag_pointer(entry);
377 		if (!page)
378 			return -ENOENT;
379 
380 		memcpy_from_page(ptr, page, 0, PAGE_SIZE);
381 
382 		ptr += PAGE_SIZE;
383 		bounce->cpages++;
384 	}
385 
386 	return 0;
387 }
388 
dmirror_read(struct dmirror * dmirror,struct hmm_dmirror_cmd * cmd)389 static int dmirror_read(struct dmirror *dmirror, struct hmm_dmirror_cmd *cmd)
390 {
391 	struct dmirror_bounce bounce;
392 	unsigned long start, end;
393 	unsigned long size = cmd->npages << PAGE_SHIFT;
394 	int ret;
395 
396 	start = cmd->addr;
397 	end = start + size;
398 	if (end < start)
399 		return -EINVAL;
400 
401 	ret = dmirror_bounce_init(&bounce, start, size);
402 	if (ret)
403 		return ret;
404 
405 	while (1) {
406 		mutex_lock(&dmirror->mutex);
407 		ret = dmirror_do_read(dmirror, start, end, &bounce);
408 		mutex_unlock(&dmirror->mutex);
409 		if (ret != -ENOENT)
410 			break;
411 
412 		start = cmd->addr + (bounce.cpages << PAGE_SHIFT);
413 		ret = dmirror_fault(dmirror, start, end, false);
414 		if (ret)
415 			break;
416 		cmd->faults++;
417 	}
418 
419 	if (ret == 0) {
420 		if (copy_to_user(u64_to_user_ptr(cmd->ptr), bounce.ptr,
421 				 bounce.size))
422 			ret = -EFAULT;
423 	}
424 	cmd->cpages = bounce.cpages;
425 	dmirror_bounce_fini(&bounce);
426 	return ret;
427 }
428 
dmirror_do_write(struct dmirror * dmirror,unsigned long start,unsigned long end,struct dmirror_bounce * bounce)429 static int dmirror_do_write(struct dmirror *dmirror, unsigned long start,
430 			    unsigned long end, struct dmirror_bounce *bounce)
431 {
432 	unsigned long pfn;
433 	void *ptr;
434 
435 	ptr = bounce->ptr + ((start - bounce->addr) & PAGE_MASK);
436 
437 	for (pfn = start >> PAGE_SHIFT; pfn < (end >> PAGE_SHIFT); pfn++) {
438 		void *entry;
439 		struct page *page;
440 
441 		entry = xa_load(&dmirror->pt, pfn);
442 		page = xa_untag_pointer(entry);
443 		if (!page || xa_pointer_tag(entry) != DPT_XA_TAG_WRITE)
444 			return -ENOENT;
445 
446 		memcpy_to_page(page, 0, ptr, PAGE_SIZE);
447 
448 		ptr += PAGE_SIZE;
449 		bounce->cpages++;
450 	}
451 
452 	return 0;
453 }
454 
dmirror_write(struct dmirror * dmirror,struct hmm_dmirror_cmd * cmd)455 static int dmirror_write(struct dmirror *dmirror, struct hmm_dmirror_cmd *cmd)
456 {
457 	struct dmirror_bounce bounce;
458 	unsigned long start, end;
459 	unsigned long size = cmd->npages << PAGE_SHIFT;
460 	int ret;
461 
462 	start = cmd->addr;
463 	end = start + size;
464 	if (end < start)
465 		return -EINVAL;
466 
467 	ret = dmirror_bounce_init(&bounce, start, size);
468 	if (ret)
469 		return ret;
470 	if (copy_from_user(bounce.ptr, u64_to_user_ptr(cmd->ptr),
471 			   bounce.size)) {
472 		ret = -EFAULT;
473 		goto fini;
474 	}
475 
476 	while (1) {
477 		mutex_lock(&dmirror->mutex);
478 		ret = dmirror_do_write(dmirror, start, end, &bounce);
479 		mutex_unlock(&dmirror->mutex);
480 		if (ret != -ENOENT)
481 			break;
482 
483 		start = cmd->addr + (bounce.cpages << PAGE_SHIFT);
484 		ret = dmirror_fault(dmirror, start, end, true);
485 		if (ret)
486 			break;
487 		cmd->faults++;
488 	}
489 
490 fini:
491 	cmd->cpages = bounce.cpages;
492 	dmirror_bounce_fini(&bounce);
493 	return ret;
494 }
495 
dmirror_allocate_chunk(struct dmirror_device * mdevice,struct page ** ppage,bool is_large)496 static int dmirror_allocate_chunk(struct dmirror_device *mdevice,
497 				  struct page **ppage, bool is_large)
498 {
499 	struct dmirror_chunk *devmem;
500 	struct resource *res = NULL;
501 	unsigned long pfn;
502 	unsigned long pfn_first;
503 	unsigned long pfn_last;
504 	void *ptr;
505 	int ret = -ENOMEM;
506 
507 	devmem = kzalloc(sizeof(*devmem), GFP_KERNEL);
508 	if (!devmem)
509 		return ret;
510 
511 	switch (mdevice->zone_device_type) {
512 	case HMM_DMIRROR_MEMORY_DEVICE_PRIVATE:
513 		res = request_free_mem_region(&iomem_resource, DEVMEM_CHUNK_SIZE,
514 					      "hmm_dmirror");
515 		if (IS_ERR_OR_NULL(res))
516 			goto err_devmem;
517 		devmem->pagemap.range.start = res->start;
518 		devmem->pagemap.range.end = res->end;
519 		devmem->pagemap.type = MEMORY_DEVICE_PRIVATE;
520 		break;
521 	case HMM_DMIRROR_MEMORY_DEVICE_COHERENT:
522 		devmem->pagemap.range.start = (MINOR(mdevice->cdevice.dev) - 2) ?
523 							spm_addr_dev0 :
524 							spm_addr_dev1;
525 		devmem->pagemap.range.end = devmem->pagemap.range.start +
526 					    DEVMEM_CHUNK_SIZE - 1;
527 		devmem->pagemap.type = MEMORY_DEVICE_COHERENT;
528 		break;
529 	default:
530 		ret = -EINVAL;
531 		goto err_devmem;
532 	}
533 
534 	devmem->pagemap.nr_range = 1;
535 	devmem->pagemap.ops = &dmirror_devmem_ops;
536 	devmem->pagemap.owner = mdevice;
537 
538 	mutex_lock(&mdevice->devmem_lock);
539 
540 	if (mdevice->devmem_count == mdevice->devmem_capacity) {
541 		struct dmirror_chunk **new_chunks;
542 		unsigned int new_capacity;
543 
544 		new_capacity = mdevice->devmem_capacity +
545 				DEVMEM_CHUNKS_RESERVE;
546 		new_chunks = krealloc(mdevice->devmem_chunks,
547 				sizeof(new_chunks[0]) * new_capacity,
548 				GFP_KERNEL);
549 		if (!new_chunks)
550 			goto err_release;
551 		mdevice->devmem_capacity = new_capacity;
552 		mdevice->devmem_chunks = new_chunks;
553 	}
554 	ptr = memremap_pages(&devmem->pagemap, numa_node_id());
555 	if (IS_ERR_OR_NULL(ptr)) {
556 		if (ptr)
557 			ret = PTR_ERR(ptr);
558 		else
559 			ret = -EFAULT;
560 		goto err_release;
561 	}
562 
563 	devmem->mdevice = mdevice;
564 	pfn_first = devmem->pagemap.range.start >> PAGE_SHIFT;
565 	pfn_last = pfn_first + (range_len(&devmem->pagemap.range) >> PAGE_SHIFT);
566 	mdevice->devmem_chunks[mdevice->devmem_count++] = devmem;
567 
568 	mutex_unlock(&mdevice->devmem_lock);
569 
570 	pr_info("added new %u MB chunk (total %u chunks, %u MB) PFNs [0x%lx 0x%lx)\n",
571 		DEVMEM_CHUNK_SIZE / (1024 * 1024),
572 		mdevice->devmem_count,
573 		mdevice->devmem_count * (DEVMEM_CHUNK_SIZE / (1024 * 1024)),
574 		pfn_first, pfn_last);
575 
576 	spin_lock(&mdevice->lock);
577 	for (pfn = pfn_first; pfn < pfn_last; ) {
578 		struct page *page = pfn_to_page(pfn);
579 
580 		if (is_large && IS_ALIGNED(pfn, HPAGE_PMD_NR)
581 			&& (pfn + HPAGE_PMD_NR <= pfn_last)) {
582 			page->zone_device_data = mdevice->free_folios;
583 			mdevice->free_folios = page_folio(page);
584 			pfn += HPAGE_PMD_NR;
585 			continue;
586 		}
587 
588 		page->zone_device_data = mdevice->free_pages;
589 		mdevice->free_pages = page;
590 		pfn++;
591 	}
592 
593 	ret = 0;
594 	if (ppage) {
595 		if (is_large) {
596 			if (!mdevice->free_folios) {
597 				ret = -ENOMEM;
598 				goto err_unlock;
599 			}
600 			*ppage = folio_page(mdevice->free_folios, 0);
601 			mdevice->free_folios = (*ppage)->zone_device_data;
602 			mdevice->calloc += HPAGE_PMD_NR;
603 		} else if (mdevice->free_pages) {
604 			*ppage = mdevice->free_pages;
605 			mdevice->free_pages = (*ppage)->zone_device_data;
606 			mdevice->calloc++;
607 		} else {
608 			ret = -ENOMEM;
609 			goto err_unlock;
610 		}
611 	}
612 err_unlock:
613 	spin_unlock(&mdevice->lock);
614 
615 	return ret;
616 
617 err_release:
618 	mutex_unlock(&mdevice->devmem_lock);
619 	if (res && devmem->pagemap.type == MEMORY_DEVICE_PRIVATE)
620 		release_mem_region(devmem->pagemap.range.start,
621 				   range_len(&devmem->pagemap.range));
622 err_devmem:
623 	kfree(devmem);
624 
625 	return ret;
626 }
627 
dmirror_devmem_alloc_page(struct dmirror * dmirror,bool is_large)628 static struct page *dmirror_devmem_alloc_page(struct dmirror *dmirror,
629 					      bool is_large)
630 {
631 	struct page *dpage = NULL;
632 	struct page *rpage = NULL;
633 	unsigned int order = is_large ? HPAGE_PMD_ORDER : 0;
634 	struct dmirror_device *mdevice = dmirror->mdevice;
635 
636 	/*
637 	 * For ZONE_DEVICE private type, this is a fake device so we allocate
638 	 * real system memory to store our device memory.
639 	 * For ZONE_DEVICE coherent type we use the actual dpage to store the
640 	 * data and ignore rpage.
641 	 */
642 	if (dmirror_is_private_zone(mdevice)) {
643 		rpage = folio_page(folio_alloc(GFP_HIGHUSER, order), 0);
644 		if (!rpage)
645 			return NULL;
646 	}
647 	spin_lock(&mdevice->lock);
648 
649 	if (is_large && mdevice->free_folios) {
650 		dpage = folio_page(mdevice->free_folios, 0);
651 		mdevice->free_folios = dpage->zone_device_data;
652 		mdevice->calloc += 1 << order;
653 		spin_unlock(&mdevice->lock);
654 	} else if (!is_large && mdevice->free_pages) {
655 		dpage = mdevice->free_pages;
656 		mdevice->free_pages = dpage->zone_device_data;
657 		mdevice->calloc++;
658 		spin_unlock(&mdevice->lock);
659 	} else {
660 		spin_unlock(&mdevice->lock);
661 		if (dmirror_allocate_chunk(mdevice, &dpage, is_large))
662 			goto error;
663 	}
664 
665 	zone_device_folio_init(page_folio(dpage),
666 			       page_pgmap(folio_page(page_folio(dpage), 0)),
667 			       order);
668 	dpage->zone_device_data = rpage;
669 	return dpage;
670 
671 error:
672 	if (rpage)
673 		__free_pages(rpage, order);
674 	return NULL;
675 }
676 
dmirror_migrate_alloc_and_copy(struct migrate_vma * args,struct dmirror * dmirror)677 static void dmirror_migrate_alloc_and_copy(struct migrate_vma *args,
678 					   struct dmirror *dmirror)
679 {
680 	const unsigned long *src = args->src;
681 	unsigned long *dst = args->dst;
682 	unsigned long addr;
683 
684 	for (addr = args->start; addr < args->end; ) {
685 		struct page *spage;
686 		struct page *dpage;
687 		struct page *rpage;
688 		bool is_large = *src & MIGRATE_PFN_COMPOUND;
689 		int write = (*src & MIGRATE_PFN_WRITE) ? MIGRATE_PFN_WRITE : 0;
690 		unsigned long nr = 1;
691 
692 		if (!(*src & MIGRATE_PFN_MIGRATE))
693 			goto next;
694 
695 		/*
696 		 * Note that spage might be NULL which is OK since it is an
697 		 * unallocated pte_none() or read-only zero page.
698 		 */
699 		spage = migrate_pfn_to_page(*src);
700 		if (WARN(spage && is_zone_device_page(spage),
701 		     "page already in device spage pfn: 0x%lx\n",
702 		     page_to_pfn(spage)))
703 			goto next;
704 
705 		if (dmirror->flags & HMM_DMIRROR_FLAG_FAIL_ALLOC) {
706 			dmirror->flags &= ~HMM_DMIRROR_FLAG_FAIL_ALLOC;
707 			dpage = NULL;
708 		} else
709 			dpage = dmirror_devmem_alloc_page(dmirror, is_large);
710 
711 		if (!dpage) {
712 			struct folio *folio;
713 			unsigned long i;
714 			unsigned long spfn = *src >> MIGRATE_PFN_SHIFT;
715 			struct page *src_page;
716 
717 			if (!is_large)
718 				goto next;
719 
720 			if (!spage && is_large) {
721 				nr = HPAGE_PMD_NR;
722 			} else {
723 				folio = page_folio(spage);
724 				nr = folio_nr_pages(folio);
725 			}
726 
727 			for (i = 0; i < nr && addr < args->end; i++) {
728 				dpage = dmirror_devmem_alloc_page(dmirror, false);
729 				rpage = BACKING_PAGE(dpage);
730 				rpage->zone_device_data = dmirror;
731 
732 				*dst = migrate_pfn(page_to_pfn(dpage)) | write;
733 				src_page = pfn_to_page(spfn + i);
734 
735 				if (spage)
736 					copy_highpage(rpage, src_page);
737 				else
738 					clear_highpage(rpage);
739 				src++;
740 				dst++;
741 				addr += PAGE_SIZE;
742 			}
743 			continue;
744 		}
745 
746 		rpage = BACKING_PAGE(dpage);
747 
748 		/*
749 		 * Normally, a device would use the page->zone_device_data to
750 		 * point to the mirror but here we use it to hold the page for
751 		 * the simulated device memory and that page holds the pointer
752 		 * to the mirror.
753 		 */
754 		rpage->zone_device_data = dmirror;
755 
756 		pr_debug("migrating from sys to dev pfn src: 0x%lx pfn dst: 0x%lx\n",
757 			 page_to_pfn(spage), page_to_pfn(dpage));
758 
759 		*dst = migrate_pfn(page_to_pfn(dpage)) | write;
760 
761 		if (is_large) {
762 			int i;
763 			struct folio *folio = page_folio(dpage);
764 			*dst |= MIGRATE_PFN_COMPOUND;
765 
766 			if (folio_test_large(folio)) {
767 				for (i = 0; i < folio_nr_pages(folio); i++) {
768 					struct page *dst_page =
769 						pfn_to_page(page_to_pfn(rpage) + i);
770 					struct page *src_page =
771 						pfn_to_page(page_to_pfn(spage) + i);
772 
773 					if (spage)
774 						copy_highpage(dst_page, src_page);
775 					else
776 						clear_highpage(dst_page);
777 					src++;
778 					dst++;
779 					addr += PAGE_SIZE;
780 				}
781 				continue;
782 			}
783 		}
784 
785 		if (spage)
786 			copy_highpage(rpage, spage);
787 		else
788 			clear_highpage(rpage);
789 
790 next:
791 		src++;
792 		dst++;
793 		addr += PAGE_SIZE;
794 	}
795 }
796 
dmirror_check_atomic(struct dmirror * dmirror,unsigned long start,unsigned long end)797 static int dmirror_check_atomic(struct dmirror *dmirror, unsigned long start,
798 			     unsigned long end)
799 {
800 	unsigned long pfn;
801 
802 	for (pfn = start >> PAGE_SHIFT; pfn < (end >> PAGE_SHIFT); pfn++) {
803 		void *entry;
804 
805 		entry = xa_load(&dmirror->pt, pfn);
806 		if (xa_pointer_tag(entry) == DPT_XA_TAG_ATOMIC)
807 			return -EPERM;
808 	}
809 
810 	return 0;
811 }
812 
dmirror_atomic_map(unsigned long addr,struct page * page,struct dmirror * dmirror)813 static int dmirror_atomic_map(unsigned long addr, struct page *page,
814 		struct dmirror *dmirror)
815 {
816 	void *entry;
817 
818 	/* Map the migrated pages into the device's page tables. */
819 	mutex_lock(&dmirror->mutex);
820 
821 	entry = xa_tag_pointer(page, DPT_XA_TAG_ATOMIC);
822 	entry = xa_store(&dmirror->pt, addr >> PAGE_SHIFT, entry, GFP_ATOMIC);
823 	if (xa_is_err(entry)) {
824 		mutex_unlock(&dmirror->mutex);
825 		return xa_err(entry);
826 	}
827 
828 	mutex_unlock(&dmirror->mutex);
829 	return 0;
830 }
831 
dmirror_migrate_finalize_and_map(struct migrate_vma * args,struct dmirror * dmirror)832 static int dmirror_migrate_finalize_and_map(struct migrate_vma *args,
833 					    struct dmirror *dmirror)
834 {
835 	unsigned long start = args->start;
836 	unsigned long end = args->end;
837 	const unsigned long *src = args->src;
838 	const unsigned long *dst = args->dst;
839 	unsigned long pfn;
840 	const unsigned long start_pfn = start >> PAGE_SHIFT;
841 	const unsigned long end_pfn = end >> PAGE_SHIFT;
842 
843 	/* Map the migrated pages into the device's page tables. */
844 	mutex_lock(&dmirror->mutex);
845 
846 	for (pfn = start_pfn; pfn < end_pfn; pfn++, src++, dst++) {
847 		struct page *dpage;
848 		void *entry;
849 		int nr, i;
850 		struct page *rpage;
851 
852 		if (!(*src & MIGRATE_PFN_MIGRATE))
853 			continue;
854 
855 		dpage = migrate_pfn_to_page(*dst);
856 		if (!dpage)
857 			continue;
858 
859 		if (*dst & MIGRATE_PFN_COMPOUND)
860 			nr = folio_nr_pages(page_folio(dpage));
861 		else
862 			nr = 1;
863 
864 		WARN_ON_ONCE(end_pfn < start_pfn + nr);
865 
866 		rpage = BACKING_PAGE(dpage);
867 		VM_WARN_ON(folio_nr_pages(page_folio(rpage)) != nr);
868 
869 		for (i = 0; i < nr; i++) {
870 			entry = folio_page(page_folio(rpage), i);
871 			if (*dst & MIGRATE_PFN_WRITE)
872 				entry = xa_tag_pointer(entry, DPT_XA_TAG_WRITE);
873 			entry = xa_store(&dmirror->pt, pfn + i, entry, GFP_ATOMIC);
874 			if (xa_is_err(entry)) {
875 				mutex_unlock(&dmirror->mutex);
876 				return xa_err(entry);
877 			}
878 		}
879 	}
880 
881 	mutex_unlock(&dmirror->mutex);
882 	return 0;
883 }
884 
dmirror_exclusive(struct dmirror * dmirror,struct hmm_dmirror_cmd * cmd)885 static int dmirror_exclusive(struct dmirror *dmirror,
886 			     struct hmm_dmirror_cmd *cmd)
887 {
888 	unsigned long start, end, addr;
889 	unsigned long size = cmd->npages << PAGE_SHIFT;
890 	struct mm_struct *mm = dmirror->notifier.mm;
891 	struct dmirror_bounce bounce;
892 	int ret = 0;
893 
894 	start = cmd->addr;
895 	end = start + size;
896 	if (end < start)
897 		return -EINVAL;
898 
899 	/* Since the mm is for the mirrored process, get a reference first. */
900 	if (!mmget_not_zero(mm))
901 		return -EINVAL;
902 
903 	mmap_read_lock(mm);
904 	for (addr = start; !ret && addr < end; addr += PAGE_SIZE) {
905 		struct folio *folio;
906 		struct page *page;
907 
908 		page = make_device_exclusive(mm, addr, NULL, &folio);
909 		if (IS_ERR(page)) {
910 			ret = PTR_ERR(page);
911 			break;
912 		}
913 
914 		ret = dmirror_atomic_map(addr, page, dmirror);
915 		folio_unlock(folio);
916 		folio_put(folio);
917 	}
918 	mmap_read_unlock(mm);
919 	mmput(mm);
920 
921 	if (ret)
922 		return ret;
923 
924 	/* Return the migrated data for verification. */
925 	ret = dmirror_bounce_init(&bounce, start, size);
926 	if (ret)
927 		return ret;
928 	mutex_lock(&dmirror->mutex);
929 	ret = dmirror_do_read(dmirror, start, end, &bounce);
930 	mutex_unlock(&dmirror->mutex);
931 	if (ret == 0) {
932 		if (copy_to_user(u64_to_user_ptr(cmd->ptr), bounce.ptr,
933 				 bounce.size))
934 			ret = -EFAULT;
935 	}
936 
937 	cmd->cpages = bounce.cpages;
938 	dmirror_bounce_fini(&bounce);
939 	return ret;
940 }
941 
dmirror_devmem_fault_alloc_and_copy(struct migrate_vma * args,struct dmirror * dmirror)942 static vm_fault_t dmirror_devmem_fault_alloc_and_copy(struct migrate_vma *args,
943 						      struct dmirror *dmirror)
944 {
945 	const unsigned long *src = args->src;
946 	unsigned long *dst = args->dst;
947 	unsigned long start = args->start;
948 	unsigned long end = args->end;
949 	unsigned long addr;
950 	unsigned int order = 0;
951 	int i;
952 
953 	for (addr = start; addr < end; ) {
954 		struct page *dpage, *spage;
955 
956 		spage = migrate_pfn_to_page(*src);
957 		if (!spage || !(*src & MIGRATE_PFN_MIGRATE)) {
958 			addr += PAGE_SIZE;
959 			goto next;
960 		}
961 
962 		if (WARN_ON(!is_device_private_page(spage) &&
963 			    !is_device_coherent_page(spage))) {
964 			addr += PAGE_SIZE;
965 			goto next;
966 		}
967 
968 		spage = BACKING_PAGE(spage);
969 		order = folio_order(page_folio(spage));
970 		if (order)
971 			*dst = MIGRATE_PFN_COMPOUND;
972 		if (*src & MIGRATE_PFN_WRITE)
973 			*dst |= MIGRATE_PFN_WRITE;
974 
975 		if (dmirror->flags & HMM_DMIRROR_FLAG_FAIL_ALLOC) {
976 			dmirror->flags &= ~HMM_DMIRROR_FLAG_FAIL_ALLOC;
977 			*dst &= ~MIGRATE_PFN_COMPOUND;
978 			dpage = NULL;
979 		} else if (order) {
980 			dpage = folio_page(vma_alloc_folio(GFP_HIGHUSER_MOVABLE,
981 						order, args->vma, addr), 0);
982 		} else {
983 			dpage = alloc_page_vma(GFP_HIGHUSER_MOVABLE, args->vma, addr);
984 		}
985 
986 		if (!dpage && !order)
987 			return VM_FAULT_OOM;
988 
989 		pr_debug("migrating from sys to dev pfn src: 0x%lx pfn dst: 0x%lx\n",
990 				page_to_pfn(spage), page_to_pfn(dpage));
991 
992 		if (dpage) {
993 			lock_page(dpage);
994 			*dst |= migrate_pfn(page_to_pfn(dpage));
995 		}
996 
997 		for (i = 0; i < (1 << order); i++) {
998 			struct page *src_page;
999 			struct page *dst_page;
1000 
1001 			/* Try with smaller pages if large allocation fails */
1002 			if (!dpage && order) {
1003 				dpage = alloc_page_vma(GFP_HIGHUSER_MOVABLE, args->vma, addr);
1004 				lock_page(dpage);
1005 				dst[i] = migrate_pfn(page_to_pfn(dpage));
1006 				dst_page = pfn_to_page(page_to_pfn(dpage));
1007 				dpage = NULL; /* For the next iteration */
1008 			} else {
1009 				dst_page = pfn_to_page(page_to_pfn(dpage) + i);
1010 			}
1011 
1012 			src_page = pfn_to_page(page_to_pfn(spage) + i);
1013 
1014 			xa_erase(&dmirror->pt, addr >> PAGE_SHIFT);
1015 			addr += PAGE_SIZE;
1016 			copy_highpage(dst_page, src_page);
1017 		}
1018 next:
1019 		src += 1 << order;
1020 		dst += 1 << order;
1021 	}
1022 	return 0;
1023 }
1024 
1025 static unsigned long
dmirror_successful_migrated_pages(struct migrate_vma * migrate)1026 dmirror_successful_migrated_pages(struct migrate_vma *migrate)
1027 {
1028 	unsigned long cpages = 0;
1029 	unsigned long i;
1030 
1031 	for (i = 0; i < migrate->npages; i++) {
1032 		if (migrate->src[i] & MIGRATE_PFN_VALID &&
1033 		    migrate->src[i] & MIGRATE_PFN_MIGRATE)
1034 			cpages++;
1035 	}
1036 	return cpages;
1037 }
1038 
dmirror_migrate_to_system(struct dmirror * dmirror,struct hmm_dmirror_cmd * cmd)1039 static int dmirror_migrate_to_system(struct dmirror *dmirror,
1040 				     struct hmm_dmirror_cmd *cmd)
1041 {
1042 	unsigned long start, end, addr;
1043 	unsigned long size = cmd->npages << PAGE_SHIFT;
1044 	struct mm_struct *mm = dmirror->notifier.mm;
1045 	struct vm_area_struct *vma;
1046 	struct migrate_vma args = { 0 };
1047 	unsigned long next;
1048 	int ret;
1049 	unsigned long *src_pfns;
1050 	unsigned long *dst_pfns;
1051 
1052 	src_pfns = kvcalloc(PTRS_PER_PTE, sizeof(*src_pfns), GFP_KERNEL | __GFP_NOFAIL);
1053 	dst_pfns = kvcalloc(PTRS_PER_PTE, sizeof(*dst_pfns), GFP_KERNEL | __GFP_NOFAIL);
1054 
1055 	start = cmd->addr;
1056 	end = start + size;
1057 	if (end < start)
1058 		return -EINVAL;
1059 
1060 	/* Since the mm is for the mirrored process, get a reference first. */
1061 	if (!mmget_not_zero(mm))
1062 		return -EINVAL;
1063 
1064 	cmd->cpages = 0;
1065 	mmap_read_lock(mm);
1066 	for (addr = start; addr < end; addr = next) {
1067 		vma = vma_lookup(mm, addr);
1068 		if (!vma || !(vma->vm_flags & VM_READ)) {
1069 			ret = -EINVAL;
1070 			goto out;
1071 		}
1072 		next = min(end, addr + (PTRS_PER_PTE << PAGE_SHIFT));
1073 		if (next > vma->vm_end)
1074 			next = vma->vm_end;
1075 
1076 		args.vma = vma;
1077 		args.src = src_pfns;
1078 		args.dst = dst_pfns;
1079 		args.start = addr;
1080 		args.end = next;
1081 		args.pgmap_owner = dmirror->mdevice;
1082 		args.flags = dmirror_select_device(dmirror) | MIGRATE_VMA_SELECT_COMPOUND;
1083 
1084 		ret = migrate_vma_setup(&args);
1085 		if (ret)
1086 			goto out;
1087 
1088 		pr_debug("Migrating from device mem to sys mem\n");
1089 		dmirror_devmem_fault_alloc_and_copy(&args, dmirror);
1090 
1091 		migrate_vma_pages(&args);
1092 		cmd->cpages += dmirror_successful_migrated_pages(&args);
1093 		migrate_vma_finalize(&args);
1094 	}
1095 out:
1096 	mmap_read_unlock(mm);
1097 	mmput(mm);
1098 	kvfree(src_pfns);
1099 	kvfree(dst_pfns);
1100 
1101 	return ret;
1102 }
1103 
dmirror_migrate_to_device(struct dmirror * dmirror,struct hmm_dmirror_cmd * cmd)1104 static int dmirror_migrate_to_device(struct dmirror *dmirror,
1105 				struct hmm_dmirror_cmd *cmd)
1106 {
1107 	unsigned long start, end, addr;
1108 	unsigned long size = cmd->npages << PAGE_SHIFT;
1109 	struct mm_struct *mm = dmirror->notifier.mm;
1110 	struct vm_area_struct *vma;
1111 	struct dmirror_bounce bounce;
1112 	struct migrate_vma args = { 0 };
1113 	unsigned long next;
1114 	int ret;
1115 	unsigned long *src_pfns = NULL;
1116 	unsigned long *dst_pfns = NULL;
1117 
1118 	start = cmd->addr;
1119 	end = start + size;
1120 	if (end < start)
1121 		return -EINVAL;
1122 
1123 	/* Since the mm is for the mirrored process, get a reference first. */
1124 	if (!mmget_not_zero(mm))
1125 		return -EINVAL;
1126 
1127 	ret = -ENOMEM;
1128 	src_pfns = kvcalloc(PTRS_PER_PTE, sizeof(*src_pfns),
1129 			  GFP_KERNEL | __GFP_NOFAIL);
1130 	if (!src_pfns)
1131 		goto free_mem;
1132 
1133 	dst_pfns = kvcalloc(PTRS_PER_PTE, sizeof(*dst_pfns),
1134 			  GFP_KERNEL | __GFP_NOFAIL);
1135 	if (!dst_pfns)
1136 		goto free_mem;
1137 
1138 	ret = 0;
1139 	mmap_read_lock(mm);
1140 	for (addr = start; addr < end; addr = next) {
1141 		vma = vma_lookup(mm, addr);
1142 		if (!vma || !(vma->vm_flags & VM_READ)) {
1143 			ret = -EINVAL;
1144 			goto out;
1145 		}
1146 		next = min(end, addr + (PTRS_PER_PTE << PAGE_SHIFT));
1147 		if (next > vma->vm_end)
1148 			next = vma->vm_end;
1149 
1150 		args.vma = vma;
1151 		args.src = src_pfns;
1152 		args.dst = dst_pfns;
1153 		args.start = addr;
1154 		args.end = next;
1155 		args.pgmap_owner = dmirror->mdevice;
1156 		args.flags = MIGRATE_VMA_SELECT_SYSTEM |
1157 				MIGRATE_VMA_SELECT_COMPOUND;
1158 		ret = migrate_vma_setup(&args);
1159 		if (ret)
1160 			goto out;
1161 
1162 		pr_debug("Migrating from sys mem to device mem\n");
1163 		dmirror_migrate_alloc_and_copy(&args, dmirror);
1164 		migrate_vma_pages(&args);
1165 		dmirror_migrate_finalize_and_map(&args, dmirror);
1166 		migrate_vma_finalize(&args);
1167 	}
1168 	mmap_read_unlock(mm);
1169 	mmput(mm);
1170 
1171 	/*
1172 	 * Return the migrated data for verification.
1173 	 * Only for pages in device zone
1174 	 */
1175 	ret = dmirror_bounce_init(&bounce, start, size);
1176 	if (ret)
1177 		goto free_mem;
1178 	mutex_lock(&dmirror->mutex);
1179 	ret = dmirror_do_read(dmirror, start, end, &bounce);
1180 	mutex_unlock(&dmirror->mutex);
1181 	if (ret == 0) {
1182 		if (copy_to_user(u64_to_user_ptr(cmd->ptr), bounce.ptr,
1183 				 bounce.size))
1184 			ret = -EFAULT;
1185 	}
1186 	cmd->cpages = bounce.cpages;
1187 	dmirror_bounce_fini(&bounce);
1188 	goto free_mem;
1189 
1190 out:
1191 	mmap_read_unlock(mm);
1192 	mmput(mm);
1193 free_mem:
1194 	kfree(src_pfns);
1195 	kfree(dst_pfns);
1196 	return ret;
1197 }
1198 
dmirror_mkentry(struct dmirror * dmirror,struct hmm_range * range,unsigned char * perm,unsigned long entry)1199 static void dmirror_mkentry(struct dmirror *dmirror, struct hmm_range *range,
1200 			    unsigned char *perm, unsigned long entry)
1201 {
1202 	struct page *page;
1203 
1204 	if (entry & HMM_PFN_ERROR) {
1205 		*perm = HMM_DMIRROR_PROT_ERROR;
1206 		return;
1207 	}
1208 	if (!(entry & HMM_PFN_VALID)) {
1209 		*perm = HMM_DMIRROR_PROT_NONE;
1210 		return;
1211 	}
1212 
1213 	page = hmm_pfn_to_page(entry);
1214 	if (is_device_private_page(page)) {
1215 		/* Is the page migrated to this device or some other? */
1216 		if (dmirror->mdevice == dmirror_page_to_device(page))
1217 			*perm = HMM_DMIRROR_PROT_DEV_PRIVATE_LOCAL;
1218 		else
1219 			*perm = HMM_DMIRROR_PROT_DEV_PRIVATE_REMOTE;
1220 	} else if (is_device_coherent_page(page)) {
1221 		/* Is the page migrated to this device or some other? */
1222 		if (dmirror->mdevice == dmirror_page_to_device(page))
1223 			*perm = HMM_DMIRROR_PROT_DEV_COHERENT_LOCAL;
1224 		else
1225 			*perm = HMM_DMIRROR_PROT_DEV_COHERENT_REMOTE;
1226 	} else if (is_zero_pfn(page_to_pfn(page)))
1227 		*perm = HMM_DMIRROR_PROT_ZERO;
1228 	else
1229 		*perm = HMM_DMIRROR_PROT_NONE;
1230 	if (entry & HMM_PFN_WRITE)
1231 		*perm |= HMM_DMIRROR_PROT_WRITE;
1232 	else
1233 		*perm |= HMM_DMIRROR_PROT_READ;
1234 	if (hmm_pfn_to_map_order(entry) + PAGE_SHIFT == PMD_SHIFT)
1235 		*perm |= HMM_DMIRROR_PROT_PMD;
1236 	else if (hmm_pfn_to_map_order(entry) + PAGE_SHIFT == PUD_SHIFT)
1237 		*perm |= HMM_DMIRROR_PROT_PUD;
1238 }
1239 
dmirror_snapshot_invalidate(struct mmu_interval_notifier * mni,const struct mmu_notifier_range * range,unsigned long cur_seq)1240 static bool dmirror_snapshot_invalidate(struct mmu_interval_notifier *mni,
1241 				const struct mmu_notifier_range *range,
1242 				unsigned long cur_seq)
1243 {
1244 	struct dmirror_interval *dmi =
1245 		container_of(mni, struct dmirror_interval, notifier);
1246 	struct dmirror *dmirror = dmi->dmirror;
1247 
1248 	if (mmu_notifier_range_blockable(range))
1249 		mutex_lock(&dmirror->mutex);
1250 	else if (!mutex_trylock(&dmirror->mutex))
1251 		return false;
1252 
1253 	/*
1254 	 * Snapshots only need to set the sequence number since any
1255 	 * invalidation in the interval invalidates the whole snapshot.
1256 	 */
1257 	mmu_interval_set_seq(mni, cur_seq);
1258 
1259 	mutex_unlock(&dmirror->mutex);
1260 	return true;
1261 }
1262 
1263 static const struct mmu_interval_notifier_ops dmirror_mrn_ops = {
1264 	.invalidate = dmirror_snapshot_invalidate,
1265 };
1266 
dmirror_range_snapshot(struct dmirror * dmirror,struct hmm_range * range,unsigned char * perm)1267 static int dmirror_range_snapshot(struct dmirror *dmirror,
1268 				  struct hmm_range *range,
1269 				  unsigned char *perm)
1270 {
1271 	struct mm_struct *mm = dmirror->notifier.mm;
1272 	struct dmirror_interval notifier;
1273 	unsigned long timeout =
1274 		jiffies + msecs_to_jiffies(HMM_RANGE_DEFAULT_TIMEOUT);
1275 	unsigned long i;
1276 	unsigned long n;
1277 	int ret = 0;
1278 
1279 	notifier.dmirror = dmirror;
1280 	range->notifier = &notifier.notifier;
1281 
1282 	ret = mmu_interval_notifier_insert(range->notifier, mm,
1283 			range->start, range->end - range->start,
1284 			&dmirror_mrn_ops);
1285 	if (ret)
1286 		return ret;
1287 
1288 	while (true) {
1289 		if (time_after(jiffies, timeout)) {
1290 			ret = -EBUSY;
1291 			goto out;
1292 		}
1293 
1294 		range->notifier_seq = mmu_interval_read_begin(range->notifier);
1295 
1296 		mmap_read_lock(mm);
1297 		ret = hmm_range_fault(range);
1298 		mmap_read_unlock(mm);
1299 		if (ret) {
1300 			if (ret == -EBUSY)
1301 				continue;
1302 			goto out;
1303 		}
1304 
1305 		mutex_lock(&dmirror->mutex);
1306 		if (mmu_interval_read_retry(range->notifier,
1307 					    range->notifier_seq)) {
1308 			mutex_unlock(&dmirror->mutex);
1309 			continue;
1310 		}
1311 		break;
1312 	}
1313 
1314 	n = (range->end - range->start) >> PAGE_SHIFT;
1315 	for (i = 0; i < n; i++)
1316 		dmirror_mkentry(dmirror, range, perm + i, range->hmm_pfns[i]);
1317 
1318 	mutex_unlock(&dmirror->mutex);
1319 out:
1320 	mmu_interval_notifier_remove(range->notifier);
1321 	return ret;
1322 }
1323 
dmirror_snapshot(struct dmirror * dmirror,struct hmm_dmirror_cmd * cmd)1324 static int dmirror_snapshot(struct dmirror *dmirror,
1325 			    struct hmm_dmirror_cmd *cmd)
1326 {
1327 	struct mm_struct *mm = dmirror->notifier.mm;
1328 	unsigned long start, end;
1329 	unsigned long size = cmd->npages << PAGE_SHIFT;
1330 	unsigned long addr;
1331 	unsigned long next;
1332 	unsigned long pfns[32];
1333 	unsigned char perm[32];
1334 	char __user *uptr;
1335 	struct hmm_range range = {
1336 		.hmm_pfns = pfns,
1337 		.dev_private_owner = dmirror->mdevice,
1338 	};
1339 	int ret = 0;
1340 
1341 	start = cmd->addr;
1342 	end = start + size;
1343 	if (end < start)
1344 		return -EINVAL;
1345 
1346 	/* Since the mm is for the mirrored process, get a reference first. */
1347 	if (!mmget_not_zero(mm))
1348 		return -EINVAL;
1349 
1350 	/*
1351 	 * Register a temporary notifier to detect invalidations even if it
1352 	 * overlaps with other mmu_interval_notifiers.
1353 	 */
1354 	uptr = u64_to_user_ptr(cmd->ptr);
1355 	for (addr = start; addr < end; addr = next) {
1356 		unsigned long n;
1357 
1358 		next = min(addr + (ARRAY_SIZE(pfns) << PAGE_SHIFT), end);
1359 		range.start = addr;
1360 		range.end = next;
1361 
1362 		ret = dmirror_range_snapshot(dmirror, &range, perm);
1363 		if (ret)
1364 			break;
1365 
1366 		n = (range.end - range.start) >> PAGE_SHIFT;
1367 		if (copy_to_user(uptr, perm, n)) {
1368 			ret = -EFAULT;
1369 			break;
1370 		}
1371 
1372 		cmd->cpages += n;
1373 		uptr += n;
1374 	}
1375 	mmput(mm);
1376 
1377 	return ret;
1378 }
1379 
dmirror_device_evict_chunk(struct dmirror_chunk * chunk)1380 static void dmirror_device_evict_chunk(struct dmirror_chunk *chunk)
1381 {
1382 	unsigned long start_pfn = chunk->pagemap.range.start >> PAGE_SHIFT;
1383 	unsigned long end_pfn = chunk->pagemap.range.end >> PAGE_SHIFT;
1384 	unsigned long npages = end_pfn - start_pfn + 1;
1385 	unsigned long i;
1386 	unsigned long *src_pfns;
1387 	unsigned long *dst_pfns;
1388 	unsigned int order = 0;
1389 
1390 	src_pfns = kvcalloc(npages, sizeof(*src_pfns), GFP_KERNEL | __GFP_NOFAIL);
1391 	dst_pfns = kvcalloc(npages, sizeof(*dst_pfns), GFP_KERNEL | __GFP_NOFAIL);
1392 
1393 	migrate_device_range(src_pfns, start_pfn, npages);
1394 	for (i = 0; i < npages; i++) {
1395 		struct page *dpage, *spage;
1396 
1397 		spage = migrate_pfn_to_page(src_pfns[i]);
1398 		if (!spage || !(src_pfns[i] & MIGRATE_PFN_MIGRATE))
1399 			continue;
1400 
1401 		if (WARN_ON(!is_device_private_page(spage) &&
1402 			    !is_device_coherent_page(spage)))
1403 			continue;
1404 
1405 		order = folio_order(page_folio(spage));
1406 		spage = BACKING_PAGE(spage);
1407 		if (src_pfns[i] & MIGRATE_PFN_COMPOUND) {
1408 			dpage = folio_page(folio_alloc(GFP_HIGHUSER_MOVABLE,
1409 					      order), 0);
1410 		} else {
1411 			dpage = alloc_page(GFP_HIGHUSER_MOVABLE | __GFP_NOFAIL);
1412 			order = 0;
1413 		}
1414 
1415 		/* TODO Support splitting here */
1416 		lock_page(dpage);
1417 		dst_pfns[i] = migrate_pfn(page_to_pfn(dpage));
1418 		if (src_pfns[i] & MIGRATE_PFN_WRITE)
1419 			dst_pfns[i] |= MIGRATE_PFN_WRITE;
1420 		if (order)
1421 			dst_pfns[i] |= MIGRATE_PFN_COMPOUND;
1422 		folio_copy(page_folio(dpage), page_folio(spage));
1423 	}
1424 	migrate_device_pages(src_pfns, dst_pfns, npages);
1425 	migrate_device_finalize(src_pfns, dst_pfns, npages);
1426 	kvfree(src_pfns);
1427 	kvfree(dst_pfns);
1428 }
1429 
1430 /* Removes free pages from the free list so they can't be re-allocated */
dmirror_remove_free_pages(struct dmirror_chunk * devmem)1431 static void dmirror_remove_free_pages(struct dmirror_chunk *devmem)
1432 {
1433 	struct dmirror_device *mdevice = devmem->mdevice;
1434 	struct page *page;
1435 	struct folio *folio;
1436 
1437 
1438 	for (folio = mdevice->free_folios; folio; folio = folio_zone_device_data(folio))
1439 		if (dmirror_page_to_chunk(folio_page(folio, 0)) == devmem)
1440 			mdevice->free_folios = folio_zone_device_data(folio);
1441 	for (page = mdevice->free_pages; page; page = page->zone_device_data)
1442 		if (dmirror_page_to_chunk(page) == devmem)
1443 			mdevice->free_pages = page->zone_device_data;
1444 }
1445 
dmirror_device_remove_chunks(struct dmirror_device * mdevice)1446 static void dmirror_device_remove_chunks(struct dmirror_device *mdevice)
1447 {
1448 	unsigned int i;
1449 
1450 	mutex_lock(&mdevice->devmem_lock);
1451 	if (mdevice->devmem_chunks) {
1452 		for (i = 0; i < mdevice->devmem_count; i++) {
1453 			struct dmirror_chunk *devmem =
1454 				mdevice->devmem_chunks[i];
1455 
1456 			spin_lock(&mdevice->lock);
1457 			devmem->remove = true;
1458 			dmirror_remove_free_pages(devmem);
1459 			spin_unlock(&mdevice->lock);
1460 
1461 			dmirror_device_evict_chunk(devmem);
1462 			memunmap_pages(&devmem->pagemap);
1463 			if (devmem->pagemap.type == MEMORY_DEVICE_PRIVATE)
1464 				release_mem_region(devmem->pagemap.range.start,
1465 						   range_len(&devmem->pagemap.range));
1466 			kfree(devmem);
1467 		}
1468 		mdevice->devmem_count = 0;
1469 		mdevice->devmem_capacity = 0;
1470 		mdevice->free_pages = NULL;
1471 		mdevice->free_folios = NULL;
1472 		kfree(mdevice->devmem_chunks);
1473 		mdevice->devmem_chunks = NULL;
1474 	}
1475 	mutex_unlock(&mdevice->devmem_lock);
1476 }
1477 
dmirror_fops_unlocked_ioctl(struct file * filp,unsigned int command,unsigned long arg)1478 static long dmirror_fops_unlocked_ioctl(struct file *filp,
1479 					unsigned int command,
1480 					unsigned long arg)
1481 {
1482 	void __user *uarg = (void __user *)arg;
1483 	struct hmm_dmirror_cmd cmd;
1484 	struct dmirror *dmirror;
1485 	int ret;
1486 
1487 	dmirror = filp->private_data;
1488 	if (!dmirror)
1489 		return -EINVAL;
1490 
1491 	if (copy_from_user(&cmd, uarg, sizeof(cmd)))
1492 		return -EFAULT;
1493 
1494 	if (cmd.addr & ~PAGE_MASK)
1495 		return -EINVAL;
1496 	if (cmd.addr >= (cmd.addr + (cmd.npages << PAGE_SHIFT)))
1497 		return -EINVAL;
1498 
1499 	cmd.cpages = 0;
1500 	cmd.faults = 0;
1501 
1502 	switch (command) {
1503 	case HMM_DMIRROR_READ:
1504 		ret = dmirror_read(dmirror, &cmd);
1505 		break;
1506 
1507 	case HMM_DMIRROR_WRITE:
1508 		ret = dmirror_write(dmirror, &cmd);
1509 		break;
1510 
1511 	case HMM_DMIRROR_MIGRATE_TO_DEV:
1512 		ret = dmirror_migrate_to_device(dmirror, &cmd);
1513 		break;
1514 
1515 	case HMM_DMIRROR_MIGRATE_TO_SYS:
1516 		ret = dmirror_migrate_to_system(dmirror, &cmd);
1517 		break;
1518 
1519 	case HMM_DMIRROR_EXCLUSIVE:
1520 		ret = dmirror_exclusive(dmirror, &cmd);
1521 		break;
1522 
1523 	case HMM_DMIRROR_CHECK_EXCLUSIVE:
1524 		ret = dmirror_check_atomic(dmirror, cmd.addr,
1525 					cmd.addr + (cmd.npages << PAGE_SHIFT));
1526 		break;
1527 
1528 	case HMM_DMIRROR_SNAPSHOT:
1529 		ret = dmirror_snapshot(dmirror, &cmd);
1530 		break;
1531 
1532 	case HMM_DMIRROR_RELEASE:
1533 		dmirror_device_remove_chunks(dmirror->mdevice);
1534 		ret = 0;
1535 		break;
1536 	case HMM_DMIRROR_FLAGS:
1537 		dmirror->flags = cmd.npages;
1538 		ret = 0;
1539 		break;
1540 
1541 	default:
1542 		return -EINVAL;
1543 	}
1544 	if (ret)
1545 		return ret;
1546 
1547 	if (copy_to_user(uarg, &cmd, sizeof(cmd)))
1548 		return -EFAULT;
1549 
1550 	return 0;
1551 }
1552 
dmirror_fops_mmap(struct file * file,struct vm_area_struct * vma)1553 static int dmirror_fops_mmap(struct file *file, struct vm_area_struct *vma)
1554 {
1555 	unsigned long addr;
1556 
1557 	for (addr = vma->vm_start; addr < vma->vm_end; addr += PAGE_SIZE) {
1558 		struct page *page;
1559 		int ret;
1560 
1561 		page = alloc_page(GFP_KERNEL | __GFP_ZERO);
1562 		if (!page)
1563 			return -ENOMEM;
1564 
1565 		ret = vm_insert_page(vma, addr, page);
1566 		if (ret) {
1567 			__free_page(page);
1568 			return ret;
1569 		}
1570 		put_page(page);
1571 	}
1572 
1573 	return 0;
1574 }
1575 
1576 static const struct file_operations dmirror_fops = {
1577 	.open		= dmirror_fops_open,
1578 	.release	= dmirror_fops_release,
1579 	.mmap		= dmirror_fops_mmap,
1580 	.unlocked_ioctl = dmirror_fops_unlocked_ioctl,
1581 	.llseek		= default_llseek,
1582 	.owner		= THIS_MODULE,
1583 };
1584 
dmirror_devmem_free(struct folio * folio)1585 static void dmirror_devmem_free(struct folio *folio)
1586 {
1587 	struct page *page = &folio->page;
1588 	struct page *rpage = BACKING_PAGE(page);
1589 	struct dmirror_device *mdevice;
1590 	struct folio *rfolio = page_folio(rpage);
1591 	unsigned int order = folio_order(rfolio);
1592 
1593 	if (rpage != page) {
1594 		if (order)
1595 			__free_pages(rpage, order);
1596 		else
1597 			__free_page(rpage);
1598 		rpage = NULL;
1599 	}
1600 
1601 	mdevice = dmirror_page_to_device(page);
1602 	spin_lock(&mdevice->lock);
1603 
1604 	/* Return page to our allocator if not freeing the chunk */
1605 	if (!dmirror_page_to_chunk(page)->remove) {
1606 		mdevice->cfree += 1 << order;
1607 		if (order) {
1608 			page->zone_device_data = mdevice->free_folios;
1609 			mdevice->free_folios = page_folio(page);
1610 		} else {
1611 			page->zone_device_data = mdevice->free_pages;
1612 			mdevice->free_pages = page;
1613 		}
1614 	}
1615 	spin_unlock(&mdevice->lock);
1616 }
1617 
dmirror_devmem_fault(struct vm_fault * vmf)1618 static vm_fault_t dmirror_devmem_fault(struct vm_fault *vmf)
1619 {
1620 	struct migrate_vma args = { 0 };
1621 	struct page *rpage;
1622 	struct dmirror *dmirror;
1623 	vm_fault_t ret = 0;
1624 	unsigned int order, nr;
1625 
1626 	/*
1627 	 * Normally, a device would use the page->zone_device_data to point to
1628 	 * the mirror but here we use it to hold the page for the simulated
1629 	 * device memory and that page holds the pointer to the mirror.
1630 	 */
1631 	rpage = folio_zone_device_data(page_folio(vmf->page));
1632 	dmirror = rpage->zone_device_data;
1633 
1634 	/* FIXME demonstrate how we can adjust migrate range */
1635 	order = folio_order(page_folio(vmf->page));
1636 	nr = 1 << order;
1637 
1638 	/*
1639 	 * When folios are partially mapped, we can't rely on the folio
1640 	 * order of vmf->page as the folio might not be fully split yet
1641 	 */
1642 	if (vmf->pte) {
1643 		order = 0;
1644 		nr = 1;
1645 	}
1646 
1647 	/*
1648 	 * Consider a per-cpu cache of src and dst pfns, but with
1649 	 * large number of cpus that might not scale well.
1650 	 */
1651 	args.start = ALIGN_DOWN(vmf->address, (PAGE_SIZE << order));
1652 	args.vma = vmf->vma;
1653 	args.end = args.start + (PAGE_SIZE << order);
1654 
1655 	nr = (args.end - args.start) >> PAGE_SHIFT;
1656 	args.src = kcalloc(nr, sizeof(unsigned long), GFP_KERNEL);
1657 	args.dst = kcalloc(nr, sizeof(unsigned long), GFP_KERNEL);
1658 	args.pgmap_owner = dmirror->mdevice;
1659 	args.flags = dmirror_select_device(dmirror);
1660 	args.fault_page = vmf->page;
1661 
1662 	if (!args.src || !args.dst) {
1663 		ret = VM_FAULT_OOM;
1664 		goto err;
1665 	}
1666 
1667 	if (order)
1668 		args.flags |= MIGRATE_VMA_SELECT_COMPOUND;
1669 
1670 	if (migrate_vma_setup(&args))
1671 		return VM_FAULT_SIGBUS;
1672 
1673 	ret = dmirror_devmem_fault_alloc_and_copy(&args, dmirror);
1674 	if (ret)
1675 		goto err;
1676 	migrate_vma_pages(&args);
1677 	/*
1678 	 * No device finalize step is needed since
1679 	 * dmirror_devmem_fault_alloc_and_copy() will have already
1680 	 * invalidated the device page table.
1681 	 */
1682 	migrate_vma_finalize(&args);
1683 err:
1684 	kfree(args.src);
1685 	kfree(args.dst);
1686 	return ret;
1687 }
1688 
dmirror_devmem_folio_split(struct folio * head,struct folio * tail)1689 static void dmirror_devmem_folio_split(struct folio *head, struct folio *tail)
1690 {
1691 	struct page *rpage = BACKING_PAGE(folio_page(head, 0));
1692 	struct page *rpage_tail;
1693 	struct folio *rfolio;
1694 	unsigned long offset = 0;
1695 
1696 	if (!rpage) {
1697 		tail->page.zone_device_data = NULL;
1698 		return;
1699 	}
1700 
1701 	rfolio = page_folio(rpage);
1702 
1703 	if (tail == NULL) {
1704 		folio_reset_order(rfolio);
1705 		rfolio->mapping = NULL;
1706 		folio_set_count(rfolio, 1);
1707 		return;
1708 	}
1709 
1710 	offset = folio_pfn(tail) - folio_pfn(head);
1711 
1712 	rpage_tail = folio_page(rfolio, offset);
1713 	tail->page.zone_device_data = rpage_tail;
1714 	rpage_tail->zone_device_data = rpage->zone_device_data;
1715 	clear_compound_head(rpage_tail);
1716 	rpage_tail->mapping = NULL;
1717 
1718 	folio_page(tail, 0)->mapping = folio_page(head, 0)->mapping;
1719 	tail->pgmap = head->pgmap;
1720 	folio_set_count(page_folio(rpage_tail), 1);
1721 }
1722 
1723 static const struct dev_pagemap_ops dmirror_devmem_ops = {
1724 	.folio_free	= dmirror_devmem_free,
1725 	.migrate_to_ram	= dmirror_devmem_fault,
1726 	.folio_split	= dmirror_devmem_folio_split,
1727 };
1728 
dmirror_device_init(struct dmirror_device * mdevice,int id)1729 static int dmirror_device_init(struct dmirror_device *mdevice, int id)
1730 {
1731 	dev_t dev;
1732 	int ret;
1733 
1734 	dev = MKDEV(MAJOR(dmirror_dev), id);
1735 	mutex_init(&mdevice->devmem_lock);
1736 	spin_lock_init(&mdevice->lock);
1737 
1738 	cdev_init(&mdevice->cdevice, &dmirror_fops);
1739 	mdevice->cdevice.owner = THIS_MODULE;
1740 	device_initialize(&mdevice->device);
1741 	mdevice->device.devt = dev;
1742 
1743 	ret = dev_set_name(&mdevice->device, "hmm_dmirror%u", id);
1744 	if (ret)
1745 		goto put_device;
1746 
1747 	ret = cdev_device_add(&mdevice->cdevice, &mdevice->device);
1748 	if (ret)
1749 		goto put_device;
1750 
1751 	/* Build a list of free ZONE_DEVICE struct pages */
1752 	return dmirror_allocate_chunk(mdevice, NULL, false);
1753 
1754 put_device:
1755 	put_device(&mdevice->device);
1756 	return ret;
1757 }
1758 
dmirror_device_remove(struct dmirror_device * mdevice)1759 static void dmirror_device_remove(struct dmirror_device *mdevice)
1760 {
1761 	dmirror_device_remove_chunks(mdevice);
1762 	cdev_device_del(&mdevice->cdevice, &mdevice->device);
1763 	put_device(&mdevice->device);
1764 }
1765 
hmm_dmirror_init(void)1766 static int __init hmm_dmirror_init(void)
1767 {
1768 	int ret;
1769 	int id = 0;
1770 	int ndevices = 0;
1771 
1772 	ret = alloc_chrdev_region(&dmirror_dev, 0, DMIRROR_NDEVICES,
1773 				  "HMM_DMIRROR");
1774 	if (ret)
1775 		goto err_unreg;
1776 
1777 	memset(dmirror_devices, 0, DMIRROR_NDEVICES * sizeof(dmirror_devices[0]));
1778 	dmirror_devices[ndevices++].zone_device_type =
1779 				HMM_DMIRROR_MEMORY_DEVICE_PRIVATE;
1780 	dmirror_devices[ndevices++].zone_device_type =
1781 				HMM_DMIRROR_MEMORY_DEVICE_PRIVATE;
1782 	if (spm_addr_dev0 && spm_addr_dev1) {
1783 		dmirror_devices[ndevices++].zone_device_type =
1784 					HMM_DMIRROR_MEMORY_DEVICE_COHERENT;
1785 		dmirror_devices[ndevices++].zone_device_type =
1786 					HMM_DMIRROR_MEMORY_DEVICE_COHERENT;
1787 	}
1788 	for (id = 0; id < ndevices; id++) {
1789 		ret = dmirror_device_init(dmirror_devices + id, id);
1790 		if (ret)
1791 			goto err_chrdev;
1792 	}
1793 
1794 	pr_info("HMM test module loaded. This is only for testing HMM.\n");
1795 	return 0;
1796 
1797 err_chrdev:
1798 	while (--id >= 0)
1799 		dmirror_device_remove(dmirror_devices + id);
1800 	unregister_chrdev_region(dmirror_dev, DMIRROR_NDEVICES);
1801 err_unreg:
1802 	return ret;
1803 }
1804 
hmm_dmirror_exit(void)1805 static void __exit hmm_dmirror_exit(void)
1806 {
1807 	int id;
1808 
1809 	for (id = 0; id < DMIRROR_NDEVICES; id++)
1810 		if (dmirror_devices[id].zone_device_type)
1811 			dmirror_device_remove(dmirror_devices + id);
1812 	unregister_chrdev_region(dmirror_dev, DMIRROR_NDEVICES);
1813 }
1814 
1815 module_init(hmm_dmirror_init);
1816 module_exit(hmm_dmirror_exit);
1817 MODULE_DESCRIPTION("HMM (Heterogeneous Memory Management) test module");
1818 MODULE_LICENSE("GPL");
1819