xref: /linux/tools/testing/selftests/arm64/fp/fp-ptrace.c (revision 364eeb79a213fcf9164208b53764223ad522d6b3)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Copyright (C) 2023 ARM Limited.
4  * Original author: Mark Brown <broonie@kernel.org>
5  */
6 
7 #define _GNU_SOURCE
8 
9 #include <errno.h>
10 #include <stdbool.h>
11 #include <stddef.h>
12 #include <stdio.h>
13 #include <stdlib.h>
14 #include <string.h>
15 #include <unistd.h>
16 
17 #include <sys/auxv.h>
18 #include <sys/prctl.h>
19 #include <sys/ptrace.h>
20 #include <sys/types.h>
21 #include <sys/uio.h>
22 #include <sys/wait.h>
23 
24 #include <linux/kernel.h>
25 
26 #include <asm/sigcontext.h>
27 #include <asm/sve_context.h>
28 #include <asm/ptrace.h>
29 
30 #include "../../kselftest.h"
31 
32 #include "fp-ptrace.h"
33 
34 #include <linux/bits.h>
35 
36 #define FPMR_LSCALE2_MASK                               GENMASK(37, 32)
37 #define FPMR_NSCALE_MASK                                GENMASK(31, 24)
38 #define FPMR_LSCALE_MASK                                GENMASK(22, 16)
39 #define FPMR_OSC_MASK                                   GENMASK(15, 15)
40 #define FPMR_OSM_MASK                                   GENMASK(14, 14)
41 
42 /* <linux/elf.h> and <sys/auxv.h> don't like each other, so: */
43 #ifndef NT_ARM_SVE
44 #define NT_ARM_SVE 0x405
45 #endif
46 
47 #ifndef NT_ARM_SSVE
48 #define NT_ARM_SSVE 0x40b
49 #endif
50 
51 #ifndef NT_ARM_ZA
52 #define NT_ARM_ZA 0x40c
53 #endif
54 
55 #ifndef NT_ARM_ZT
56 #define NT_ARM_ZT 0x40d
57 #endif
58 
59 #ifndef NT_ARM_FPMR
60 #define NT_ARM_FPMR 0x40e
61 #endif
62 
63 #define ARCH_VQ_MAX 256
64 
65 /* VL 128..2048 in powers of 2 */
66 #define MAX_NUM_VLS 5
67 
68 /*
69  * FPMR bits we can set without doing feature checks to see if values
70  * are valid.
71  */
72 #define FPMR_SAFE_BITS (FPMR_LSCALE2_MASK | FPMR_NSCALE_MASK | \
73 			FPMR_LSCALE_MASK | FPMR_OSC_MASK | FPMR_OSM_MASK)
74 
75 #define NUM_FPR 32
76 __uint128_t v_in[NUM_FPR];
77 __uint128_t v_expected[NUM_FPR];
78 __uint128_t v_out[NUM_FPR];
79 
80 char z_in[__SVE_ZREGS_SIZE(ARCH_VQ_MAX)];
81 char z_expected[__SVE_ZREGS_SIZE(ARCH_VQ_MAX)];
82 char z_out[__SVE_ZREGS_SIZE(ARCH_VQ_MAX)];
83 
84 char p_in[__SVE_PREGS_SIZE(ARCH_VQ_MAX)];
85 char p_expected[__SVE_PREGS_SIZE(ARCH_VQ_MAX)];
86 char p_out[__SVE_PREGS_SIZE(ARCH_VQ_MAX)];
87 
88 char ffr_in[__SVE_PREG_SIZE(ARCH_VQ_MAX)];
89 char ffr_expected[__SVE_PREG_SIZE(ARCH_VQ_MAX)];
90 char ffr_out[__SVE_PREG_SIZE(ARCH_VQ_MAX)];
91 
92 char za_in[ZA_SIG_REGS_SIZE(ARCH_VQ_MAX)];
93 char za_expected[ZA_SIG_REGS_SIZE(ARCH_VQ_MAX)];
94 char za_out[ZA_SIG_REGS_SIZE(ARCH_VQ_MAX)];
95 
96 char zt_in[ZT_SIG_REG_BYTES];
97 char zt_expected[ZT_SIG_REG_BYTES];
98 char zt_out[ZT_SIG_REG_BYTES];
99 
100 uint64_t fpmr_in, fpmr_expected, fpmr_out;
101 
102 uint64_t sve_vl_out;
103 uint64_t sme_vl_out;
104 uint64_t svcr_in, svcr_expected, svcr_out;
105 
106 void load_and_save(int flags);
107 
108 static bool got_alarm;
109 
110 static void handle_alarm(int sig, siginfo_t *info, void *context)
111 {
112 	got_alarm = true;
113 }
114 
115 #ifdef CONFIG_CPU_BIG_ENDIAN
116 static __uint128_t arm64_cpu_to_le128(__uint128_t x)
117 {
118 	u64 a = swab64(x);
119 	u64 b = swab64(x >> 64);
120 
121 	return ((__uint128_t)a << 64) | b;
122 }
123 #else
124 static __uint128_t arm64_cpu_to_le128(__uint128_t x)
125 {
126 	return x;
127 }
128 #endif
129 
130 #define arm64_le128_to_cpu(x) arm64_cpu_to_le128(x)
131 
132 static bool sve_supported(void)
133 {
134 	return getauxval(AT_HWCAP) & HWCAP_SVE;
135 }
136 
137 static bool sme_supported(void)
138 {
139 	return getauxval(AT_HWCAP2) & HWCAP2_SME;
140 }
141 
142 static bool sme2_supported(void)
143 {
144 	return getauxval(AT_HWCAP2) & HWCAP2_SME2;
145 }
146 
147 static bool fa64_supported(void)
148 {
149 	return getauxval(AT_HWCAP2) & HWCAP2_SME_FA64;
150 }
151 
152 static bool fpmr_supported(void)
153 {
154 	return getauxval(AT_HWCAP2) & HWCAP2_FPMR;
155 }
156 
157 static bool compare_buffer(const char *name, void *out,
158 			   void *expected, size_t size)
159 {
160 	void *tmp;
161 
162 	if (memcmp(out, expected, size) == 0)
163 		return true;
164 
165 	ksft_print_msg("Mismatch in %s\n", name);
166 
167 	/* Did we just get zeros back? */
168 	tmp = malloc(size);
169 	if (!tmp) {
170 		ksft_print_msg("OOM allocating %lu bytes for %s\n",
171 			       size, name);
172 		ksft_exit_fail();
173 	}
174 	memset(tmp, 0, size);
175 
176 	if (memcmp(out, tmp, size) == 0)
177 		ksft_print_msg("%s is zero\n", name);
178 
179 	free(tmp);
180 
181 	return false;
182 }
183 
184 struct test_config {
185 	int sve_vl_in;
186 	int sve_vl_expected;
187 	int sme_vl_in;
188 	int sme_vl_expected;
189 	int svcr_in;
190 	int svcr_expected;
191 };
192 
193 struct test_definition {
194 	const char *name;
195 	bool sve_vl_change;
196 	bool (*supported)(struct test_config *config);
197 	void (*set_expected_values)(struct test_config *config);
198 	void (*modify_values)(pid_t child, struct test_config *test_config);
199 };
200 
201 static int vl_in(struct test_config *config)
202 {
203 	int vl;
204 
205 	if (config->svcr_in & SVCR_SM)
206 		vl = config->sme_vl_in;
207 	else
208 		vl = config->sve_vl_in;
209 
210 	return vl;
211 }
212 
213 static int vl_expected(struct test_config *config)
214 {
215 	int vl;
216 
217 	if (config->svcr_expected & SVCR_SM)
218 		vl = config->sme_vl_expected;
219 	else
220 		vl = config->sve_vl_expected;
221 
222 	return vl;
223 }
224 
225 static void run_child(struct test_config *config)
226 {
227 	int ret, flags;
228 
229 	/* Let the parent attach to us */
230 	ret = ptrace(PTRACE_TRACEME, 0, 0, 0);
231 	if (ret < 0)
232 		ksft_exit_fail_msg("PTRACE_TRACEME failed: %s (%d)\n",
233 				   strerror(errno), errno);
234 
235 	/* VL setup */
236 	if (sve_supported()) {
237 		ret = prctl(PR_SVE_SET_VL, config->sve_vl_in);
238 		if (ret != config->sve_vl_in) {
239 			ksft_print_msg("Failed to set SVE VL %d: %d\n",
240 				       config->sve_vl_in, ret);
241 		}
242 	}
243 
244 	if (sme_supported()) {
245 		ret = prctl(PR_SME_SET_VL, config->sme_vl_in);
246 		if (ret != config->sme_vl_in) {
247 			ksft_print_msg("Failed to set SME VL %d: %d\n",
248 				       config->sme_vl_in, ret);
249 		}
250 	}
251 
252 	/* Load values and wait for the parent */
253 	flags = 0;
254 	if (sve_supported())
255 		flags |= HAVE_SVE;
256 	if (sme_supported())
257 		flags |= HAVE_SME;
258 	if (sme2_supported())
259 		flags |= HAVE_SME2;
260 	if (fa64_supported())
261 		flags |= HAVE_FA64;
262 	if (fpmr_supported())
263 		flags |= HAVE_FPMR;
264 
265 	load_and_save(flags);
266 
267 	exit(0);
268 }
269 
270 static void read_one_child_regs(pid_t child, char *name,
271 				struct iovec *iov_parent,
272 				struct iovec *iov_child)
273 {
274 	int len = iov_parent->iov_len;
275 	int ret;
276 
277 	ret = process_vm_readv(child, iov_parent, 1, iov_child, 1, 0);
278 	if (ret == -1)
279 		ksft_print_msg("%s read failed: %s (%d)\n",
280 			       name, strerror(errno), errno);
281 	else if (ret != len)
282 		ksft_print_msg("Short read of %s: %d\n", name, ret);
283 }
284 
285 static void read_child_regs(pid_t child)
286 {
287 	struct iovec iov_parent, iov_child;
288 
289 	/*
290 	 * Since the child fork()ed from us the buffer addresses are
291 	 * the same in parent and child.
292 	 */
293 	iov_parent.iov_base = &v_out;
294 	iov_parent.iov_len = sizeof(v_out);
295 	iov_child.iov_base = &v_out;
296 	iov_child.iov_len = sizeof(v_out);
297 	read_one_child_regs(child, "FPSIMD", &iov_parent, &iov_child);
298 
299 	if (sve_supported() || sme_supported()) {
300 		iov_parent.iov_base = &sve_vl_out;
301 		iov_parent.iov_len = sizeof(sve_vl_out);
302 		iov_child.iov_base = &sve_vl_out;
303 		iov_child.iov_len = sizeof(sve_vl_out);
304 		read_one_child_regs(child, "SVE VL", &iov_parent, &iov_child);
305 
306 		iov_parent.iov_base = &z_out;
307 		iov_parent.iov_len = sizeof(z_out);
308 		iov_child.iov_base = &z_out;
309 		iov_child.iov_len = sizeof(z_out);
310 		read_one_child_regs(child, "Z", &iov_parent, &iov_child);
311 
312 		iov_parent.iov_base = &p_out;
313 		iov_parent.iov_len = sizeof(p_out);
314 		iov_child.iov_base = &p_out;
315 		iov_child.iov_len = sizeof(p_out);
316 		read_one_child_regs(child, "P", &iov_parent, &iov_child);
317 
318 		iov_parent.iov_base = &ffr_out;
319 		iov_parent.iov_len = sizeof(ffr_out);
320 		iov_child.iov_base = &ffr_out;
321 		iov_child.iov_len = sizeof(ffr_out);
322 		read_one_child_regs(child, "FFR", &iov_parent, &iov_child);
323 	}
324 
325 	if (sme_supported()) {
326 		iov_parent.iov_base = &sme_vl_out;
327 		iov_parent.iov_len = sizeof(sme_vl_out);
328 		iov_child.iov_base = &sme_vl_out;
329 		iov_child.iov_len = sizeof(sme_vl_out);
330 		read_one_child_regs(child, "SME VL", &iov_parent, &iov_child);
331 
332 		iov_parent.iov_base = &svcr_out;
333 		iov_parent.iov_len = sizeof(svcr_out);
334 		iov_child.iov_base = &svcr_out;
335 		iov_child.iov_len = sizeof(svcr_out);
336 		read_one_child_regs(child, "SVCR", &iov_parent, &iov_child);
337 
338 		iov_parent.iov_base = &za_out;
339 		iov_parent.iov_len = sizeof(za_out);
340 		iov_child.iov_base = &za_out;
341 		iov_child.iov_len = sizeof(za_out);
342 		read_one_child_regs(child, "ZA", &iov_parent, &iov_child);
343 	}
344 
345 	if (sme2_supported()) {
346 		iov_parent.iov_base = &zt_out;
347 		iov_parent.iov_len = sizeof(zt_out);
348 		iov_child.iov_base = &zt_out;
349 		iov_child.iov_len = sizeof(zt_out);
350 		read_one_child_regs(child, "ZT", &iov_parent, &iov_child);
351 	}
352 
353 	if (fpmr_supported()) {
354 		iov_parent.iov_base = &fpmr_out;
355 		iov_parent.iov_len = sizeof(fpmr_out);
356 		iov_child.iov_base = &fpmr_out;
357 		iov_child.iov_len = sizeof(fpmr_out);
358 		read_one_child_regs(child, "FPMR", &iov_parent, &iov_child);
359 	}
360 }
361 
362 static bool continue_breakpoint(pid_t child,
363 				enum __ptrace_request restart_type)
364 {
365 	struct user_pt_regs pt_regs;
366 	struct iovec iov;
367 	int ret;
368 
369 	/* Get PC */
370 	iov.iov_base = &pt_regs;
371 	iov.iov_len = sizeof(pt_regs);
372 	ret = ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov);
373 	if (ret < 0) {
374 		ksft_print_msg("Failed to get PC: %s (%d)\n",
375 			       strerror(errno), errno);
376 		return false;
377 	}
378 
379 	/* Skip over the BRK */
380 	pt_regs.pc += 4;
381 	ret = ptrace(PTRACE_SETREGSET, child, NT_PRSTATUS, &iov);
382 	if (ret < 0) {
383 		ksft_print_msg("Failed to skip BRK: %s (%d)\n",
384 			       strerror(errno), errno);
385 		return false;
386 	}
387 
388 	/* Restart */
389 	ret = ptrace(restart_type, child, 0, 0);
390 	if (ret < 0) {
391 		ksft_print_msg("Failed to restart child: %s (%d)\n",
392 			       strerror(errno), errno);
393 		return false;
394 	}
395 
396 	return true;
397 }
398 
399 static bool check_ptrace_values_sve(pid_t child, struct test_config *config)
400 {
401 	struct user_sve_header *sve;
402 	struct user_fpsimd_state *fpsimd;
403 	struct iovec iov;
404 	int ret, vq;
405 	bool pass = true;
406 
407 	if (!sve_supported())
408 		return true;
409 
410 	vq = __sve_vq_from_vl(config->sve_vl_in);
411 
412 	iov.iov_len = SVE_PT_SVE_OFFSET + SVE_PT_SVE_SIZE(vq, SVE_PT_REGS_SVE);
413 	iov.iov_base = malloc(iov.iov_len);
414 	if (!iov.iov_base) {
415 		ksft_print_msg("OOM allocating %lu byte SVE buffer\n",
416 			       iov.iov_len);
417 		return false;
418 	}
419 
420 	ret = ptrace(PTRACE_GETREGSET, child, NT_ARM_SVE, &iov);
421 	if (ret != 0) {
422 		ksft_print_msg("Failed to read initial SVE: %s (%d)\n",
423 			       strerror(errno), errno);
424 		pass = false;
425 		goto out;
426 	}
427 
428 	sve = iov.iov_base;
429 
430 	if (sve->vl != config->sve_vl_in) {
431 		ksft_print_msg("Mismatch in initial SVE VL: %d != %d\n",
432 			       sve->vl, config->sve_vl_in);
433 		pass = false;
434 	}
435 
436 	/* If we are in streaming mode we should just read FPSIMD */
437 	if ((config->svcr_in & SVCR_SM) && (sve->flags & SVE_PT_REGS_SVE)) {
438 		ksft_print_msg("NT_ARM_SVE reports SVE with PSTATE.SM\n");
439 		pass = false;
440 	}
441 
442 	if (sve->size != SVE_PT_SIZE(vq, sve->flags)) {
443 		ksft_print_msg("Mismatch in SVE header size: %d != %lu\n",
444 			       sve->size, SVE_PT_SIZE(vq, sve->flags));
445 		pass = false;
446 	}
447 
448 	/* The registers might be in completely different formats! */
449 	if (sve->flags & SVE_PT_REGS_SVE) {
450 		if (!compare_buffer("initial SVE Z",
451 				    iov.iov_base + SVE_PT_SVE_ZREG_OFFSET(vq, 0),
452 				    z_in, SVE_PT_SVE_ZREGS_SIZE(vq)))
453 			pass = false;
454 
455 		if (!compare_buffer("initial SVE P",
456 				    iov.iov_base + SVE_PT_SVE_PREG_OFFSET(vq, 0),
457 				    p_in, SVE_PT_SVE_PREGS_SIZE(vq)))
458 			pass = false;
459 
460 		if (!compare_buffer("initial SVE FFR",
461 				    iov.iov_base + SVE_PT_SVE_FFR_OFFSET(vq),
462 				    ffr_in, SVE_PT_SVE_PREG_SIZE(vq)))
463 			pass = false;
464 	} else {
465 		fpsimd = iov.iov_base + SVE_PT_FPSIMD_OFFSET;
466 		if (!compare_buffer("initial V via SVE", &fpsimd->vregs[0],
467 				    v_in, sizeof(v_in)))
468 			pass = false;
469 	}
470 
471 out:
472 	free(iov.iov_base);
473 	return pass;
474 }
475 
476 static bool check_ptrace_values_ssve(pid_t child, struct test_config *config)
477 {
478 	struct user_sve_header *sve;
479 	struct user_fpsimd_state *fpsimd;
480 	struct iovec iov;
481 	int ret, vq;
482 	bool pass = true;
483 
484 	if (!sme_supported())
485 		return true;
486 
487 	vq = __sve_vq_from_vl(config->sme_vl_in);
488 
489 	iov.iov_len = SVE_PT_SVE_OFFSET + SVE_PT_SVE_SIZE(vq, SVE_PT_REGS_SVE);
490 	iov.iov_base = malloc(iov.iov_len);
491 	if (!iov.iov_base) {
492 		ksft_print_msg("OOM allocating %lu byte SSVE buffer\n",
493 			       iov.iov_len);
494 		return false;
495 	}
496 
497 	ret = ptrace(PTRACE_GETREGSET, child, NT_ARM_SSVE, &iov);
498 	if (ret != 0) {
499 		ksft_print_msg("Failed to read initial SSVE: %s (%d)\n",
500 			       strerror(errno), errno);
501 		pass = false;
502 		goto out;
503 	}
504 
505 	sve = iov.iov_base;
506 
507 	if (sve->vl != config->sme_vl_in) {
508 		ksft_print_msg("Mismatch in initial SSVE VL: %d != %d\n",
509 			       sve->vl, config->sme_vl_in);
510 		pass = false;
511 	}
512 
513 	if ((config->svcr_in & SVCR_SM) && !(sve->flags & SVE_PT_REGS_SVE)) {
514 		ksft_print_msg("NT_ARM_SSVE reports FPSIMD with PSTATE.SM\n");
515 		pass = false;
516 	}
517 
518 	if (sve->size != SVE_PT_SIZE(vq, sve->flags)) {
519 		ksft_print_msg("Mismatch in SSVE header size: %d != %lu\n",
520 			       sve->size, SVE_PT_SIZE(vq, sve->flags));
521 		pass = false;
522 	}
523 
524 	/* The registers might be in completely different formats! */
525 	if (sve->flags & SVE_PT_REGS_SVE) {
526 		if (!compare_buffer("initial SSVE Z",
527 				    iov.iov_base + SVE_PT_SVE_ZREG_OFFSET(vq, 0),
528 				    z_in, SVE_PT_SVE_ZREGS_SIZE(vq)))
529 			pass = false;
530 
531 		if (!compare_buffer("initial SSVE P",
532 				    iov.iov_base + SVE_PT_SVE_PREG_OFFSET(vq, 0),
533 				    p_in, SVE_PT_SVE_PREGS_SIZE(vq)))
534 			pass = false;
535 
536 		if (!compare_buffer("initial SSVE FFR",
537 				    iov.iov_base + SVE_PT_SVE_FFR_OFFSET(vq),
538 				    ffr_in, SVE_PT_SVE_PREG_SIZE(vq)))
539 			pass = false;
540 	} else {
541 		fpsimd = iov.iov_base + SVE_PT_FPSIMD_OFFSET;
542 		if (!compare_buffer("initial V via SSVE",
543 				    &fpsimd->vregs[0], v_in, sizeof(v_in)))
544 			pass = false;
545 	}
546 
547 out:
548 	free(iov.iov_base);
549 	return pass;
550 }
551 
552 static bool check_ptrace_values_za(pid_t child, struct test_config *config)
553 {
554 	struct user_za_header *za;
555 	struct iovec iov;
556 	int ret, vq;
557 	bool pass = true;
558 
559 	if (!sme_supported())
560 		return true;
561 
562 	vq = __sve_vq_from_vl(config->sme_vl_in);
563 
564 	iov.iov_len = ZA_SIG_CONTEXT_SIZE(vq);
565 	iov.iov_base = malloc(iov.iov_len);
566 	if (!iov.iov_base) {
567 		ksft_print_msg("OOM allocating %lu byte ZA buffer\n",
568 			       iov.iov_len);
569 		return false;
570 	}
571 
572 	ret = ptrace(PTRACE_GETREGSET, child, NT_ARM_ZA, &iov);
573 	if (ret != 0) {
574 		ksft_print_msg("Failed to read initial ZA: %s (%d)\n",
575 			       strerror(errno), errno);
576 		pass = false;
577 		goto out;
578 	}
579 
580 	za = iov.iov_base;
581 
582 	if (za->vl != config->sme_vl_in) {
583 		ksft_print_msg("Mismatch in initial SME VL: %d != %d\n",
584 			       za->vl, config->sme_vl_in);
585 		pass = false;
586 	}
587 
588 	/* If PSTATE.ZA is not set we should just read the header */
589 	if (config->svcr_in & SVCR_ZA) {
590 		if (za->size != ZA_PT_SIZE(vq)) {
591 			ksft_print_msg("Unexpected ZA ptrace read size: %d != %lu\n",
592 				       za->size, ZA_PT_SIZE(vq));
593 			pass = false;
594 		}
595 
596 		if (!compare_buffer("initial ZA",
597 				    iov.iov_base + ZA_PT_ZA_OFFSET,
598 				    za_in, ZA_PT_ZA_SIZE(vq)))
599 			pass = false;
600 	} else {
601 		if (za->size != sizeof(*za)) {
602 			ksft_print_msg("Unexpected ZA ptrace read size: %d != %lu\n",
603 				       za->size, sizeof(*za));
604 			pass = false;
605 		}
606 	}
607 
608 out:
609 	free(iov.iov_base);
610 	return pass;
611 }
612 
613 static bool check_ptrace_values_zt(pid_t child, struct test_config *config)
614 {
615 	uint8_t buf[512];
616 	struct iovec iov;
617 	int ret;
618 
619 	if (!sme2_supported())
620 		return true;
621 
622 	iov.iov_base = &buf;
623 	iov.iov_len = ZT_SIG_REG_BYTES;
624 	ret = ptrace(PTRACE_GETREGSET, child, NT_ARM_ZT, &iov);
625 	if (ret != 0) {
626 		ksft_print_msg("Failed to read initial ZT: %s (%d)\n",
627 			       strerror(errno), errno);
628 		return false;
629 	}
630 
631 	return compare_buffer("initial ZT", buf, zt_in, ZT_SIG_REG_BYTES);
632 }
633 
634 static bool check_ptrace_values_fpmr(pid_t child, struct test_config *config)
635 {
636 	uint64_t val;
637 	struct iovec iov;
638 	int ret;
639 
640 	if (!fpmr_supported())
641 		return true;
642 
643 	iov.iov_base = &val;
644 	iov.iov_len = sizeof(val);
645 	ret = ptrace(PTRACE_GETREGSET, child, NT_ARM_FPMR, &iov);
646 	if (ret != 0) {
647 		ksft_print_msg("Failed to read initial FPMR: %s (%d)\n",
648 			       strerror(errno), errno);
649 		return false;
650 	}
651 
652 	return compare_buffer("initial FPMR", &val, &fpmr_in, sizeof(val));
653 }
654 
655 static bool check_ptrace_values(pid_t child, struct test_config *config)
656 {
657 	bool pass = true;
658 	struct user_fpsimd_state fpsimd;
659 	struct iovec iov;
660 	int ret;
661 
662 	iov.iov_base = &fpsimd;
663 	iov.iov_len = sizeof(fpsimd);
664 	ret = ptrace(PTRACE_GETREGSET, child, NT_PRFPREG, &iov);
665 	if (ret == 0) {
666 		if (!compare_buffer("initial V", &fpsimd.vregs, v_in,
667 				    sizeof(v_in))) {
668 			pass = false;
669 		}
670 	} else {
671 		ksft_print_msg("Failed to read initial V: %s (%d)\n",
672 			       strerror(errno), errno);
673 		pass = false;
674 	}
675 
676 	if (!check_ptrace_values_sve(child, config))
677 		pass = false;
678 
679 	if (!check_ptrace_values_ssve(child, config))
680 		pass = false;
681 
682 	if (!check_ptrace_values_za(child, config))
683 		pass = false;
684 
685 	if (!check_ptrace_values_zt(child, config))
686 		pass = false;
687 
688 	if (!check_ptrace_values_fpmr(child, config))
689 		pass = false;
690 
691 	return pass;
692 }
693 
694 static bool run_parent(pid_t child, struct test_definition *test,
695 		       struct test_config *config)
696 {
697 	int wait_status, ret;
698 	pid_t pid;
699 	bool pass;
700 
701 	/* Initial attach */
702 	while (1) {
703 		pid = waitpid(child, &wait_status, 0);
704 		if (pid < 0) {
705 			if (errno == EINTR)
706 				continue;
707 			ksft_exit_fail_msg("waitpid() failed: %s (%d)\n",
708 					   strerror(errno), errno);
709 		}
710 
711 		if (pid == child)
712 			break;
713 	}
714 
715 	if (WIFEXITED(wait_status)) {
716 		ksft_print_msg("Child exited loading values with status %d\n",
717 			       WEXITSTATUS(wait_status));
718 		pass = false;
719 		goto out;
720 	}
721 
722 	if (WIFSIGNALED(wait_status)) {
723 		ksft_print_msg("Child died from signal %d loading values\n",
724 			       WTERMSIG(wait_status));
725 		pass = false;
726 		goto out;
727 	}
728 
729 	/* Read initial values via ptrace */
730 	pass = check_ptrace_values(child, config);
731 
732 	/* Do whatever writes we want to do */
733 	if (test->modify_values)
734 		test->modify_values(child, config);
735 
736 	if (!continue_breakpoint(child, PTRACE_CONT))
737 		goto cleanup;
738 
739 	while (1) {
740 		pid = waitpid(child, &wait_status, 0);
741 		if (pid < 0) {
742 			if (errno == EINTR)
743 				continue;
744 			ksft_exit_fail_msg("waitpid() failed: %s (%d)\n",
745 					   strerror(errno), errno);
746 		}
747 
748 		if (pid == child)
749 			break;
750 	}
751 
752 	if (WIFEXITED(wait_status)) {
753 		ksft_print_msg("Child exited saving values with status %d\n",
754 			       WEXITSTATUS(wait_status));
755 		pass = false;
756 		goto out;
757 	}
758 
759 	if (WIFSIGNALED(wait_status)) {
760 		ksft_print_msg("Child died from signal %d saving values\n",
761 			       WTERMSIG(wait_status));
762 		pass = false;
763 		goto out;
764 	}
765 
766 	/* See what happened as a result */
767 	read_child_regs(child);
768 
769 	if (!continue_breakpoint(child, PTRACE_DETACH))
770 		goto cleanup;
771 
772 	/* The child should exit cleanly */
773 	got_alarm = false;
774 	alarm(1);
775 	while (1) {
776 		if (got_alarm) {
777 			ksft_print_msg("Wait for child timed out\n");
778 			goto cleanup;
779 		}
780 
781 		pid = waitpid(child, &wait_status, 0);
782 		if (pid < 0) {
783 			if (errno == EINTR)
784 				continue;
785 			ksft_exit_fail_msg("waitpid() failed: %s (%d)\n",
786 					   strerror(errno), errno);
787 		}
788 
789 		if (pid == child)
790 			break;
791 	}
792 	alarm(0);
793 
794 	if (got_alarm) {
795 		ksft_print_msg("Timed out waiting for child\n");
796 		pass = false;
797 		goto cleanup;
798 	}
799 
800 	if (pid == child && WIFSIGNALED(wait_status)) {
801 		ksft_print_msg("Child died from signal %d cleaning up\n",
802 			       WTERMSIG(wait_status));
803 		pass = false;
804 		goto out;
805 	}
806 
807 	if (pid == child && WIFEXITED(wait_status)) {
808 		if (WEXITSTATUS(wait_status) != 0) {
809 			ksft_print_msg("Child exited with error %d\n",
810 				       WEXITSTATUS(wait_status));
811 			pass = false;
812 		}
813 	} else {
814 		ksft_print_msg("Child did not exit cleanly\n");
815 		pass = false;
816 		goto cleanup;
817 	}
818 
819 	goto out;
820 
821 cleanup:
822 	ret = kill(child, SIGKILL);
823 	if (ret != 0) {
824 		ksft_print_msg("kill() failed: %s (%d)\n",
825 			       strerror(errno), errno);
826 		return false;
827 	}
828 
829 	while (1) {
830 		pid = waitpid(child, &wait_status, 0);
831 		if (pid < 0) {
832 			if (errno == EINTR)
833 				continue;
834 			ksft_exit_fail_msg("waitpid() failed: %s (%d)\n",
835 					   strerror(errno), errno);
836 		}
837 
838 		if (pid == child)
839 			break;
840 	}
841 
842 out:
843 	return pass;
844 }
845 
846 static void fill_random(void *buf, size_t size)
847 {
848 	int i;
849 	uint32_t *lbuf = buf;
850 
851 	/* random() returns a 32 bit number regardless of the size of long */
852 	for (i = 0; i < size / sizeof(uint32_t); i++)
853 		lbuf[i] = random();
854 }
855 
856 static void fill_random_ffr(void *buf, size_t vq)
857 {
858 	uint8_t *lbuf = buf;
859 	int bits, i;
860 
861 	/*
862 	 * Only values with a continuous set of 0..n bits set are
863 	 * valid for FFR, set all bits then clear a random number of
864 	 * high bits.
865 	 */
866 	memset(buf, 0, __SVE_FFR_SIZE(vq));
867 
868 	bits = random() % (__SVE_FFR_SIZE(vq) * 8);
869 	for (i = 0; i < bits / 8; i++)
870 		lbuf[i] = 0xff;
871 	if (bits / 8 != __SVE_FFR_SIZE(vq))
872 		lbuf[i] = (1 << (bits % 8)) - 1;
873 }
874 
875 static void fpsimd_to_sve(__uint128_t *v, char *z, int vl)
876 {
877 	int vq = __sve_vq_from_vl(vl);
878 	int i;
879 	__uint128_t *p;
880 
881 	if (!vl)
882 		return;
883 
884 	for (i = 0; i < __SVE_NUM_ZREGS; i++) {
885 		p = (__uint128_t *)&z[__SVE_ZREG_OFFSET(vq, i)];
886 		*p = arm64_cpu_to_le128(v[i]);
887 	}
888 }
889 
890 static void set_initial_values(struct test_config *config)
891 {
892 	int vq = __sve_vq_from_vl(vl_in(config));
893 	int sme_vq = __sve_vq_from_vl(config->sme_vl_in);
894 	bool sm_change;
895 
896 	svcr_in = config->svcr_in;
897 	svcr_expected = config->svcr_expected;
898 	svcr_out = 0;
899 
900 	if (sme_supported() &&
901 	    (svcr_in & SVCR_SM) != (svcr_expected & SVCR_SM))
902 		sm_change = true;
903 	else
904 		sm_change = false;
905 
906 	fill_random(&v_in, sizeof(v_in));
907 	memcpy(v_expected, v_in, sizeof(v_in));
908 	memset(v_out, 0, sizeof(v_out));
909 
910 	/* Changes will be handled in the test case */
911 	if (sve_supported() || (config->svcr_in & SVCR_SM)) {
912 		/* The low 128 bits of Z are shared with the V registers */
913 		fill_random(&z_in, __SVE_ZREGS_SIZE(vq));
914 		fpsimd_to_sve(v_in, z_in, vl_in(config));
915 		memcpy(z_expected, z_in, __SVE_ZREGS_SIZE(vq));
916 		memset(z_out, 0, sizeof(z_out));
917 
918 		fill_random(&p_in, __SVE_PREGS_SIZE(vq));
919 		memcpy(p_expected, p_in, __SVE_PREGS_SIZE(vq));
920 		memset(p_out, 0, sizeof(p_out));
921 
922 		if ((config->svcr_in & SVCR_SM) && !fa64_supported())
923 			memset(ffr_in, 0, __SVE_PREG_SIZE(vq));
924 		else
925 			fill_random_ffr(&ffr_in, vq);
926 		memcpy(ffr_expected, ffr_in, __SVE_PREG_SIZE(vq));
927 		memset(ffr_out, 0, __SVE_PREG_SIZE(vq));
928 	}
929 
930 	if (config->svcr_in & SVCR_ZA)
931 		fill_random(za_in, ZA_SIG_REGS_SIZE(sme_vq));
932 	else
933 		memset(za_in, 0, ZA_SIG_REGS_SIZE(sme_vq));
934 	if (config->svcr_expected & SVCR_ZA)
935 		memcpy(za_expected, za_in, ZA_SIG_REGS_SIZE(sme_vq));
936 	else
937 		memset(za_expected, 0, ZA_SIG_REGS_SIZE(sme_vq));
938 	if (sme_supported())
939 		memset(za_out, 0, sizeof(za_out));
940 
941 	if (sme2_supported()) {
942 		if (config->svcr_in & SVCR_ZA)
943 			fill_random(zt_in, ZT_SIG_REG_BYTES);
944 		else
945 			memset(zt_in, 0, ZT_SIG_REG_BYTES);
946 		if (config->svcr_expected & SVCR_ZA)
947 			memcpy(zt_expected, zt_in, ZT_SIG_REG_BYTES);
948 		else
949 			memset(zt_expected, 0, ZT_SIG_REG_BYTES);
950 		memset(zt_out, 0, sizeof(zt_out));
951 	}
952 
953 	if (fpmr_supported()) {
954 		fill_random(&fpmr_in, sizeof(fpmr_in));
955 		fpmr_in &= FPMR_SAFE_BITS;
956 
957 		/* Entering or exiting streaming mode clears FPMR */
958 		if (sm_change)
959 			fpmr_expected = 0;
960 		else
961 			fpmr_expected = fpmr_in;
962 	} else {
963 		fpmr_in = 0;
964 		fpmr_expected = 0;
965 		fpmr_out = 0;
966 	}
967 }
968 
969 static bool check_memory_values(struct test_config *config)
970 {
971 	bool pass = true;
972 	int vq, sme_vq;
973 
974 	if (!compare_buffer("saved V", v_out, v_expected, sizeof(v_out)))
975 		pass = false;
976 
977 	vq = __sve_vq_from_vl(vl_expected(config));
978 	sme_vq = __sve_vq_from_vl(config->sme_vl_expected);
979 
980 	if (svcr_out != svcr_expected) {
981 		ksft_print_msg("Mismatch in saved SVCR %lx != %lx\n",
982 			       svcr_out, svcr_expected);
983 		pass = false;
984 	}
985 
986 	if (sve_vl_out != config->sve_vl_expected) {
987 		ksft_print_msg("Mismatch in SVE VL: %ld != %d\n",
988 			       sve_vl_out, config->sve_vl_expected);
989 		pass = false;
990 	}
991 
992 	if (sme_vl_out != config->sme_vl_expected) {
993 		ksft_print_msg("Mismatch in SME VL: %ld != %d\n",
994 			       sme_vl_out, config->sme_vl_expected);
995 		pass = false;
996 	}
997 
998 	if (!compare_buffer("saved Z", z_out, z_expected,
999 			    __SVE_ZREGS_SIZE(vq)))
1000 		pass = false;
1001 
1002 	if (!compare_buffer("saved P", p_out, p_expected,
1003 			    __SVE_PREGS_SIZE(vq)))
1004 		pass = false;
1005 
1006 	if (!compare_buffer("saved FFR", ffr_out, ffr_expected,
1007 			    __SVE_PREG_SIZE(vq)))
1008 		pass = false;
1009 
1010 	if (!compare_buffer("saved ZA", za_out, za_expected,
1011 			    ZA_PT_ZA_SIZE(sme_vq)))
1012 		pass = false;
1013 
1014 	if (!compare_buffer("saved ZT", zt_out, zt_expected, ZT_SIG_REG_BYTES))
1015 		pass = false;
1016 
1017 	if (fpmr_out != fpmr_expected) {
1018 		ksft_print_msg("Mismatch in saved FPMR: %lx != %lx\n",
1019 			       fpmr_out, fpmr_expected);
1020 		pass = false;
1021 	}
1022 
1023 	return pass;
1024 }
1025 
1026 static bool sve_sme_same(struct test_config *config)
1027 {
1028 	if (config->sve_vl_in != config->sve_vl_expected)
1029 		return false;
1030 
1031 	if (config->sme_vl_in != config->sme_vl_expected)
1032 		return false;
1033 
1034 	if (config->svcr_in != config->svcr_expected)
1035 		return false;
1036 
1037 	return true;
1038 }
1039 
1040 static bool sve_write_supported(struct test_config *config)
1041 {
1042 	if (!sve_supported() && !sme_supported())
1043 		return false;
1044 
1045 	if ((config->svcr_in & SVCR_ZA) != (config->svcr_expected & SVCR_ZA))
1046 		return false;
1047 
1048 	if (config->svcr_expected & SVCR_SM) {
1049 		if (config->sve_vl_in != config->sve_vl_expected) {
1050 			return false;
1051 		}
1052 
1053 		/* Changing the SME VL disables ZA */
1054 		if ((config->svcr_expected & SVCR_ZA) &&
1055 		    (config->sme_vl_in != config->sme_vl_expected)) {
1056 			return false;
1057 		}
1058 	} else {
1059 		if (config->sme_vl_in != config->sme_vl_expected) {
1060 			return false;
1061 		}
1062 	}
1063 
1064 	return true;
1065 }
1066 
1067 static void fpsimd_write_expected(struct test_config *config)
1068 {
1069 	int vl;
1070 
1071 	fill_random(&v_expected, sizeof(v_expected));
1072 
1073 	/* The SVE registers are flushed by a FPSIMD write */
1074 	vl = vl_expected(config);
1075 
1076 	memset(z_expected, 0, __SVE_ZREGS_SIZE(__sve_vq_from_vl(vl)));
1077 	memset(p_expected, 0, __SVE_PREGS_SIZE(__sve_vq_from_vl(vl)));
1078 	memset(ffr_expected, 0, __SVE_PREG_SIZE(__sve_vq_from_vl(vl)));
1079 
1080 	fpsimd_to_sve(v_expected, z_expected, vl);
1081 }
1082 
1083 static void fpsimd_write(pid_t child, struct test_config *test_config)
1084 {
1085 	struct user_fpsimd_state fpsimd;
1086 	struct iovec iov;
1087 	int ret;
1088 
1089 	memset(&fpsimd, 0, sizeof(fpsimd));
1090 	memcpy(&fpsimd.vregs, v_expected, sizeof(v_expected));
1091 
1092 	iov.iov_base = &fpsimd;
1093 	iov.iov_len = sizeof(fpsimd);
1094 	ret = ptrace(PTRACE_SETREGSET, child, NT_PRFPREG, &iov);
1095 	if (ret == -1)
1096 		ksft_print_msg("FPSIMD set failed: (%s) %d\n",
1097 			       strerror(errno), errno);
1098 }
1099 
1100 static bool fpmr_write_supported(struct test_config *config)
1101 {
1102 	if (!fpmr_supported())
1103 		return false;
1104 
1105 	if (!sve_sme_same(config))
1106 		return false;
1107 
1108 	return true;
1109 }
1110 
1111 static void fpmr_write_expected(struct test_config *config)
1112 {
1113 	fill_random(&fpmr_expected, sizeof(fpmr_expected));
1114 	fpmr_expected &= FPMR_SAFE_BITS;
1115 }
1116 
1117 static void fpmr_write(pid_t child, struct test_config *config)
1118 {
1119 	struct iovec iov;
1120 	int ret;
1121 
1122 	iov.iov_len = sizeof(fpmr_expected);
1123 	iov.iov_base = &fpmr_expected;
1124 	ret = ptrace(PTRACE_SETREGSET, child, NT_ARM_FPMR, &iov);
1125 	if (ret != 0)
1126 		ksft_print_msg("Failed to write FPMR: %s (%d)\n",
1127 			       strerror(errno), errno);
1128 }
1129 
1130 static void sve_write_expected(struct test_config *config)
1131 {
1132 	int vl = vl_expected(config);
1133 	int sme_vq = __sve_vq_from_vl(config->sme_vl_expected);
1134 
1135 	fill_random(z_expected, __SVE_ZREGS_SIZE(__sve_vq_from_vl(vl)));
1136 	fill_random(p_expected, __SVE_PREGS_SIZE(__sve_vq_from_vl(vl)));
1137 
1138 	if ((svcr_expected & SVCR_SM) && !fa64_supported())
1139 		memset(ffr_expected, 0, __SVE_PREG_SIZE(sme_vq));
1140 	else
1141 		fill_random_ffr(ffr_expected, __sve_vq_from_vl(vl));
1142 
1143 	/* Share the low bits of Z with V */
1144 	fill_random(&v_expected, sizeof(v_expected));
1145 	fpsimd_to_sve(v_expected, z_expected, vl);
1146 
1147 	if (config->sme_vl_in != config->sme_vl_expected) {
1148 		memset(za_expected, 0, ZA_PT_ZA_SIZE(sme_vq));
1149 		memset(zt_expected, 0, sizeof(zt_expected));
1150 	}
1151 }
1152 
1153 static void sve_write(pid_t child, struct test_config *config)
1154 {
1155 	struct user_sve_header *sve;
1156 	struct iovec iov;
1157 	int ret, vl, vq, regset;
1158 
1159 	vl = vl_expected(config);
1160 	vq = __sve_vq_from_vl(vl);
1161 
1162 	iov.iov_len = SVE_PT_SVE_OFFSET + SVE_PT_SVE_SIZE(vq, SVE_PT_REGS_SVE);
1163 	iov.iov_base = malloc(iov.iov_len);
1164 	if (!iov.iov_base) {
1165 		ksft_print_msg("Failed allocating %lu byte SVE write buffer\n",
1166 			       iov.iov_len);
1167 		return;
1168 	}
1169 	memset(iov.iov_base, 0, iov.iov_len);
1170 
1171 	sve = iov.iov_base;
1172 	sve->size = iov.iov_len;
1173 	sve->flags = SVE_PT_REGS_SVE;
1174 	sve->vl = vl;
1175 
1176 	memcpy(iov.iov_base + SVE_PT_SVE_ZREG_OFFSET(vq, 0),
1177 	       z_expected, SVE_PT_SVE_ZREGS_SIZE(vq));
1178 	memcpy(iov.iov_base + SVE_PT_SVE_PREG_OFFSET(vq, 0),
1179 	       p_expected, SVE_PT_SVE_PREGS_SIZE(vq));
1180 	memcpy(iov.iov_base + SVE_PT_SVE_FFR_OFFSET(vq),
1181 	       ffr_expected, SVE_PT_SVE_PREG_SIZE(vq));
1182 
1183 	if (svcr_expected & SVCR_SM)
1184 		regset = NT_ARM_SSVE;
1185 	else
1186 		regset = NT_ARM_SVE;
1187 
1188 	ret = ptrace(PTRACE_SETREGSET, child, regset, &iov);
1189 	if (ret != 0)
1190 		ksft_print_msg("Failed to write SVE: %s (%d)\n",
1191 			       strerror(errno), errno);
1192 
1193 	free(iov.iov_base);
1194 }
1195 
1196 static bool za_write_supported(struct test_config *config)
1197 {
1198 	if (config->sme_vl_in != config->sme_vl_expected) {
1199 		/* Changing the SME VL exits streaming mode. */
1200 		if (config->svcr_expected & SVCR_SM) {
1201 			return false;
1202 		}
1203 	} else {
1204 		/* Otherwise we can't change streaming mode */
1205 		if ((config->svcr_in & SVCR_SM) !=
1206 		    (config->svcr_expected & SVCR_SM)) {
1207 			return false;
1208 		}
1209 	}
1210 
1211 	return true;
1212 }
1213 
1214 static void za_write_expected(struct test_config *config)
1215 {
1216 	int sme_vq, sve_vq;
1217 
1218 	sme_vq = __sve_vq_from_vl(config->sme_vl_expected);
1219 
1220 	if (config->svcr_expected & SVCR_ZA) {
1221 		fill_random(za_expected, ZA_PT_ZA_SIZE(sme_vq));
1222 	} else {
1223 		memset(za_expected, 0, ZA_PT_ZA_SIZE(sme_vq));
1224 		memset(zt_expected, 0, sizeof(zt_expected));
1225 	}
1226 
1227 	/* Changing the SME VL flushes ZT, SVE state and exits SM */
1228 	if (config->sme_vl_in != config->sme_vl_expected) {
1229 		svcr_expected &= ~SVCR_SM;
1230 
1231 		sve_vq = __sve_vq_from_vl(vl_expected(config));
1232 		memset(z_expected, 0, __SVE_ZREGS_SIZE(sve_vq));
1233 		memset(p_expected, 0, __SVE_PREGS_SIZE(sve_vq));
1234 		memset(ffr_expected, 0, __SVE_PREG_SIZE(sve_vq));
1235 		memset(zt_expected, 0, sizeof(zt_expected));
1236 
1237 		fpsimd_to_sve(v_expected, z_expected, vl_expected(config));
1238 	}
1239 }
1240 
1241 static void za_write(pid_t child, struct test_config *config)
1242 {
1243 	struct user_za_header *za;
1244 	struct iovec iov;
1245 	int ret, vq;
1246 
1247 	vq = __sve_vq_from_vl(config->sme_vl_expected);
1248 
1249 	if (config->svcr_expected & SVCR_ZA)
1250 		iov.iov_len = ZA_PT_SIZE(vq);
1251 	else
1252 		iov.iov_len = sizeof(*za);
1253 	iov.iov_base = malloc(iov.iov_len);
1254 	if (!iov.iov_base) {
1255 		ksft_print_msg("Failed allocating %lu byte ZA write buffer\n",
1256 			       iov.iov_len);
1257 		return;
1258 	}
1259 	memset(iov.iov_base, 0, iov.iov_len);
1260 
1261 	za = iov.iov_base;
1262 	za->size = iov.iov_len;
1263 	za->vl = config->sme_vl_expected;
1264 	if (config->svcr_expected & SVCR_ZA)
1265 		memcpy(iov.iov_base + ZA_PT_ZA_OFFSET, za_expected,
1266 		       ZA_PT_ZA_SIZE(vq));
1267 
1268 	ret = ptrace(PTRACE_SETREGSET, child, NT_ARM_ZA, &iov);
1269 	if (ret != 0)
1270 		ksft_print_msg("Failed to write ZA: %s (%d)\n",
1271 			       strerror(errno), errno);
1272 
1273 	free(iov.iov_base);
1274 }
1275 
1276 static bool zt_write_supported(struct test_config *config)
1277 {
1278 	if (!sme2_supported())
1279 		return false;
1280 	if (config->sme_vl_in != config->sme_vl_expected)
1281 		return false;
1282 	if (!(config->svcr_expected & SVCR_ZA))
1283 		return false;
1284 	if ((config->svcr_in & SVCR_SM) != (config->svcr_expected & SVCR_SM))
1285 		return false;
1286 
1287 	return true;
1288 }
1289 
1290 static void zt_write_expected(struct test_config *config)
1291 {
1292 	int sme_vq;
1293 
1294 	sme_vq = __sve_vq_from_vl(config->sme_vl_expected);
1295 
1296 	if (config->svcr_expected & SVCR_ZA) {
1297 		fill_random(zt_expected, sizeof(zt_expected));
1298 	} else {
1299 		memset(za_expected, 0, ZA_PT_ZA_SIZE(sme_vq));
1300 		memset(zt_expected, 0, sizeof(zt_expected));
1301 	}
1302 }
1303 
1304 static void zt_write(pid_t child, struct test_config *config)
1305 {
1306 	struct iovec iov;
1307 	int ret;
1308 
1309 	iov.iov_len = ZT_SIG_REG_BYTES;
1310 	iov.iov_base = zt_expected;
1311 	ret = ptrace(PTRACE_SETREGSET, child, NT_ARM_ZT, &iov);
1312 	if (ret != 0)
1313 		ksft_print_msg("Failed to write ZT: %s (%d)\n",
1314 			       strerror(errno), errno);
1315 }
1316 
1317 /* Actually run a test */
1318 static void run_test(struct test_definition *test, struct test_config *config)
1319 {
1320 	pid_t child;
1321 	char name[1024];
1322 	bool pass;
1323 
1324 	if (sve_supported() && sme_supported())
1325 		snprintf(name, sizeof(name), "%s, SVE %d->%d, SME %d/%x->%d/%x",
1326 			 test->name,
1327 			 config->sve_vl_in, config->sve_vl_expected,
1328 			 config->sme_vl_in, config->svcr_in,
1329 			 config->sme_vl_expected, config->svcr_expected);
1330 	else if (sve_supported())
1331 		snprintf(name, sizeof(name), "%s, SVE %d->%d", test->name,
1332 			 config->sve_vl_in, config->sve_vl_expected);
1333 	else if (sme_supported())
1334 		snprintf(name, sizeof(name), "%s, SME %d/%x->%d/%x",
1335 			 test->name,
1336 			 config->sme_vl_in, config->svcr_in,
1337 			 config->sme_vl_expected, config->svcr_expected);
1338 	else
1339 		snprintf(name, sizeof(name), "%s", test->name);
1340 
1341 	if (test->supported && !test->supported(config)) {
1342 		ksft_test_result_skip("%s\n", name);
1343 		return;
1344 	}
1345 
1346 	set_initial_values(config);
1347 
1348 	if (test->set_expected_values)
1349 		test->set_expected_values(config);
1350 
1351 	child = fork();
1352 	if (child < 0)
1353 		ksft_exit_fail_msg("fork() failed: %s (%d)\n",
1354 				   strerror(errno), errno);
1355 	/* run_child() never returns */
1356 	if (child == 0)
1357 		run_child(config);
1358 
1359 	pass = run_parent(child, test, config);
1360 	if (!check_memory_values(config))
1361 		pass = false;
1362 
1363 	ksft_test_result(pass, "%s\n", name);
1364 }
1365 
1366 static void run_tests(struct test_definition defs[], int count,
1367 		      struct test_config *config)
1368 {
1369 	int i;
1370 
1371 	for (i = 0; i < count; i++)
1372 		run_test(&defs[i], config);
1373 }
1374 
1375 static struct test_definition base_test_defs[] = {
1376 	{
1377 		.name = "No writes",
1378 		.supported = sve_sme_same,
1379 	},
1380 	{
1381 		.name = "FPSIMD write",
1382 		.supported = sve_sme_same,
1383 		.set_expected_values = fpsimd_write_expected,
1384 		.modify_values = fpsimd_write,
1385 	},
1386 	{
1387 		.name = "FPMR write",
1388 		.supported = fpmr_write_supported,
1389 		.set_expected_values = fpmr_write_expected,
1390 		.modify_values = fpmr_write,
1391 	},
1392 };
1393 
1394 static struct test_definition sve_test_defs[] = {
1395 	{
1396 		.name = "SVE write",
1397 		.supported = sve_write_supported,
1398 		.set_expected_values = sve_write_expected,
1399 		.modify_values = sve_write,
1400 	},
1401 };
1402 
1403 static struct test_definition za_test_defs[] = {
1404 	{
1405 		.name = "ZA write",
1406 		.supported = za_write_supported,
1407 		.set_expected_values = za_write_expected,
1408 		.modify_values = za_write,
1409 	},
1410 };
1411 
1412 static struct test_definition zt_test_defs[] = {
1413 	{
1414 		.name = "ZT write",
1415 		.supported = zt_write_supported,
1416 		.set_expected_values = zt_write_expected,
1417 		.modify_values = zt_write,
1418 	},
1419 };
1420 
1421 static int sve_vls[MAX_NUM_VLS], sme_vls[MAX_NUM_VLS];
1422 static int sve_vl_count, sme_vl_count;
1423 
1424 static void probe_vls(const char *name, int vls[], int *vl_count, int set_vl)
1425 {
1426 	unsigned int vq;
1427 	int vl;
1428 
1429 	*vl_count = 0;
1430 
1431 	for (vq = ARCH_VQ_MAX; vq > 0; vq /= 2) {
1432 		vl = prctl(set_vl, vq * 16);
1433 		if (vl == -1)
1434 			ksft_exit_fail_msg("SET_VL failed: %s (%d)\n",
1435 					   strerror(errno), errno);
1436 
1437 		vl &= PR_SVE_VL_LEN_MASK;
1438 
1439 		if (*vl_count && (vl == vls[*vl_count - 1]))
1440 			break;
1441 
1442 		vq = sve_vq_from_vl(vl);
1443 
1444 		vls[*vl_count] = vl;
1445 		*vl_count += 1;
1446 	}
1447 
1448 	if (*vl_count > 2) {
1449 		/* Just use the minimum and maximum */
1450 		vls[1] = vls[*vl_count - 1];
1451 		ksft_print_msg("%d %s VLs, using %d and %d\n",
1452 			       *vl_count, name, vls[0], vls[1]);
1453 		*vl_count = 2;
1454 	} else {
1455 		ksft_print_msg("%d %s VLs\n", *vl_count, name);
1456 	}
1457 }
1458 
1459 static struct {
1460 	int svcr_in, svcr_expected;
1461 } svcr_combinations[] = {
1462 	{ .svcr_in = 0, .svcr_expected = 0, },
1463 	{ .svcr_in = 0, .svcr_expected = SVCR_SM, },
1464 	{ .svcr_in = 0, .svcr_expected = SVCR_ZA, },
1465 	/* Can't enable both SM and ZA with a single ptrace write */
1466 
1467 	{ .svcr_in = SVCR_SM, .svcr_expected = 0, },
1468 	{ .svcr_in = SVCR_SM, .svcr_expected = SVCR_SM, },
1469 	{ .svcr_in = SVCR_SM, .svcr_expected = SVCR_ZA, },
1470 	{ .svcr_in = SVCR_SM, .svcr_expected = SVCR_SM | SVCR_ZA, },
1471 
1472 	{ .svcr_in = SVCR_ZA, .svcr_expected = 0, },
1473 	{ .svcr_in = SVCR_ZA, .svcr_expected = SVCR_SM, },
1474 	{ .svcr_in = SVCR_ZA, .svcr_expected = SVCR_ZA, },
1475 	{ .svcr_in = SVCR_ZA, .svcr_expected = SVCR_SM | SVCR_ZA, },
1476 
1477 	{ .svcr_in = SVCR_SM | SVCR_ZA, .svcr_expected = 0, },
1478 	{ .svcr_in = SVCR_SM | SVCR_ZA, .svcr_expected = SVCR_SM, },
1479 	{ .svcr_in = SVCR_SM | SVCR_ZA, .svcr_expected = SVCR_ZA, },
1480 	{ .svcr_in = SVCR_SM | SVCR_ZA, .svcr_expected = SVCR_SM | SVCR_ZA, },
1481 };
1482 
1483 static void run_sve_tests(void)
1484 {
1485 	struct test_config test_config;
1486 	int i, j;
1487 
1488 	if (!sve_supported())
1489 		return;
1490 
1491 	test_config.sme_vl_in = sme_vls[0];
1492 	test_config.sme_vl_expected = sme_vls[0];
1493 	test_config.svcr_in = 0;
1494 	test_config.svcr_expected = 0;
1495 
1496 	for (i = 0; i < sve_vl_count; i++) {
1497 		test_config.sve_vl_in = sve_vls[i];
1498 
1499 		for (j = 0; j < sve_vl_count; j++) {
1500 			test_config.sve_vl_expected = sve_vls[j];
1501 
1502 			run_tests(base_test_defs,
1503 				  ARRAY_SIZE(base_test_defs),
1504 				  &test_config);
1505 			if (sve_supported())
1506 				run_tests(sve_test_defs,
1507 					  ARRAY_SIZE(sve_test_defs),
1508 					  &test_config);
1509 		}
1510 	}
1511 
1512 }
1513 
1514 static void run_sme_tests(void)
1515 {
1516 	struct test_config test_config;
1517 	int i, j, k;
1518 
1519 	if (!sme_supported())
1520 		return;
1521 
1522 	test_config.sve_vl_in = sve_vls[0];
1523 	test_config.sve_vl_expected = sve_vls[0];
1524 
1525 	/*
1526 	 * Every SME VL/SVCR combination
1527 	 */
1528 	for (i = 0; i < sme_vl_count; i++) {
1529 		test_config.sme_vl_in = sme_vls[i];
1530 
1531 		for (j = 0; j < sme_vl_count; j++) {
1532 			test_config.sme_vl_expected = sme_vls[j];
1533 
1534 			for (k = 0; k < ARRAY_SIZE(svcr_combinations); k++) {
1535 				test_config.svcr_in = svcr_combinations[k].svcr_in;
1536 				test_config.svcr_expected = svcr_combinations[k].svcr_expected;
1537 
1538 				run_tests(base_test_defs,
1539 					  ARRAY_SIZE(base_test_defs),
1540 					  &test_config);
1541 				run_tests(sve_test_defs,
1542 					  ARRAY_SIZE(sve_test_defs),
1543 					  &test_config);
1544 				run_tests(za_test_defs,
1545 					  ARRAY_SIZE(za_test_defs),
1546 					  &test_config);
1547 
1548 				if (sme2_supported())
1549 					run_tests(zt_test_defs,
1550 						  ARRAY_SIZE(zt_test_defs),
1551 						  &test_config);
1552 			}
1553 		}
1554 	}
1555 }
1556 
1557 int main(void)
1558 {
1559 	struct test_config test_config;
1560 	struct sigaction sa;
1561 	int tests, ret, tmp;
1562 
1563 	srandom(getpid());
1564 
1565 	ksft_print_header();
1566 
1567 	if (sve_supported()) {
1568 		probe_vls("SVE", sve_vls, &sve_vl_count, PR_SVE_SET_VL);
1569 
1570 		tests = ARRAY_SIZE(base_test_defs) +
1571 			ARRAY_SIZE(sve_test_defs);
1572 		tests *= sve_vl_count * sve_vl_count;
1573 	} else {
1574 		/* Only run the FPSIMD tests */
1575 		sve_vl_count = 1;
1576 		tests = ARRAY_SIZE(base_test_defs);
1577 	}
1578 
1579 	if (sme_supported()) {
1580 		probe_vls("SME", sme_vls, &sme_vl_count, PR_SME_SET_VL);
1581 
1582 		tmp = ARRAY_SIZE(base_test_defs) + ARRAY_SIZE(sve_test_defs)
1583 			+ ARRAY_SIZE(za_test_defs);
1584 
1585 		if (sme2_supported())
1586 			tmp += ARRAY_SIZE(zt_test_defs);
1587 
1588 		tmp *= sme_vl_count * sme_vl_count;
1589 		tmp *= ARRAY_SIZE(svcr_combinations);
1590 		tests += tmp;
1591 	} else {
1592 		sme_vl_count = 1;
1593 	}
1594 
1595 	if (sme2_supported())
1596 		ksft_print_msg("SME2 supported\n");
1597 
1598 	if (fa64_supported())
1599 		ksft_print_msg("FA64 supported\n");
1600 
1601 	if (fpmr_supported())
1602 		ksft_print_msg("FPMR supported\n");
1603 
1604 	ksft_set_plan(tests);
1605 
1606 	/* Get signal handers ready before we start any children */
1607 	memset(&sa, 0, sizeof(sa));
1608 	sa.sa_sigaction = handle_alarm;
1609 	sa.sa_flags = SA_RESTART | SA_SIGINFO;
1610 	sigemptyset(&sa.sa_mask);
1611 	ret = sigaction(SIGALRM, &sa, NULL);
1612 	if (ret < 0)
1613 		ksft_print_msg("Failed to install SIGALRM handler: %s (%d)\n",
1614 			       strerror(errno), errno);
1615 
1616 	/*
1617 	 * Run the test set if there is no SVE or SME, with those we
1618 	 * have to pick a VL for each run.
1619 	 */
1620 	if (!sve_supported()) {
1621 		test_config.sve_vl_in = 0;
1622 		test_config.sve_vl_expected = 0;
1623 		test_config.sme_vl_in = 0;
1624 		test_config.sme_vl_expected = 0;
1625 		test_config.svcr_in = 0;
1626 		test_config.svcr_expected = 0;
1627 
1628 		run_tests(base_test_defs, ARRAY_SIZE(base_test_defs),
1629 			  &test_config);
1630 	}
1631 
1632 	run_sve_tests();
1633 	run_sme_tests();
1634 
1635 	ksft_finished();
1636 }
1637