xref: /linux/tools/testing/selftests/x86/lam.c (revision 74f1af95820fc2ee580a775a3a17c416db30b38c)
1 // SPDX-License-Identifier: GPL-2.0
2 #define _GNU_SOURCE
3 #include <stdio.h>
4 #include <stdlib.h>
5 #include <string.h>
6 #include <sys/syscall.h>
7 #include <sys/ioctl.h>
8 #include <time.h>
9 #include <signal.h>
10 #include <setjmp.h>
11 #include <sys/mman.h>
12 #include <sys/utsname.h>
13 #include <sys/wait.h>
14 #include <sys/stat.h>
15 #include <fcntl.h>
16 #include <inttypes.h>
17 #include <sched.h>
18 
19 #include <sys/uio.h>
20 #include <linux/io_uring.h>
21 #include "../kselftest.h"
22 
23 #ifndef __x86_64__
24 # error This test is 64-bit only
25 #endif
26 
27 /* LAM modes, these definitions were copied from kernel code */
28 #define LAM_NONE                0
29 #define LAM_U57_BITS            6
30 
31 #define LAM_U57_MASK            (0x3fULL << 57)
32 /* arch prctl for LAM */
33 #define ARCH_GET_UNTAG_MASK     0x4001
34 #define ARCH_ENABLE_TAGGED_ADDR 0x4002
35 #define ARCH_GET_MAX_TAG_BITS   0x4003
36 #define ARCH_FORCE_TAGGED_SVA	0x4004
37 
38 /* Specified test function bits */
39 #define FUNC_MALLOC             0x1
40 #define FUNC_BITS               0x2
41 #define FUNC_MMAP               0x4
42 #define FUNC_SYSCALL            0x8
43 #define FUNC_URING              0x10
44 #define FUNC_INHERITE           0x20
45 #define FUNC_PASID              0x40
46 
47 /* get_user() pointer test cases */
48 #define GET_USER_USER           0
49 #define GET_USER_KERNEL_TOP     1
50 #define GET_USER_KERNEL_BOT     2
51 #define GET_USER_KERNEL         3
52 
53 #define TEST_MASK               0x7f
54 #define L5_SIGN_EXT_MASK        (0xFFUL << 56)
55 #define L4_SIGN_EXT_MASK        (0x1FFFFUL << 47)
56 
57 #define LOW_ADDR                (0x1UL << 30)
58 #define HIGH_ADDR               (0x3UL << 48)
59 
60 #define MALLOC_LEN              32
61 
62 #define PAGE_SIZE               (4 << 10)
63 
64 #define STACK_SIZE		65536
65 
66 #define barrier() ({						\
67 		   __asm__ __volatile__("" : : : "memory");	\
68 })
69 
70 #define URING_QUEUE_SZ 1
71 #define URING_BLOCK_SZ 2048
72 
73 /* Pasid test define */
74 #define LAM_CMD_BIT 0x1
75 #define PAS_CMD_BIT 0x2
76 #define SVA_CMD_BIT 0x4
77 
78 #define PAS_CMD(cmd1, cmd2, cmd3) (((cmd3) << 8) | ((cmd2) << 4) | ((cmd1) << 0))
79 
80 struct testcases {
81 	unsigned int later;
82 	int expected; /* 2: SIGSEGV Error; 1: other errors */
83 	unsigned long lam;
84 	uint64_t addr;
85 	uint64_t cmd;
86 	int (*test_func)(struct testcases *test);
87 	const char *msg;
88 };
89 
90 /* Used by CQ of uring, source file handler and file's size */
91 struct file_io {
92 	int file_fd;
93 	off_t file_sz;
94 	struct iovec iovecs[];
95 };
96 
97 struct io_uring_queue {
98 	unsigned int *head;
99 	unsigned int *tail;
100 	unsigned int *ring_mask;
101 	unsigned int *ring_entries;
102 	unsigned int *flags;
103 	unsigned int *array;
104 	union {
105 		struct io_uring_cqe *cqes;
106 		struct io_uring_sqe *sqes;
107 	} queue;
108 	size_t ring_sz;
109 };
110 
111 struct io_ring {
112 	int ring_fd;
113 	struct io_uring_queue sq_ring;
114 	struct io_uring_queue cq_ring;
115 };
116 
117 int tests_cnt;
118 jmp_buf segv_env;
119 
120 static void segv_handler(int sig)
121 {
122 	ksft_print_msg("Get segmentation fault(%d).", sig);
123 
124 	siglongjmp(segv_env, 1);
125 }
126 
127 static inline int lam_is_available(void)
128 {
129 	unsigned int cpuinfo[4];
130 	unsigned long bits = 0;
131 	int ret;
132 
133 	__cpuid_count(0x7, 1, cpuinfo[0], cpuinfo[1], cpuinfo[2], cpuinfo[3]);
134 
135 	/* Check if cpu supports LAM */
136 	if (!(cpuinfo[0] & (1 << 26))) {
137 		ksft_print_msg("LAM is not supported!\n");
138 		return 0;
139 	}
140 
141 	/* Return 0 if CONFIG_ADDRESS_MASKING is not set */
142 	ret = syscall(SYS_arch_prctl, ARCH_GET_MAX_TAG_BITS, &bits);
143 	if (ret) {
144 		ksft_print_msg("LAM is disabled in the kernel!\n");
145 		return 0;
146 	}
147 
148 	return 1;
149 }
150 
151 static inline int la57_enabled(void)
152 {
153 	int ret;
154 	void *p;
155 
156 	p = mmap((void *)HIGH_ADDR, PAGE_SIZE, PROT_READ | PROT_WRITE,
157 		 MAP_PRIVATE | MAP_ANONYMOUS | MAP_FIXED, -1, 0);
158 
159 	ret = p == MAP_FAILED ? 0 : 1;
160 
161 	munmap(p, PAGE_SIZE);
162 	return ret;
163 }
164 
165 /*
166  * Set tagged address and read back untag mask.
167  * check if the untagged mask is expected.
168  *
169  * @return:
170  * 0: Set LAM mode successfully
171  * others: failed to set LAM
172  */
173 static int set_lam(unsigned long lam)
174 {
175 	int ret = 0;
176 	uint64_t ptr = 0;
177 
178 	if (lam != LAM_U57_BITS && lam != LAM_NONE)
179 		return -1;
180 
181 	/* Skip check return */
182 	syscall(SYS_arch_prctl, ARCH_ENABLE_TAGGED_ADDR, lam);
183 
184 	/* Get untagged mask */
185 	syscall(SYS_arch_prctl, ARCH_GET_UNTAG_MASK, &ptr);
186 
187 	/* Check mask returned is expected */
188 	if (lam == LAM_U57_BITS)
189 		ret = (ptr != ~(LAM_U57_MASK));
190 	else if (lam == LAM_NONE)
191 		ret = (ptr != -1ULL);
192 
193 	return ret;
194 }
195 
196 static unsigned long get_default_tag_bits(void)
197 {
198 	pid_t pid;
199 	int lam = LAM_NONE;
200 	int ret = 0;
201 
202 	pid = fork();
203 	if (pid < 0) {
204 		perror("Fork failed.");
205 	} else if (pid == 0) {
206 		/* Set LAM mode in child process */
207 		if (set_lam(LAM_U57_BITS) == 0)
208 			lam = LAM_U57_BITS;
209 		else
210 			lam = LAM_NONE;
211 		exit(lam);
212 	} else {
213 		wait(&ret);
214 		lam = WEXITSTATUS(ret);
215 	}
216 
217 	return lam;
218 }
219 
220 /*
221  * Set tagged address and read back untag mask.
222  * check if the untag mask is expected.
223  */
224 static int get_lam(void)
225 {
226 	uint64_t ptr = 0;
227 	int ret = -1;
228 	/* Get untagged mask */
229 	if (syscall(SYS_arch_prctl, ARCH_GET_UNTAG_MASK, &ptr) == -1)
230 		return -1;
231 
232 	/* Check mask returned is expected */
233 	if (ptr == ~(LAM_U57_MASK))
234 		ret = LAM_U57_BITS;
235 	else if (ptr == -1ULL)
236 		ret = LAM_NONE;
237 
238 
239 	return ret;
240 }
241 
242 /* According to LAM mode, set metadata in high bits */
243 static uint64_t set_metadata(uint64_t src, unsigned long lam)
244 {
245 	uint64_t metadata;
246 
247 	srand(time(NULL));
248 
249 	switch (lam) {
250 	case LAM_U57_BITS: /* Set metadata in bits 62:57 */
251 		/* Get a random non-zero value as metadata */
252 		metadata = (rand() % ((1UL << LAM_U57_BITS) - 1) + 1) << 57;
253 		metadata |= (src & ~(LAM_U57_MASK));
254 		break;
255 	default:
256 		metadata = src;
257 		break;
258 	}
259 
260 	return metadata;
261 }
262 
263 /*
264  * Set metadata in user pointer, compare new pointer with original pointer.
265  * both pointers should point to the same address.
266  *
267  * @return:
268  * 0: value on the pointer with metadata and value on original are same
269  * 1: not same.
270  */
271 static int handle_lam_test(void *src, unsigned int lam)
272 {
273 	char *ptr;
274 
275 	strcpy((char *)src, "USER POINTER");
276 
277 	ptr = (char *)set_metadata((uint64_t)src, lam);
278 	if (src == ptr)
279 		return 0;
280 
281 	/* Copy a string into the pointer with metadata */
282 	strcpy((char *)ptr, "METADATA POINTER");
283 
284 	return (!!strcmp((char *)src, (char *)ptr));
285 }
286 
287 
288 int handle_max_bits(struct testcases *test)
289 {
290 	unsigned long exp_bits = get_default_tag_bits();
291 	unsigned long bits = 0;
292 
293 	if (exp_bits != LAM_NONE)
294 		exp_bits = LAM_U57_BITS;
295 
296 	/* Get LAM max tag bits */
297 	if (syscall(SYS_arch_prctl, ARCH_GET_MAX_TAG_BITS, &bits) == -1)
298 		return 1;
299 
300 	return (exp_bits != bits);
301 }
302 
303 /*
304  * Test lam feature through dereference pointer get from malloc.
305  * @return 0: Pass test. 1: Get failure during test 2: Get SIGSEGV
306  */
307 static int handle_malloc(struct testcases *test)
308 {
309 	char *ptr = NULL;
310 	int ret = 0;
311 
312 	if (test->later == 0 && test->lam != 0)
313 		if (set_lam(test->lam) == -1)
314 			return 1;
315 
316 	ptr = (char *)malloc(MALLOC_LEN);
317 	if (ptr == NULL) {
318 		perror("malloc() failure\n");
319 		return 1;
320 	}
321 
322 	/* Set signal handler */
323 	if (sigsetjmp(segv_env, 1) == 0) {
324 		signal(SIGSEGV, segv_handler);
325 		ret = handle_lam_test(ptr, test->lam);
326 	} else {
327 		ret = 2;
328 	}
329 
330 	if (test->later != 0 && test->lam != 0)
331 		if (set_lam(test->lam) == -1 && ret == 0)
332 			ret = 1;
333 
334 	free(ptr);
335 
336 	return ret;
337 }
338 
339 static int handle_mmap(struct testcases *test)
340 {
341 	void *ptr;
342 	unsigned int flags = MAP_PRIVATE | MAP_ANONYMOUS | MAP_FIXED;
343 	int ret = 0;
344 
345 	if (test->later == 0 && test->lam != 0)
346 		if (set_lam(test->lam) != 0)
347 			return 1;
348 
349 	ptr = mmap((void *)test->addr, PAGE_SIZE, PROT_READ | PROT_WRITE,
350 		   flags, -1, 0);
351 	if (ptr == MAP_FAILED) {
352 		if (test->addr == HIGH_ADDR)
353 			if (!la57_enabled())
354 				return 3; /* unsupport LA57 */
355 		return 1;
356 	}
357 
358 	if (test->later != 0 && test->lam != 0)
359 		if (set_lam(test->lam) != 0)
360 			ret = 1;
361 
362 	if (ret == 0) {
363 		if (sigsetjmp(segv_env, 1) == 0) {
364 			signal(SIGSEGV, segv_handler);
365 			ret = handle_lam_test(ptr, test->lam);
366 		} else {
367 			ret = 2;
368 		}
369 	}
370 
371 	munmap(ptr, PAGE_SIZE);
372 	return ret;
373 }
374 
375 static int handle_syscall(struct testcases *test)
376 {
377 	struct utsname unme, *pu;
378 	int ret = 0;
379 
380 	if (test->later == 0 && test->lam != 0)
381 		if (set_lam(test->lam) != 0)
382 			return 1;
383 
384 	if (sigsetjmp(segv_env, 1) == 0) {
385 		signal(SIGSEGV, segv_handler);
386 		pu = (struct utsname *)set_metadata((uint64_t)&unme, test->lam);
387 		ret = uname(pu);
388 		if (ret < 0)
389 			ret = 1;
390 	} else {
391 		ret = 2;
392 	}
393 
394 	if (test->later != 0 && test->lam != 0)
395 		if (set_lam(test->lam) != -1 && ret == 0)
396 			ret = 1;
397 
398 	return ret;
399 }
400 
401 static int get_user_syscall(struct testcases *test)
402 {
403 	uint64_t ptr_address, bitmask;
404 	int fd, ret = 0;
405 	void *ptr;
406 
407 	if (la57_enabled()) {
408 		bitmask = L5_SIGN_EXT_MASK;
409 		ptr_address = HIGH_ADDR;
410 	} else {
411 		bitmask = L4_SIGN_EXT_MASK;
412 		ptr_address = LOW_ADDR;
413 	}
414 
415 	ptr = mmap((void *)ptr_address, PAGE_SIZE, PROT_READ | PROT_WRITE,
416 		   MAP_PRIVATE | MAP_ANONYMOUS | MAP_FIXED, -1, 0);
417 
418 	if (ptr == MAP_FAILED) {
419 		perror("failed to map byte to pass into get_user");
420 		return 1;
421 	}
422 
423 	if (set_lam(test->lam) != 0) {
424 		ret = 2;
425 		goto error;
426 	}
427 
428 	fd = memfd_create("lam_ioctl", 0);
429 	if (fd == -1) {
430 		munmap(ptr, PAGE_SIZE);
431 		exit(EXIT_FAILURE);
432 	}
433 
434 	switch (test->later) {
435 	case GET_USER_USER:
436 		/* Control group - properly tagged user pointer */
437 		ptr = (void *)set_metadata((uint64_t)ptr, test->lam);
438 		break;
439 	case GET_USER_KERNEL_TOP:
440 		/* Kernel address with top bit cleared */
441 		bitmask &= (bitmask >> 1);
442 		ptr = (void *)((uint64_t)ptr | bitmask);
443 		break;
444 	case GET_USER_KERNEL_BOT:
445 		/* Kernel address with bottom sign-extension bit cleared */
446 		bitmask &= (bitmask << 1);
447 		ptr = (void *)((uint64_t)ptr | bitmask);
448 		break;
449 	case GET_USER_KERNEL:
450 		/* Try to pass a kernel address */
451 		ptr = (void *)((uint64_t)ptr | bitmask);
452 		break;
453 	default:
454 		printf("Invalid test case value passed!\n");
455 		break;
456 	}
457 
458 	/*
459 	 * Use FIOASYNC ioctl because it utilizes get_user() internally and is
460 	 * very non-invasive to the system. Pass differently tagged pointers to
461 	 * get_user() in order to verify that valid user pointers are going
462 	 * through and invalid kernel/non-canonical pointers are not.
463 	 */
464 	if (ioctl(fd, FIOASYNC, ptr) != 0)
465 		ret = 1;
466 
467 	close(fd);
468 error:
469 	munmap(ptr, PAGE_SIZE);
470 	return ret;
471 }
472 
473 int sys_uring_setup(unsigned int entries, struct io_uring_params *p)
474 {
475 	return (int)syscall(__NR_io_uring_setup, entries, p);
476 }
477 
478 int sys_uring_enter(int fd, unsigned int to, unsigned int min, unsigned int flags)
479 {
480 	return (int)syscall(__NR_io_uring_enter, fd, to, min, flags, NULL, 0);
481 }
482 
483 /* Init submission queue and completion queue */
484 int mmap_io_uring(struct io_uring_params p, struct io_ring *s)
485 {
486 	struct io_uring_queue *sring = &s->sq_ring;
487 	struct io_uring_queue *cring = &s->cq_ring;
488 
489 	sring->ring_sz = p.sq_off.array + p.sq_entries * sizeof(unsigned int);
490 	cring->ring_sz = p.cq_off.cqes + p.cq_entries * sizeof(struct io_uring_cqe);
491 
492 	if (p.features & IORING_FEAT_SINGLE_MMAP) {
493 		if (cring->ring_sz > sring->ring_sz)
494 			sring->ring_sz = cring->ring_sz;
495 
496 		cring->ring_sz = sring->ring_sz;
497 	}
498 
499 	void *sq_ptr = mmap(0, sring->ring_sz, PROT_READ | PROT_WRITE,
500 			    MAP_SHARED | MAP_POPULATE, s->ring_fd,
501 			    IORING_OFF_SQ_RING);
502 
503 	if (sq_ptr == MAP_FAILED) {
504 		perror("sub-queue!");
505 		return 1;
506 	}
507 
508 	void *cq_ptr = sq_ptr;
509 
510 	if (!(p.features & IORING_FEAT_SINGLE_MMAP)) {
511 		cq_ptr = mmap(0, cring->ring_sz, PROT_READ | PROT_WRITE,
512 			      MAP_SHARED | MAP_POPULATE, s->ring_fd,
513 			      IORING_OFF_CQ_RING);
514 		if (cq_ptr == MAP_FAILED) {
515 			perror("cpl-queue!");
516 			munmap(sq_ptr, sring->ring_sz);
517 			return 1;
518 		}
519 	}
520 
521 	sring->head = sq_ptr + p.sq_off.head;
522 	sring->tail = sq_ptr + p.sq_off.tail;
523 	sring->ring_mask = sq_ptr + p.sq_off.ring_mask;
524 	sring->ring_entries = sq_ptr + p.sq_off.ring_entries;
525 	sring->flags = sq_ptr + p.sq_off.flags;
526 	sring->array = sq_ptr + p.sq_off.array;
527 
528 	/* Map a queue as mem map */
529 	s->sq_ring.queue.sqes = mmap(0, p.sq_entries * sizeof(struct io_uring_sqe),
530 				     PROT_READ | PROT_WRITE, MAP_SHARED | MAP_POPULATE,
531 				     s->ring_fd, IORING_OFF_SQES);
532 	if (s->sq_ring.queue.sqes == MAP_FAILED) {
533 		munmap(sq_ptr, sring->ring_sz);
534 		if (sq_ptr != cq_ptr) {
535 			ksft_print_msg("failed to mmap uring queue!");
536 			munmap(cq_ptr, cring->ring_sz);
537 			return 1;
538 		}
539 	}
540 
541 	cring->head = cq_ptr + p.cq_off.head;
542 	cring->tail = cq_ptr + p.cq_off.tail;
543 	cring->ring_mask = cq_ptr + p.cq_off.ring_mask;
544 	cring->ring_entries = cq_ptr + p.cq_off.ring_entries;
545 	cring->queue.cqes = cq_ptr + p.cq_off.cqes;
546 
547 	return 0;
548 }
549 
550 /* Init io_uring queues */
551 int setup_io_uring(struct io_ring *s)
552 {
553 	struct io_uring_params para;
554 
555 	memset(&para, 0, sizeof(para));
556 	s->ring_fd = sys_uring_setup(URING_QUEUE_SZ, &para);
557 	if (s->ring_fd < 0)
558 		return 1;
559 
560 	return mmap_io_uring(para, s);
561 }
562 
563 /*
564  * Get data from completion queue. the data buffer saved the file data
565  * return 0: success; others: error;
566  */
567 int handle_uring_cq(struct io_ring *s)
568 {
569 	struct file_io *fi = NULL;
570 	struct io_uring_queue *cring = &s->cq_ring;
571 	struct io_uring_cqe *cqe;
572 	unsigned int head;
573 	off_t len = 0;
574 
575 	head = *cring->head;
576 
577 	do {
578 		barrier();
579 		if (head == *cring->tail)
580 			break;
581 		/* Get the entry */
582 		cqe = &cring->queue.cqes[head & *s->cq_ring.ring_mask];
583 		fi = (struct file_io *)cqe->user_data;
584 		if (cqe->res < 0)
585 			break;
586 
587 		int blocks = (int)(fi->file_sz + URING_BLOCK_SZ - 1) / URING_BLOCK_SZ;
588 
589 		for (int i = 0; i < blocks; i++)
590 			len += fi->iovecs[i].iov_len;
591 
592 		head++;
593 	} while (1);
594 
595 	*cring->head = head;
596 	barrier();
597 
598 	return (len != fi->file_sz);
599 }
600 
601 /*
602  * Submit squeue. specify via IORING_OP_READV.
603  * the buffer need to be set metadata according to LAM mode
604  */
605 int handle_uring_sq(struct io_ring *ring, struct file_io *fi, unsigned long lam)
606 {
607 	int file_fd = fi->file_fd;
608 	struct io_uring_queue *sring = &ring->sq_ring;
609 	unsigned int index = 0, cur_block = 0, tail = 0, next_tail = 0;
610 	struct io_uring_sqe *sqe;
611 
612 	off_t remain = fi->file_sz;
613 	int blocks = (int)(remain + URING_BLOCK_SZ - 1) / URING_BLOCK_SZ;
614 
615 	while (remain) {
616 		off_t bytes = remain;
617 		void *buf;
618 
619 		if (bytes > URING_BLOCK_SZ)
620 			bytes = URING_BLOCK_SZ;
621 
622 		fi->iovecs[cur_block].iov_len = bytes;
623 
624 		if (posix_memalign(&buf, URING_BLOCK_SZ, URING_BLOCK_SZ))
625 			return 1;
626 
627 		fi->iovecs[cur_block].iov_base = (void *)set_metadata((uint64_t)buf, lam);
628 		remain -= bytes;
629 		cur_block++;
630 	}
631 
632 	next_tail = *sring->tail;
633 	tail = next_tail;
634 	next_tail++;
635 
636 	barrier();
637 
638 	index = tail & *ring->sq_ring.ring_mask;
639 
640 	sqe = &ring->sq_ring.queue.sqes[index];
641 	sqe->fd = file_fd;
642 	sqe->flags = 0;
643 	sqe->opcode = IORING_OP_READV;
644 	sqe->addr = (unsigned long)fi->iovecs;
645 	sqe->len = blocks;
646 	sqe->off = 0;
647 	sqe->user_data = (uint64_t)fi;
648 
649 	sring->array[index] = index;
650 	tail = next_tail;
651 
652 	if (*sring->tail != tail) {
653 		*sring->tail = tail;
654 		barrier();
655 	}
656 
657 	if (sys_uring_enter(ring->ring_fd, 1, 1, IORING_ENTER_GETEVENTS) < 0)
658 		return 1;
659 
660 	return 0;
661 }
662 
663 /*
664  * Test LAM in async I/O and io_uring, read current binery through io_uring
665  * Set metadata in pointers to iovecs buffer.
666  */
667 int do_uring(unsigned long lam)
668 {
669 	struct io_ring *ring;
670 	struct file_io *fi;
671 	struct stat st;
672 	int ret = 1;
673 	char path[PATH_MAX] = {0};
674 
675 	/* get current process path */
676 	if (readlink("/proc/self/exe", path, PATH_MAX - 1) <= 0)
677 		return 1;
678 
679 	int file_fd = open(path, O_RDONLY);
680 
681 	if (file_fd < 0)
682 		return 1;
683 
684 	if (fstat(file_fd, &st) < 0)
685 		goto cleanup;
686 
687 	off_t file_sz = st.st_size;
688 
689 	int blocks = (int)(file_sz + URING_BLOCK_SZ - 1) / URING_BLOCK_SZ;
690 
691 	fi = malloc(sizeof(*fi) + sizeof(struct iovec) * blocks);
692 	if (!fi)
693 		goto cleanup;
694 
695 	fi->file_sz = file_sz;
696 	fi->file_fd = file_fd;
697 
698 	ring = malloc(sizeof(*ring));
699 	if (!ring) {
700 		free(fi);
701 		goto cleanup;
702 	}
703 
704 	memset(ring, 0, sizeof(struct io_ring));
705 
706 	if (setup_io_uring(ring))
707 		goto out;
708 
709 	if (handle_uring_sq(ring, fi, lam))
710 		goto out;
711 
712 	ret = handle_uring_cq(ring);
713 
714 out:
715 	free(ring);
716 
717 	for (int i = 0; i < blocks; i++) {
718 		if (fi->iovecs[i].iov_base) {
719 			uint64_t addr = ((uint64_t)fi->iovecs[i].iov_base);
720 
721 			switch (lam) {
722 			case LAM_U57_BITS: /* Clear bits 62:57 */
723 				addr = (addr & ~(LAM_U57_MASK));
724 				break;
725 			}
726 			free((void *)addr);
727 			fi->iovecs[i].iov_base = NULL;
728 		}
729 	}
730 
731 	free(fi);
732 cleanup:
733 	close(file_fd);
734 
735 	return ret;
736 }
737 
738 int handle_uring(struct testcases *test)
739 {
740 	int ret = 0;
741 
742 	if (test->later == 0 && test->lam != 0)
743 		if (set_lam(test->lam) != 0)
744 			return 1;
745 
746 	if (sigsetjmp(segv_env, 1) == 0) {
747 		signal(SIGSEGV, segv_handler);
748 		ret = do_uring(test->lam);
749 	} else {
750 		ret = 2;
751 	}
752 
753 	return ret;
754 }
755 
756 static int fork_test(struct testcases *test)
757 {
758 	int ret, child_ret;
759 	pid_t pid;
760 
761 	pid = fork();
762 	if (pid < 0) {
763 		perror("Fork failed.");
764 		ret = 1;
765 	} else if (pid == 0) {
766 		ret = test->test_func(test);
767 		exit(ret);
768 	} else {
769 		wait(&child_ret);
770 		ret = WEXITSTATUS(child_ret);
771 	}
772 
773 	return ret;
774 }
775 
776 static int handle_execve(struct testcases *test)
777 {
778 	int ret, child_ret;
779 	int lam = test->lam;
780 	pid_t pid;
781 
782 	pid = fork();
783 	if (pid < 0) {
784 		perror("Fork failed.");
785 		ret = 1;
786 	} else if (pid == 0) {
787 		char path[PATH_MAX] = {0};
788 
789 		/* Set LAM mode in parent process */
790 		if (set_lam(lam) != 0)
791 			return 1;
792 
793 		/* Get current binary's path and the binary was run by execve */
794 		if (readlink("/proc/self/exe", path, PATH_MAX - 1) <= 0)
795 			exit(-1);
796 
797 		/* run binary to get LAM mode and return to parent process */
798 		if (execlp(path, path, "-t 0x0", NULL) < 0) {
799 			perror("error on exec");
800 			exit(-1);
801 		}
802 	} else {
803 		wait(&child_ret);
804 		ret = WEXITSTATUS(child_ret);
805 		if (ret != LAM_NONE)
806 			return 1;
807 	}
808 
809 	return 0;
810 }
811 
812 static int handle_inheritance(struct testcases *test)
813 {
814 	int ret, child_ret;
815 	int lam = test->lam;
816 	pid_t pid;
817 
818 	/* Set LAM mode in parent process */
819 	if (set_lam(lam) != 0)
820 		return 1;
821 
822 	pid = fork();
823 	if (pid < 0) {
824 		perror("Fork failed.");
825 		return 1;
826 	} else if (pid == 0) {
827 		/* Set LAM mode in parent process */
828 		int child_lam = get_lam();
829 
830 		exit(child_lam);
831 	} else {
832 		wait(&child_ret);
833 		ret = WEXITSTATUS(child_ret);
834 
835 		if (lam != ret)
836 			return 1;
837 	}
838 
839 	return 0;
840 }
841 
842 static int thread_fn_get_lam(void *arg)
843 {
844 	return get_lam();
845 }
846 
847 static int thread_fn_set_lam(void *arg)
848 {
849 	struct testcases *test = arg;
850 
851 	return set_lam(test->lam);
852 }
853 
854 static int handle_thread(struct testcases *test)
855 {
856 	char stack[STACK_SIZE];
857 	int ret, child_ret;
858 	int lam = 0;
859 	pid_t pid;
860 
861 	/* Set LAM mode in parent process */
862 	if (!test->later) {
863 		lam = test->lam;
864 		if (set_lam(lam) != 0)
865 			return 1;
866 	}
867 
868 	pid = clone(thread_fn_get_lam, stack + STACK_SIZE,
869 		    SIGCHLD | CLONE_FILES | CLONE_FS | CLONE_VM, NULL);
870 	if (pid < 0) {
871 		perror("Clone failed.");
872 		return 1;
873 	}
874 
875 	waitpid(pid, &child_ret, 0);
876 	ret = WEXITSTATUS(child_ret);
877 
878 	if (lam != ret)
879 		return 1;
880 
881 	if (test->later) {
882 		if (set_lam(test->lam) != 0)
883 			return 1;
884 	}
885 
886 	return 0;
887 }
888 
889 static int handle_thread_enable(struct testcases *test)
890 {
891 	char stack[STACK_SIZE];
892 	int ret, child_ret;
893 	int lam = test->lam;
894 	pid_t pid;
895 
896 	pid = clone(thread_fn_set_lam, stack + STACK_SIZE,
897 		    SIGCHLD | CLONE_FILES | CLONE_FS | CLONE_VM, test);
898 	if (pid < 0) {
899 		perror("Clone failed.");
900 		return 1;
901 	}
902 
903 	waitpid(pid, &child_ret, 0);
904 	ret = WEXITSTATUS(child_ret);
905 
906 	if (lam != ret)
907 		return 1;
908 
909 	return 0;
910 }
911 static void run_test(struct testcases *test, int count)
912 {
913 	int i, ret = 0;
914 
915 	for (i = 0; i < count; i++) {
916 		struct testcases *t = test + i;
917 
918 		/* fork a process to run test case */
919 		tests_cnt++;
920 		ret = fork_test(t);
921 
922 		/* return 3 is not support LA57, the case should be skipped */
923 		if (ret == 3) {
924 			ksft_test_result_skip("%s", t->msg);
925 			continue;
926 		}
927 
928 		if (ret != 0)
929 			ret = (t->expected == ret);
930 		else
931 			ret = !(t->expected);
932 
933 		ksft_test_result(ret, "%s", t->msg);
934 	}
935 }
936 
937 static struct testcases uring_cases[] = {
938 	{
939 		.later = 0,
940 		.lam = LAM_U57_BITS,
941 		.test_func = handle_uring,
942 		.msg = "URING: LAM_U57. Dereferencing pointer with metadata\n",
943 	},
944 	{
945 		.later = 1,
946 		.expected = 1,
947 		.lam = LAM_U57_BITS,
948 		.test_func = handle_uring,
949 		.msg = "URING:[Negative] Disable LAM. Dereferencing pointer with metadata.\n",
950 	},
951 };
952 
953 static struct testcases malloc_cases[] = {
954 	{
955 		.later = 0,
956 		.lam = LAM_U57_BITS,
957 		.test_func = handle_malloc,
958 		.msg = "MALLOC: LAM_U57. Dereferencing pointer with metadata\n",
959 	},
960 	{
961 		.later = 1,
962 		.expected = 2,
963 		.lam = LAM_U57_BITS,
964 		.test_func = handle_malloc,
965 		.msg = "MALLOC:[Negative] Disable LAM. Dereferencing pointer with metadata.\n",
966 	},
967 };
968 
969 static struct testcases bits_cases[] = {
970 	{
971 		.test_func = handle_max_bits,
972 		.msg = "BITS: Check default tag bits\n",
973 	},
974 };
975 
976 static struct testcases syscall_cases[] = {
977 	{
978 		.later = 0,
979 		.lam = LAM_U57_BITS,
980 		.test_func = handle_syscall,
981 		.msg = "SYSCALL: LAM_U57. syscall with metadata\n",
982 	},
983 	{
984 		.later = 1,
985 		.expected = 1,
986 		.lam = LAM_U57_BITS,
987 		.test_func = handle_syscall,
988 		.msg = "SYSCALL:[Negative] Disable LAM. Dereferencing pointer with metadata.\n",
989 	},
990 	{
991 		.later = GET_USER_USER,
992 		.lam = LAM_U57_BITS,
993 		.test_func = get_user_syscall,
994 		.msg = "GET_USER: get_user() and pass a properly tagged user pointer.\n",
995 	},
996 	{
997 		.later = GET_USER_KERNEL_TOP,
998 		.expected = 1,
999 		.lam = LAM_U57_BITS,
1000 		.test_func = get_user_syscall,
1001 		.msg = "GET_USER:[Negative] get_user() with a kernel pointer and the top bit cleared.\n",
1002 	},
1003 	{
1004 		.later = GET_USER_KERNEL_BOT,
1005 		.expected = 1,
1006 		.lam = LAM_U57_BITS,
1007 		.test_func = get_user_syscall,
1008 		.msg = "GET_USER:[Negative] get_user() with a kernel pointer and the bottom sign-extension bit cleared.\n",
1009 	},
1010 	{
1011 		.later = GET_USER_KERNEL,
1012 		.expected = 1,
1013 		.lam = LAM_U57_BITS,
1014 		.test_func = get_user_syscall,
1015 		.msg = "GET_USER:[Negative] get_user() and pass a kernel pointer.\n",
1016 	},
1017 };
1018 
1019 static struct testcases mmap_cases[] = {
1020 	{
1021 		.later = 1,
1022 		.expected = 0,
1023 		.lam = LAM_U57_BITS,
1024 		.addr = HIGH_ADDR,
1025 		.test_func = handle_mmap,
1026 		.msg = "MMAP: First mmap high address, then set LAM_U57.\n",
1027 	},
1028 	{
1029 		.later = 0,
1030 		.expected = 0,
1031 		.lam = LAM_U57_BITS,
1032 		.addr = HIGH_ADDR,
1033 		.test_func = handle_mmap,
1034 		.msg = "MMAP: First LAM_U57, then High address.\n",
1035 	},
1036 	{
1037 		.later = 0,
1038 		.expected = 0,
1039 		.lam = LAM_U57_BITS,
1040 		.addr = LOW_ADDR,
1041 		.test_func = handle_mmap,
1042 		.msg = "MMAP: First LAM_U57, then Low address.\n",
1043 	},
1044 };
1045 
1046 static struct testcases inheritance_cases[] = {
1047 	{
1048 		.expected = 0,
1049 		.lam = LAM_U57_BITS,
1050 		.test_func = handle_inheritance,
1051 		.msg = "FORK: LAM_U57, child process should get LAM mode same as parent\n",
1052 	},
1053 	{
1054 		.expected = 0,
1055 		.lam = LAM_U57_BITS,
1056 		.test_func = handle_thread,
1057 		.msg = "THREAD: LAM_U57, child thread should get LAM mode same as parent\n",
1058 	},
1059 	{
1060 		.expected = 1,
1061 		.lam = LAM_U57_BITS,
1062 		.test_func = handle_thread_enable,
1063 		.msg = "THREAD: [NEGATIVE] Enable LAM in child.\n",
1064 	},
1065 	{
1066 		.expected = 1,
1067 		.later = 1,
1068 		.lam = LAM_U57_BITS,
1069 		.test_func = handle_thread,
1070 		.msg = "THREAD: [NEGATIVE] Enable LAM in parent after thread created.\n",
1071 	},
1072 	{
1073 		.expected = 0,
1074 		.lam = LAM_U57_BITS,
1075 		.test_func = handle_execve,
1076 		.msg = "EXECVE: LAM_U57, child process should get disabled LAM mode\n",
1077 	},
1078 };
1079 
1080 static void cmd_help(void)
1081 {
1082 	printf("usage: lam [-h] [-t test list]\n");
1083 	printf("\t-t test list: run tests specified in the test list, default:0x%x\n", TEST_MASK);
1084 	printf("\t\t0x1:malloc; 0x2:max_bits; 0x4:mmap; 0x8:syscall; 0x10:io_uring; 0x20:inherit;\n");
1085 	printf("\t-h: help\n");
1086 }
1087 
1088 /* Check for file existence */
1089 uint8_t file_Exists(const char *fileName)
1090 {
1091 	struct stat buffer;
1092 
1093 	uint8_t ret = (stat(fileName, &buffer) == 0);
1094 
1095 	return ret;
1096 }
1097 
1098 /* Sysfs idxd files */
1099 const char *dsa_configs[] = {
1100 	"echo 1 > /sys/bus/dsa/devices/dsa0/wq0.1/group_id",
1101 	"echo shared > /sys/bus/dsa/devices/dsa0/wq0.1/mode",
1102 	"echo 10 > /sys/bus/dsa/devices/dsa0/wq0.1/priority",
1103 	"echo 16 > /sys/bus/dsa/devices/dsa0/wq0.1/size",
1104 	"echo 15 > /sys/bus/dsa/devices/dsa0/wq0.1/threshold",
1105 	"echo user > /sys/bus/dsa/devices/dsa0/wq0.1/type",
1106 	"echo MyApp1 > /sys/bus/dsa/devices/dsa0/wq0.1/name",
1107 	"echo 1 > /sys/bus/dsa/devices/dsa0/engine0.1/group_id",
1108 	"echo dsa0 > /sys/bus/dsa/drivers/idxd/bind",
1109 	/* bind files and devices, generated a device file in /dev */
1110 	"echo wq0.1 > /sys/bus/dsa/drivers/user/bind",
1111 };
1112 
1113 /* DSA device file */
1114 const char *dsaDeviceFile = "/dev/dsa/wq0.1";
1115 /* file for io*/
1116 const char *dsaPasidEnable = "/sys/bus/dsa/devices/dsa0/pasid_enabled";
1117 
1118 /*
1119  * DSA depends on kernel cmdline "intel_iommu=on,sm_on"
1120  * return pasid_enabled (0: disable 1:enable)
1121  */
1122 int Check_DSA_Kernel_Setting(void)
1123 {
1124 	char command[256] = "";
1125 	char buf[256] = "";
1126 	char *ptr;
1127 	int rv = -1;
1128 
1129 	snprintf(command, sizeof(command) - 1, "cat %s", dsaPasidEnable);
1130 
1131 	FILE *cmd = popen(command, "r");
1132 
1133 	if (cmd) {
1134 		while (fgets(buf, sizeof(buf) - 1, cmd) != NULL);
1135 
1136 		pclose(cmd);
1137 		rv = strtol(buf, &ptr, 16);
1138 	}
1139 
1140 	return rv;
1141 }
1142 
1143 /*
1144  * Config DSA's sysfs files as shared DSA's WQ.
1145  * Generated a device file /dev/dsa/wq0.1
1146  * Return:  0 OK; 1 Failed; 3 Skip(SVA disabled).
1147  */
1148 int Dsa_Init_Sysfs(void)
1149 {
1150 	uint len = ARRAY_SIZE(dsa_configs);
1151 	const char **p = dsa_configs;
1152 
1153 	if (file_Exists(dsaDeviceFile) == 1)
1154 		return 0;
1155 
1156 	/* check the idxd driver */
1157 	if (file_Exists(dsaPasidEnable) != 1) {
1158 		printf("Please make sure idxd driver was loaded\n");
1159 		return 3;
1160 	}
1161 
1162 	/* Check SVA feature */
1163 	if (Check_DSA_Kernel_Setting() != 1) {
1164 		printf("Please enable SVA.(Add intel_iommu=on,sm_on in kernel cmdline)\n");
1165 		return 3;
1166 	}
1167 
1168 	/* Check the idxd device file on /dev/dsa/ */
1169 	for (int i = 0; i < len; i++) {
1170 		if (system(p[i]))
1171 			return 1;
1172 	}
1173 
1174 	/* After config, /dev/dsa/wq0.1 should be generated */
1175 	return (file_Exists(dsaDeviceFile) != 1);
1176 }
1177 
1178 /*
1179  * Open DSA device file, triger API: iommu_sva_alloc_pasid
1180  */
1181 void *allocate_dsa_pasid(void)
1182 {
1183 	int fd;
1184 	void *wq;
1185 
1186 	fd = open(dsaDeviceFile, O_RDWR);
1187 	if (fd < 0) {
1188 		perror("open");
1189 		return MAP_FAILED;
1190 	}
1191 
1192 	wq = mmap(NULL, 0x1000, PROT_WRITE,
1193 			   MAP_SHARED | MAP_POPULATE, fd, 0);
1194 	close(fd);
1195 	if (wq == MAP_FAILED)
1196 		perror("mmap");
1197 
1198 	return wq;
1199 }
1200 
1201 int set_force_svm(void)
1202 {
1203 	int ret = 0;
1204 
1205 	ret = syscall(SYS_arch_prctl, ARCH_FORCE_TAGGED_SVA);
1206 
1207 	return ret;
1208 }
1209 
1210 int handle_pasid(struct testcases *test)
1211 {
1212 	uint tmp = test->cmd;
1213 	uint runed = 0x0;
1214 	int ret = 0;
1215 	void *wq = NULL;
1216 
1217 	ret = Dsa_Init_Sysfs();
1218 	if (ret != 0)
1219 		return ret;
1220 
1221 	for (int i = 0; i < 3; i++) {
1222 		int err = 0;
1223 
1224 		if (tmp & 0x1) {
1225 			/* run set lam mode*/
1226 			if ((runed & 0x1) == 0)	{
1227 				err = set_lam(LAM_U57_BITS);
1228 				runed = runed | 0x1;
1229 			} else
1230 				err = 1;
1231 		} else if (tmp & 0x4) {
1232 			/* run force svm */
1233 			if ((runed & 0x4) == 0)	{
1234 				err = set_force_svm();
1235 				runed = runed | 0x4;
1236 			} else
1237 				err = 1;
1238 		} else if (tmp & 0x2) {
1239 			/* run allocate pasid */
1240 			if ((runed & 0x2) == 0) {
1241 				runed = runed | 0x2;
1242 				wq = allocate_dsa_pasid();
1243 				if (wq == MAP_FAILED)
1244 					err = 1;
1245 			} else
1246 				err = 1;
1247 		}
1248 
1249 		ret = ret + err;
1250 		if (ret > 0)
1251 			break;
1252 
1253 		tmp = tmp >> 4;
1254 	}
1255 
1256 	if (wq != MAP_FAILED && wq != NULL)
1257 		if (munmap(wq, 0x1000))
1258 			printf("munmap failed %d\n", errno);
1259 
1260 	if (runed != 0x7)
1261 		ret = 1;
1262 
1263 	return (ret != 0);
1264 }
1265 
1266 /*
1267  * Pasid test depends on idxd and SVA, kernel should enable iommu and sm.
1268  * command line(intel_iommu=on,sm_on)
1269  */
1270 static struct testcases pasid_cases[] = {
1271 	{
1272 		.expected = 1,
1273 		.cmd = PAS_CMD(LAM_CMD_BIT, PAS_CMD_BIT, SVA_CMD_BIT),
1274 		.test_func = handle_pasid,
1275 		.msg = "PASID: [Negative] Execute LAM, PASID, SVA in sequence\n",
1276 	},
1277 	{
1278 		.expected = 0,
1279 		.cmd = PAS_CMD(LAM_CMD_BIT, SVA_CMD_BIT, PAS_CMD_BIT),
1280 		.test_func = handle_pasid,
1281 		.msg = "PASID: Execute LAM, SVA, PASID in sequence\n",
1282 	},
1283 	{
1284 		.expected = 1,
1285 		.cmd = PAS_CMD(PAS_CMD_BIT, LAM_CMD_BIT, SVA_CMD_BIT),
1286 		.test_func = handle_pasid,
1287 		.msg = "PASID: [Negative] Execute PASID, LAM, SVA in sequence\n",
1288 	},
1289 	{
1290 		.expected = 0,
1291 		.cmd = PAS_CMD(PAS_CMD_BIT, SVA_CMD_BIT, LAM_CMD_BIT),
1292 		.test_func = handle_pasid,
1293 		.msg = "PASID: Execute PASID, SVA, LAM in sequence\n",
1294 	},
1295 	{
1296 		.expected = 0,
1297 		.cmd = PAS_CMD(SVA_CMD_BIT, LAM_CMD_BIT, PAS_CMD_BIT),
1298 		.test_func = handle_pasid,
1299 		.msg = "PASID: Execute SVA, LAM, PASID in sequence\n",
1300 	},
1301 	{
1302 		.expected = 0,
1303 		.cmd = PAS_CMD(SVA_CMD_BIT, PAS_CMD_BIT, LAM_CMD_BIT),
1304 		.test_func = handle_pasid,
1305 		.msg = "PASID: Execute SVA, PASID, LAM in sequence\n",
1306 	},
1307 };
1308 
1309 int main(int argc, char **argv)
1310 {
1311 	int c = 0;
1312 	unsigned int tests = TEST_MASK;
1313 
1314 	tests_cnt = 0;
1315 
1316 	if (!lam_is_available())
1317 		return KSFT_SKIP;
1318 
1319 	while ((c = getopt(argc, argv, "ht:")) != -1) {
1320 		switch (c) {
1321 		case 't':
1322 			tests = strtoul(optarg, NULL, 16);
1323 			if (tests && !(tests & TEST_MASK)) {
1324 				ksft_print_msg("Invalid argument!\n");
1325 				return -1;
1326 			}
1327 			break;
1328 		case 'h':
1329 			cmd_help();
1330 			return 0;
1331 		default:
1332 			ksft_print_msg("Invalid argument\n");
1333 			return -1;
1334 		}
1335 	}
1336 
1337 	/*
1338 	 * When tests is 0, it is not a real test case;
1339 	 * the option used by test case(execve) to check the lam mode in
1340 	 * process generated by execve, the process read back lam mode and
1341 	 * check with lam mode in parent process.
1342 	 */
1343 	if (!tests)
1344 		return (get_lam());
1345 
1346 	/* Run test cases */
1347 	if (tests & FUNC_MALLOC)
1348 		run_test(malloc_cases, ARRAY_SIZE(malloc_cases));
1349 
1350 	if (tests & FUNC_BITS)
1351 		run_test(bits_cases, ARRAY_SIZE(bits_cases));
1352 
1353 	if (tests & FUNC_MMAP)
1354 		run_test(mmap_cases, ARRAY_SIZE(mmap_cases));
1355 
1356 	if (tests & FUNC_SYSCALL)
1357 		run_test(syscall_cases, ARRAY_SIZE(syscall_cases));
1358 
1359 	if (tests & FUNC_URING)
1360 		run_test(uring_cases, ARRAY_SIZE(uring_cases));
1361 
1362 	if (tests & FUNC_INHERITE)
1363 		run_test(inheritance_cases, ARRAY_SIZE(inheritance_cases));
1364 
1365 	if (tests & FUNC_PASID)
1366 		run_test(pasid_cases, ARRAY_SIZE(pasid_cases));
1367 
1368 	ksft_set_plan(tests_cnt);
1369 
1370 	ksft_exit_pass();
1371 }
1372