1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3 * Copyright (C) 2021 ARM Limited.
4 */
5 #include <errno.h>
6 #include <stdbool.h>
7 #include <stddef.h>
8 #include <stdio.h>
9 #include <stdlib.h>
10 #include <string.h>
11 #include <unistd.h>
12 #include <sys/auxv.h>
13 #include <sys/prctl.h>
14 #include <sys/ptrace.h>
15 #include <sys/types.h>
16 #include <sys/uio.h>
17 #include <sys/wait.h>
18 #include <asm/sigcontext.h>
19 #include <asm/ptrace.h>
20
21 #include "../../kselftest.h"
22
23 /* <linux/elf.h> and <sys/auxv.h> don't like each other, so: */
24 #ifndef NT_ARM_ZA
25 #define NT_ARM_ZA 0x40c
26 #endif
27 #ifndef NT_ARM_ZT
28 #define NT_ARM_ZT 0x40d
29 #endif
30
31 #define EXPECTED_TESTS 3
32
33 static int sme_vl;
34
fill_buf(char * buf,size_t size)35 static void fill_buf(char *buf, size_t size)
36 {
37 int i;
38
39 for (i = 0; i < size; i++)
40 buf[i] = random();
41 }
42
do_child(void)43 static int do_child(void)
44 {
45 if (ptrace(PTRACE_TRACEME, -1, NULL, NULL))
46 ksft_exit_fail_msg("ptrace(PTRACE_TRACEME) failed: %s (%d)\n",
47 strerror(errno), errno);
48
49 if (raise(SIGSTOP))
50 ksft_exit_fail_msg("raise(SIGSTOP) failed: %s (%d)\n",
51 strerror(errno), errno);
52
53 return EXIT_SUCCESS;
54 }
55
get_za(pid_t pid,void ** buf,size_t * size)56 static struct user_za_header *get_za(pid_t pid, void **buf, size_t *size)
57 {
58 struct user_za_header *za;
59 void *p;
60 size_t sz = sizeof(*za);
61 struct iovec iov;
62
63 while (1) {
64 if (*size < sz) {
65 p = realloc(*buf, sz);
66 if (!p) {
67 errno = ENOMEM;
68 goto error;
69 }
70
71 *buf = p;
72 *size = sz;
73 }
74
75 iov.iov_base = *buf;
76 iov.iov_len = sz;
77 if (ptrace(PTRACE_GETREGSET, pid, NT_ARM_ZA, &iov))
78 goto error;
79
80 za = *buf;
81 if (za->size <= sz)
82 break;
83
84 sz = za->size;
85 }
86
87 return za;
88
89 error:
90 return NULL;
91 }
92
set_za(pid_t pid,const struct user_za_header * za)93 static int set_za(pid_t pid, const struct user_za_header *za)
94 {
95 struct iovec iov;
96
97 iov.iov_base = (void *)za;
98 iov.iov_len = za->size;
99 return ptrace(PTRACE_SETREGSET, pid, NT_ARM_ZA, &iov);
100 }
101
get_zt(pid_t pid,char zt[ZT_SIG_REG_BYTES])102 static int get_zt(pid_t pid, char zt[ZT_SIG_REG_BYTES])
103 {
104 struct iovec iov;
105
106 iov.iov_base = zt;
107 iov.iov_len = ZT_SIG_REG_BYTES;
108 return ptrace(PTRACE_GETREGSET, pid, NT_ARM_ZT, &iov);
109 }
110
set_zt(pid_t pid,const char zt[ZT_SIG_REG_BYTES])111 static int set_zt(pid_t pid, const char zt[ZT_SIG_REG_BYTES])
112 {
113 struct iovec iov;
114
115 iov.iov_base = (void *)zt;
116 iov.iov_len = ZT_SIG_REG_BYTES;
117 return ptrace(PTRACE_SETREGSET, pid, NT_ARM_ZT, &iov);
118 }
119
120 /* Reading with ZA disabled returns all zeros */
ptrace_za_disabled_read_zt(pid_t child)121 static void ptrace_za_disabled_read_zt(pid_t child)
122 {
123 struct user_za_header za;
124 char zt[ZT_SIG_REG_BYTES];
125 int ret, i;
126 bool fail = false;
127
128 /* Disable PSTATE.ZA using the ZA interface */
129 memset(&za, 0, sizeof(za));
130 za.vl = sme_vl;
131 za.size = sizeof(za);
132
133 ret = set_za(child, &za);
134 if (ret != 0) {
135 ksft_print_msg("Failed to disable ZA\n");
136 fail = true;
137 }
138
139 /* Read back ZT */
140 ret = get_zt(child, zt);
141 if (ret != 0) {
142 ksft_print_msg("Failed to read ZT\n");
143 fail = true;
144 }
145
146 for (i = 0; i < ARRAY_SIZE(zt); i++) {
147 if (zt[i]) {
148 ksft_print_msg("zt[%d]: 0x%x != 0\n", i, zt[i]);
149 fail = true;
150 }
151 }
152
153 ksft_test_result(!fail, "ptrace_za_disabled_read_zt\n");
154 }
155
156 /* Writing then reading ZT should return the data written */
ptrace_set_get_zt(pid_t child)157 static void ptrace_set_get_zt(pid_t child)
158 {
159 char zt_in[ZT_SIG_REG_BYTES];
160 char zt_out[ZT_SIG_REG_BYTES];
161 int ret, i;
162 bool fail = false;
163
164 fill_buf(zt_in, sizeof(zt_in));
165
166 ret = set_zt(child, zt_in);
167 if (ret != 0) {
168 ksft_print_msg("Failed to set ZT\n");
169 fail = true;
170 }
171
172 ret = get_zt(child, zt_out);
173 if (ret != 0) {
174 ksft_print_msg("Failed to read ZT\n");
175 fail = true;
176 }
177
178 for (i = 0; i < ARRAY_SIZE(zt_in); i++) {
179 if (zt_in[i] != zt_out[i]) {
180 ksft_print_msg("zt[%d]: 0x%x != 0x%x\n", i,
181 zt_in[i], zt_out[i]);
182 fail = true;
183 }
184 }
185
186 ksft_test_result(!fail, "ptrace_set_get_zt\n");
187 }
188
189 /* Writing ZT should set PSTATE.ZA */
ptrace_enable_za_via_zt(pid_t child)190 static void ptrace_enable_za_via_zt(pid_t child)
191 {
192 struct user_za_header za_in;
193 struct user_za_header *za_out;
194 char zt[ZT_SIG_REG_BYTES];
195 char *za_data;
196 size_t za_out_size;
197 int ret, i, vq;
198 bool fail = false;
199
200 /* Disable PSTATE.ZA using the ZA interface */
201 memset(&za_in, 0, sizeof(za_in));
202 za_in.vl = sme_vl;
203 za_in.size = sizeof(za_in);
204
205 ret = set_za(child, &za_in);
206 if (ret != 0) {
207 ksft_print_msg("Failed to disable ZA\n");
208 fail = true;
209 }
210
211 /* Write ZT */
212 fill_buf(zt, sizeof(zt));
213 ret = set_zt(child, zt);
214 if (ret != 0) {
215 ksft_print_msg("Failed to set ZT\n");
216 fail = true;
217 }
218
219 /* Read back ZA and check for register data */
220 za_out = NULL;
221 za_out_size = 0;
222 if (get_za(child, (void **)&za_out, &za_out_size)) {
223 /* Should have an unchanged VL */
224 if (za_out->vl != sme_vl) {
225 ksft_print_msg("VL changed from %d to %d\n",
226 sme_vl, za_out->vl);
227 fail = true;
228 }
229 vq = __sve_vq_from_vl(za_out->vl);
230 za_data = (char *)za_out + ZA_PT_ZA_OFFSET;
231
232 /* Should have register data */
233 if (za_out->size < ZA_PT_SIZE(vq)) {
234 ksft_print_msg("ZA data less than expected: %u < %u\n",
235 za_out->size, (unsigned int)ZA_PT_SIZE(vq));
236 fail = true;
237 vq = 0;
238 }
239
240 /* That register data should be non-zero */
241 for (i = 0; i < ZA_PT_ZA_SIZE(vq); i++) {
242 if (za_data[i]) {
243 ksft_print_msg("ZA byte %d is %x\n",
244 i, za_data[i]);
245 fail = true;
246 }
247 }
248 } else {
249 ksft_print_msg("Failed to read ZA\n");
250 fail = true;
251 }
252
253 ksft_test_result(!fail, "ptrace_enable_za_via_zt\n");
254 }
255
do_parent(pid_t child)256 static int do_parent(pid_t child)
257 {
258 int ret = EXIT_FAILURE;
259 pid_t pid;
260 int status;
261 siginfo_t si;
262
263 /* Attach to the child */
264 while (1) {
265 int sig;
266
267 pid = wait(&status);
268 if (pid == -1) {
269 perror("wait");
270 goto error;
271 }
272
273 /*
274 * This should never happen but it's hard to flag in
275 * the framework.
276 */
277 if (pid != child)
278 continue;
279
280 if (WIFEXITED(status) || WIFSIGNALED(status))
281 ksft_exit_fail_msg("Child died unexpectedly\n");
282
283 if (!WIFSTOPPED(status))
284 goto error;
285
286 sig = WSTOPSIG(status);
287
288 if (ptrace(PTRACE_GETSIGINFO, pid, NULL, &si)) {
289 if (errno == ESRCH)
290 goto disappeared;
291
292 if (errno == EINVAL) {
293 sig = 0; /* bust group-stop */
294 goto cont;
295 }
296
297 ksft_test_result_fail("PTRACE_GETSIGINFO: %s\n",
298 strerror(errno));
299 goto error;
300 }
301
302 if (sig == SIGSTOP && si.si_code == SI_TKILL &&
303 si.si_pid == pid)
304 break;
305
306 cont:
307 if (ptrace(PTRACE_CONT, pid, NULL, sig)) {
308 if (errno == ESRCH)
309 goto disappeared;
310
311 ksft_test_result_fail("PTRACE_CONT: %s\n",
312 strerror(errno));
313 goto error;
314 }
315 }
316
317 ksft_print_msg("Parent is %d, child is %d\n", getpid(), child);
318
319 ptrace_za_disabled_read_zt(child);
320 ptrace_set_get_zt(child);
321 ptrace_enable_za_via_zt(child);
322
323 ret = EXIT_SUCCESS;
324
325 error:
326 kill(child, SIGKILL);
327
328 disappeared:
329 return ret;
330 }
331
main(void)332 int main(void)
333 {
334 int ret = EXIT_SUCCESS;
335 pid_t child;
336
337 srandom(getpid());
338
339 ksft_print_header();
340
341 if (!(getauxval(AT_HWCAP2) & HWCAP2_SME2)) {
342 ksft_set_plan(1);
343 ksft_exit_skip("SME2 not available\n");
344 }
345
346 /* We need a valid SME VL to enable/disable ZA */
347 sme_vl = prctl(PR_SME_GET_VL);
348 if (sme_vl == -1) {
349 ksft_set_plan(1);
350 ksft_exit_skip("Failed to read SME VL: %d (%s)\n",
351 errno, strerror(errno));
352 }
353
354 ksft_set_plan(EXPECTED_TESTS);
355
356 child = fork();
357 if (!child)
358 return do_child();
359
360 if (do_parent(child))
361 ret = EXIT_FAILURE;
362
363 ksft_print_cnts();
364
365 return ret;
366 }
367