xref: /linux/tools/testing/selftests/arm64/fp/zt-ptrace.c (revision 7f71507851fc7764b36a3221839607d3a45c2025)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Copyright (C) 2021 ARM Limited.
4  */
5 #include <errno.h>
6 #include <stdbool.h>
7 #include <stddef.h>
8 #include <stdio.h>
9 #include <stdlib.h>
10 #include <string.h>
11 #include <unistd.h>
12 #include <sys/auxv.h>
13 #include <sys/prctl.h>
14 #include <sys/ptrace.h>
15 #include <sys/types.h>
16 #include <sys/uio.h>
17 #include <sys/wait.h>
18 #include <asm/sigcontext.h>
19 #include <asm/ptrace.h>
20 
21 #include "../../kselftest.h"
22 
23 /* <linux/elf.h> and <sys/auxv.h> don't like each other, so: */
24 #ifndef NT_ARM_ZA
25 #define NT_ARM_ZA 0x40c
26 #endif
27 #ifndef NT_ARM_ZT
28 #define NT_ARM_ZT 0x40d
29 #endif
30 
31 #define EXPECTED_TESTS 3
32 
33 static int sme_vl;
34 
35 static void fill_buf(char *buf, size_t size)
36 {
37 	int i;
38 
39 	for (i = 0; i < size; i++)
40 		buf[i] = random();
41 }
42 
43 static int do_child(void)
44 {
45 	if (ptrace(PTRACE_TRACEME, -1, NULL, NULL))
46 		ksft_exit_fail_msg("ptrace(PTRACE_TRACEME) failed: %s (%d)\n",
47 				   strerror(errno), errno);
48 
49 	if (raise(SIGSTOP))
50 		ksft_exit_fail_msg("raise(SIGSTOP) failed: %s (%d)\n",
51 				   strerror(errno), errno);
52 
53 	return EXIT_SUCCESS;
54 }
55 
56 static struct user_za_header *get_za(pid_t pid, void **buf, size_t *size)
57 {
58 	struct user_za_header *za;
59 	void *p;
60 	size_t sz = sizeof(*za);
61 	struct iovec iov;
62 
63 	while (1) {
64 		if (*size < sz) {
65 			p = realloc(*buf, sz);
66 			if (!p) {
67 				errno = ENOMEM;
68 				goto error;
69 			}
70 
71 			*buf = p;
72 			*size = sz;
73 		}
74 
75 		iov.iov_base = *buf;
76 		iov.iov_len = sz;
77 		if (ptrace(PTRACE_GETREGSET, pid, NT_ARM_ZA, &iov))
78 			goto error;
79 
80 		za = *buf;
81 		if (za->size <= sz)
82 			break;
83 
84 		sz = za->size;
85 	}
86 
87 	return za;
88 
89 error:
90 	return NULL;
91 }
92 
93 static int set_za(pid_t pid, const struct user_za_header *za)
94 {
95 	struct iovec iov;
96 
97 	iov.iov_base = (void *)za;
98 	iov.iov_len = za->size;
99 	return ptrace(PTRACE_SETREGSET, pid, NT_ARM_ZA, &iov);
100 }
101 
102 static int get_zt(pid_t pid, char zt[ZT_SIG_REG_BYTES])
103 {
104 	struct iovec iov;
105 
106 	iov.iov_base = zt;
107 	iov.iov_len = ZT_SIG_REG_BYTES;
108 	return ptrace(PTRACE_GETREGSET, pid, NT_ARM_ZT, &iov);
109 }
110 
111 
112 static int set_zt(pid_t pid, const char zt[ZT_SIG_REG_BYTES])
113 {
114 	struct iovec iov;
115 
116 	iov.iov_base = (void *)zt;
117 	iov.iov_len = ZT_SIG_REG_BYTES;
118 	return ptrace(PTRACE_SETREGSET, pid, NT_ARM_ZT, &iov);
119 }
120 
121 /* Reading with ZA disabled returns all zeros */
122 static void ptrace_za_disabled_read_zt(pid_t child)
123 {
124 	struct user_za_header za;
125 	char zt[ZT_SIG_REG_BYTES];
126 	int ret, i;
127 	bool fail = false;
128 
129 	/* Disable PSTATE.ZA using the ZA interface */
130 	memset(&za, 0, sizeof(za));
131 	za.vl = sme_vl;
132 	za.size = sizeof(za);
133 
134 	ret = set_za(child, &za);
135 	if (ret != 0) {
136 		ksft_print_msg("Failed to disable ZA\n");
137 		fail = true;
138 	}
139 
140 	/* Read back ZT */
141 	ret = get_zt(child, zt);
142 	if (ret != 0) {
143 		ksft_print_msg("Failed to read ZT\n");
144 		fail = true;
145 	}
146 
147 	for (i = 0; i < ARRAY_SIZE(zt); i++) {
148 		if (zt[i]) {
149 			ksft_print_msg("zt[%d]: 0x%x != 0\n", i, zt[i]);
150 			fail = true;
151 		}
152 	}
153 
154 	ksft_test_result(!fail, "ptrace_za_disabled_read_zt\n");
155 }
156 
157 /* Writing then reading ZT should return the data written */
158 static void ptrace_set_get_zt(pid_t child)
159 {
160 	char zt_in[ZT_SIG_REG_BYTES];
161 	char zt_out[ZT_SIG_REG_BYTES];
162 	int ret, i;
163 	bool fail = false;
164 
165 	fill_buf(zt_in, sizeof(zt_in));
166 
167 	ret = set_zt(child, zt_in);
168 	if (ret != 0) {
169 		ksft_print_msg("Failed to set ZT\n");
170 		fail = true;
171 	}
172 
173 	ret = get_zt(child, zt_out);
174 	if (ret != 0) {
175 		ksft_print_msg("Failed to read ZT\n");
176 		fail = true;
177 	}
178 
179 	for (i = 0; i < ARRAY_SIZE(zt_in); i++) {
180 		if (zt_in[i] != zt_out[i]) {
181 			ksft_print_msg("zt[%d]: 0x%x != 0x%x\n", i,
182 				       zt_in[i], zt_out[i]);
183 			fail = true;
184 		}
185 	}
186 
187 	ksft_test_result(!fail, "ptrace_set_get_zt\n");
188 }
189 
190 /* Writing ZT should set PSTATE.ZA */
191 static void ptrace_enable_za_via_zt(pid_t child)
192 {
193 	struct user_za_header za_in;
194 	struct user_za_header *za_out;
195 	char zt[ZT_SIG_REG_BYTES];
196 	char *za_data;
197 	size_t za_out_size;
198 	int ret, i, vq;
199 	bool fail = false;
200 
201 	/* Disable PSTATE.ZA using the ZA interface */
202 	memset(&za_in, 0, sizeof(za_in));
203 	za_in.vl = sme_vl;
204 	za_in.size = sizeof(za_in);
205 
206 	ret = set_za(child, &za_in);
207 	if (ret != 0) {
208 		ksft_print_msg("Failed to disable ZA\n");
209 		fail = true;
210 	}
211 
212 	/* Write ZT */
213 	fill_buf(zt, sizeof(zt));
214 	ret = set_zt(child, zt);
215 	if (ret != 0) {
216 		ksft_print_msg("Failed to set ZT\n");
217 		fail = true;
218 	}
219 
220 	/* Read back ZA and check for register data */
221 	za_out = NULL;
222 	za_out_size = 0;
223 	if (get_za(child, (void **)&za_out, &za_out_size)) {
224 		/* Should have an unchanged VL */
225 		if (za_out->vl != sme_vl) {
226 			ksft_print_msg("VL changed from %d to %d\n",
227 				       sme_vl, za_out->vl);
228 			fail = true;
229 		}
230 		vq = __sve_vq_from_vl(za_out->vl);
231 		za_data = (char *)za_out + ZA_PT_ZA_OFFSET;
232 
233 		/* Should have register data */
234 		if (za_out->size < ZA_PT_SIZE(vq)) {
235 			ksft_print_msg("ZA data less than expected: %u < %u\n",
236 				       za_out->size, (unsigned int)ZA_PT_SIZE(vq));
237 			fail = true;
238 			vq = 0;
239 		}
240 
241 		/* That register data should be non-zero */
242 		for (i = 0; i < ZA_PT_ZA_SIZE(vq); i++) {
243 			if (za_data[i]) {
244 				ksft_print_msg("ZA byte %d is %x\n",
245 					       i, za_data[i]);
246 				fail = true;
247 			}
248 		}
249 	} else {
250 		ksft_print_msg("Failed to read ZA\n");
251 		fail = true;
252 	}
253 
254 	ksft_test_result(!fail, "ptrace_enable_za_via_zt\n");
255 }
256 
257 static int do_parent(pid_t child)
258 {
259 	int ret = EXIT_FAILURE;
260 	pid_t pid;
261 	int status;
262 	siginfo_t si;
263 
264 	/* Attach to the child */
265 	while (1) {
266 		int sig;
267 
268 		pid = wait(&status);
269 		if (pid == -1) {
270 			perror("wait");
271 			goto error;
272 		}
273 
274 		/*
275 		 * This should never happen but it's hard to flag in
276 		 * the framework.
277 		 */
278 		if (pid != child)
279 			continue;
280 
281 		if (WIFEXITED(status) || WIFSIGNALED(status))
282 			ksft_exit_fail_msg("Child died unexpectedly\n");
283 
284 		if (!WIFSTOPPED(status))
285 			goto error;
286 
287 		sig = WSTOPSIG(status);
288 
289 		if (ptrace(PTRACE_GETSIGINFO, pid, NULL, &si)) {
290 			if (errno == ESRCH)
291 				goto disappeared;
292 
293 			if (errno == EINVAL) {
294 				sig = 0; /* bust group-stop */
295 				goto cont;
296 			}
297 
298 			ksft_test_result_fail("PTRACE_GETSIGINFO: %s\n",
299 					      strerror(errno));
300 			goto error;
301 		}
302 
303 		if (sig == SIGSTOP && si.si_code == SI_TKILL &&
304 		    si.si_pid == pid)
305 			break;
306 
307 	cont:
308 		if (ptrace(PTRACE_CONT, pid, NULL, sig)) {
309 			if (errno == ESRCH)
310 				goto disappeared;
311 
312 			ksft_test_result_fail("PTRACE_CONT: %s\n",
313 					      strerror(errno));
314 			goto error;
315 		}
316 	}
317 
318 	ksft_print_msg("Parent is %d, child is %d\n", getpid(), child);
319 
320 	ptrace_za_disabled_read_zt(child);
321 	ptrace_set_get_zt(child);
322 	ptrace_enable_za_via_zt(child);
323 
324 	ret = EXIT_SUCCESS;
325 
326 error:
327 	kill(child, SIGKILL);
328 
329 disappeared:
330 	return ret;
331 }
332 
333 int main(void)
334 {
335 	int ret = EXIT_SUCCESS;
336 	pid_t child;
337 
338 	srandom(getpid());
339 
340 	ksft_print_header();
341 
342 	if (!(getauxval(AT_HWCAP2) & HWCAP2_SME2)) {
343 		ksft_set_plan(1);
344 		ksft_exit_skip("SME2 not available\n");
345 	}
346 
347 	/* We need a valid SME VL to enable/disable ZA */
348 	sme_vl = prctl(PR_SME_GET_VL);
349 	if (sme_vl == -1) {
350 		ksft_set_plan(1);
351 		ksft_exit_skip("Failed to read SME VL: %d (%s)\n",
352 			       errno, strerror(errno));
353 	}
354 
355 	ksft_set_plan(EXPECTED_TESTS);
356 
357 	child = fork();
358 	if (!child)
359 		return do_child();
360 
361 	if (do_parent(child))
362 		ret = EXIT_FAILURE;
363 
364 	ksft_print_cnts();
365 
366 	return ret;
367 }
368