xref: /linux/tools/testing/selftests/arm64/fp/zt-ptrace.c (revision feafee284579d29537a5a56ba8f23894f0463f3d)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Copyright (C) 2021 ARM Limited.
4  */
5 #include <errno.h>
6 #include <stdbool.h>
7 #include <stddef.h>
8 #include <stdio.h>
9 #include <stdlib.h>
10 #include <string.h>
11 #include <unistd.h>
12 #include <sys/auxv.h>
13 #include <sys/prctl.h>
14 #include <sys/ptrace.h>
15 #include <sys/types.h>
16 #include <sys/uio.h>
17 #include <sys/wait.h>
18 #include <asm/sigcontext.h>
19 #include <asm/ptrace.h>
20 
21 #include "../../kselftest.h"
22 
23 /* <linux/elf.h> and <sys/auxv.h> don't like each other, so: */
24 #ifndef NT_ARM_ZA
25 #define NT_ARM_ZA 0x40c
26 #endif
27 #ifndef NT_ARM_ZT
28 #define NT_ARM_ZT 0x40d
29 #endif
30 
31 #define EXPECTED_TESTS 3
32 
33 static int sme_vl;
34 
fill_buf(char * buf,size_t size)35 static void fill_buf(char *buf, size_t size)
36 {
37 	int i;
38 
39 	for (i = 0; i < size; i++)
40 		buf[i] = random();
41 }
42 
do_child(void)43 static int do_child(void)
44 {
45 	if (ptrace(PTRACE_TRACEME, -1, NULL, NULL))
46 		ksft_exit_fail_msg("ptrace(PTRACE_TRACEME) failed: %s (%d)\n",
47 				   strerror(errno), errno);
48 
49 	if (raise(SIGSTOP))
50 		ksft_exit_fail_msg("raise(SIGSTOP) failed: %s (%d)\n",
51 				   strerror(errno), errno);
52 
53 	return EXIT_SUCCESS;
54 }
55 
get_za(pid_t pid,void ** buf,size_t * size)56 static struct user_za_header *get_za(pid_t pid, void **buf, size_t *size)
57 {
58 	struct user_za_header *za;
59 	void *p;
60 	size_t sz = sizeof(*za);
61 	struct iovec iov;
62 
63 	while (1) {
64 		if (*size < sz) {
65 			p = realloc(*buf, sz);
66 			if (!p) {
67 				errno = ENOMEM;
68 				goto error;
69 			}
70 
71 			*buf = p;
72 			*size = sz;
73 		}
74 
75 		iov.iov_base = *buf;
76 		iov.iov_len = sz;
77 		if (ptrace(PTRACE_GETREGSET, pid, NT_ARM_ZA, &iov))
78 			goto error;
79 
80 		za = *buf;
81 		if (za->size <= sz)
82 			break;
83 
84 		sz = za->size;
85 	}
86 
87 	return za;
88 
89 error:
90 	return NULL;
91 }
92 
set_za(pid_t pid,const struct user_za_header * za)93 static int set_za(pid_t pid, const struct user_za_header *za)
94 {
95 	struct iovec iov;
96 
97 	iov.iov_base = (void *)za;
98 	iov.iov_len = za->size;
99 	return ptrace(PTRACE_SETREGSET, pid, NT_ARM_ZA, &iov);
100 }
101 
get_zt(pid_t pid,char zt[ZT_SIG_REG_BYTES])102 static int get_zt(pid_t pid, char zt[ZT_SIG_REG_BYTES])
103 {
104 	struct iovec iov;
105 
106 	iov.iov_base = zt;
107 	iov.iov_len = ZT_SIG_REG_BYTES;
108 	return ptrace(PTRACE_GETREGSET, pid, NT_ARM_ZT, &iov);
109 }
110 
set_zt(pid_t pid,const char zt[ZT_SIG_REG_BYTES])111 static int set_zt(pid_t pid, const char zt[ZT_SIG_REG_BYTES])
112 {
113 	struct iovec iov;
114 
115 	iov.iov_base = (void *)zt;
116 	iov.iov_len = ZT_SIG_REG_BYTES;
117 	return ptrace(PTRACE_SETREGSET, pid, NT_ARM_ZT, &iov);
118 }
119 
120 /* Reading with ZA disabled returns all zeros */
ptrace_za_disabled_read_zt(pid_t child)121 static void ptrace_za_disabled_read_zt(pid_t child)
122 {
123 	struct user_za_header za;
124 	char zt[ZT_SIG_REG_BYTES];
125 	int ret, i;
126 	bool fail = false;
127 
128 	/* Disable PSTATE.ZA using the ZA interface */
129 	memset(&za, 0, sizeof(za));
130 	za.vl = sme_vl;
131 	za.size = sizeof(za);
132 
133 	ret = set_za(child, &za);
134 	if (ret != 0) {
135 		ksft_print_msg("Failed to disable ZA\n");
136 		fail = true;
137 	}
138 
139 	/* Read back ZT */
140 	ret = get_zt(child, zt);
141 	if (ret != 0) {
142 		ksft_print_msg("Failed to read ZT\n");
143 		fail = true;
144 	}
145 
146 	for (i = 0; i < ARRAY_SIZE(zt); i++) {
147 		if (zt[i]) {
148 			ksft_print_msg("zt[%d]: 0x%x != 0\n", i, zt[i]);
149 			fail = true;
150 		}
151 	}
152 
153 	ksft_test_result(!fail, "ptrace_za_disabled_read_zt\n");
154 }
155 
156 /* Writing then reading ZT should return the data written */
ptrace_set_get_zt(pid_t child)157 static void ptrace_set_get_zt(pid_t child)
158 {
159 	char zt_in[ZT_SIG_REG_BYTES];
160 	char zt_out[ZT_SIG_REG_BYTES];
161 	int ret, i;
162 	bool fail = false;
163 
164 	fill_buf(zt_in, sizeof(zt_in));
165 
166 	ret = set_zt(child, zt_in);
167 	if (ret != 0) {
168 		ksft_print_msg("Failed to set ZT\n");
169 		fail = true;
170 	}
171 
172 	ret = get_zt(child, zt_out);
173 	if (ret != 0) {
174 		ksft_print_msg("Failed to read ZT\n");
175 		fail = true;
176 	}
177 
178 	for (i = 0; i < ARRAY_SIZE(zt_in); i++) {
179 		if (zt_in[i] != zt_out[i]) {
180 			ksft_print_msg("zt[%d]: 0x%x != 0x%x\n", i,
181 				       zt_in[i], zt_out[i]);
182 			fail = true;
183 		}
184 	}
185 
186 	ksft_test_result(!fail, "ptrace_set_get_zt\n");
187 }
188 
189 /* Writing ZT should set PSTATE.ZA */
ptrace_enable_za_via_zt(pid_t child)190 static void ptrace_enable_za_via_zt(pid_t child)
191 {
192 	struct user_za_header za_in;
193 	struct user_za_header *za_out;
194 	char zt[ZT_SIG_REG_BYTES];
195 	char *za_data;
196 	size_t za_out_size;
197 	int ret, i, vq;
198 	bool fail = false;
199 
200 	/* Disable PSTATE.ZA using the ZA interface */
201 	memset(&za_in, 0, sizeof(za_in));
202 	za_in.vl = sme_vl;
203 	za_in.size = sizeof(za_in);
204 
205 	ret = set_za(child, &za_in);
206 	if (ret != 0) {
207 		ksft_print_msg("Failed to disable ZA\n");
208 		fail = true;
209 	}
210 
211 	/* Write ZT */
212 	fill_buf(zt, sizeof(zt));
213 	ret = set_zt(child, zt);
214 	if (ret != 0) {
215 		ksft_print_msg("Failed to set ZT\n");
216 		fail = true;
217 	}
218 
219 	/* Read back ZA and check for register data */
220 	za_out = NULL;
221 	za_out_size = 0;
222 	if (get_za(child, (void **)&za_out, &za_out_size)) {
223 		/* Should have an unchanged VL */
224 		if (za_out->vl != sme_vl) {
225 			ksft_print_msg("VL changed from %d to %d\n",
226 				       sme_vl, za_out->vl);
227 			fail = true;
228 		}
229 		vq = __sve_vq_from_vl(za_out->vl);
230 		za_data = (char *)za_out + ZA_PT_ZA_OFFSET;
231 
232 		/* Should have register data */
233 		if (za_out->size < ZA_PT_SIZE(vq)) {
234 			ksft_print_msg("ZA data less than expected: %u < %u\n",
235 				       za_out->size, (unsigned int)ZA_PT_SIZE(vq));
236 			fail = true;
237 			vq = 0;
238 		}
239 
240 		/* That register data should be non-zero */
241 		for (i = 0; i < ZA_PT_ZA_SIZE(vq); i++) {
242 			if (za_data[i]) {
243 				ksft_print_msg("ZA byte %d is %x\n",
244 					       i, za_data[i]);
245 				fail = true;
246 			}
247 		}
248 	} else {
249 		ksft_print_msg("Failed to read ZA\n");
250 		fail = true;
251 	}
252 
253 	ksft_test_result(!fail, "ptrace_enable_za_via_zt\n");
254 }
255 
do_parent(pid_t child)256 static int do_parent(pid_t child)
257 {
258 	int ret = EXIT_FAILURE;
259 	pid_t pid;
260 	int status;
261 	siginfo_t si;
262 
263 	/* Attach to the child */
264 	while (1) {
265 		int sig;
266 
267 		pid = wait(&status);
268 		if (pid == -1) {
269 			perror("wait");
270 			goto error;
271 		}
272 
273 		/*
274 		 * This should never happen but it's hard to flag in
275 		 * the framework.
276 		 */
277 		if (pid != child)
278 			continue;
279 
280 		if (WIFEXITED(status) || WIFSIGNALED(status))
281 			ksft_exit_fail_msg("Child died unexpectedly\n");
282 
283 		if (!WIFSTOPPED(status))
284 			goto error;
285 
286 		sig = WSTOPSIG(status);
287 
288 		if (ptrace(PTRACE_GETSIGINFO, pid, NULL, &si)) {
289 			if (errno == ESRCH)
290 				goto disappeared;
291 
292 			if (errno == EINVAL) {
293 				sig = 0; /* bust group-stop */
294 				goto cont;
295 			}
296 
297 			ksft_test_result_fail("PTRACE_GETSIGINFO: %s\n",
298 					      strerror(errno));
299 			goto error;
300 		}
301 
302 		if (sig == SIGSTOP && si.si_code == SI_TKILL &&
303 		    si.si_pid == pid)
304 			break;
305 
306 	cont:
307 		if (ptrace(PTRACE_CONT, pid, NULL, sig)) {
308 			if (errno == ESRCH)
309 				goto disappeared;
310 
311 			ksft_test_result_fail("PTRACE_CONT: %s\n",
312 					      strerror(errno));
313 			goto error;
314 		}
315 	}
316 
317 	ksft_print_msg("Parent is %d, child is %d\n", getpid(), child);
318 
319 	ptrace_za_disabled_read_zt(child);
320 	ptrace_set_get_zt(child);
321 	ptrace_enable_za_via_zt(child);
322 
323 	ret = EXIT_SUCCESS;
324 
325 error:
326 	kill(child, SIGKILL);
327 
328 disappeared:
329 	return ret;
330 }
331 
main(void)332 int main(void)
333 {
334 	int ret = EXIT_SUCCESS;
335 	pid_t child;
336 
337 	srandom(getpid());
338 
339 	ksft_print_header();
340 
341 	if (!(getauxval(AT_HWCAP2) & HWCAP2_SME2)) {
342 		ksft_set_plan(1);
343 		ksft_exit_skip("SME2 not available\n");
344 	}
345 
346 	/* We need a valid SME VL to enable/disable ZA */
347 	sme_vl = prctl(PR_SME_GET_VL);
348 	if (sme_vl == -1) {
349 		ksft_set_plan(1);
350 		ksft_exit_skip("Failed to read SME VL: %d (%s)\n",
351 			       errno, strerror(errno));
352 	}
353 
354 	ksft_set_plan(EXPECTED_TESTS);
355 
356 	child = fork();
357 	if (!child)
358 		return do_child();
359 
360 	if (do_parent(child))
361 		ret = EXIT_FAILURE;
362 
363 	ksft_print_cnts();
364 
365 	return ret;
366 }
367