xref: /linux/tools/testing/selftests/drivers/net/psp_responder.c (revision 07fdad3a93756b872da7b53647715c48d0f4a2d0)
1 // SPDX-License-Identifier: GPL-2.0
2 
3 #include <stdio.h>
4 #include <string.h>
5 #include <sys/poll.h>
6 #include <sys/socket.h>
7 #include <sys/time.h>
8 #include <netinet/in.h>
9 #include <unistd.h>
10 
11 #include <ynl.h>
12 
13 #include "psp-user.h"
14 
15 #define dbg(msg...)				\
16 do {						\
17 	if (opts->verbose)			\
18 		fprintf(stderr, "DEBUG: " msg);	\
19 } while (0)
20 
21 static bool should_quit;
22 
23 struct opts {
24 	int port;
25 	int devid;
26 	bool verbose;
27 };
28 
29 enum accept_cfg {
30 	ACCEPT_CFG_NONE = 0,
31 	ACCEPT_CFG_CLEAR,
32 	ACCEPT_CFG_PSP,
33 };
34 
35 static struct {
36 	unsigned char tx;
37 	unsigned char rx;
38 } psp_vers;
39 
40 static int conn_setup_psp(struct ynl_sock *ys, struct opts *opts, int data_sock)
41 {
42 	struct psp_rx_assoc_rsp *rsp;
43 	struct psp_rx_assoc_req *req;
44 	struct psp_tx_assoc_rsp *tsp;
45 	struct psp_tx_assoc_req *teq;
46 	char info[300];
47 	int key_len;
48 	ssize_t sz;
49 	__u32 spi;
50 
51 	dbg("create PSP connection\n");
52 
53 	// Rx assoc alloc
54 	req = psp_rx_assoc_req_alloc();
55 
56 	psp_rx_assoc_req_set_sock_fd(req, data_sock);
57 	psp_rx_assoc_req_set_version(req, psp_vers.rx);
58 
59 	rsp = psp_rx_assoc(ys, req);
60 	psp_rx_assoc_req_free(req);
61 
62 	if (!rsp) {
63 		perror("ERROR: failed to Rx assoc");
64 		return -1;
65 	}
66 
67 	// SPI exchange
68 	key_len = rsp->rx_key._len.key;
69 	memcpy(info, &rsp->rx_key.spi, sizeof(spi));
70 	memcpy(&info[sizeof(spi)], rsp->rx_key.key, key_len);
71 	sz = sizeof(spi) + key_len;
72 
73 	send(data_sock, info, sz, MSG_WAITALL);
74 	psp_rx_assoc_rsp_free(rsp);
75 
76 	sz = recv(data_sock, info, sz, MSG_WAITALL);
77 	if (sz < 0) {
78 		perror("ERROR: failed to read PSP key from sock");
79 		return -1;
80 	}
81 	memcpy(&spi, info, sizeof(spi));
82 
83 	// Setup Tx assoc
84 	teq = psp_tx_assoc_req_alloc();
85 
86 	psp_tx_assoc_req_set_sock_fd(teq, data_sock);
87 	psp_tx_assoc_req_set_version(teq, psp_vers.tx);
88 	psp_tx_assoc_req_set_tx_key_spi(teq, spi);
89 	psp_tx_assoc_req_set_tx_key_key(teq, &info[sizeof(spi)], key_len);
90 
91 	tsp = psp_tx_assoc(ys, teq);
92 	psp_tx_assoc_req_free(teq);
93 	if (!tsp) {
94 		perror("ERROR: failed to Tx assoc");
95 		return -1;
96 	}
97 	psp_tx_assoc_rsp_free(tsp);
98 
99 	return 0;
100 }
101 
102 static void send_ack(int sock)
103 {
104 	send(sock, "ack", 4, MSG_WAITALL);
105 }
106 
107 static void send_err(int sock)
108 {
109 	send(sock, "err", 4, MSG_WAITALL);
110 }
111 
112 static void send_str(int sock, int value)
113 {
114 	char buf[128];
115 	int ret;
116 
117 	ret = snprintf(buf, sizeof(buf), "%d", value);
118 	send(sock, buf, ret + 1, MSG_WAITALL);
119 }
120 
121 static void
122 run_session(struct ynl_sock *ys, struct opts *opts,
123 	    int server_sock, int comm_sock)
124 {
125 	enum accept_cfg accept_cfg = ACCEPT_CFG_NONE;
126 	struct pollfd pfds[3];
127 	size_t data_read = 0;
128 	int data_sock = -1;
129 
130 	while (true) {
131 		bool race_close = false;
132 		int nfds;
133 
134 		memset(pfds, 0, sizeof(pfds));
135 
136 		pfds[0].fd = server_sock;
137 		pfds[0].events = POLLIN;
138 
139 		pfds[1].fd = comm_sock;
140 		pfds[1].events = POLLIN;
141 
142 		nfds = 2;
143 		if (data_sock >= 0) {
144 			pfds[2].fd = data_sock;
145 			pfds[2].events = POLLIN;
146 			nfds++;
147 		}
148 
149 		dbg(" ...\n");
150 		if (poll(pfds, nfds, -1) < 0) {
151 			perror("poll");
152 			break;
153 		}
154 
155 		/* data sock */
156 		if (pfds[2].revents & POLLIN) {
157 			char buf[8192];
158 			ssize_t n;
159 
160 			n = recv(data_sock, buf, sizeof(buf), 0);
161 			if (n <= 0) {
162 				if (n < 0)
163 					perror("data read");
164 				close(data_sock);
165 				data_sock = -1;
166 				dbg("data sock closed\n");
167 			} else {
168 				data_read += n;
169 				dbg("data read %zd\n", data_read);
170 			}
171 		}
172 
173 		/* comm sock */
174 		if (pfds[1].revents & POLLIN) {
175 			static char buf[4096];
176 			static ssize_t off;
177 			bool consumed;
178 			ssize_t n;
179 
180 			n = recv(comm_sock, &buf[off], sizeof(buf) - off, 0);
181 			if (n <= 0) {
182 				if (n < 0)
183 					perror("comm read");
184 				return;
185 			}
186 
187 			off += n;
188 			n = off;
189 
190 #define __consume(sz)						\
191 		({						\
192 			if (n == (sz)) {			\
193 				off = 0;			\
194 			} else {				\
195 				off -= (sz);			\
196 				memmove(buf, &buf[(sz)], off);	\
197 			}					\
198 		})
199 
200 #define cmd(_name)							\
201 		({							\
202 			ssize_t sz = sizeof(_name);			\
203 			bool match = n >= sz &&	!memcmp(buf, _name, sz); \
204 									\
205 			if (match) {					\
206 				dbg("command: " _name "\n");		\
207 				__consume(sz);				\
208 			}						\
209 			consumed |= match;				\
210 			match;						\
211 		})
212 
213 			do {
214 				consumed = false;
215 
216 				if (cmd("read len"))
217 					send_str(comm_sock, data_read);
218 
219 				if (cmd("data echo")) {
220 					if (data_sock >= 0)
221 						send(data_sock, "echo", 5,
222 						     MSG_WAITALL);
223 					else
224 						fprintf(stderr, "WARN: echo but no data sock\n");
225 					send_ack(comm_sock);
226 				}
227 				if (cmd("data close")) {
228 					if (data_sock >= 0) {
229 						close(data_sock);
230 						data_sock = -1;
231 						send_ack(comm_sock);
232 					} else {
233 						race_close = true;
234 					}
235 				}
236 				if (cmd("conn psp")) {
237 					if (accept_cfg != ACCEPT_CFG_NONE)
238 						fprintf(stderr, "WARN: old conn config still set!\n");
239 					accept_cfg = ACCEPT_CFG_PSP;
240 					send_ack(comm_sock);
241 					/* next two bytes are versions */
242 					if (off >= 2) {
243 						memcpy(&psp_vers, buf, 2);
244 						__consume(2);
245 					} else {
246 						fprintf(stderr, "WARN: short conn psp command!\n");
247 					}
248 				}
249 				if (cmd("conn clr")) {
250 					if (accept_cfg != ACCEPT_CFG_NONE)
251 						fprintf(stderr, "WARN: old conn config still set!\n");
252 					accept_cfg = ACCEPT_CFG_CLEAR;
253 					send_ack(comm_sock);
254 				}
255 				if (cmd("exit"))
256 					should_quit = true;
257 #undef cmd
258 
259 				if (!consumed) {
260 					fprintf(stderr, "WARN: unknown cmd: [%zd] %s\n",
261 						off, buf);
262 				}
263 			} while (consumed && off);
264 		}
265 
266 		/* server sock */
267 		if (pfds[0].revents & POLLIN) {
268 			if (data_sock >= 0) {
269 				fprintf(stderr, "WARN: new data sock but old one still here\n");
270 				close(data_sock);
271 				data_sock = -1;
272 			}
273 			data_sock = accept(server_sock, NULL, NULL);
274 			if (data_sock < 0) {
275 				perror("accept");
276 				continue;
277 			}
278 			data_read = 0;
279 
280 			if (accept_cfg == ACCEPT_CFG_CLEAR) {
281 				dbg("new data sock: clear\n");
282 				/* nothing to do */
283 			} else if (accept_cfg == ACCEPT_CFG_PSP) {
284 				dbg("new data sock: psp\n");
285 				conn_setup_psp(ys, opts, data_sock);
286 			} else {
287 				fprintf(stderr, "WARN: new data sock but no config\n");
288 			}
289 			accept_cfg = ACCEPT_CFG_NONE;
290 		}
291 
292 		if (race_close) {
293 			if (data_sock >= 0) {
294 				/* indeed, ordering problem, handle the close */
295 				close(data_sock);
296 				data_sock = -1;
297 				send_ack(comm_sock);
298 			} else {
299 				fprintf(stderr, "WARN: close but no data sock\n");
300 				send_err(comm_sock);
301 			}
302 		}
303 	}
304 	dbg("session ending\n");
305 }
306 
307 static int spawn_server(struct opts *opts)
308 {
309 	struct sockaddr_in6 addr;
310 	int fd;
311 
312 	fd = socket(AF_INET6, SOCK_STREAM, 0);
313 	if (fd < 0) {
314 		perror("can't open socket");
315 		return -1;
316 	}
317 
318 	memset(&addr, 0, sizeof(addr));
319 
320 	addr.sin6_family = AF_INET6;
321 	addr.sin6_addr = in6addr_any;
322 	addr.sin6_port = htons(opts->port);
323 
324 	if (bind(fd, (struct sockaddr *)&addr, sizeof(addr))) {
325 		perror("can't bind socket");
326 		return -1;
327 	}
328 
329 	if (listen(fd, 5)) {
330 		perror("can't listen");
331 		return -1;
332 	}
333 
334 	return fd;
335 }
336 
337 static int run_responder(struct ynl_sock *ys, struct opts *opts)
338 {
339 	int server_sock, comm;
340 
341 	server_sock = spawn_server(opts);
342 	if (server_sock < 0)
343 		return 4;
344 
345 	while (!should_quit) {
346 		comm = accept(server_sock, NULL, NULL);
347 		if (comm < 0) {
348 			perror("accept failed");
349 		} else {
350 			run_session(ys, opts, server_sock, comm);
351 			close(comm);
352 		}
353 	}
354 
355 	return 0;
356 }
357 
358 static void usage(const char *name, const char *miss)
359 {
360 	if (miss)
361 		fprintf(stderr, "Missing argument: %s\n", miss);
362 
363 	fprintf(stderr, "Usage: %s -p port [-v] [-d psp-dev-id]\n", name);
364 	exit(EXIT_FAILURE);
365 }
366 
367 static void parse_cmd_opts(int argc, char **argv, struct opts *opts)
368 {
369 	int opt;
370 
371 	while ((opt = getopt(argc, argv, "vp:d:")) != -1) {
372 		switch (opt) {
373 		case 'v':
374 			opts->verbose = 1;
375 			break;
376 		case 'p':
377 			opts->port = atoi(optarg);
378 			break;
379 		case 'd':
380 			opts->devid = atoi(optarg);
381 			break;
382 		default:
383 			usage(argv[0], NULL);
384 		}
385 	}
386 }
387 
388 static int psp_dev_set_ena(struct ynl_sock *ys, __u32 dev_id, __u32 versions)
389 {
390 	struct psp_dev_set_req *sreq;
391 	struct psp_dev_set_rsp *srsp;
392 
393 	fprintf(stderr, "Set PSP enable on device %d to 0x%x\n",
394 		dev_id, versions);
395 
396 	sreq = psp_dev_set_req_alloc();
397 
398 	psp_dev_set_req_set_id(sreq, dev_id);
399 	psp_dev_set_req_set_psp_versions_ena(sreq, versions);
400 
401 	srsp = psp_dev_set(ys, sreq);
402 	psp_dev_set_req_free(sreq);
403 	if (!srsp)
404 		return 10;
405 
406 	psp_dev_set_rsp_free(srsp);
407 	return 0;
408 }
409 
410 int main(int argc, char **argv)
411 {
412 	struct psp_dev_get_list *dev_list;
413 	bool devid_found = false;
414 	__u32 ver_ena, ver_cap;
415 	struct opts opts = {};
416 	struct ynl_error yerr;
417 	struct ynl_sock *ys;
418 	int first_id = 0;
419 	int ret;
420 
421 	parse_cmd_opts(argc, argv, &opts);
422 	if (!opts.port)
423 		usage(argv[0], "port"); // exits
424 
425 	ys = ynl_sock_create(&ynl_psp_family, &yerr);
426 	if (!ys) {
427 		fprintf(stderr, "YNL: %s\n", yerr.msg);
428 		return 1;
429 	}
430 
431 	dev_list = psp_dev_get_dump(ys);
432 	if (ynl_dump_empty(dev_list)) {
433 		if (ys->err.code)
434 			goto err_close;
435 		fprintf(stderr, "No PSP devices\n");
436 		goto err_close_silent;
437 	}
438 
439 	ynl_dump_foreach(dev_list, d) {
440 		if (opts.devid) {
441 			devid_found = true;
442 			ver_ena = d->psp_versions_ena;
443 			ver_cap = d->psp_versions_cap;
444 		} else if (!first_id) {
445 			first_id = d->id;
446 			ver_ena = d->psp_versions_ena;
447 			ver_cap = d->psp_versions_cap;
448 		} else {
449 			fprintf(stderr, "Multiple PSP devices found\n");
450 			goto err_close_silent;
451 		}
452 	}
453 	psp_dev_get_list_free(dev_list);
454 
455 	if (opts.devid && !devid_found) {
456 		fprintf(stderr, "PSP device %d requested on cmdline, not found\n",
457 			opts.devid);
458 		goto err_close_silent;
459 	} else if (!opts.devid) {
460 		opts.devid = first_id;
461 	}
462 
463 	if (ver_ena != ver_cap) {
464 		ret = psp_dev_set_ena(ys, opts.devid, ver_cap);
465 		if (ret)
466 			goto err_close;
467 	}
468 
469 	ret = run_responder(ys, &opts);
470 
471 	if (ver_ena != ver_cap && psp_dev_set_ena(ys, opts.devid, ver_ena))
472 		fprintf(stderr, "WARN: failed to set the PSP versions back\n");
473 
474 	ynl_sock_destroy(ys);
475 
476 	return ret;
477 
478 err_close:
479 	fprintf(stderr, "YNL: %s\n", ys->err.msg);
480 err_close_silent:
481 	ynl_sock_destroy(ys);
482 	return 2;
483 }
484