xref: /linux/tools/testing/selftests/x86/lam.c (revision c532de5a67a70f8533d495f8f2aaa9a0491c3ad0)
1 // SPDX-License-Identifier: GPL-2.0
2 #define _GNU_SOURCE
3 #include <stdio.h>
4 #include <stdlib.h>
5 #include <string.h>
6 #include <sys/syscall.h>
7 #include <time.h>
8 #include <signal.h>
9 #include <setjmp.h>
10 #include <sys/mman.h>
11 #include <sys/utsname.h>
12 #include <sys/wait.h>
13 #include <sys/stat.h>
14 #include <fcntl.h>
15 #include <inttypes.h>
16 #include <sched.h>
17 
18 #include <sys/uio.h>
19 #include <linux/io_uring.h>
20 #include "../kselftest.h"
21 
22 #ifndef __x86_64__
23 # error This test is 64-bit only
24 #endif
25 
26 /* LAM modes, these definitions were copied from kernel code */
27 #define LAM_NONE                0
28 #define LAM_U57_BITS            6
29 
30 #define LAM_U57_MASK            (0x3fULL << 57)
31 /* arch prctl for LAM */
32 #define ARCH_GET_UNTAG_MASK     0x4001
33 #define ARCH_ENABLE_TAGGED_ADDR 0x4002
34 #define ARCH_GET_MAX_TAG_BITS   0x4003
35 #define ARCH_FORCE_TAGGED_SVA	0x4004
36 
37 /* Specified test function bits */
38 #define FUNC_MALLOC             0x1
39 #define FUNC_BITS               0x2
40 #define FUNC_MMAP               0x4
41 #define FUNC_SYSCALL            0x8
42 #define FUNC_URING              0x10
43 #define FUNC_INHERITE           0x20
44 #define FUNC_PASID              0x40
45 
46 #define TEST_MASK               0x7f
47 
48 #define LOW_ADDR                (0x1UL << 30)
49 #define HIGH_ADDR               (0x3UL << 48)
50 
51 #define MALLOC_LEN              32
52 
53 #define PAGE_SIZE               (4 << 10)
54 
55 #define STACK_SIZE		65536
56 
57 #define barrier() ({						\
58 		   __asm__ __volatile__("" : : : "memory");	\
59 })
60 
61 #define URING_QUEUE_SZ 1
62 #define URING_BLOCK_SZ 2048
63 
64 /* Pasid test define */
65 #define LAM_CMD_BIT 0x1
66 #define PAS_CMD_BIT 0x2
67 #define SVA_CMD_BIT 0x4
68 
69 #define PAS_CMD(cmd1, cmd2, cmd3) (((cmd3) << 8) | ((cmd2) << 4) | ((cmd1) << 0))
70 
71 struct testcases {
72 	unsigned int later;
73 	int expected; /* 2: SIGSEGV Error; 1: other errors */
74 	unsigned long lam;
75 	uint64_t addr;
76 	uint64_t cmd;
77 	int (*test_func)(struct testcases *test);
78 	const char *msg;
79 };
80 
81 /* Used by CQ of uring, source file handler and file's size */
82 struct file_io {
83 	int file_fd;
84 	off_t file_sz;
85 	struct iovec iovecs[];
86 };
87 
88 struct io_uring_queue {
89 	unsigned int *head;
90 	unsigned int *tail;
91 	unsigned int *ring_mask;
92 	unsigned int *ring_entries;
93 	unsigned int *flags;
94 	unsigned int *array;
95 	union {
96 		struct io_uring_cqe *cqes;
97 		struct io_uring_sqe *sqes;
98 	} queue;
99 	size_t ring_sz;
100 };
101 
102 struct io_ring {
103 	int ring_fd;
104 	struct io_uring_queue sq_ring;
105 	struct io_uring_queue cq_ring;
106 };
107 
108 int tests_cnt;
109 jmp_buf segv_env;
110 
111 static void segv_handler(int sig)
112 {
113 	ksft_print_msg("Get segmentation fault(%d).", sig);
114 
115 	siglongjmp(segv_env, 1);
116 }
117 
118 static inline int cpu_has_lam(void)
119 {
120 	unsigned int cpuinfo[4];
121 
122 	__cpuid_count(0x7, 1, cpuinfo[0], cpuinfo[1], cpuinfo[2], cpuinfo[3]);
123 
124 	return (cpuinfo[0] & (1 << 26));
125 }
126 
127 /* Check 5-level page table feature in CPUID.(EAX=07H, ECX=00H):ECX.[bit 16] */
128 static inline int cpu_has_la57(void)
129 {
130 	unsigned int cpuinfo[4];
131 
132 	__cpuid_count(0x7, 0, cpuinfo[0], cpuinfo[1], cpuinfo[2], cpuinfo[3]);
133 
134 	return (cpuinfo[2] & (1 << 16));
135 }
136 
137 /*
138  * Set tagged address and read back untag mask.
139  * check if the untagged mask is expected.
140  *
141  * @return:
142  * 0: Set LAM mode successfully
143  * others: failed to set LAM
144  */
145 static int set_lam(unsigned long lam)
146 {
147 	int ret = 0;
148 	uint64_t ptr = 0;
149 
150 	if (lam != LAM_U57_BITS && lam != LAM_NONE)
151 		return -1;
152 
153 	/* Skip check return */
154 	syscall(SYS_arch_prctl, ARCH_ENABLE_TAGGED_ADDR, lam);
155 
156 	/* Get untagged mask */
157 	syscall(SYS_arch_prctl, ARCH_GET_UNTAG_MASK, &ptr);
158 
159 	/* Check mask returned is expected */
160 	if (lam == LAM_U57_BITS)
161 		ret = (ptr != ~(LAM_U57_MASK));
162 	else if (lam == LAM_NONE)
163 		ret = (ptr != -1ULL);
164 
165 	return ret;
166 }
167 
168 static unsigned long get_default_tag_bits(void)
169 {
170 	pid_t pid;
171 	int lam = LAM_NONE;
172 	int ret = 0;
173 
174 	pid = fork();
175 	if (pid < 0) {
176 		perror("Fork failed.");
177 	} else if (pid == 0) {
178 		/* Set LAM mode in child process */
179 		if (set_lam(LAM_U57_BITS) == 0)
180 			lam = LAM_U57_BITS;
181 		else
182 			lam = LAM_NONE;
183 		exit(lam);
184 	} else {
185 		wait(&ret);
186 		lam = WEXITSTATUS(ret);
187 	}
188 
189 	return lam;
190 }
191 
192 /*
193  * Set tagged address and read back untag mask.
194  * check if the untag mask is expected.
195  */
196 static int get_lam(void)
197 {
198 	uint64_t ptr = 0;
199 	int ret = -1;
200 	/* Get untagged mask */
201 	if (syscall(SYS_arch_prctl, ARCH_GET_UNTAG_MASK, &ptr) == -1)
202 		return -1;
203 
204 	/* Check mask returned is expected */
205 	if (ptr == ~(LAM_U57_MASK))
206 		ret = LAM_U57_BITS;
207 	else if (ptr == -1ULL)
208 		ret = LAM_NONE;
209 
210 
211 	return ret;
212 }
213 
214 /* According to LAM mode, set metadata in high bits */
215 static uint64_t set_metadata(uint64_t src, unsigned long lam)
216 {
217 	uint64_t metadata;
218 
219 	srand(time(NULL));
220 
221 	switch (lam) {
222 	case LAM_U57_BITS: /* Set metadata in bits 62:57 */
223 		/* Get a random non-zero value as metadata */
224 		metadata = (rand() % ((1UL << LAM_U57_BITS) - 1) + 1) << 57;
225 		metadata |= (src & ~(LAM_U57_MASK));
226 		break;
227 	default:
228 		metadata = src;
229 		break;
230 	}
231 
232 	return metadata;
233 }
234 
235 /*
236  * Set metadata in user pointer, compare new pointer with original pointer.
237  * both pointers should point to the same address.
238  *
239  * @return:
240  * 0: value on the pointer with metadate and value on original are same
241  * 1: not same.
242  */
243 static int handle_lam_test(void *src, unsigned int lam)
244 {
245 	char *ptr;
246 
247 	strcpy((char *)src, "USER POINTER");
248 
249 	ptr = (char *)set_metadata((uint64_t)src, lam);
250 	if (src == ptr)
251 		return 0;
252 
253 	/* Copy a string into the pointer with metadata */
254 	strcpy((char *)ptr, "METADATA POINTER");
255 
256 	return (!!strcmp((char *)src, (char *)ptr));
257 }
258 
259 
260 int handle_max_bits(struct testcases *test)
261 {
262 	unsigned long exp_bits = get_default_tag_bits();
263 	unsigned long bits = 0;
264 
265 	if (exp_bits != LAM_NONE)
266 		exp_bits = LAM_U57_BITS;
267 
268 	/* Get LAM max tag bits */
269 	if (syscall(SYS_arch_prctl, ARCH_GET_MAX_TAG_BITS, &bits) == -1)
270 		return 1;
271 
272 	return (exp_bits != bits);
273 }
274 
275 /*
276  * Test lam feature through dereference pointer get from malloc.
277  * @return 0: Pass test. 1: Get failure during test 2: Get SIGSEGV
278  */
279 static int handle_malloc(struct testcases *test)
280 {
281 	char *ptr = NULL;
282 	int ret = 0;
283 
284 	if (test->later == 0 && test->lam != 0)
285 		if (set_lam(test->lam) == -1)
286 			return 1;
287 
288 	ptr = (char *)malloc(MALLOC_LEN);
289 	if (ptr == NULL) {
290 		perror("malloc() failure\n");
291 		return 1;
292 	}
293 
294 	/* Set signal handler */
295 	if (sigsetjmp(segv_env, 1) == 0) {
296 		signal(SIGSEGV, segv_handler);
297 		ret = handle_lam_test(ptr, test->lam);
298 	} else {
299 		ret = 2;
300 	}
301 
302 	if (test->later != 0 && test->lam != 0)
303 		if (set_lam(test->lam) == -1 && ret == 0)
304 			ret = 1;
305 
306 	free(ptr);
307 
308 	return ret;
309 }
310 
311 static int handle_mmap(struct testcases *test)
312 {
313 	void *ptr;
314 	unsigned int flags = MAP_PRIVATE | MAP_ANONYMOUS | MAP_FIXED;
315 	int ret = 0;
316 
317 	if (test->later == 0 && test->lam != 0)
318 		if (set_lam(test->lam) != 0)
319 			return 1;
320 
321 	ptr = mmap((void *)test->addr, PAGE_SIZE, PROT_READ | PROT_WRITE,
322 		   flags, -1, 0);
323 	if (ptr == MAP_FAILED) {
324 		if (test->addr == HIGH_ADDR)
325 			if (!cpu_has_la57())
326 				return 3; /* unsupport LA57 */
327 		return 1;
328 	}
329 
330 	if (test->later != 0 && test->lam != 0)
331 		if (set_lam(test->lam) != 0)
332 			ret = 1;
333 
334 	if (ret == 0) {
335 		if (sigsetjmp(segv_env, 1) == 0) {
336 			signal(SIGSEGV, segv_handler);
337 			ret = handle_lam_test(ptr, test->lam);
338 		} else {
339 			ret = 2;
340 		}
341 	}
342 
343 	munmap(ptr, PAGE_SIZE);
344 	return ret;
345 }
346 
347 static int handle_syscall(struct testcases *test)
348 {
349 	struct utsname unme, *pu;
350 	int ret = 0;
351 
352 	if (test->later == 0 && test->lam != 0)
353 		if (set_lam(test->lam) != 0)
354 			return 1;
355 
356 	if (sigsetjmp(segv_env, 1) == 0) {
357 		signal(SIGSEGV, segv_handler);
358 		pu = (struct utsname *)set_metadata((uint64_t)&unme, test->lam);
359 		ret = uname(pu);
360 		if (ret < 0)
361 			ret = 1;
362 	} else {
363 		ret = 2;
364 	}
365 
366 	if (test->later != 0 && test->lam != 0)
367 		if (set_lam(test->lam) != -1 && ret == 0)
368 			ret = 1;
369 
370 	return ret;
371 }
372 
373 int sys_uring_setup(unsigned int entries, struct io_uring_params *p)
374 {
375 	return (int)syscall(__NR_io_uring_setup, entries, p);
376 }
377 
378 int sys_uring_enter(int fd, unsigned int to, unsigned int min, unsigned int flags)
379 {
380 	return (int)syscall(__NR_io_uring_enter, fd, to, min, flags, NULL, 0);
381 }
382 
383 /* Init submission queue and completion queue */
384 int mmap_io_uring(struct io_uring_params p, struct io_ring *s)
385 {
386 	struct io_uring_queue *sring = &s->sq_ring;
387 	struct io_uring_queue *cring = &s->cq_ring;
388 
389 	sring->ring_sz = p.sq_off.array + p.sq_entries * sizeof(unsigned int);
390 	cring->ring_sz = p.cq_off.cqes + p.cq_entries * sizeof(struct io_uring_cqe);
391 
392 	if (p.features & IORING_FEAT_SINGLE_MMAP) {
393 		if (cring->ring_sz > sring->ring_sz)
394 			sring->ring_sz = cring->ring_sz;
395 
396 		cring->ring_sz = sring->ring_sz;
397 	}
398 
399 	void *sq_ptr = mmap(0, sring->ring_sz, PROT_READ | PROT_WRITE,
400 			    MAP_SHARED | MAP_POPULATE, s->ring_fd,
401 			    IORING_OFF_SQ_RING);
402 
403 	if (sq_ptr == MAP_FAILED) {
404 		perror("sub-queue!");
405 		return 1;
406 	}
407 
408 	void *cq_ptr = sq_ptr;
409 
410 	if (!(p.features & IORING_FEAT_SINGLE_MMAP)) {
411 		cq_ptr = mmap(0, cring->ring_sz, PROT_READ | PROT_WRITE,
412 			      MAP_SHARED | MAP_POPULATE, s->ring_fd,
413 			      IORING_OFF_CQ_RING);
414 		if (cq_ptr == MAP_FAILED) {
415 			perror("cpl-queue!");
416 			munmap(sq_ptr, sring->ring_sz);
417 			return 1;
418 		}
419 	}
420 
421 	sring->head = sq_ptr + p.sq_off.head;
422 	sring->tail = sq_ptr + p.sq_off.tail;
423 	sring->ring_mask = sq_ptr + p.sq_off.ring_mask;
424 	sring->ring_entries = sq_ptr + p.sq_off.ring_entries;
425 	sring->flags = sq_ptr + p.sq_off.flags;
426 	sring->array = sq_ptr + p.sq_off.array;
427 
428 	/* Map a queue as mem map */
429 	s->sq_ring.queue.sqes = mmap(0, p.sq_entries * sizeof(struct io_uring_sqe),
430 				     PROT_READ | PROT_WRITE, MAP_SHARED | MAP_POPULATE,
431 				     s->ring_fd, IORING_OFF_SQES);
432 	if (s->sq_ring.queue.sqes == MAP_FAILED) {
433 		munmap(sq_ptr, sring->ring_sz);
434 		if (sq_ptr != cq_ptr) {
435 			ksft_print_msg("failed to mmap uring queue!");
436 			munmap(cq_ptr, cring->ring_sz);
437 			return 1;
438 		}
439 	}
440 
441 	cring->head = cq_ptr + p.cq_off.head;
442 	cring->tail = cq_ptr + p.cq_off.tail;
443 	cring->ring_mask = cq_ptr + p.cq_off.ring_mask;
444 	cring->ring_entries = cq_ptr + p.cq_off.ring_entries;
445 	cring->queue.cqes = cq_ptr + p.cq_off.cqes;
446 
447 	return 0;
448 }
449 
450 /* Init io_uring queues */
451 int setup_io_uring(struct io_ring *s)
452 {
453 	struct io_uring_params para;
454 
455 	memset(&para, 0, sizeof(para));
456 	s->ring_fd = sys_uring_setup(URING_QUEUE_SZ, &para);
457 	if (s->ring_fd < 0)
458 		return 1;
459 
460 	return mmap_io_uring(para, s);
461 }
462 
463 /*
464  * Get data from completion queue. the data buffer saved the file data
465  * return 0: success; others: error;
466  */
467 int handle_uring_cq(struct io_ring *s)
468 {
469 	struct file_io *fi = NULL;
470 	struct io_uring_queue *cring = &s->cq_ring;
471 	struct io_uring_cqe *cqe;
472 	unsigned int head;
473 	off_t len = 0;
474 
475 	head = *cring->head;
476 
477 	do {
478 		barrier();
479 		if (head == *cring->tail)
480 			break;
481 		/* Get the entry */
482 		cqe = &cring->queue.cqes[head & *s->cq_ring.ring_mask];
483 		fi = (struct file_io *)cqe->user_data;
484 		if (cqe->res < 0)
485 			break;
486 
487 		int blocks = (int)(fi->file_sz + URING_BLOCK_SZ - 1) / URING_BLOCK_SZ;
488 
489 		for (int i = 0; i < blocks; i++)
490 			len += fi->iovecs[i].iov_len;
491 
492 		head++;
493 	} while (1);
494 
495 	*cring->head = head;
496 	barrier();
497 
498 	return (len != fi->file_sz);
499 }
500 
501 /*
502  * Submit squeue. specify via IORING_OP_READV.
503  * the buffer need to be set metadata according to LAM mode
504  */
505 int handle_uring_sq(struct io_ring *ring, struct file_io *fi, unsigned long lam)
506 {
507 	int file_fd = fi->file_fd;
508 	struct io_uring_queue *sring = &ring->sq_ring;
509 	unsigned int index = 0, cur_block = 0, tail = 0, next_tail = 0;
510 	struct io_uring_sqe *sqe;
511 
512 	off_t remain = fi->file_sz;
513 	int blocks = (int)(remain + URING_BLOCK_SZ - 1) / URING_BLOCK_SZ;
514 
515 	while (remain) {
516 		off_t bytes = remain;
517 		void *buf;
518 
519 		if (bytes > URING_BLOCK_SZ)
520 			bytes = URING_BLOCK_SZ;
521 
522 		fi->iovecs[cur_block].iov_len = bytes;
523 
524 		if (posix_memalign(&buf, URING_BLOCK_SZ, URING_BLOCK_SZ))
525 			return 1;
526 
527 		fi->iovecs[cur_block].iov_base = (void *)set_metadata((uint64_t)buf, lam);
528 		remain -= bytes;
529 		cur_block++;
530 	}
531 
532 	next_tail = *sring->tail;
533 	tail = next_tail;
534 	next_tail++;
535 
536 	barrier();
537 
538 	index = tail & *ring->sq_ring.ring_mask;
539 
540 	sqe = &ring->sq_ring.queue.sqes[index];
541 	sqe->fd = file_fd;
542 	sqe->flags = 0;
543 	sqe->opcode = IORING_OP_READV;
544 	sqe->addr = (unsigned long)fi->iovecs;
545 	sqe->len = blocks;
546 	sqe->off = 0;
547 	sqe->user_data = (uint64_t)fi;
548 
549 	sring->array[index] = index;
550 	tail = next_tail;
551 
552 	if (*sring->tail != tail) {
553 		*sring->tail = tail;
554 		barrier();
555 	}
556 
557 	if (sys_uring_enter(ring->ring_fd, 1, 1, IORING_ENTER_GETEVENTS) < 0)
558 		return 1;
559 
560 	return 0;
561 }
562 
563 /*
564  * Test LAM in async I/O and io_uring, read current binery through io_uring
565  * Set metadata in pointers to iovecs buffer.
566  */
567 int do_uring(unsigned long lam)
568 {
569 	struct io_ring *ring;
570 	struct file_io *fi;
571 	struct stat st;
572 	int ret = 1;
573 	char path[PATH_MAX] = {0};
574 
575 	/* get current process path */
576 	if (readlink("/proc/self/exe", path, PATH_MAX - 1) <= 0)
577 		return 1;
578 
579 	int file_fd = open(path, O_RDONLY);
580 
581 	if (file_fd < 0)
582 		return 1;
583 
584 	if (fstat(file_fd, &st) < 0)
585 		return 1;
586 
587 	off_t file_sz = st.st_size;
588 
589 	int blocks = (int)(file_sz + URING_BLOCK_SZ - 1) / URING_BLOCK_SZ;
590 
591 	fi = malloc(sizeof(*fi) + sizeof(struct iovec) * blocks);
592 	if (!fi)
593 		return 1;
594 
595 	fi->file_sz = file_sz;
596 	fi->file_fd = file_fd;
597 
598 	ring = malloc(sizeof(*ring));
599 	if (!ring)
600 		return 1;
601 
602 	memset(ring, 0, sizeof(struct io_ring));
603 
604 	if (setup_io_uring(ring))
605 		goto out;
606 
607 	if (handle_uring_sq(ring, fi, lam))
608 		goto out;
609 
610 	ret = handle_uring_cq(ring);
611 
612 out:
613 	free(ring);
614 
615 	for (int i = 0; i < blocks; i++) {
616 		if (fi->iovecs[i].iov_base) {
617 			uint64_t addr = ((uint64_t)fi->iovecs[i].iov_base);
618 
619 			switch (lam) {
620 			case LAM_U57_BITS: /* Clear bits 62:57 */
621 				addr = (addr & ~(LAM_U57_MASK));
622 				break;
623 			}
624 			free((void *)addr);
625 			fi->iovecs[i].iov_base = NULL;
626 		}
627 	}
628 
629 	free(fi);
630 
631 	return ret;
632 }
633 
634 int handle_uring(struct testcases *test)
635 {
636 	int ret = 0;
637 
638 	if (test->later == 0 && test->lam != 0)
639 		if (set_lam(test->lam) != 0)
640 			return 1;
641 
642 	if (sigsetjmp(segv_env, 1) == 0) {
643 		signal(SIGSEGV, segv_handler);
644 		ret = do_uring(test->lam);
645 	} else {
646 		ret = 2;
647 	}
648 
649 	return ret;
650 }
651 
652 static int fork_test(struct testcases *test)
653 {
654 	int ret, child_ret;
655 	pid_t pid;
656 
657 	pid = fork();
658 	if (pid < 0) {
659 		perror("Fork failed.");
660 		ret = 1;
661 	} else if (pid == 0) {
662 		ret = test->test_func(test);
663 		exit(ret);
664 	} else {
665 		wait(&child_ret);
666 		ret = WEXITSTATUS(child_ret);
667 	}
668 
669 	return ret;
670 }
671 
672 static int handle_execve(struct testcases *test)
673 {
674 	int ret, child_ret;
675 	int lam = test->lam;
676 	pid_t pid;
677 
678 	pid = fork();
679 	if (pid < 0) {
680 		perror("Fork failed.");
681 		ret = 1;
682 	} else if (pid == 0) {
683 		char path[PATH_MAX] = {0};
684 
685 		/* Set LAM mode in parent process */
686 		if (set_lam(lam) != 0)
687 			return 1;
688 
689 		/* Get current binary's path and the binary was run by execve */
690 		if (readlink("/proc/self/exe", path, PATH_MAX - 1) <= 0)
691 			exit(-1);
692 
693 		/* run binary to get LAM mode and return to parent process */
694 		if (execlp(path, path, "-t 0x0", NULL) < 0) {
695 			perror("error on exec");
696 			exit(-1);
697 		}
698 	} else {
699 		wait(&child_ret);
700 		ret = WEXITSTATUS(child_ret);
701 		if (ret != LAM_NONE)
702 			return 1;
703 	}
704 
705 	return 0;
706 }
707 
708 static int handle_inheritance(struct testcases *test)
709 {
710 	int ret, child_ret;
711 	int lam = test->lam;
712 	pid_t pid;
713 
714 	/* Set LAM mode in parent process */
715 	if (set_lam(lam) != 0)
716 		return 1;
717 
718 	pid = fork();
719 	if (pid < 0) {
720 		perror("Fork failed.");
721 		return 1;
722 	} else if (pid == 0) {
723 		/* Set LAM mode in parent process */
724 		int child_lam = get_lam();
725 
726 		exit(child_lam);
727 	} else {
728 		wait(&child_ret);
729 		ret = WEXITSTATUS(child_ret);
730 
731 		if (lam != ret)
732 			return 1;
733 	}
734 
735 	return 0;
736 }
737 
738 static int thread_fn_get_lam(void *arg)
739 {
740 	return get_lam();
741 }
742 
743 static int thread_fn_set_lam(void *arg)
744 {
745 	struct testcases *test = arg;
746 
747 	return set_lam(test->lam);
748 }
749 
750 static int handle_thread(struct testcases *test)
751 {
752 	char stack[STACK_SIZE];
753 	int ret, child_ret;
754 	int lam = 0;
755 	pid_t pid;
756 
757 	/* Set LAM mode in parent process */
758 	if (!test->later) {
759 		lam = test->lam;
760 		if (set_lam(lam) != 0)
761 			return 1;
762 	}
763 
764 	pid = clone(thread_fn_get_lam, stack + STACK_SIZE,
765 		    SIGCHLD | CLONE_FILES | CLONE_FS | CLONE_VM, NULL);
766 	if (pid < 0) {
767 		perror("Clone failed.");
768 		return 1;
769 	}
770 
771 	waitpid(pid, &child_ret, 0);
772 	ret = WEXITSTATUS(child_ret);
773 
774 	if (lam != ret)
775 		return 1;
776 
777 	if (test->later) {
778 		if (set_lam(test->lam) != 0)
779 			return 1;
780 	}
781 
782 	return 0;
783 }
784 
785 static int handle_thread_enable(struct testcases *test)
786 {
787 	char stack[STACK_SIZE];
788 	int ret, child_ret;
789 	int lam = test->lam;
790 	pid_t pid;
791 
792 	pid = clone(thread_fn_set_lam, stack + STACK_SIZE,
793 		    SIGCHLD | CLONE_FILES | CLONE_FS | CLONE_VM, test);
794 	if (pid < 0) {
795 		perror("Clone failed.");
796 		return 1;
797 	}
798 
799 	waitpid(pid, &child_ret, 0);
800 	ret = WEXITSTATUS(child_ret);
801 
802 	if (lam != ret)
803 		return 1;
804 
805 	return 0;
806 }
807 static void run_test(struct testcases *test, int count)
808 {
809 	int i, ret = 0;
810 
811 	for (i = 0; i < count; i++) {
812 		struct testcases *t = test + i;
813 
814 		/* fork a process to run test case */
815 		tests_cnt++;
816 		ret = fork_test(t);
817 
818 		/* return 3 is not support LA57, the case should be skipped */
819 		if (ret == 3) {
820 			ksft_test_result_skip("%s", t->msg);
821 			continue;
822 		}
823 
824 		if (ret != 0)
825 			ret = (t->expected == ret);
826 		else
827 			ret = !(t->expected);
828 
829 		ksft_test_result(ret, "%s", t->msg);
830 	}
831 }
832 
833 static struct testcases uring_cases[] = {
834 	{
835 		.later = 0,
836 		.lam = LAM_U57_BITS,
837 		.test_func = handle_uring,
838 		.msg = "URING: LAM_U57. Dereferencing pointer with metadata\n",
839 	},
840 	{
841 		.later = 1,
842 		.expected = 1,
843 		.lam = LAM_U57_BITS,
844 		.test_func = handle_uring,
845 		.msg = "URING:[Negative] Disable LAM. Dereferencing pointer with metadata.\n",
846 	},
847 };
848 
849 static struct testcases malloc_cases[] = {
850 	{
851 		.later = 0,
852 		.lam = LAM_U57_BITS,
853 		.test_func = handle_malloc,
854 		.msg = "MALLOC: LAM_U57. Dereferencing pointer with metadata\n",
855 	},
856 	{
857 		.later = 1,
858 		.expected = 2,
859 		.lam = LAM_U57_BITS,
860 		.test_func = handle_malloc,
861 		.msg = "MALLOC:[Negative] Disable LAM. Dereferencing pointer with metadata.\n",
862 	},
863 };
864 
865 static struct testcases bits_cases[] = {
866 	{
867 		.test_func = handle_max_bits,
868 		.msg = "BITS: Check default tag bits\n",
869 	},
870 };
871 
872 static struct testcases syscall_cases[] = {
873 	{
874 		.later = 0,
875 		.lam = LAM_U57_BITS,
876 		.test_func = handle_syscall,
877 		.msg = "SYSCALL: LAM_U57. syscall with metadata\n",
878 	},
879 	{
880 		.later = 1,
881 		.expected = 1,
882 		.lam = LAM_U57_BITS,
883 		.test_func = handle_syscall,
884 		.msg = "SYSCALL:[Negative] Disable LAM. Dereferencing pointer with metadata.\n",
885 	},
886 };
887 
888 static struct testcases mmap_cases[] = {
889 	{
890 		.later = 1,
891 		.expected = 0,
892 		.lam = LAM_U57_BITS,
893 		.addr = HIGH_ADDR,
894 		.test_func = handle_mmap,
895 		.msg = "MMAP: First mmap high address, then set LAM_U57.\n",
896 	},
897 	{
898 		.later = 0,
899 		.expected = 0,
900 		.lam = LAM_U57_BITS,
901 		.addr = HIGH_ADDR,
902 		.test_func = handle_mmap,
903 		.msg = "MMAP: First LAM_U57, then High address.\n",
904 	},
905 	{
906 		.later = 0,
907 		.expected = 0,
908 		.lam = LAM_U57_BITS,
909 		.addr = LOW_ADDR,
910 		.test_func = handle_mmap,
911 		.msg = "MMAP: First LAM_U57, then Low address.\n",
912 	},
913 };
914 
915 static struct testcases inheritance_cases[] = {
916 	{
917 		.expected = 0,
918 		.lam = LAM_U57_BITS,
919 		.test_func = handle_inheritance,
920 		.msg = "FORK: LAM_U57, child process should get LAM mode same as parent\n",
921 	},
922 	{
923 		.expected = 0,
924 		.lam = LAM_U57_BITS,
925 		.test_func = handle_thread,
926 		.msg = "THREAD: LAM_U57, child thread should get LAM mode same as parent\n",
927 	},
928 	{
929 		.expected = 1,
930 		.lam = LAM_U57_BITS,
931 		.test_func = handle_thread_enable,
932 		.msg = "THREAD: [NEGATIVE] Enable LAM in child.\n",
933 	},
934 	{
935 		.expected = 1,
936 		.later = 1,
937 		.lam = LAM_U57_BITS,
938 		.test_func = handle_thread,
939 		.msg = "THREAD: [NEGATIVE] Enable LAM in parent after thread created.\n",
940 	},
941 	{
942 		.expected = 0,
943 		.lam = LAM_U57_BITS,
944 		.test_func = handle_execve,
945 		.msg = "EXECVE: LAM_U57, child process should get disabled LAM mode\n",
946 	},
947 };
948 
949 static void cmd_help(void)
950 {
951 	printf("usage: lam [-h] [-t test list]\n");
952 	printf("\t-t test list: run tests specified in the test list, default:0x%x\n", TEST_MASK);
953 	printf("\t\t0x1:malloc; 0x2:max_bits; 0x4:mmap; 0x8:syscall; 0x10:io_uring; 0x20:inherit;\n");
954 	printf("\t-h: help\n");
955 }
956 
957 /* Check for file existence */
958 uint8_t file_Exists(const char *fileName)
959 {
960 	struct stat buffer;
961 
962 	uint8_t ret = (stat(fileName, &buffer) == 0);
963 
964 	return ret;
965 }
966 
967 /* Sysfs idxd files */
968 const char *dsa_configs[] = {
969 	"echo 1 > /sys/bus/dsa/devices/dsa0/wq0.1/group_id",
970 	"echo shared > /sys/bus/dsa/devices/dsa0/wq0.1/mode",
971 	"echo 10 > /sys/bus/dsa/devices/dsa0/wq0.1/priority",
972 	"echo 16 > /sys/bus/dsa/devices/dsa0/wq0.1/size",
973 	"echo 15 > /sys/bus/dsa/devices/dsa0/wq0.1/threshold",
974 	"echo user > /sys/bus/dsa/devices/dsa0/wq0.1/type",
975 	"echo MyApp1 > /sys/bus/dsa/devices/dsa0/wq0.1/name",
976 	"echo 1 > /sys/bus/dsa/devices/dsa0/engine0.1/group_id",
977 	"echo dsa0 > /sys/bus/dsa/drivers/idxd/bind",
978 	/* bind files and devices, generated a device file in /dev */
979 	"echo wq0.1 > /sys/bus/dsa/drivers/user/bind",
980 };
981 
982 /* DSA device file */
983 const char *dsaDeviceFile = "/dev/dsa/wq0.1";
984 /* file for io*/
985 const char *dsaPasidEnable = "/sys/bus/dsa/devices/dsa0/pasid_enabled";
986 
987 /*
988  * DSA depends on kernel cmdline "intel_iommu=on,sm_on"
989  * return pasid_enabled (0: disable 1:enable)
990  */
991 int Check_DSA_Kernel_Setting(void)
992 {
993 	char command[256] = "";
994 	char buf[256] = "";
995 	char *ptr;
996 	int rv = -1;
997 
998 	snprintf(command, sizeof(command) - 1, "cat %s", dsaPasidEnable);
999 
1000 	FILE *cmd = popen(command, "r");
1001 
1002 	if (cmd) {
1003 		while (fgets(buf, sizeof(buf) - 1, cmd) != NULL);
1004 
1005 		pclose(cmd);
1006 		rv = strtol(buf, &ptr, 16);
1007 	}
1008 
1009 	return rv;
1010 }
1011 
1012 /*
1013  * Config DSA's sysfs files as shared DSA's WQ.
1014  * Generated a device file /dev/dsa/wq0.1
1015  * Return:  0 OK; 1 Failed; 3 Skip(SVA disabled).
1016  */
1017 int Dsa_Init_Sysfs(void)
1018 {
1019 	uint len = ARRAY_SIZE(dsa_configs);
1020 	const char **p = dsa_configs;
1021 
1022 	if (file_Exists(dsaDeviceFile) == 1)
1023 		return 0;
1024 
1025 	/* check the idxd driver */
1026 	if (file_Exists(dsaPasidEnable) != 1) {
1027 		printf("Please make sure idxd driver was loaded\n");
1028 		return 3;
1029 	}
1030 
1031 	/* Check SVA feature */
1032 	if (Check_DSA_Kernel_Setting() != 1) {
1033 		printf("Please enable SVA.(Add intel_iommu=on,sm_on in kernel cmdline)\n");
1034 		return 3;
1035 	}
1036 
1037 	/* Check the idxd device file on /dev/dsa/ */
1038 	for (int i = 0; i < len; i++) {
1039 		if (system(p[i]))
1040 			return 1;
1041 	}
1042 
1043 	/* After config, /dev/dsa/wq0.1 should be generated */
1044 	return (file_Exists(dsaDeviceFile) != 1);
1045 }
1046 
1047 /*
1048  * Open DSA device file, triger API: iommu_sva_alloc_pasid
1049  */
1050 void *allocate_dsa_pasid(void)
1051 {
1052 	int fd;
1053 	void *wq;
1054 
1055 	fd = open(dsaDeviceFile, O_RDWR);
1056 	if (fd < 0) {
1057 		perror("open");
1058 		return MAP_FAILED;
1059 	}
1060 
1061 	wq = mmap(NULL, 0x1000, PROT_WRITE,
1062 			   MAP_SHARED | MAP_POPULATE, fd, 0);
1063 	if (wq == MAP_FAILED)
1064 		perror("mmap");
1065 
1066 	return wq;
1067 }
1068 
1069 int set_force_svm(void)
1070 {
1071 	int ret = 0;
1072 
1073 	ret = syscall(SYS_arch_prctl, ARCH_FORCE_TAGGED_SVA);
1074 
1075 	return ret;
1076 }
1077 
1078 int handle_pasid(struct testcases *test)
1079 {
1080 	uint tmp = test->cmd;
1081 	uint runed = 0x0;
1082 	int ret = 0;
1083 	void *wq = NULL;
1084 
1085 	ret = Dsa_Init_Sysfs();
1086 	if (ret != 0)
1087 		return ret;
1088 
1089 	for (int i = 0; i < 3; i++) {
1090 		int err = 0;
1091 
1092 		if (tmp & 0x1) {
1093 			/* run set lam mode*/
1094 			if ((runed & 0x1) == 0)	{
1095 				err = set_lam(LAM_U57_BITS);
1096 				runed = runed | 0x1;
1097 			} else
1098 				err = 1;
1099 		} else if (tmp & 0x4) {
1100 			/* run force svm */
1101 			if ((runed & 0x4) == 0)	{
1102 				err = set_force_svm();
1103 				runed = runed | 0x4;
1104 			} else
1105 				err = 1;
1106 		} else if (tmp & 0x2) {
1107 			/* run allocate pasid */
1108 			if ((runed & 0x2) == 0) {
1109 				runed = runed | 0x2;
1110 				wq = allocate_dsa_pasid();
1111 				if (wq == MAP_FAILED)
1112 					err = 1;
1113 			} else
1114 				err = 1;
1115 		}
1116 
1117 		ret = ret + err;
1118 		if (ret > 0)
1119 			break;
1120 
1121 		tmp = tmp >> 4;
1122 	}
1123 
1124 	if (wq != MAP_FAILED && wq != NULL)
1125 		if (munmap(wq, 0x1000))
1126 			printf("munmap failed %d\n", errno);
1127 
1128 	if (runed != 0x7)
1129 		ret = 1;
1130 
1131 	return (ret != 0);
1132 }
1133 
1134 /*
1135  * Pasid test depends on idxd and SVA, kernel should enable iommu and sm.
1136  * command line(intel_iommu=on,sm_on)
1137  */
1138 static struct testcases pasid_cases[] = {
1139 	{
1140 		.expected = 1,
1141 		.cmd = PAS_CMD(LAM_CMD_BIT, PAS_CMD_BIT, SVA_CMD_BIT),
1142 		.test_func = handle_pasid,
1143 		.msg = "PASID: [Negative] Execute LAM, PASID, SVA in sequence\n",
1144 	},
1145 	{
1146 		.expected = 0,
1147 		.cmd = PAS_CMD(LAM_CMD_BIT, SVA_CMD_BIT, PAS_CMD_BIT),
1148 		.test_func = handle_pasid,
1149 		.msg = "PASID: Execute LAM, SVA, PASID in sequence\n",
1150 	},
1151 	{
1152 		.expected = 1,
1153 		.cmd = PAS_CMD(PAS_CMD_BIT, LAM_CMD_BIT, SVA_CMD_BIT),
1154 		.test_func = handle_pasid,
1155 		.msg = "PASID: [Negative] Execute PASID, LAM, SVA in sequence\n",
1156 	},
1157 	{
1158 		.expected = 0,
1159 		.cmd = PAS_CMD(PAS_CMD_BIT, SVA_CMD_BIT, LAM_CMD_BIT),
1160 		.test_func = handle_pasid,
1161 		.msg = "PASID: Execute PASID, SVA, LAM in sequence\n",
1162 	},
1163 	{
1164 		.expected = 0,
1165 		.cmd = PAS_CMD(SVA_CMD_BIT, LAM_CMD_BIT, PAS_CMD_BIT),
1166 		.test_func = handle_pasid,
1167 		.msg = "PASID: Execute SVA, LAM, PASID in sequence\n",
1168 	},
1169 	{
1170 		.expected = 0,
1171 		.cmd = PAS_CMD(SVA_CMD_BIT, PAS_CMD_BIT, LAM_CMD_BIT),
1172 		.test_func = handle_pasid,
1173 		.msg = "PASID: Execute SVA, PASID, LAM in sequence\n",
1174 	},
1175 };
1176 
1177 int main(int argc, char **argv)
1178 {
1179 	int c = 0;
1180 	unsigned int tests = TEST_MASK;
1181 
1182 	tests_cnt = 0;
1183 
1184 	if (!cpu_has_lam()) {
1185 		ksft_print_msg("Unsupported LAM feature!\n");
1186 		return KSFT_SKIP;
1187 	}
1188 
1189 	while ((c = getopt(argc, argv, "ht:")) != -1) {
1190 		switch (c) {
1191 		case 't':
1192 			tests = strtoul(optarg, NULL, 16);
1193 			if (tests && !(tests & TEST_MASK)) {
1194 				ksft_print_msg("Invalid argument!\n");
1195 				return -1;
1196 			}
1197 			break;
1198 		case 'h':
1199 			cmd_help();
1200 			return 0;
1201 		default:
1202 			ksft_print_msg("Invalid argument\n");
1203 			return -1;
1204 		}
1205 	}
1206 
1207 	/*
1208 	 * When tests is 0, it is not a real test case;
1209 	 * the option used by test case(execve) to check the lam mode in
1210 	 * process generated by execve, the process read back lam mode and
1211 	 * check with lam mode in parent process.
1212 	 */
1213 	if (!tests)
1214 		return (get_lam());
1215 
1216 	/* Run test cases */
1217 	if (tests & FUNC_MALLOC)
1218 		run_test(malloc_cases, ARRAY_SIZE(malloc_cases));
1219 
1220 	if (tests & FUNC_BITS)
1221 		run_test(bits_cases, ARRAY_SIZE(bits_cases));
1222 
1223 	if (tests & FUNC_MMAP)
1224 		run_test(mmap_cases, ARRAY_SIZE(mmap_cases));
1225 
1226 	if (tests & FUNC_SYSCALL)
1227 		run_test(syscall_cases, ARRAY_SIZE(syscall_cases));
1228 
1229 	if (tests & FUNC_URING)
1230 		run_test(uring_cases, ARRAY_SIZE(uring_cases));
1231 
1232 	if (tests & FUNC_INHERITE)
1233 		run_test(inheritance_cases, ARRAY_SIZE(inheritance_cases));
1234 
1235 	if (tests & FUNC_PASID)
1236 		run_test(pasid_cases, ARRAY_SIZE(pasid_cases));
1237 
1238 	ksft_set_plan(tests_cnt);
1239 
1240 	ksft_exit_pass();
1241 }
1242