1 // SPDX-License-Identifier: GPL-2.0-only 2 /* 3 * Copyright (C) 2021 ARM Limited. 4 */ 5 6 #include <errno.h> 7 #include <stdbool.h> 8 #include <stddef.h> 9 #include <stdio.h> 10 #include <stdlib.h> 11 #include <string.h> 12 #include <unistd.h> 13 #include <sys/auxv.h> 14 #include <sys/prctl.h> 15 #include <asm/hwcap.h> 16 #include <asm/sigcontext.h> 17 #include <asm/unistd.h> 18 19 #include "../../kselftest.h" 20 21 #include "syscall-abi.h" 22 23 /* 24 * The kernel defines a much larger SVE_VQ_MAX than is expressable in 25 * the architecture, this creates a *lot* of overhead filling the 26 * buffers (especially ZA) on emulated platforms so use the actual 27 * architectural maximum instead. 28 */ 29 #define ARCH_SVE_VQ_MAX 16 30 31 static int default_sme_vl; 32 33 static int sve_vl_count; 34 static unsigned int sve_vls[ARCH_SVE_VQ_MAX]; 35 static int sme_vl_count; 36 static unsigned int sme_vls[ARCH_SVE_VQ_MAX]; 37 38 extern void do_syscall(int sve_vl, int sme_vl); 39 40 static void fill_random(void *buf, size_t size) 41 { 42 int i; 43 uint32_t *lbuf = buf; 44 45 /* random() returns a 32 bit number regardless of the size of long */ 46 for (i = 0; i < size / sizeof(uint32_t); i++) 47 lbuf[i] = random(); 48 } 49 50 /* 51 * We also repeat the test for several syscalls to try to expose different 52 * behaviour. 53 */ 54 static struct syscall_cfg { 55 int syscall_nr; 56 const char *name; 57 } syscalls[] = { 58 { __NR_getpid, "getpid()" }, 59 { __NR_sched_yield, "sched_yield()" }, 60 }; 61 62 #define NUM_GPR 31 63 uint64_t gpr_in[NUM_GPR]; 64 uint64_t gpr_out[NUM_GPR]; 65 66 static void setup_gpr(struct syscall_cfg *cfg, int sve_vl, int sme_vl, 67 uint64_t svcr) 68 { 69 fill_random(gpr_in, sizeof(gpr_in)); 70 gpr_in[8] = cfg->syscall_nr; 71 memset(gpr_out, 0, sizeof(gpr_out)); 72 } 73 74 static int check_gpr(struct syscall_cfg *cfg, int sve_vl, int sme_vl, uint64_t svcr) 75 { 76 int errors = 0; 77 int i; 78 79 /* 80 * GPR x0-x7 may be clobbered, and all others should be preserved. 81 */ 82 for (i = 9; i < ARRAY_SIZE(gpr_in); i++) { 83 if (gpr_in[i] != gpr_out[i]) { 84 ksft_print_msg("%s SVE VL %d mismatch in GPR %d: %llx != %llx\n", 85 cfg->name, sve_vl, i, 86 gpr_in[i], gpr_out[i]); 87 errors++; 88 } 89 } 90 91 return errors; 92 } 93 94 #define NUM_FPR 32 95 uint64_t fpr_in[NUM_FPR * 2]; 96 uint64_t fpr_out[NUM_FPR * 2]; 97 uint64_t fpr_zero[NUM_FPR * 2]; 98 99 static void setup_fpr(struct syscall_cfg *cfg, int sve_vl, int sme_vl, 100 uint64_t svcr) 101 { 102 fill_random(fpr_in, sizeof(fpr_in)); 103 memset(fpr_out, 0, sizeof(fpr_out)); 104 } 105 106 static int check_fpr(struct syscall_cfg *cfg, int sve_vl, int sme_vl, 107 uint64_t svcr) 108 { 109 int errors = 0; 110 int i; 111 112 if (!sve_vl && !(svcr & SVCR_SM_MASK)) { 113 for (i = 0; i < ARRAY_SIZE(fpr_in); i++) { 114 if (fpr_in[i] != fpr_out[i]) { 115 ksft_print_msg("%s Q%d/%d mismatch %llx != %llx\n", 116 cfg->name, 117 i / 2, i % 2, 118 fpr_in[i], fpr_out[i]); 119 errors++; 120 } 121 } 122 } 123 124 /* 125 * In streaming mode the whole register set should be cleared 126 * by the transition out of streaming mode. 127 */ 128 if (svcr & SVCR_SM_MASK) { 129 if (memcmp(fpr_zero, fpr_out, sizeof(fpr_out)) != 0) { 130 ksft_print_msg("%s FPSIMD registers non-zero exiting SM\n", 131 cfg->name); 132 errors++; 133 } 134 } 135 136 return errors; 137 } 138 139 #define SVE_Z_SHARED_BYTES (128 / 8) 140 141 static uint8_t z_zero[__SVE_ZREG_SIZE(ARCH_SVE_VQ_MAX)]; 142 uint8_t z_in[SVE_NUM_ZREGS * __SVE_ZREG_SIZE(ARCH_SVE_VQ_MAX)]; 143 uint8_t z_out[SVE_NUM_ZREGS * __SVE_ZREG_SIZE(ARCH_SVE_VQ_MAX)]; 144 145 static void setup_z(struct syscall_cfg *cfg, int sve_vl, int sme_vl, 146 uint64_t svcr) 147 { 148 fill_random(z_in, sizeof(z_in)); 149 fill_random(z_out, sizeof(z_out)); 150 } 151 152 static int check_z(struct syscall_cfg *cfg, int sve_vl, int sme_vl, 153 uint64_t svcr) 154 { 155 size_t reg_size = sve_vl; 156 int errors = 0; 157 int i; 158 159 if (!sve_vl) 160 return 0; 161 162 for (i = 0; i < SVE_NUM_ZREGS; i++) { 163 uint8_t *in = &z_in[reg_size * i]; 164 uint8_t *out = &z_out[reg_size * i]; 165 166 if (svcr & SVCR_SM_MASK) { 167 /* 168 * In streaming mode the whole register should 169 * be cleared by the transition out of 170 * streaming mode. 171 */ 172 if (memcmp(z_zero, out, reg_size) != 0) { 173 ksft_print_msg("%s SVE VL %d Z%d non-zero\n", 174 cfg->name, sve_vl, i); 175 errors++; 176 } 177 } else { 178 /* 179 * For standard SVE the low 128 bits should be 180 * preserved and any additional bits cleared. 181 */ 182 if (memcmp(in, out, SVE_Z_SHARED_BYTES) != 0) { 183 ksft_print_msg("%s SVE VL %d Z%d low 128 bits changed\n", 184 cfg->name, sve_vl, i); 185 errors++; 186 } 187 188 if (reg_size > SVE_Z_SHARED_BYTES && 189 (memcmp(z_zero, out + SVE_Z_SHARED_BYTES, 190 reg_size - SVE_Z_SHARED_BYTES) != 0)) { 191 ksft_print_msg("%s SVE VL %d Z%d high bits non-zero\n", 192 cfg->name, sve_vl, i); 193 errors++; 194 } 195 } 196 } 197 198 return errors; 199 } 200 201 uint8_t p_in[SVE_NUM_PREGS * __SVE_PREG_SIZE(ARCH_SVE_VQ_MAX)]; 202 uint8_t p_out[SVE_NUM_PREGS * __SVE_PREG_SIZE(ARCH_SVE_VQ_MAX)]; 203 204 static void setup_p(struct syscall_cfg *cfg, int sve_vl, int sme_vl, 205 uint64_t svcr) 206 { 207 fill_random(p_in, sizeof(p_in)); 208 fill_random(p_out, sizeof(p_out)); 209 } 210 211 static int check_p(struct syscall_cfg *cfg, int sve_vl, int sme_vl, 212 uint64_t svcr) 213 { 214 size_t reg_size = sve_vq_from_vl(sve_vl) * 2; /* 1 bit per VL byte */ 215 216 int errors = 0; 217 int i; 218 219 if (!sve_vl) 220 return 0; 221 222 /* After a syscall the P registers should be zeroed */ 223 for (i = 0; i < SVE_NUM_PREGS * reg_size; i++) 224 if (p_out[i]) 225 errors++; 226 if (errors) 227 ksft_print_msg("%s SVE VL %d predicate registers non-zero\n", 228 cfg->name, sve_vl); 229 230 return errors; 231 } 232 233 uint8_t ffr_in[__SVE_PREG_SIZE(ARCH_SVE_VQ_MAX)]; 234 uint8_t ffr_out[__SVE_PREG_SIZE(ARCH_SVE_VQ_MAX)]; 235 236 static void setup_ffr(struct syscall_cfg *cfg, int sve_vl, int sme_vl, 237 uint64_t svcr) 238 { 239 /* 240 * If we are in streaming mode and do not have FA64 then FFR 241 * is unavailable. 242 */ 243 if ((svcr & SVCR_SM_MASK) && 244 !(getauxval(AT_HWCAP2) & HWCAP2_SME_FA64)) { 245 memset(&ffr_in, 0, sizeof(ffr_in)); 246 return; 247 } 248 249 /* 250 * It is only valid to set a contiguous set of bits starting 251 * at 0. For now since we're expecting this to be cleared by 252 * a syscall just set all bits. 253 */ 254 memset(ffr_in, 0xff, sizeof(ffr_in)); 255 fill_random(ffr_out, sizeof(ffr_out)); 256 } 257 258 static int check_ffr(struct syscall_cfg *cfg, int sve_vl, int sme_vl, 259 uint64_t svcr) 260 { 261 size_t reg_size = sve_vq_from_vl(sve_vl) * 2; /* 1 bit per VL byte */ 262 int errors = 0; 263 int i; 264 265 if (!sve_vl) 266 return 0; 267 268 if ((svcr & SVCR_SM_MASK) && 269 !(getauxval(AT_HWCAP2) & HWCAP2_SME_FA64)) 270 return 0; 271 272 /* After a syscall FFR should be zeroed */ 273 for (i = 0; i < reg_size; i++) 274 if (ffr_out[i]) 275 errors++; 276 if (errors) 277 ksft_print_msg("%s SVE VL %d FFR non-zero\n", 278 cfg->name, sve_vl); 279 280 return errors; 281 } 282 283 uint64_t svcr_in, svcr_out; 284 285 static void setup_svcr(struct syscall_cfg *cfg, int sve_vl, int sme_vl, 286 uint64_t svcr) 287 { 288 svcr_in = svcr; 289 } 290 291 static int check_svcr(struct syscall_cfg *cfg, int sve_vl, int sme_vl, 292 uint64_t svcr) 293 { 294 int errors = 0; 295 296 if (svcr_out & SVCR_SM_MASK) { 297 ksft_print_msg("%s Still in SM, SVCR %llx\n", 298 cfg->name, svcr_out); 299 errors++; 300 } 301 302 if ((svcr_in & SVCR_ZA_MASK) != (svcr_out & SVCR_ZA_MASK)) { 303 ksft_print_msg("%s PSTATE.ZA changed, SVCR %llx != %llx\n", 304 cfg->name, svcr_in, svcr_out); 305 errors++; 306 } 307 308 return errors; 309 } 310 311 uint8_t za_in[ZA_SIG_REGS_SIZE(ARCH_SVE_VQ_MAX)]; 312 uint8_t za_out[ZA_SIG_REGS_SIZE(ARCH_SVE_VQ_MAX)]; 313 314 static void setup_za(struct syscall_cfg *cfg, int sve_vl, int sme_vl, 315 uint64_t svcr) 316 { 317 fill_random(za_in, sizeof(za_in)); 318 memset(za_out, 0, sizeof(za_out)); 319 } 320 321 static int check_za(struct syscall_cfg *cfg, int sve_vl, int sme_vl, 322 uint64_t svcr) 323 { 324 size_t reg_size = sme_vl * sme_vl; 325 int errors = 0; 326 327 if (!(svcr & SVCR_ZA_MASK)) 328 return 0; 329 330 if (memcmp(za_in, za_out, reg_size) != 0) { 331 ksft_print_msg("SME VL %d ZA does not match\n", sme_vl); 332 errors++; 333 } 334 335 return errors; 336 } 337 338 uint8_t zt_in[ZT_SIG_REG_BYTES] __attribute__((aligned(16))); 339 uint8_t zt_out[ZT_SIG_REG_BYTES] __attribute__((aligned(16))); 340 341 static void setup_zt(struct syscall_cfg *cfg, int sve_vl, int sme_vl, 342 uint64_t svcr) 343 { 344 fill_random(zt_in, sizeof(zt_in)); 345 memset(zt_out, 0, sizeof(zt_out)); 346 } 347 348 static int check_zt(struct syscall_cfg *cfg, int sve_vl, int sme_vl, 349 uint64_t svcr) 350 { 351 int errors = 0; 352 353 if (!(getauxval(AT_HWCAP2) & HWCAP2_SME2)) 354 return 0; 355 356 if (!(svcr & SVCR_ZA_MASK)) 357 return 0; 358 359 if (memcmp(zt_in, zt_out, sizeof(zt_in)) != 0) { 360 ksft_print_msg("SME VL %d ZT does not match\n", sme_vl); 361 errors++; 362 } 363 364 return errors; 365 } 366 367 typedef void (*setup_fn)(struct syscall_cfg *cfg, int sve_vl, int sme_vl, 368 uint64_t svcr); 369 typedef int (*check_fn)(struct syscall_cfg *cfg, int sve_vl, int sme_vl, 370 uint64_t svcr); 371 372 /* 373 * Each set of registers has a setup function which is called before 374 * the syscall to fill values in a global variable for loading by the 375 * test code and a check function which validates that the results are 376 * as expected. Vector lengths are passed everywhere, a vector length 377 * of 0 should be treated as do not test. 378 */ 379 static struct { 380 setup_fn setup; 381 check_fn check; 382 } regset[] = { 383 { setup_gpr, check_gpr }, 384 { setup_fpr, check_fpr }, 385 { setup_z, check_z }, 386 { setup_p, check_p }, 387 { setup_ffr, check_ffr }, 388 { setup_svcr, check_svcr }, 389 { setup_za, check_za }, 390 { setup_zt, check_zt }, 391 }; 392 393 static bool do_test(struct syscall_cfg *cfg, int sve_vl, int sme_vl, 394 uint64_t svcr) 395 { 396 int errors = 0; 397 int i; 398 399 for (i = 0; i < ARRAY_SIZE(regset); i++) 400 regset[i].setup(cfg, sve_vl, sme_vl, svcr); 401 402 do_syscall(sve_vl, sme_vl); 403 404 for (i = 0; i < ARRAY_SIZE(regset); i++) 405 errors += regset[i].check(cfg, sve_vl, sme_vl, svcr); 406 407 return errors == 0; 408 } 409 410 static void test_one_syscall(struct syscall_cfg *cfg) 411 { 412 int sve, sme; 413 int ret; 414 415 /* FPSIMD only case */ 416 ksft_test_result(do_test(cfg, 0, default_sme_vl, 0), 417 "%s FPSIMD\n", cfg->name); 418 419 for (sve = 0; sve < sve_vl_count; sve++) { 420 ret = prctl(PR_SVE_SET_VL, sve_vls[sve]); 421 if (ret == -1) 422 ksft_exit_fail_msg("PR_SVE_SET_VL failed: %s (%d)\n", 423 strerror(errno), errno); 424 425 ksft_test_result(do_test(cfg, sve_vls[sve], default_sme_vl, 0), 426 "%s SVE VL %d\n", cfg->name, sve_vls[sve]); 427 428 for (sme = 0; sme < sme_vl_count; sme++) { 429 ret = prctl(PR_SME_SET_VL, sme_vls[sme]); 430 if (ret == -1) 431 ksft_exit_fail_msg("PR_SME_SET_VL failed: %s (%d)\n", 432 strerror(errno), errno); 433 434 ksft_test_result(do_test(cfg, sve_vls[sve], 435 sme_vls[sme], 436 SVCR_ZA_MASK | SVCR_SM_MASK), 437 "%s SVE VL %d/SME VL %d SM+ZA\n", 438 cfg->name, sve_vls[sve], 439 sme_vls[sme]); 440 ksft_test_result(do_test(cfg, sve_vls[sve], 441 sme_vls[sme], SVCR_SM_MASK), 442 "%s SVE VL %d/SME VL %d SM\n", 443 cfg->name, sve_vls[sve], 444 sme_vls[sme]); 445 ksft_test_result(do_test(cfg, sve_vls[sve], 446 sme_vls[sme], SVCR_ZA_MASK), 447 "%s SVE VL %d/SME VL %d ZA\n", 448 cfg->name, sve_vls[sve], 449 sme_vls[sme]); 450 } 451 } 452 453 for (sme = 0; sme < sme_vl_count; sme++) { 454 ret = prctl(PR_SME_SET_VL, sme_vls[sme]); 455 if (ret == -1) 456 ksft_exit_fail_msg("PR_SME_SET_VL failed: %s (%d)\n", 457 strerror(errno), errno); 458 459 ksft_test_result(do_test(cfg, 0, sme_vls[sme], 460 SVCR_ZA_MASK | SVCR_SM_MASK), 461 "%s SME VL %d SM+ZA\n", 462 cfg->name, sme_vls[sme]); 463 ksft_test_result(do_test(cfg, 0, sme_vls[sme], SVCR_SM_MASK), 464 "%s SME VL %d SM\n", 465 cfg->name, sme_vls[sme]); 466 ksft_test_result(do_test(cfg, 0, sme_vls[sme], SVCR_ZA_MASK), 467 "%s SME VL %d ZA\n", 468 cfg->name, sme_vls[sme]); 469 } 470 } 471 472 void sve_count_vls(void) 473 { 474 unsigned int vq; 475 int vl; 476 477 if (!(getauxval(AT_HWCAP) & HWCAP_SVE)) 478 return; 479 480 /* 481 * Enumerate up to ARCH_SVE_VQ_MAX vector lengths 482 */ 483 for (vq = ARCH_SVE_VQ_MAX; vq > 0; vq /= 2) { 484 vl = prctl(PR_SVE_SET_VL, vq * 16); 485 if (vl == -1) 486 ksft_exit_fail_msg("PR_SVE_SET_VL failed: %s (%d)\n", 487 strerror(errno), errno); 488 489 vl &= PR_SVE_VL_LEN_MASK; 490 491 if (vq != sve_vq_from_vl(vl)) 492 vq = sve_vq_from_vl(vl); 493 494 sve_vls[sve_vl_count++] = vl; 495 } 496 } 497 498 void sme_count_vls(void) 499 { 500 unsigned int vq; 501 int vl; 502 503 if (!(getauxval(AT_HWCAP2) & HWCAP2_SME)) 504 return; 505 506 /* 507 * Enumerate up to ARCH_SVE_VQ_MAX vector lengths 508 */ 509 for (vq = ARCH_SVE_VQ_MAX; vq > 0; vq /= 2) { 510 vl = prctl(PR_SME_SET_VL, vq * 16); 511 if (vl == -1) 512 ksft_exit_fail_msg("PR_SME_SET_VL failed: %s (%d)\n", 513 strerror(errno), errno); 514 515 vl &= PR_SME_VL_LEN_MASK; 516 517 /* Found lowest VL */ 518 if (sve_vq_from_vl(vl) > vq) 519 break; 520 521 if (vq != sve_vq_from_vl(vl)) 522 vq = sve_vq_from_vl(vl); 523 524 sme_vls[sme_vl_count++] = vl; 525 } 526 527 /* Ensure we configure a SME VL, used to flag if SVCR is set */ 528 default_sme_vl = sme_vls[0]; 529 } 530 531 int main(void) 532 { 533 int i; 534 int tests = 1; /* FPSIMD */ 535 int sme_ver; 536 537 srandom(getpid()); 538 539 ksft_print_header(); 540 541 sve_count_vls(); 542 sme_count_vls(); 543 544 tests += sve_vl_count; 545 tests += sme_vl_count * 3; 546 tests += (sve_vl_count * sme_vl_count) * 3; 547 ksft_set_plan(ARRAY_SIZE(syscalls) * tests); 548 549 if (getauxval(AT_HWCAP2) & HWCAP2_SME2) 550 sme_ver = 2; 551 else 552 sme_ver = 1; 553 554 if (getauxval(AT_HWCAP2) & HWCAP2_SME_FA64) 555 ksft_print_msg("SME%d with FA64\n", sme_ver); 556 else if (getauxval(AT_HWCAP2) & HWCAP2_SME) 557 ksft_print_msg("SME%d without FA64\n", sme_ver); 558 559 for (i = 0; i < ARRAY_SIZE(syscalls); i++) 560 test_one_syscall(&syscalls[i]); 561 562 ksft_print_cnts(); 563 564 return 0; 565 } 566