1 // SPDX-License-Identifier: GPL-2.0 2 3 #include <stdio.h> 4 #include <string.h> 5 #include <sys/poll.h> 6 #include <sys/socket.h> 7 #include <sys/time.h> 8 #include <netinet/in.h> 9 #include <unistd.h> 10 11 #include <ynl.h> 12 13 #include "psp-user.h" 14 15 #define dbg(msg...) \ 16 do { \ 17 if (opts->verbose) \ 18 fprintf(stderr, "DEBUG: " msg); \ 19 } while (0) 20 21 static bool should_quit; 22 23 struct opts { 24 int port; 25 int devid; 26 bool verbose; 27 }; 28 29 enum accept_cfg { 30 ACCEPT_CFG_NONE = 0, 31 ACCEPT_CFG_CLEAR, 32 ACCEPT_CFG_PSP, 33 }; 34 35 static struct { 36 unsigned char tx; 37 unsigned char rx; 38 } psp_vers; 39 40 static int conn_setup_psp(struct ynl_sock *ys, struct opts *opts, int data_sock) 41 { 42 struct psp_rx_assoc_rsp *rsp; 43 struct psp_rx_assoc_req *req; 44 struct psp_tx_assoc_rsp *tsp; 45 struct psp_tx_assoc_req *teq; 46 char info[300]; 47 int key_len; 48 ssize_t sz; 49 __u32 spi; 50 51 dbg("create PSP connection\n"); 52 53 // Rx assoc alloc 54 req = psp_rx_assoc_req_alloc(); 55 56 psp_rx_assoc_req_set_sock_fd(req, data_sock); 57 psp_rx_assoc_req_set_version(req, psp_vers.rx); 58 59 rsp = psp_rx_assoc(ys, req); 60 psp_rx_assoc_req_free(req); 61 62 if (!rsp) { 63 perror("ERROR: failed to Rx assoc"); 64 return -1; 65 } 66 67 // SPI exchange 68 key_len = rsp->rx_key._len.key; 69 memcpy(info, &rsp->rx_key.spi, sizeof(spi)); 70 memcpy(&info[sizeof(spi)], rsp->rx_key.key, key_len); 71 sz = sizeof(spi) + key_len; 72 73 send(data_sock, info, sz, MSG_WAITALL); 74 psp_rx_assoc_rsp_free(rsp); 75 76 sz = recv(data_sock, info, sz, MSG_WAITALL); 77 if (sz < 0) { 78 perror("ERROR: failed to read PSP key from sock"); 79 return -1; 80 } 81 memcpy(&spi, info, sizeof(spi)); 82 83 // Setup Tx assoc 84 teq = psp_tx_assoc_req_alloc(); 85 86 psp_tx_assoc_req_set_sock_fd(teq, data_sock); 87 psp_tx_assoc_req_set_version(teq, psp_vers.tx); 88 psp_tx_assoc_req_set_tx_key_spi(teq, spi); 89 psp_tx_assoc_req_set_tx_key_key(teq, &info[sizeof(spi)], key_len); 90 91 tsp = psp_tx_assoc(ys, teq); 92 psp_tx_assoc_req_free(teq); 93 if (!tsp) { 94 perror("ERROR: failed to Tx assoc"); 95 return -1; 96 } 97 psp_tx_assoc_rsp_free(tsp); 98 99 return 0; 100 } 101 102 static void send_ack(int sock) 103 { 104 send(sock, "ack", 4, MSG_WAITALL); 105 } 106 107 static void send_err(int sock) 108 { 109 send(sock, "err", 4, MSG_WAITALL); 110 } 111 112 static void send_str(int sock, int value) 113 { 114 char buf[128]; 115 int ret; 116 117 ret = snprintf(buf, sizeof(buf), "%d", value); 118 send(sock, buf, ret + 1, MSG_WAITALL); 119 } 120 121 static void 122 run_session(struct ynl_sock *ys, struct opts *opts, 123 int server_sock, int comm_sock) 124 { 125 enum accept_cfg accept_cfg = ACCEPT_CFG_NONE; 126 struct pollfd pfds[3]; 127 size_t data_read = 0; 128 int data_sock = -1; 129 130 while (true) { 131 bool race_close = false; 132 int nfds; 133 134 memset(pfds, 0, sizeof(pfds)); 135 136 pfds[0].fd = server_sock; 137 pfds[0].events = POLLIN; 138 139 pfds[1].fd = comm_sock; 140 pfds[1].events = POLLIN; 141 142 nfds = 2; 143 if (data_sock >= 0) { 144 pfds[2].fd = data_sock; 145 pfds[2].events = POLLIN; 146 nfds++; 147 } 148 149 dbg(" ...\n"); 150 if (poll(pfds, nfds, -1) < 0) { 151 perror("poll"); 152 break; 153 } 154 155 /* data sock */ 156 if (pfds[2].revents & POLLIN) { 157 char buf[8192]; 158 ssize_t n; 159 160 n = recv(data_sock, buf, sizeof(buf), 0); 161 if (n <= 0) { 162 if (n < 0) 163 perror("data read"); 164 close(data_sock); 165 data_sock = -1; 166 dbg("data sock closed\n"); 167 } else { 168 data_read += n; 169 dbg("data read %zd\n", data_read); 170 } 171 } 172 173 /* comm sock */ 174 if (pfds[1].revents & POLLIN) { 175 static char buf[4096]; 176 static ssize_t off; 177 bool consumed; 178 ssize_t n; 179 180 n = recv(comm_sock, &buf[off], sizeof(buf) - off, 0); 181 if (n <= 0) { 182 if (n < 0) 183 perror("comm read"); 184 return; 185 } 186 187 off += n; 188 n = off; 189 190 #define __consume(sz) \ 191 ({ \ 192 if (n == (sz)) { \ 193 off = 0; \ 194 } else { \ 195 off -= (sz); \ 196 memmove(buf, &buf[(sz)], off); \ 197 } \ 198 }) 199 200 #define cmd(_name) \ 201 ({ \ 202 ssize_t sz = sizeof(_name); \ 203 bool match = n >= sz && !memcmp(buf, _name, sz); \ 204 \ 205 if (match) { \ 206 dbg("command: " _name "\n"); \ 207 __consume(sz); \ 208 } \ 209 consumed |= match; \ 210 match; \ 211 }) 212 213 do { 214 consumed = false; 215 216 if (cmd("read len")) 217 send_str(comm_sock, data_read); 218 219 if (cmd("data echo")) { 220 if (data_sock >= 0) 221 send(data_sock, "echo", 5, 222 MSG_WAITALL); 223 else 224 fprintf(stderr, "WARN: echo but no data sock\n"); 225 send_ack(comm_sock); 226 } 227 if (cmd("data close")) { 228 if (data_sock >= 0) { 229 close(data_sock); 230 data_sock = -1; 231 send_ack(comm_sock); 232 } else { 233 race_close = true; 234 } 235 } 236 if (cmd("conn psp")) { 237 if (accept_cfg != ACCEPT_CFG_NONE) 238 fprintf(stderr, "WARN: old conn config still set!\n"); 239 accept_cfg = ACCEPT_CFG_PSP; 240 send_ack(comm_sock); 241 /* next two bytes are versions */ 242 if (off >= 2) { 243 memcpy(&psp_vers, buf, 2); 244 __consume(2); 245 } else { 246 fprintf(stderr, "WARN: short conn psp command!\n"); 247 } 248 } 249 if (cmd("conn clr")) { 250 if (accept_cfg != ACCEPT_CFG_NONE) 251 fprintf(stderr, "WARN: old conn config still set!\n"); 252 accept_cfg = ACCEPT_CFG_CLEAR; 253 send_ack(comm_sock); 254 } 255 if (cmd("exit")) 256 should_quit = true; 257 #undef cmd 258 259 if (!consumed) { 260 fprintf(stderr, "WARN: unknown cmd: [%zd] %s\n", 261 off, buf); 262 } 263 } while (consumed && off); 264 } 265 266 /* server sock */ 267 if (pfds[0].revents & POLLIN) { 268 if (data_sock >= 0) { 269 fprintf(stderr, "WARN: new data sock but old one still here\n"); 270 close(data_sock); 271 data_sock = -1; 272 } 273 data_sock = accept(server_sock, NULL, NULL); 274 if (data_sock < 0) { 275 perror("accept"); 276 continue; 277 } 278 data_read = 0; 279 280 if (accept_cfg == ACCEPT_CFG_CLEAR) { 281 dbg("new data sock: clear\n"); 282 /* nothing to do */ 283 } else if (accept_cfg == ACCEPT_CFG_PSP) { 284 dbg("new data sock: psp\n"); 285 conn_setup_psp(ys, opts, data_sock); 286 } else { 287 fprintf(stderr, "WARN: new data sock but no config\n"); 288 } 289 accept_cfg = ACCEPT_CFG_NONE; 290 } 291 292 if (race_close) { 293 if (data_sock >= 0) { 294 /* indeed, ordering problem, handle the close */ 295 close(data_sock); 296 data_sock = -1; 297 send_ack(comm_sock); 298 } else { 299 fprintf(stderr, "WARN: close but no data sock\n"); 300 send_err(comm_sock); 301 } 302 } 303 } 304 dbg("session ending\n"); 305 } 306 307 static int spawn_server(struct opts *opts) 308 { 309 struct sockaddr_in6 addr; 310 int fd; 311 312 fd = socket(AF_INET6, SOCK_STREAM, 0); 313 if (fd < 0) { 314 perror("can't open socket"); 315 return -1; 316 } 317 318 memset(&addr, 0, sizeof(addr)); 319 320 addr.sin6_family = AF_INET6; 321 addr.sin6_addr = in6addr_any; 322 addr.sin6_port = htons(opts->port); 323 324 if (bind(fd, (struct sockaddr *)&addr, sizeof(addr))) { 325 perror("can't bind socket"); 326 return -1; 327 } 328 329 if (listen(fd, 5)) { 330 perror("can't listen"); 331 return -1; 332 } 333 334 return fd; 335 } 336 337 static int run_responder(struct ynl_sock *ys, struct opts *opts) 338 { 339 int server_sock, comm; 340 341 server_sock = spawn_server(opts); 342 if (server_sock < 0) 343 return 4; 344 345 while (!should_quit) { 346 comm = accept(server_sock, NULL, NULL); 347 if (comm < 0) { 348 perror("accept failed"); 349 } else { 350 run_session(ys, opts, server_sock, comm); 351 close(comm); 352 } 353 } 354 355 return 0; 356 } 357 358 static void usage(const char *name, const char *miss) 359 { 360 if (miss) 361 fprintf(stderr, "Missing argument: %s\n", miss); 362 363 fprintf(stderr, "Usage: %s -p port [-v] [-d psp-dev-id]\n", name); 364 exit(EXIT_FAILURE); 365 } 366 367 static void parse_cmd_opts(int argc, char **argv, struct opts *opts) 368 { 369 int opt; 370 371 while ((opt = getopt(argc, argv, "vp:d:")) != -1) { 372 switch (opt) { 373 case 'v': 374 opts->verbose = 1; 375 break; 376 case 'p': 377 opts->port = atoi(optarg); 378 break; 379 case 'd': 380 opts->devid = atoi(optarg); 381 break; 382 default: 383 usage(argv[0], NULL); 384 } 385 } 386 } 387 388 static int psp_dev_set_ena(struct ynl_sock *ys, __u32 dev_id, __u32 versions) 389 { 390 struct psp_dev_set_req *sreq; 391 struct psp_dev_set_rsp *srsp; 392 393 fprintf(stderr, "Set PSP enable on device %d to 0x%x\n", 394 dev_id, versions); 395 396 sreq = psp_dev_set_req_alloc(); 397 398 psp_dev_set_req_set_id(sreq, dev_id); 399 psp_dev_set_req_set_psp_versions_ena(sreq, versions); 400 401 srsp = psp_dev_set(ys, sreq); 402 psp_dev_set_req_free(sreq); 403 if (!srsp) 404 return 10; 405 406 psp_dev_set_rsp_free(srsp); 407 return 0; 408 } 409 410 int main(int argc, char **argv) 411 { 412 struct psp_dev_get_list *dev_list; 413 bool devid_found = false; 414 __u32 ver_ena, ver_cap; 415 struct opts opts = {}; 416 struct ynl_error yerr; 417 struct ynl_sock *ys; 418 int first_id = 0; 419 int ret; 420 421 parse_cmd_opts(argc, argv, &opts); 422 if (!opts.port) 423 usage(argv[0], "port"); // exits 424 425 ys = ynl_sock_create(&ynl_psp_family, &yerr); 426 if (!ys) { 427 fprintf(stderr, "YNL: %s\n", yerr.msg); 428 return 1; 429 } 430 431 dev_list = psp_dev_get_dump(ys); 432 if (ynl_dump_empty(dev_list)) { 433 if (ys->err.code) 434 goto err_close; 435 fprintf(stderr, "No PSP devices\n"); 436 goto err_close_silent; 437 } 438 439 ynl_dump_foreach(dev_list, d) { 440 if (opts.devid) { 441 devid_found = true; 442 ver_ena = d->psp_versions_ena; 443 ver_cap = d->psp_versions_cap; 444 } else if (!first_id) { 445 first_id = d->id; 446 ver_ena = d->psp_versions_ena; 447 ver_cap = d->psp_versions_cap; 448 } else { 449 fprintf(stderr, "Multiple PSP devices found\n"); 450 goto err_close_silent; 451 } 452 } 453 psp_dev_get_list_free(dev_list); 454 455 if (opts.devid && !devid_found) { 456 fprintf(stderr, "PSP device %d requested on cmdline, not found\n", 457 opts.devid); 458 goto err_close_silent; 459 } else if (!opts.devid) { 460 opts.devid = first_id; 461 } 462 463 if (ver_ena != ver_cap) { 464 ret = psp_dev_set_ena(ys, opts.devid, ver_cap); 465 if (ret) 466 goto err_close; 467 } 468 469 ret = run_responder(ys, &opts); 470 471 if (ver_ena != ver_cap && psp_dev_set_ena(ys, opts.devid, ver_ena)) 472 fprintf(stderr, "WARN: failed to set the PSP versions back\n"); 473 474 ynl_sock_destroy(ys); 475 476 return ret; 477 478 err_close: 479 fprintf(stderr, "YNL: %s\n", ys->err.msg); 480 err_close_silent: 481 ynl_sock_destroy(ys); 482 return 2; 483 } 484