xref: /linux/tools/testing/selftests/arm64/fp/fp-ptrace.c (revision eb01fe7abbe2d0b38824d2a93fdb4cc3eaf2ccc1)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Copyright (C) 2023 ARM Limited.
4  * Original author: Mark Brown <broonie@kernel.org>
5  */
6 
7 #define _GNU_SOURCE
8 
9 #include <errno.h>
10 #include <stdbool.h>
11 #include <stddef.h>
12 #include <stdio.h>
13 #include <stdlib.h>
14 #include <string.h>
15 #include <unistd.h>
16 
17 #include <sys/auxv.h>
18 #include <sys/prctl.h>
19 #include <sys/ptrace.h>
20 #include <sys/types.h>
21 #include <sys/uio.h>
22 #include <sys/wait.h>
23 
24 #include <linux/kernel.h>
25 
26 #include <asm/sigcontext.h>
27 #include <asm/sve_context.h>
28 #include <asm/ptrace.h>
29 
30 #include "../../kselftest.h"
31 
32 #include "fp-ptrace.h"
33 
34 /* <linux/elf.h> and <sys/auxv.h> don't like each other, so: */
35 #ifndef NT_ARM_SVE
36 #define NT_ARM_SVE 0x405
37 #endif
38 
39 #ifndef NT_ARM_SSVE
40 #define NT_ARM_SSVE 0x40b
41 #endif
42 
43 #ifndef NT_ARM_ZA
44 #define NT_ARM_ZA 0x40c
45 #endif
46 
47 #ifndef NT_ARM_ZT
48 #define NT_ARM_ZT 0x40d
49 #endif
50 
51 #define ARCH_VQ_MAX 256
52 
53 /* VL 128..2048 in powers of 2 */
54 #define MAX_NUM_VLS 5
55 
56 #define NUM_FPR 32
57 __uint128_t v_in[NUM_FPR];
58 __uint128_t v_expected[NUM_FPR];
59 __uint128_t v_out[NUM_FPR];
60 
61 char z_in[__SVE_ZREGS_SIZE(ARCH_VQ_MAX)];
62 char z_expected[__SVE_ZREGS_SIZE(ARCH_VQ_MAX)];
63 char z_out[__SVE_ZREGS_SIZE(ARCH_VQ_MAX)];
64 
65 char p_in[__SVE_PREGS_SIZE(ARCH_VQ_MAX)];
66 char p_expected[__SVE_PREGS_SIZE(ARCH_VQ_MAX)];
67 char p_out[__SVE_PREGS_SIZE(ARCH_VQ_MAX)];
68 
69 char ffr_in[__SVE_PREG_SIZE(ARCH_VQ_MAX)];
70 char ffr_expected[__SVE_PREG_SIZE(ARCH_VQ_MAX)];
71 char ffr_out[__SVE_PREG_SIZE(ARCH_VQ_MAX)];
72 
73 char za_in[ZA_SIG_REGS_SIZE(ARCH_VQ_MAX)];
74 char za_expected[ZA_SIG_REGS_SIZE(ARCH_VQ_MAX)];
75 char za_out[ZA_SIG_REGS_SIZE(ARCH_VQ_MAX)];
76 
77 char zt_in[ZT_SIG_REG_BYTES];
78 char zt_expected[ZT_SIG_REG_BYTES];
79 char zt_out[ZT_SIG_REG_BYTES];
80 
81 uint64_t sve_vl_out;
82 uint64_t sme_vl_out;
83 uint64_t svcr_in, svcr_expected, svcr_out;
84 
85 void load_and_save(int sve, int sme, int sme2, int fa64);
86 
87 static bool got_alarm;
88 
89 static void handle_alarm(int sig, siginfo_t *info, void *context)
90 {
91 	got_alarm = true;
92 }
93 
94 #ifdef CONFIG_CPU_BIG_ENDIAN
95 static __uint128_t arm64_cpu_to_le128(__uint128_t x)
96 {
97 	u64 a = swab64(x);
98 	u64 b = swab64(x >> 64);
99 
100 	return ((__uint128_t)a << 64) | b;
101 }
102 #else
103 static __uint128_t arm64_cpu_to_le128(__uint128_t x)
104 {
105 	return x;
106 }
107 #endif
108 
109 #define arm64_le128_to_cpu(x) arm64_cpu_to_le128(x)
110 
111 static bool sve_supported(void)
112 {
113 	return getauxval(AT_HWCAP) & HWCAP_SVE;
114 }
115 
116 static bool sme_supported(void)
117 {
118 	return getauxval(AT_HWCAP2) & HWCAP2_SME;
119 }
120 
121 static bool sme2_supported(void)
122 {
123 	return getauxval(AT_HWCAP2) & HWCAP2_SME2;
124 }
125 
126 static bool fa64_supported(void)
127 {
128 	return getauxval(AT_HWCAP2) & HWCAP2_SME_FA64;
129 }
130 
131 static bool compare_buffer(const char *name, void *out,
132 			   void *expected, size_t size)
133 {
134 	void *tmp;
135 
136 	if (memcmp(out, expected, size) == 0)
137 		return true;
138 
139 	ksft_print_msg("Mismatch in %s\n", name);
140 
141 	/* Did we just get zeros back? */
142 	tmp = malloc(size);
143 	if (!tmp) {
144 		ksft_print_msg("OOM allocating %lu bytes for %s\n",
145 			       size, name);
146 		ksft_exit_fail();
147 	}
148 	memset(tmp, 0, size);
149 
150 	if (memcmp(out, tmp, size) == 0)
151 		ksft_print_msg("%s is zero\n", name);
152 
153 	free(tmp);
154 
155 	return false;
156 }
157 
158 struct test_config {
159 	int sve_vl_in;
160 	int sve_vl_expected;
161 	int sme_vl_in;
162 	int sme_vl_expected;
163 	int svcr_in;
164 	int svcr_expected;
165 };
166 
167 struct test_definition {
168 	const char *name;
169 	bool sve_vl_change;
170 	bool (*supported)(struct test_config *config);
171 	void (*set_expected_values)(struct test_config *config);
172 	void (*modify_values)(pid_t child, struct test_config *test_config);
173 };
174 
175 static int vl_in(struct test_config *config)
176 {
177 	int vl;
178 
179 	if (config->svcr_in & SVCR_SM)
180 		vl = config->sme_vl_in;
181 	else
182 		vl = config->sve_vl_in;
183 
184 	return vl;
185 }
186 
187 static int vl_expected(struct test_config *config)
188 {
189 	int vl;
190 
191 	if (config->svcr_expected & SVCR_SM)
192 		vl = config->sme_vl_expected;
193 	else
194 		vl = config->sve_vl_expected;
195 
196 	return vl;
197 }
198 
199 static void run_child(struct test_config *config)
200 {
201 	int ret;
202 
203 	/* Let the parent attach to us */
204 	ret = ptrace(PTRACE_TRACEME, 0, 0, 0);
205 	if (ret < 0)
206 		ksft_exit_fail_msg("PTRACE_TRACEME failed: %s (%d)\n",
207 				   strerror(errno), errno);
208 
209 	/* VL setup */
210 	if (sve_supported()) {
211 		ret = prctl(PR_SVE_SET_VL, config->sve_vl_in);
212 		if (ret != config->sve_vl_in) {
213 			ksft_print_msg("Failed to set SVE VL %d: %d\n",
214 				       config->sve_vl_in, ret);
215 		}
216 	}
217 
218 	if (sme_supported()) {
219 		ret = prctl(PR_SME_SET_VL, config->sme_vl_in);
220 		if (ret != config->sme_vl_in) {
221 			ksft_print_msg("Failed to set SME VL %d: %d\n",
222 				       config->sme_vl_in, ret);
223 		}
224 	}
225 
226 	/* Load values and wait for the parent */
227 	load_and_save(sve_supported(), sme_supported(),
228 		      sme2_supported(), fa64_supported());
229 
230 	exit(0);
231 }
232 
233 static void read_one_child_regs(pid_t child, char *name,
234 				struct iovec *iov_parent,
235 				struct iovec *iov_child)
236 {
237 	int len = iov_parent->iov_len;
238 	int ret;
239 
240 	ret = process_vm_readv(child, iov_parent, 1, iov_child, 1, 0);
241 	if (ret == -1)
242 		ksft_print_msg("%s read failed: %s (%d)\n",
243 			       name, strerror(errno), errno);
244 	else if (ret != len)
245 		ksft_print_msg("Short read of %s: %d\n", name, ret);
246 }
247 
248 static void read_child_regs(pid_t child)
249 {
250 	struct iovec iov_parent, iov_child;
251 
252 	/*
253 	 * Since the child fork()ed from us the buffer addresses are
254 	 * the same in parent and child.
255 	 */
256 	iov_parent.iov_base = &v_out;
257 	iov_parent.iov_len = sizeof(v_out);
258 	iov_child.iov_base = &v_out;
259 	iov_child.iov_len = sizeof(v_out);
260 	read_one_child_regs(child, "FPSIMD", &iov_parent, &iov_child);
261 
262 	if (sve_supported() || sme_supported()) {
263 		iov_parent.iov_base = &sve_vl_out;
264 		iov_parent.iov_len = sizeof(sve_vl_out);
265 		iov_child.iov_base = &sve_vl_out;
266 		iov_child.iov_len = sizeof(sve_vl_out);
267 		read_one_child_regs(child, "SVE VL", &iov_parent, &iov_child);
268 
269 		iov_parent.iov_base = &z_out;
270 		iov_parent.iov_len = sizeof(z_out);
271 		iov_child.iov_base = &z_out;
272 		iov_child.iov_len = sizeof(z_out);
273 		read_one_child_regs(child, "Z", &iov_parent, &iov_child);
274 
275 		iov_parent.iov_base = &p_out;
276 		iov_parent.iov_len = sizeof(p_out);
277 		iov_child.iov_base = &p_out;
278 		iov_child.iov_len = sizeof(p_out);
279 		read_one_child_regs(child, "P", &iov_parent, &iov_child);
280 
281 		iov_parent.iov_base = &ffr_out;
282 		iov_parent.iov_len = sizeof(ffr_out);
283 		iov_child.iov_base = &ffr_out;
284 		iov_child.iov_len = sizeof(ffr_out);
285 		read_one_child_regs(child, "FFR", &iov_parent, &iov_child);
286 	}
287 
288 	if (sme_supported()) {
289 		iov_parent.iov_base = &sme_vl_out;
290 		iov_parent.iov_len = sizeof(sme_vl_out);
291 		iov_child.iov_base = &sme_vl_out;
292 		iov_child.iov_len = sizeof(sme_vl_out);
293 		read_one_child_regs(child, "SME VL", &iov_parent, &iov_child);
294 
295 		iov_parent.iov_base = &svcr_out;
296 		iov_parent.iov_len = sizeof(svcr_out);
297 		iov_child.iov_base = &svcr_out;
298 		iov_child.iov_len = sizeof(svcr_out);
299 		read_one_child_regs(child, "SVCR", &iov_parent, &iov_child);
300 
301 		iov_parent.iov_base = &za_out;
302 		iov_parent.iov_len = sizeof(za_out);
303 		iov_child.iov_base = &za_out;
304 		iov_child.iov_len = sizeof(za_out);
305 		read_one_child_regs(child, "ZA", &iov_parent, &iov_child);
306 	}
307 
308 	if (sme2_supported()) {
309 		iov_parent.iov_base = &zt_out;
310 		iov_parent.iov_len = sizeof(zt_out);
311 		iov_child.iov_base = &zt_out;
312 		iov_child.iov_len = sizeof(zt_out);
313 		read_one_child_regs(child, "ZT", &iov_parent, &iov_child);
314 	}
315 }
316 
317 static bool continue_breakpoint(pid_t child,
318 				enum __ptrace_request restart_type)
319 {
320 	struct user_pt_regs pt_regs;
321 	struct iovec iov;
322 	int ret;
323 
324 	/* Get PC */
325 	iov.iov_base = &pt_regs;
326 	iov.iov_len = sizeof(pt_regs);
327 	ret = ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov);
328 	if (ret < 0) {
329 		ksft_print_msg("Failed to get PC: %s (%d)\n",
330 			       strerror(errno), errno);
331 		return false;
332 	}
333 
334 	/* Skip over the BRK */
335 	pt_regs.pc += 4;
336 	ret = ptrace(PTRACE_SETREGSET, child, NT_PRSTATUS, &iov);
337 	if (ret < 0) {
338 		ksft_print_msg("Failed to skip BRK: %s (%d)\n",
339 			       strerror(errno), errno);
340 		return false;
341 	}
342 
343 	/* Restart */
344 	ret = ptrace(restart_type, child, 0, 0);
345 	if (ret < 0) {
346 		ksft_print_msg("Failed to restart child: %s (%d)\n",
347 			       strerror(errno), errno);
348 		return false;
349 	}
350 
351 	return true;
352 }
353 
354 static bool check_ptrace_values_sve(pid_t child, struct test_config *config)
355 {
356 	struct user_sve_header *sve;
357 	struct user_fpsimd_state *fpsimd;
358 	struct iovec iov;
359 	int ret, vq;
360 	bool pass = true;
361 
362 	if (!sve_supported())
363 		return true;
364 
365 	vq = __sve_vq_from_vl(config->sve_vl_in);
366 
367 	iov.iov_len = SVE_PT_SVE_OFFSET + SVE_PT_SVE_SIZE(vq, SVE_PT_REGS_SVE);
368 	iov.iov_base = malloc(iov.iov_len);
369 	if (!iov.iov_base) {
370 		ksft_print_msg("OOM allocating %lu byte SVE buffer\n",
371 			       iov.iov_len);
372 		return false;
373 	}
374 
375 	ret = ptrace(PTRACE_GETREGSET, child, NT_ARM_SVE, &iov);
376 	if (ret != 0) {
377 		ksft_print_msg("Failed to read initial SVE: %s (%d)\n",
378 			       strerror(errno), errno);
379 		pass = false;
380 		goto out;
381 	}
382 
383 	sve = iov.iov_base;
384 
385 	if (sve->vl != config->sve_vl_in) {
386 		ksft_print_msg("Mismatch in initial SVE VL: %d != %d\n",
387 			       sve->vl, config->sve_vl_in);
388 		pass = false;
389 	}
390 
391 	/* If we are in streaming mode we should just read FPSIMD */
392 	if ((config->svcr_in & SVCR_SM) && (sve->flags & SVE_PT_REGS_SVE)) {
393 		ksft_print_msg("NT_ARM_SVE reports SVE with PSTATE.SM\n");
394 		pass = false;
395 	}
396 
397 	if (sve->size != SVE_PT_SIZE(vq, sve->flags)) {
398 		ksft_print_msg("Mismatch in SVE header size: %d != %lu\n",
399 			       sve->size, SVE_PT_SIZE(vq, sve->flags));
400 		pass = false;
401 	}
402 
403 	/* The registers might be in completely different formats! */
404 	if (sve->flags & SVE_PT_REGS_SVE) {
405 		if (!compare_buffer("initial SVE Z",
406 				    iov.iov_base + SVE_PT_SVE_ZREG_OFFSET(vq, 0),
407 				    z_in, SVE_PT_SVE_ZREGS_SIZE(vq)))
408 			pass = false;
409 
410 		if (!compare_buffer("initial SVE P",
411 				    iov.iov_base + SVE_PT_SVE_PREG_OFFSET(vq, 0),
412 				    p_in, SVE_PT_SVE_PREGS_SIZE(vq)))
413 			pass = false;
414 
415 		if (!compare_buffer("initial SVE FFR",
416 				    iov.iov_base + SVE_PT_SVE_FFR_OFFSET(vq),
417 				    ffr_in, SVE_PT_SVE_PREG_SIZE(vq)))
418 			pass = false;
419 	} else {
420 		fpsimd = iov.iov_base + SVE_PT_FPSIMD_OFFSET;
421 		if (!compare_buffer("initial V via SVE", &fpsimd->vregs[0],
422 				    v_in, sizeof(v_in)))
423 			pass = false;
424 	}
425 
426 out:
427 	free(iov.iov_base);
428 	return pass;
429 }
430 
431 static bool check_ptrace_values_ssve(pid_t child, struct test_config *config)
432 {
433 	struct user_sve_header *sve;
434 	struct user_fpsimd_state *fpsimd;
435 	struct iovec iov;
436 	int ret, vq;
437 	bool pass = true;
438 
439 	if (!sme_supported())
440 		return true;
441 
442 	vq = __sve_vq_from_vl(config->sme_vl_in);
443 
444 	iov.iov_len = SVE_PT_SVE_OFFSET + SVE_PT_SVE_SIZE(vq, SVE_PT_REGS_SVE);
445 	iov.iov_base = malloc(iov.iov_len);
446 	if (!iov.iov_base) {
447 		ksft_print_msg("OOM allocating %lu byte SSVE buffer\n",
448 			       iov.iov_len);
449 		return false;
450 	}
451 
452 	ret = ptrace(PTRACE_GETREGSET, child, NT_ARM_SSVE, &iov);
453 	if (ret != 0) {
454 		ksft_print_msg("Failed to read initial SSVE: %s (%d)\n",
455 			       strerror(errno), errno);
456 		pass = false;
457 		goto out;
458 	}
459 
460 	sve = iov.iov_base;
461 
462 	if (sve->vl != config->sme_vl_in) {
463 		ksft_print_msg("Mismatch in initial SSVE VL: %d != %d\n",
464 			       sve->vl, config->sme_vl_in);
465 		pass = false;
466 	}
467 
468 	if ((config->svcr_in & SVCR_SM) && !(sve->flags & SVE_PT_REGS_SVE)) {
469 		ksft_print_msg("NT_ARM_SSVE reports FPSIMD with PSTATE.SM\n");
470 		pass = false;
471 	}
472 
473 	if (sve->size != SVE_PT_SIZE(vq, sve->flags)) {
474 		ksft_print_msg("Mismatch in SSVE header size: %d != %lu\n",
475 			       sve->size, SVE_PT_SIZE(vq, sve->flags));
476 		pass = false;
477 	}
478 
479 	/* The registers might be in completely different formats! */
480 	if (sve->flags & SVE_PT_REGS_SVE) {
481 		if (!compare_buffer("initial SSVE Z",
482 				    iov.iov_base + SVE_PT_SVE_ZREG_OFFSET(vq, 0),
483 				    z_in, SVE_PT_SVE_ZREGS_SIZE(vq)))
484 			pass = false;
485 
486 		if (!compare_buffer("initial SSVE P",
487 				    iov.iov_base + SVE_PT_SVE_PREG_OFFSET(vq, 0),
488 				    p_in, SVE_PT_SVE_PREGS_SIZE(vq)))
489 			pass = false;
490 
491 		if (!compare_buffer("initial SSVE FFR",
492 				    iov.iov_base + SVE_PT_SVE_FFR_OFFSET(vq),
493 				    ffr_in, SVE_PT_SVE_PREG_SIZE(vq)))
494 			pass = false;
495 	} else {
496 		fpsimd = iov.iov_base + SVE_PT_FPSIMD_OFFSET;
497 		if (!compare_buffer("initial V via SSVE",
498 				    &fpsimd->vregs[0], v_in, sizeof(v_in)))
499 			pass = false;
500 	}
501 
502 out:
503 	free(iov.iov_base);
504 	return pass;
505 }
506 
507 static bool check_ptrace_values_za(pid_t child, struct test_config *config)
508 {
509 	struct user_za_header *za;
510 	struct iovec iov;
511 	int ret, vq;
512 	bool pass = true;
513 
514 	if (!sme_supported())
515 		return true;
516 
517 	vq = __sve_vq_from_vl(config->sme_vl_in);
518 
519 	iov.iov_len = ZA_SIG_CONTEXT_SIZE(vq);
520 	iov.iov_base = malloc(iov.iov_len);
521 	if (!iov.iov_base) {
522 		ksft_print_msg("OOM allocating %lu byte ZA buffer\n",
523 			       iov.iov_len);
524 		return false;
525 	}
526 
527 	ret = ptrace(PTRACE_GETREGSET, child, NT_ARM_ZA, &iov);
528 	if (ret != 0) {
529 		ksft_print_msg("Failed to read initial ZA: %s (%d)\n",
530 			       strerror(errno), errno);
531 		pass = false;
532 		goto out;
533 	}
534 
535 	za = iov.iov_base;
536 
537 	if (za->vl != config->sme_vl_in) {
538 		ksft_print_msg("Mismatch in initial SME VL: %d != %d\n",
539 			       za->vl, config->sme_vl_in);
540 		pass = false;
541 	}
542 
543 	/* If PSTATE.ZA is not set we should just read the header */
544 	if (config->svcr_in & SVCR_ZA) {
545 		if (za->size != ZA_PT_SIZE(vq)) {
546 			ksft_print_msg("Unexpected ZA ptrace read size: %d != %lu\n",
547 				       za->size, ZA_PT_SIZE(vq));
548 			pass = false;
549 		}
550 
551 		if (!compare_buffer("initial ZA",
552 				    iov.iov_base + ZA_PT_ZA_OFFSET,
553 				    za_in, ZA_PT_ZA_SIZE(vq)))
554 			pass = false;
555 	} else {
556 		if (za->size != sizeof(*za)) {
557 			ksft_print_msg("Unexpected ZA ptrace read size: %d != %lu\n",
558 				       za->size, sizeof(*za));
559 			pass = false;
560 		}
561 	}
562 
563 out:
564 	free(iov.iov_base);
565 	return pass;
566 }
567 
568 static bool check_ptrace_values_zt(pid_t child, struct test_config *config)
569 {
570 	uint8_t buf[512];
571 	struct iovec iov;
572 	int ret;
573 
574 	if (!sme2_supported())
575 		return true;
576 
577 	iov.iov_base = &buf;
578 	iov.iov_len = ZT_SIG_REG_BYTES;
579 	ret = ptrace(PTRACE_GETREGSET, child, NT_ARM_ZT, &iov);
580 	if (ret != 0) {
581 		ksft_print_msg("Failed to read initial ZT: %s (%d)\n",
582 			       strerror(errno), errno);
583 		return false;
584 	}
585 
586 	return compare_buffer("initial ZT", buf, zt_in, ZT_SIG_REG_BYTES);
587 }
588 
589 
590 static bool check_ptrace_values(pid_t child, struct test_config *config)
591 {
592 	bool pass = true;
593 	struct user_fpsimd_state fpsimd;
594 	struct iovec iov;
595 	int ret;
596 
597 	iov.iov_base = &fpsimd;
598 	iov.iov_len = sizeof(fpsimd);
599 	ret = ptrace(PTRACE_GETREGSET, child, NT_PRFPREG, &iov);
600 	if (ret == 0) {
601 		if (!compare_buffer("initial V", &fpsimd.vregs, v_in,
602 				    sizeof(v_in))) {
603 			pass = false;
604 		}
605 	} else {
606 		ksft_print_msg("Failed to read initial V: %s (%d)\n",
607 			       strerror(errno), errno);
608 		pass = false;
609 	}
610 
611 	if (!check_ptrace_values_sve(child, config))
612 		pass = false;
613 
614 	if (!check_ptrace_values_ssve(child, config))
615 		pass = false;
616 
617 	if (!check_ptrace_values_za(child, config))
618 		pass = false;
619 
620 	if (!check_ptrace_values_zt(child, config))
621 		pass = false;
622 
623 	return pass;
624 }
625 
626 static bool run_parent(pid_t child, struct test_definition *test,
627 		       struct test_config *config)
628 {
629 	int wait_status, ret;
630 	pid_t pid;
631 	bool pass;
632 
633 	/* Initial attach */
634 	while (1) {
635 		pid = waitpid(child, &wait_status, 0);
636 		if (pid < 0) {
637 			if (errno == EINTR)
638 				continue;
639 			ksft_exit_fail_msg("waitpid() failed: %s (%d)\n",
640 					   strerror(errno), errno);
641 		}
642 
643 		if (pid == child)
644 			break;
645 	}
646 
647 	if (WIFEXITED(wait_status)) {
648 		ksft_print_msg("Child exited loading values with status %d\n",
649 			       WEXITSTATUS(wait_status));
650 		pass = false;
651 		goto out;
652 	}
653 
654 	if (WIFSIGNALED(wait_status)) {
655 		ksft_print_msg("Child died from signal %d loading values\n",
656 			       WTERMSIG(wait_status));
657 		pass = false;
658 		goto out;
659 	}
660 
661 	/* Read initial values via ptrace */
662 	pass = check_ptrace_values(child, config);
663 
664 	/* Do whatever writes we want to do */
665 	if (test->modify_values)
666 		test->modify_values(child, config);
667 
668 	if (!continue_breakpoint(child, PTRACE_CONT))
669 		goto cleanup;
670 
671 	while (1) {
672 		pid = waitpid(child, &wait_status, 0);
673 		if (pid < 0) {
674 			if (errno == EINTR)
675 				continue;
676 			ksft_exit_fail_msg("waitpid() failed: %s (%d)\n",
677 					   strerror(errno), errno);
678 		}
679 
680 		if (pid == child)
681 			break;
682 	}
683 
684 	if (WIFEXITED(wait_status)) {
685 		ksft_print_msg("Child exited saving values with status %d\n",
686 			       WEXITSTATUS(wait_status));
687 		pass = false;
688 		goto out;
689 	}
690 
691 	if (WIFSIGNALED(wait_status)) {
692 		ksft_print_msg("Child died from signal %d saving values\n",
693 			       WTERMSIG(wait_status));
694 		pass = false;
695 		goto out;
696 	}
697 
698 	/* See what happened as a result */
699 	read_child_regs(child);
700 
701 	if (!continue_breakpoint(child, PTRACE_DETACH))
702 		goto cleanup;
703 
704 	/* The child should exit cleanly */
705 	got_alarm = false;
706 	alarm(1);
707 	while (1) {
708 		if (got_alarm) {
709 			ksft_print_msg("Wait for child timed out\n");
710 			goto cleanup;
711 		}
712 
713 		pid = waitpid(child, &wait_status, 0);
714 		if (pid < 0) {
715 			if (errno == EINTR)
716 				continue;
717 			ksft_exit_fail_msg("waitpid() failed: %s (%d)\n",
718 					   strerror(errno), errno);
719 		}
720 
721 		if (pid == child)
722 			break;
723 	}
724 	alarm(0);
725 
726 	if (got_alarm) {
727 		ksft_print_msg("Timed out waiting for child\n");
728 		pass = false;
729 		goto cleanup;
730 	}
731 
732 	if (pid == child && WIFSIGNALED(wait_status)) {
733 		ksft_print_msg("Child died from signal %d cleaning up\n",
734 			       WTERMSIG(wait_status));
735 		pass = false;
736 		goto out;
737 	}
738 
739 	if (pid == child && WIFEXITED(wait_status)) {
740 		if (WEXITSTATUS(wait_status) != 0) {
741 			ksft_print_msg("Child exited with error %d\n",
742 				       WEXITSTATUS(wait_status));
743 			pass = false;
744 		}
745 	} else {
746 		ksft_print_msg("Child did not exit cleanly\n");
747 		pass = false;
748 		goto cleanup;
749 	}
750 
751 	goto out;
752 
753 cleanup:
754 	ret = kill(child, SIGKILL);
755 	if (ret != 0) {
756 		ksft_print_msg("kill() failed: %s (%d)\n",
757 			       strerror(errno), errno);
758 		return false;
759 	}
760 
761 	while (1) {
762 		pid = waitpid(child, &wait_status, 0);
763 		if (pid < 0) {
764 			if (errno == EINTR)
765 				continue;
766 			ksft_exit_fail_msg("waitpid() failed: %s (%d)\n",
767 					   strerror(errno), errno);
768 		}
769 
770 		if (pid == child)
771 			break;
772 	}
773 
774 out:
775 	return pass;
776 }
777 
778 static void fill_random(void *buf, size_t size)
779 {
780 	int i;
781 	uint32_t *lbuf = buf;
782 
783 	/* random() returns a 32 bit number regardless of the size of long */
784 	for (i = 0; i < size / sizeof(uint32_t); i++)
785 		lbuf[i] = random();
786 }
787 
788 static void fill_random_ffr(void *buf, size_t vq)
789 {
790 	uint8_t *lbuf = buf;
791 	int bits, i;
792 
793 	/*
794 	 * Only values with a continuous set of 0..n bits set are
795 	 * valid for FFR, set all bits then clear a random number of
796 	 * high bits.
797 	 */
798 	memset(buf, 0, __SVE_FFR_SIZE(vq));
799 
800 	bits = random() % (__SVE_FFR_SIZE(vq) * 8);
801 	for (i = 0; i < bits / 8; i++)
802 		lbuf[i] = 0xff;
803 	if (bits / 8 != __SVE_FFR_SIZE(vq))
804 		lbuf[i] = (1 << (bits % 8)) - 1;
805 }
806 
807 static void fpsimd_to_sve(__uint128_t *v, char *z, int vl)
808 {
809 	int vq = __sve_vq_from_vl(vl);
810 	int i;
811 	__uint128_t *p;
812 
813 	if (!vl)
814 		return;
815 
816 	for (i = 0; i < __SVE_NUM_ZREGS; i++) {
817 		p = (__uint128_t *)&z[__SVE_ZREG_OFFSET(vq, i)];
818 		*p = arm64_cpu_to_le128(v[i]);
819 	}
820 }
821 
822 static void set_initial_values(struct test_config *config)
823 {
824 	int vq = __sve_vq_from_vl(vl_in(config));
825 	int sme_vq = __sve_vq_from_vl(config->sme_vl_in);
826 
827 	svcr_in = config->svcr_in;
828 	svcr_expected = config->svcr_expected;
829 	svcr_out = 0;
830 
831 	fill_random(&v_in, sizeof(v_in));
832 	memcpy(v_expected, v_in, sizeof(v_in));
833 	memset(v_out, 0, sizeof(v_out));
834 
835 	/* Changes will be handled in the test case */
836 	if (sve_supported() || (config->svcr_in & SVCR_SM)) {
837 		/* The low 128 bits of Z are shared with the V registers */
838 		fill_random(&z_in, __SVE_ZREGS_SIZE(vq));
839 		fpsimd_to_sve(v_in, z_in, vl_in(config));
840 		memcpy(z_expected, z_in, __SVE_ZREGS_SIZE(vq));
841 		memset(z_out, 0, sizeof(z_out));
842 
843 		fill_random(&p_in, __SVE_PREGS_SIZE(vq));
844 		memcpy(p_expected, p_in, __SVE_PREGS_SIZE(vq));
845 		memset(p_out, 0, sizeof(p_out));
846 
847 		if ((config->svcr_in & SVCR_SM) && !fa64_supported())
848 			memset(ffr_in, 0, __SVE_PREG_SIZE(vq));
849 		else
850 			fill_random_ffr(&ffr_in, vq);
851 		memcpy(ffr_expected, ffr_in, __SVE_PREG_SIZE(vq));
852 		memset(ffr_out, 0, __SVE_PREG_SIZE(vq));
853 	}
854 
855 	if (config->svcr_in & SVCR_ZA)
856 		fill_random(za_in, ZA_SIG_REGS_SIZE(sme_vq));
857 	else
858 		memset(za_in, 0, ZA_SIG_REGS_SIZE(sme_vq));
859 	if (config->svcr_expected & SVCR_ZA)
860 		memcpy(za_expected, za_in, ZA_SIG_REGS_SIZE(sme_vq));
861 	else
862 		memset(za_expected, 0, ZA_SIG_REGS_SIZE(sme_vq));
863 	if (sme_supported())
864 		memset(za_out, 0, sizeof(za_out));
865 
866 	if (sme2_supported()) {
867 		if (config->svcr_in & SVCR_ZA)
868 			fill_random(zt_in, ZT_SIG_REG_BYTES);
869 		else
870 			memset(zt_in, 0, ZT_SIG_REG_BYTES);
871 		if (config->svcr_expected & SVCR_ZA)
872 			memcpy(zt_expected, zt_in, ZT_SIG_REG_BYTES);
873 		else
874 			memset(zt_expected, 0, ZT_SIG_REG_BYTES);
875 		memset(zt_out, 0, sizeof(zt_out));
876 	}
877 }
878 
879 static bool check_memory_values(struct test_config *config)
880 {
881 	bool pass = true;
882 	int vq, sme_vq;
883 
884 	if (!compare_buffer("saved V", v_out, v_expected, sizeof(v_out)))
885 		pass = false;
886 
887 	vq = __sve_vq_from_vl(vl_expected(config));
888 	sme_vq = __sve_vq_from_vl(config->sme_vl_expected);
889 
890 	if (svcr_out != svcr_expected) {
891 		ksft_print_msg("Mismatch in saved SVCR %lx != %lx\n",
892 			       svcr_out, svcr_expected);
893 		pass = false;
894 	}
895 
896 	if (sve_vl_out != config->sve_vl_expected) {
897 		ksft_print_msg("Mismatch in SVE VL: %ld != %d\n",
898 			       sve_vl_out, config->sve_vl_expected);
899 		pass = false;
900 	}
901 
902 	if (sme_vl_out != config->sme_vl_expected) {
903 		ksft_print_msg("Mismatch in SME VL: %ld != %d\n",
904 			       sme_vl_out, config->sme_vl_expected);
905 		pass = false;
906 	}
907 
908 	if (!compare_buffer("saved Z", z_out, z_expected,
909 			    __SVE_ZREGS_SIZE(vq)))
910 		pass = false;
911 
912 	if (!compare_buffer("saved P", p_out, p_expected,
913 			    __SVE_PREGS_SIZE(vq)))
914 		pass = false;
915 
916 	if (!compare_buffer("saved FFR", ffr_out, ffr_expected,
917 			    __SVE_PREG_SIZE(vq)))
918 		pass = false;
919 
920 	if (!compare_buffer("saved ZA", za_out, za_expected,
921 			    ZA_PT_ZA_SIZE(sme_vq)))
922 		pass = false;
923 
924 	if (!compare_buffer("saved ZT", zt_out, zt_expected, ZT_SIG_REG_BYTES))
925 		pass = false;
926 
927 	return pass;
928 }
929 
930 static bool sve_sme_same(struct test_config *config)
931 {
932 	if (config->sve_vl_in != config->sve_vl_expected)
933 		return false;
934 
935 	if (config->sme_vl_in != config->sme_vl_expected)
936 		return false;
937 
938 	if (config->svcr_in != config->svcr_expected)
939 		return false;
940 
941 	return true;
942 }
943 
944 static bool sve_write_supported(struct test_config *config)
945 {
946 	if (!sve_supported() && !sme_supported())
947 		return false;
948 
949 	if ((config->svcr_in & SVCR_ZA) != (config->svcr_expected & SVCR_ZA))
950 		return false;
951 
952 	if (config->svcr_expected & SVCR_SM) {
953 		if (config->sve_vl_in != config->sve_vl_expected) {
954 			return false;
955 		}
956 
957 		/* Changing the SME VL disables ZA */
958 		if ((config->svcr_expected & SVCR_ZA) &&
959 		    (config->sme_vl_in != config->sme_vl_expected)) {
960 			return false;
961 		}
962 	} else {
963 		if (config->sme_vl_in != config->sme_vl_expected) {
964 			return false;
965 		}
966 	}
967 
968 	return true;
969 }
970 
971 static void fpsimd_write_expected(struct test_config *config)
972 {
973 	int vl;
974 
975 	fill_random(&v_expected, sizeof(v_expected));
976 
977 	/* The SVE registers are flushed by a FPSIMD write */
978 	vl = vl_expected(config);
979 
980 	memset(z_expected, 0, __SVE_ZREGS_SIZE(__sve_vq_from_vl(vl)));
981 	memset(p_expected, 0, __SVE_PREGS_SIZE(__sve_vq_from_vl(vl)));
982 	memset(ffr_expected, 0, __SVE_PREG_SIZE(__sve_vq_from_vl(vl)));
983 
984 	fpsimd_to_sve(v_expected, z_expected, vl);
985 }
986 
987 static void fpsimd_write(pid_t child, struct test_config *test_config)
988 {
989 	struct user_fpsimd_state fpsimd;
990 	struct iovec iov;
991 	int ret;
992 
993 	memset(&fpsimd, 0, sizeof(fpsimd));
994 	memcpy(&fpsimd.vregs, v_expected, sizeof(v_expected));
995 
996 	iov.iov_base = &fpsimd;
997 	iov.iov_len = sizeof(fpsimd);
998 	ret = ptrace(PTRACE_SETREGSET, child, NT_PRFPREG, &iov);
999 	if (ret == -1)
1000 		ksft_print_msg("FPSIMD set failed: (%s) %d\n",
1001 			       strerror(errno), errno);
1002 }
1003 
1004 static void sve_write_expected(struct test_config *config)
1005 {
1006 	int vl = vl_expected(config);
1007 	int sme_vq = __sve_vq_from_vl(config->sme_vl_expected);
1008 
1009 	fill_random(z_expected, __SVE_ZREGS_SIZE(__sve_vq_from_vl(vl)));
1010 	fill_random(p_expected, __SVE_PREGS_SIZE(__sve_vq_from_vl(vl)));
1011 
1012 	if ((svcr_expected & SVCR_SM) && !fa64_supported())
1013 		memset(ffr_expected, 0, __SVE_PREG_SIZE(sme_vq));
1014 	else
1015 		fill_random_ffr(ffr_expected, __sve_vq_from_vl(vl));
1016 
1017 	/* Share the low bits of Z with V */
1018 	fill_random(&v_expected, sizeof(v_expected));
1019 	fpsimd_to_sve(v_expected, z_expected, vl);
1020 
1021 	if (config->sme_vl_in != config->sme_vl_expected) {
1022 		memset(za_expected, 0, ZA_PT_ZA_SIZE(sme_vq));
1023 		memset(zt_expected, 0, sizeof(zt_expected));
1024 	}
1025 }
1026 
1027 static void sve_write(pid_t child, struct test_config *config)
1028 {
1029 	struct user_sve_header *sve;
1030 	struct iovec iov;
1031 	int ret, vl, vq, regset;
1032 
1033 	vl = vl_expected(config);
1034 	vq = __sve_vq_from_vl(vl);
1035 
1036 	iov.iov_len = SVE_PT_SVE_OFFSET + SVE_PT_SVE_SIZE(vq, SVE_PT_REGS_SVE);
1037 	iov.iov_base = malloc(iov.iov_len);
1038 	if (!iov.iov_base) {
1039 		ksft_print_msg("Failed allocating %lu byte SVE write buffer\n",
1040 			       iov.iov_len);
1041 		return;
1042 	}
1043 	memset(iov.iov_base, 0, iov.iov_len);
1044 
1045 	sve = iov.iov_base;
1046 	sve->size = iov.iov_len;
1047 	sve->flags = SVE_PT_REGS_SVE;
1048 	sve->vl = vl;
1049 
1050 	memcpy(iov.iov_base + SVE_PT_SVE_ZREG_OFFSET(vq, 0),
1051 	       z_expected, SVE_PT_SVE_ZREGS_SIZE(vq));
1052 	memcpy(iov.iov_base + SVE_PT_SVE_PREG_OFFSET(vq, 0),
1053 	       p_expected, SVE_PT_SVE_PREGS_SIZE(vq));
1054 	memcpy(iov.iov_base + SVE_PT_SVE_FFR_OFFSET(vq),
1055 	       ffr_expected, SVE_PT_SVE_PREG_SIZE(vq));
1056 
1057 	if (svcr_expected & SVCR_SM)
1058 		regset = NT_ARM_SSVE;
1059 	else
1060 		regset = NT_ARM_SVE;
1061 
1062 	ret = ptrace(PTRACE_SETREGSET, child, regset, &iov);
1063 	if (ret != 0)
1064 		ksft_print_msg("Failed to write SVE: %s (%d)\n",
1065 			       strerror(errno), errno);
1066 
1067 	free(iov.iov_base);
1068 }
1069 
1070 static bool za_write_supported(struct test_config *config)
1071 {
1072 	if (config->svcr_expected & SVCR_SM) {
1073 		if (!(config->svcr_in & SVCR_SM))
1074 			return false;
1075 
1076 		/* Changing the SME VL exits streaming mode */
1077 		if (config->sme_vl_in != config->sme_vl_expected) {
1078 			return false;
1079 		}
1080 	}
1081 
1082 	/* Can't disable SM outside a VL change */
1083 	if ((config->svcr_in & SVCR_SM) &&
1084 	    !(config->svcr_expected & SVCR_SM))
1085 		return false;
1086 
1087 	return true;
1088 }
1089 
1090 static void za_write_expected(struct test_config *config)
1091 {
1092 	int sme_vq, sve_vq;
1093 
1094 	sme_vq = __sve_vq_from_vl(config->sme_vl_expected);
1095 
1096 	if (config->svcr_expected & SVCR_ZA) {
1097 		fill_random(za_expected, ZA_PT_ZA_SIZE(sme_vq));
1098 	} else {
1099 		memset(za_expected, 0, ZA_PT_ZA_SIZE(sme_vq));
1100 		memset(zt_expected, 0, sizeof(zt_expected));
1101 	}
1102 
1103 	/* Changing the SME VL flushes ZT, SVE state and exits SM */
1104 	if (config->sme_vl_in != config->sme_vl_expected) {
1105 		svcr_expected &= ~SVCR_SM;
1106 
1107 		sve_vq = __sve_vq_from_vl(vl_expected(config));
1108 		memset(z_expected, 0, __SVE_ZREGS_SIZE(sve_vq));
1109 		memset(p_expected, 0, __SVE_PREGS_SIZE(sve_vq));
1110 		memset(ffr_expected, 0, __SVE_PREG_SIZE(sve_vq));
1111 		memset(zt_expected, 0, sizeof(zt_expected));
1112 
1113 		fpsimd_to_sve(v_expected, z_expected, vl_expected(config));
1114 	}
1115 }
1116 
1117 static void za_write(pid_t child, struct test_config *config)
1118 {
1119 	struct user_za_header *za;
1120 	struct iovec iov;
1121 	int ret, vq;
1122 
1123 	vq = __sve_vq_from_vl(config->sme_vl_expected);
1124 
1125 	if (config->svcr_expected & SVCR_ZA)
1126 		iov.iov_len = ZA_PT_SIZE(vq);
1127 	else
1128 		iov.iov_len = sizeof(*za);
1129 	iov.iov_base = malloc(iov.iov_len);
1130 	if (!iov.iov_base) {
1131 		ksft_print_msg("Failed allocating %lu byte ZA write buffer\n",
1132 			       iov.iov_len);
1133 		return;
1134 	}
1135 	memset(iov.iov_base, 0, iov.iov_len);
1136 
1137 	za = iov.iov_base;
1138 	za->size = iov.iov_len;
1139 	za->vl = config->sme_vl_expected;
1140 	if (config->svcr_expected & SVCR_ZA)
1141 		memcpy(iov.iov_base + ZA_PT_ZA_OFFSET, za_expected,
1142 		       ZA_PT_ZA_SIZE(vq));
1143 
1144 	ret = ptrace(PTRACE_SETREGSET, child, NT_ARM_ZA, &iov);
1145 	if (ret != 0)
1146 		ksft_print_msg("Failed to write ZA: %s (%d)\n",
1147 			       strerror(errno), errno);
1148 
1149 	free(iov.iov_base);
1150 }
1151 
1152 static bool zt_write_supported(struct test_config *config)
1153 {
1154 	if (!sme2_supported())
1155 		return false;
1156 	if (config->sme_vl_in != config->sme_vl_expected)
1157 		return false;
1158 	if (!(config->svcr_expected & SVCR_ZA))
1159 		return false;
1160 	if ((config->svcr_in & SVCR_SM) != (config->svcr_expected & SVCR_SM))
1161 		return false;
1162 
1163 	return true;
1164 }
1165 
1166 static void zt_write_expected(struct test_config *config)
1167 {
1168 	int sme_vq;
1169 
1170 	sme_vq = __sve_vq_from_vl(config->sme_vl_expected);
1171 
1172 	if (config->svcr_expected & SVCR_ZA) {
1173 		fill_random(zt_expected, sizeof(zt_expected));
1174 	} else {
1175 		memset(za_expected, 0, ZA_PT_ZA_SIZE(sme_vq));
1176 		memset(zt_expected, 0, sizeof(zt_expected));
1177 	}
1178 }
1179 
1180 static void zt_write(pid_t child, struct test_config *config)
1181 {
1182 	struct iovec iov;
1183 	int ret;
1184 
1185 	iov.iov_len = ZT_SIG_REG_BYTES;
1186 	iov.iov_base = zt_expected;
1187 	ret = ptrace(PTRACE_SETREGSET, child, NT_ARM_ZT, &iov);
1188 	if (ret != 0)
1189 		ksft_print_msg("Failed to write ZT: %s (%d)\n",
1190 			       strerror(errno), errno);
1191 }
1192 
1193 /* Actually run a test */
1194 static void run_test(struct test_definition *test, struct test_config *config)
1195 {
1196 	pid_t child;
1197 	char name[1024];
1198 	bool pass;
1199 
1200 	if (sve_supported() && sme_supported())
1201 		snprintf(name, sizeof(name), "%s, SVE %d->%d, SME %d/%x->%d/%x",
1202 			 test->name,
1203 			 config->sve_vl_in, config->sve_vl_expected,
1204 			 config->sme_vl_in, config->svcr_in,
1205 			 config->sme_vl_expected, config->svcr_expected);
1206 	else if (sve_supported())
1207 		snprintf(name, sizeof(name), "%s, SVE %d->%d", test->name,
1208 			 config->sve_vl_in, config->sve_vl_expected);
1209 	else if (sme_supported())
1210 		snprintf(name, sizeof(name), "%s, SME %d/%x->%d/%x",
1211 			 test->name,
1212 			 config->sme_vl_in, config->svcr_in,
1213 			 config->sme_vl_expected, config->svcr_expected);
1214 	else
1215 		snprintf(name, sizeof(name), "%s", test->name);
1216 
1217 	if (test->supported && !test->supported(config)) {
1218 		ksft_test_result_skip("%s\n", name);
1219 		return;
1220 	}
1221 
1222 	set_initial_values(config);
1223 
1224 	if (test->set_expected_values)
1225 		test->set_expected_values(config);
1226 
1227 	child = fork();
1228 	if (child < 0)
1229 		ksft_exit_fail_msg("fork() failed: %s (%d)\n",
1230 				   strerror(errno), errno);
1231 	/* run_child() never returns */
1232 	if (child == 0)
1233 		run_child(config);
1234 
1235 	pass = run_parent(child, test, config);
1236 	if (!check_memory_values(config))
1237 		pass = false;
1238 
1239 	ksft_test_result(pass, "%s\n", name);
1240 }
1241 
1242 static void run_tests(struct test_definition defs[], int count,
1243 		      struct test_config *config)
1244 {
1245 	int i;
1246 
1247 	for (i = 0; i < count; i++)
1248 		run_test(&defs[i], config);
1249 }
1250 
1251 static struct test_definition base_test_defs[] = {
1252 	{
1253 		.name = "No writes",
1254 		.supported = sve_sme_same,
1255 	},
1256 	{
1257 		.name = "FPSIMD write",
1258 		.supported = sve_sme_same,
1259 		.set_expected_values = fpsimd_write_expected,
1260 		.modify_values = fpsimd_write,
1261 	},
1262 };
1263 
1264 static struct test_definition sve_test_defs[] = {
1265 	{
1266 		.name = "SVE write",
1267 		.supported = sve_write_supported,
1268 		.set_expected_values = sve_write_expected,
1269 		.modify_values = sve_write,
1270 	},
1271 };
1272 
1273 static struct test_definition za_test_defs[] = {
1274 	{
1275 		.name = "ZA write",
1276 		.supported = za_write_supported,
1277 		.set_expected_values = za_write_expected,
1278 		.modify_values = za_write,
1279 	},
1280 };
1281 
1282 static struct test_definition zt_test_defs[] = {
1283 	{
1284 		.name = "ZT write",
1285 		.supported = zt_write_supported,
1286 		.set_expected_values = zt_write_expected,
1287 		.modify_values = zt_write,
1288 	},
1289 };
1290 
1291 static int sve_vls[MAX_NUM_VLS], sme_vls[MAX_NUM_VLS];
1292 static int sve_vl_count, sme_vl_count;
1293 
1294 static void probe_vls(const char *name, int vls[], int *vl_count, int set_vl)
1295 {
1296 	unsigned int vq;
1297 	int vl;
1298 
1299 	*vl_count = 0;
1300 
1301 	for (vq = ARCH_VQ_MAX; vq > 0; vq /= 2) {
1302 		vl = prctl(set_vl, vq * 16);
1303 		if (vl == -1)
1304 			ksft_exit_fail_msg("SET_VL failed: %s (%d)\n",
1305 					   strerror(errno), errno);
1306 
1307 		vl &= PR_SVE_VL_LEN_MASK;
1308 
1309 		if (*vl_count && (vl == vls[*vl_count - 1]))
1310 			break;
1311 
1312 		vq = sve_vq_from_vl(vl);
1313 
1314 		vls[*vl_count] = vl;
1315 		*vl_count += 1;
1316 	}
1317 
1318 	if (*vl_count > 2) {
1319 		/* Just use the minimum and maximum */
1320 		vls[1] = vls[*vl_count - 1];
1321 		ksft_print_msg("%d %s VLs, using %d and %d\n",
1322 			       *vl_count, name, vls[0], vls[1]);
1323 		*vl_count = 2;
1324 	} else {
1325 		ksft_print_msg("%d %s VLs\n", *vl_count, name);
1326 	}
1327 }
1328 
1329 static struct {
1330 	int svcr_in, svcr_expected;
1331 } svcr_combinations[] = {
1332 	{ .svcr_in = 0, .svcr_expected = 0, },
1333 	{ .svcr_in = 0, .svcr_expected = SVCR_SM, },
1334 	{ .svcr_in = 0, .svcr_expected = SVCR_ZA, },
1335 	/* Can't enable both SM and ZA with a single ptrace write */
1336 
1337 	{ .svcr_in = SVCR_SM, .svcr_expected = 0, },
1338 	{ .svcr_in = SVCR_SM, .svcr_expected = SVCR_SM, },
1339 	{ .svcr_in = SVCR_SM, .svcr_expected = SVCR_ZA, },
1340 	{ .svcr_in = SVCR_SM, .svcr_expected = SVCR_SM | SVCR_ZA, },
1341 
1342 	{ .svcr_in = SVCR_ZA, .svcr_expected = 0, },
1343 	{ .svcr_in = SVCR_ZA, .svcr_expected = SVCR_SM, },
1344 	{ .svcr_in = SVCR_ZA, .svcr_expected = SVCR_ZA, },
1345 	{ .svcr_in = SVCR_ZA, .svcr_expected = SVCR_SM | SVCR_ZA, },
1346 
1347 	{ .svcr_in = SVCR_SM | SVCR_ZA, .svcr_expected = 0, },
1348 	{ .svcr_in = SVCR_SM | SVCR_ZA, .svcr_expected = SVCR_SM, },
1349 	{ .svcr_in = SVCR_SM | SVCR_ZA, .svcr_expected = SVCR_ZA, },
1350 	{ .svcr_in = SVCR_SM | SVCR_ZA, .svcr_expected = SVCR_SM | SVCR_ZA, },
1351 };
1352 
1353 static void run_sve_tests(void)
1354 {
1355 	struct test_config test_config;
1356 	int i, j;
1357 
1358 	if (!sve_supported())
1359 		return;
1360 
1361 	test_config.sme_vl_in = sme_vls[0];
1362 	test_config.sme_vl_expected = sme_vls[0];
1363 	test_config.svcr_in = 0;
1364 	test_config.svcr_expected = 0;
1365 
1366 	for (i = 0; i < sve_vl_count; i++) {
1367 		test_config.sve_vl_in = sve_vls[i];
1368 
1369 		for (j = 0; j < sve_vl_count; j++) {
1370 			test_config.sve_vl_expected = sve_vls[j];
1371 
1372 			run_tests(base_test_defs,
1373 				  ARRAY_SIZE(base_test_defs),
1374 				  &test_config);
1375 			if (sve_supported())
1376 				run_tests(sve_test_defs,
1377 					  ARRAY_SIZE(sve_test_defs),
1378 					  &test_config);
1379 		}
1380 	}
1381 
1382 }
1383 
1384 static void run_sme_tests(void)
1385 {
1386 	struct test_config test_config;
1387 	int i, j, k;
1388 
1389 	if (!sme_supported())
1390 		return;
1391 
1392 	test_config.sve_vl_in = sve_vls[0];
1393 	test_config.sve_vl_expected = sve_vls[0];
1394 
1395 	/*
1396 	 * Every SME VL/SVCR combination
1397 	 */
1398 	for (i = 0; i < sme_vl_count; i++) {
1399 		test_config.sme_vl_in = sme_vls[i];
1400 
1401 		for (j = 0; j < sme_vl_count; j++) {
1402 			test_config.sme_vl_expected = sme_vls[j];
1403 
1404 			for (k = 0; k < ARRAY_SIZE(svcr_combinations); k++) {
1405 				test_config.svcr_in = svcr_combinations[k].svcr_in;
1406 				test_config.svcr_expected = svcr_combinations[k].svcr_expected;
1407 
1408 				run_tests(base_test_defs,
1409 					  ARRAY_SIZE(base_test_defs),
1410 					  &test_config);
1411 				run_tests(sve_test_defs,
1412 					  ARRAY_SIZE(sve_test_defs),
1413 					  &test_config);
1414 				run_tests(za_test_defs,
1415 					  ARRAY_SIZE(za_test_defs),
1416 					  &test_config);
1417 
1418 				if (sme2_supported())
1419 					run_tests(zt_test_defs,
1420 						  ARRAY_SIZE(zt_test_defs),
1421 						  &test_config);
1422 			}
1423 		}
1424 	}
1425 }
1426 
1427 int main(void)
1428 {
1429 	struct test_config test_config;
1430 	struct sigaction sa;
1431 	int tests, ret, tmp;
1432 
1433 	srandom(getpid());
1434 
1435 	ksft_print_header();
1436 
1437 	if (sve_supported()) {
1438 		probe_vls("SVE", sve_vls, &sve_vl_count, PR_SVE_SET_VL);
1439 
1440 		tests = ARRAY_SIZE(base_test_defs) +
1441 			ARRAY_SIZE(sve_test_defs);
1442 		tests *= sve_vl_count * sve_vl_count;
1443 	} else {
1444 		/* Only run the FPSIMD tests */
1445 		sve_vl_count = 1;
1446 		tests = ARRAY_SIZE(base_test_defs);
1447 	}
1448 
1449 	if (sme_supported()) {
1450 		probe_vls("SME", sme_vls, &sme_vl_count, PR_SME_SET_VL);
1451 
1452 		tmp = ARRAY_SIZE(base_test_defs) + ARRAY_SIZE(sve_test_defs)
1453 			+ ARRAY_SIZE(za_test_defs);
1454 
1455 		if (sme2_supported())
1456 			tmp += ARRAY_SIZE(zt_test_defs);
1457 
1458 		tmp *= sme_vl_count * sme_vl_count;
1459 		tmp *= ARRAY_SIZE(svcr_combinations);
1460 		tests += tmp;
1461 	} else {
1462 		sme_vl_count = 1;
1463 	}
1464 
1465 	if (sme2_supported())
1466 		ksft_print_msg("SME2 supported\n");
1467 
1468 	if (fa64_supported())
1469 		ksft_print_msg("FA64 supported\n");
1470 
1471 	ksft_set_plan(tests);
1472 
1473 	/* Get signal handers ready before we start any children */
1474 	memset(&sa, 0, sizeof(sa));
1475 	sa.sa_sigaction = handle_alarm;
1476 	sa.sa_flags = SA_RESTART | SA_SIGINFO;
1477 	sigemptyset(&sa.sa_mask);
1478 	ret = sigaction(SIGALRM, &sa, NULL);
1479 	if (ret < 0)
1480 		ksft_print_msg("Failed to install SIGALRM handler: %s (%d)\n",
1481 			       strerror(errno), errno);
1482 
1483 	/*
1484 	 * Run the test set if there is no SVE or SME, with those we
1485 	 * have to pick a VL for each run.
1486 	 */
1487 	if (!sve_supported()) {
1488 		test_config.sve_vl_in = 0;
1489 		test_config.sve_vl_expected = 0;
1490 		test_config.sme_vl_in = 0;
1491 		test_config.sme_vl_expected = 0;
1492 		test_config.svcr_in = 0;
1493 		test_config.svcr_expected = 0;
1494 
1495 		run_tests(base_test_defs, ARRAY_SIZE(base_test_defs),
1496 			  &test_config);
1497 	}
1498 
1499 	run_sve_tests();
1500 	run_sme_tests();
1501 
1502 	ksft_finished();
1503 }
1504