xref: /linux/tools/testing/selftests/net/af_unix/scm_pidfd.c (revision 22c55fb9eb92395d999b8404d73e58540d11bdd8)
1 // SPDX-License-Identifier: GPL-2.0 OR MIT
2 #define _GNU_SOURCE
3 #include <error.h>
4 #include <limits.h>
5 #include <stddef.h>
6 #include <stdio.h>
7 #include <stdlib.h>
8 #include <sys/socket.h>
9 #include <linux/socket.h>
10 #include <unistd.h>
11 #include <string.h>
12 #include <errno.h>
13 #include <sys/un.h>
14 #include <sys/signal.h>
15 #include <sys/types.h>
16 #include <sys/wait.h>
17 
18 #include "../../pidfd/pidfd.h"
19 #include "../../kselftest_harness.h"
20 
21 #define clean_errno() (errno == 0 ? "None" : strerror(errno))
22 #define log_err(MSG, ...)                                                   \
23 	fprintf(stderr, "(%s:%d: errno: %s) " MSG "\n", __FILE__, __LINE__, \
24 		clean_errno(), ##__VA_ARGS__)
25 
26 #ifndef SCM_PIDFD
27 #define SCM_PIDFD 0x04
28 #endif
29 
30 #define CHILD_EXIT_CODE_OK 123
31 
32 static void child_die()
33 {
34 	exit(1);
35 }
36 
37 static int safe_int(const char *numstr, int *converted)
38 {
39 	char *err = NULL;
40 	long sli;
41 
42 	errno = 0;
43 	sli = strtol(numstr, &err, 0);
44 	if (errno == ERANGE && (sli == LONG_MAX || sli == LONG_MIN))
45 		return -ERANGE;
46 
47 	if (errno != 0 && sli == 0)
48 		return -EINVAL;
49 
50 	if (err == numstr || *err != '\0')
51 		return -EINVAL;
52 
53 	if (sli > INT_MAX || sli < INT_MIN)
54 		return -ERANGE;
55 
56 	*converted = (int)sli;
57 	return 0;
58 }
59 
60 static int char_left_gc(const char *buffer, size_t len)
61 {
62 	size_t i;
63 
64 	for (i = 0; i < len; i++) {
65 		if (buffer[i] == ' ' || buffer[i] == '\t')
66 			continue;
67 
68 		return i;
69 	}
70 
71 	return 0;
72 }
73 
74 static int char_right_gc(const char *buffer, size_t len)
75 {
76 	int i;
77 
78 	for (i = len - 1; i >= 0; i--) {
79 		if (buffer[i] == ' ' || buffer[i] == '\t' ||
80 		    buffer[i] == '\n' || buffer[i] == '\0')
81 			continue;
82 
83 		return i + 1;
84 	}
85 
86 	return 0;
87 }
88 
89 static char *trim_whitespace_in_place(char *buffer)
90 {
91 	buffer += char_left_gc(buffer, strlen(buffer));
92 	buffer[char_right_gc(buffer, strlen(buffer))] = '\0';
93 	return buffer;
94 }
95 
96 /* borrowed (with all helpers) from pidfd/pidfd_open_test.c */
97 static pid_t get_pid_from_fdinfo_file(int pidfd, const char *key, size_t keylen)
98 {
99 	int ret;
100 	char path[512];
101 	FILE *f;
102 	size_t n = 0;
103 	pid_t result = -1;
104 	char *line = NULL;
105 
106 	snprintf(path, sizeof(path), "/proc/self/fdinfo/%d", pidfd);
107 
108 	f = fopen(path, "re");
109 	if (!f)
110 		return -1;
111 
112 	while (getline(&line, &n, f) != -1) {
113 		char *numstr;
114 
115 		if (strncmp(line, key, keylen))
116 			continue;
117 
118 		numstr = trim_whitespace_in_place(line + 4);
119 		ret = safe_int(numstr, &result);
120 		if (ret < 0)
121 			goto out;
122 
123 		break;
124 	}
125 
126 out:
127 	free(line);
128 	fclose(f);
129 	return result;
130 }
131 
132 struct cmsg_data {
133 	struct ucred *ucred;
134 	int *pidfd;
135 };
136 
137 static int parse_cmsg(struct msghdr *msg, struct cmsg_data *res)
138 {
139 	struct cmsghdr *cmsg;
140 	int data = 0;
141 
142 	if (msg->msg_flags & (MSG_TRUNC | MSG_CTRUNC)) {
143 		log_err("recvmsg: truncated");
144 		return 1;
145 	}
146 
147 	for (cmsg = CMSG_FIRSTHDR(msg); cmsg != NULL;
148 	     cmsg = CMSG_NXTHDR(msg, cmsg)) {
149 		if (cmsg->cmsg_level == SOL_SOCKET &&
150 		    cmsg->cmsg_type == SCM_PIDFD) {
151 			if (cmsg->cmsg_len < sizeof(*res->pidfd)) {
152 				log_err("CMSG parse: SCM_PIDFD wrong len");
153 				return 1;
154 			}
155 
156 			res->pidfd = (void *)CMSG_DATA(cmsg);
157 		}
158 
159 		if (cmsg->cmsg_level == SOL_SOCKET &&
160 		    cmsg->cmsg_type == SCM_CREDENTIALS) {
161 			if (cmsg->cmsg_len < sizeof(*res->ucred)) {
162 				log_err("CMSG parse: SCM_CREDENTIALS wrong len");
163 				return 1;
164 			}
165 
166 			res->ucred = (void *)CMSG_DATA(cmsg);
167 		}
168 	}
169 
170 	if (!res->pidfd) {
171 		log_err("CMSG parse: SCM_PIDFD not found");
172 		return 1;
173 	}
174 
175 	if (!res->ucred) {
176 		log_err("CMSG parse: SCM_CREDENTIALS not found");
177 		return 1;
178 	}
179 
180 	return 0;
181 }
182 
183 static int cmsg_check(int fd)
184 {
185 	struct msghdr msg = { 0 };
186 	struct cmsg_data res;
187 	struct iovec iov;
188 	int data = 0;
189 	char control[CMSG_SPACE(sizeof(struct ucred)) +
190 		     CMSG_SPACE(sizeof(int))] = { 0 };
191 	pid_t parent_pid;
192 	int err;
193 
194 	iov.iov_base = &data;
195 	iov.iov_len = sizeof(data);
196 
197 	msg.msg_iov = &iov;
198 	msg.msg_iovlen = 1;
199 	msg.msg_control = control;
200 	msg.msg_controllen = sizeof(control);
201 
202 	err = recvmsg(fd, &msg, 0);
203 	if (err < 0) {
204 		log_err("recvmsg");
205 		return 1;
206 	}
207 
208 	if (msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC)) {
209 		log_err("recvmsg: truncated");
210 		return 1;
211 	}
212 
213 	/* send(pfd, "x", sizeof(char), 0) */
214 	if (data != 'x') {
215 		log_err("recvmsg: data corruption");
216 		return 1;
217 	}
218 
219 	if (parse_cmsg(&msg, &res)) {
220 		log_err("CMSG parse: parse_cmsg() failed");
221 		return 1;
222 	}
223 
224 	/* pidfd from SCM_PIDFD should point to the parent process PID */
225 	parent_pid =
226 		get_pid_from_fdinfo_file(*res.pidfd, "Pid:", sizeof("Pid:") - 1);
227 	if (parent_pid != getppid()) {
228 		log_err("wrong SCM_PIDFD %d != %d", parent_pid, getppid());
229 		close(*res.pidfd);
230 		return 1;
231 	}
232 
233 	close(*res.pidfd);
234 	return 0;
235 }
236 
237 static int cmsg_check_dead(int fd, int expected_pid)
238 {
239 	int err;
240 	struct msghdr msg = { 0 };
241 	struct cmsg_data res;
242 	struct iovec iov;
243 	int data = 0;
244 	char control[CMSG_SPACE(sizeof(struct ucred)) +
245 		     CMSG_SPACE(sizeof(int))] = { 0 };
246 	pid_t client_pid;
247 	struct pidfd_info info = {
248 		.mask = PIDFD_INFO_EXIT,
249 	};
250 
251 	iov.iov_base = &data;
252 	iov.iov_len = sizeof(data);
253 
254 	msg.msg_iov = &iov;
255 	msg.msg_iovlen = 1;
256 	msg.msg_control = control;
257 	msg.msg_controllen = sizeof(control);
258 
259 	err = recvmsg(fd, &msg, 0);
260 	if (err < 0) {
261 		log_err("recvmsg");
262 		return 1;
263 	}
264 
265 	if (msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC)) {
266 		log_err("recvmsg: truncated");
267 		return 1;
268 	}
269 
270 	/* send(cfd, "y", sizeof(char), 0) */
271 	if (data != 'y') {
272 		log_err("recvmsg: data corruption");
273 		return 1;
274 	}
275 
276 	if (parse_cmsg(&msg, &res)) {
277 		log_err("CMSG parse: parse_cmsg() failed");
278 		return 1;
279 	}
280 
281 	/*
282 	 * pidfd from SCM_PIDFD should point to the client_pid.
283 	 * Let's read exit information and check if it's what
284 	 * we expect to see.
285 	 */
286 	if (ioctl(*res.pidfd, PIDFD_GET_INFO, &info)) {
287 		log_err("%s: ioctl(PIDFD_GET_INFO) failed", __func__);
288 		close(*res.pidfd);
289 		return 1;
290 	}
291 
292 	if (!(info.mask & PIDFD_INFO_EXIT)) {
293 		log_err("%s: No exit information from ioctl(PIDFD_GET_INFO)", __func__);
294 		close(*res.pidfd);
295 		return 1;
296 	}
297 
298 	err = WIFEXITED(info.exit_code) ? WEXITSTATUS(info.exit_code) : 1;
299 	if (err != CHILD_EXIT_CODE_OK) {
300 		log_err("%s: wrong exit_code %d != %d", __func__, err, CHILD_EXIT_CODE_OK);
301 		close(*res.pidfd);
302 		return 1;
303 	}
304 
305 	close(*res.pidfd);
306 	return 0;
307 }
308 
309 struct sock_addr {
310 	char sock_name[32];
311 	struct sockaddr_un listen_addr;
312 	socklen_t addrlen;
313 };
314 
315 FIXTURE(scm_pidfd)
316 {
317 	int server;
318 	pid_t client_pid;
319 	int startup_pipe[2];
320 	struct sock_addr server_addr;
321 	struct sock_addr *client_addr;
322 };
323 
324 FIXTURE_VARIANT(scm_pidfd)
325 {
326 	int type;
327 	bool abstract;
328 };
329 
330 FIXTURE_VARIANT_ADD(scm_pidfd, stream_pathname)
331 {
332 	.type = SOCK_STREAM,
333 	.abstract = 0,
334 };
335 
336 FIXTURE_VARIANT_ADD(scm_pidfd, stream_abstract)
337 {
338 	.type = SOCK_STREAM,
339 	.abstract = 1,
340 };
341 
342 FIXTURE_VARIANT_ADD(scm_pidfd, dgram_pathname)
343 {
344 	.type = SOCK_DGRAM,
345 	.abstract = 0,
346 };
347 
348 FIXTURE_VARIANT_ADD(scm_pidfd, dgram_abstract)
349 {
350 	.type = SOCK_DGRAM,
351 	.abstract = 1,
352 };
353 
354 FIXTURE_SETUP(scm_pidfd)
355 {
356 	self->client_addr = mmap(NULL, sizeof(*self->client_addr), PROT_READ | PROT_WRITE,
357 				 MAP_SHARED | MAP_ANONYMOUS, -1, 0);
358 	ASSERT_NE(MAP_FAILED, self->client_addr);
359 }
360 
361 FIXTURE_TEARDOWN(scm_pidfd)
362 {
363 	close(self->server);
364 
365 	kill(self->client_pid, SIGKILL);
366 	waitpid(self->client_pid, NULL, 0);
367 
368 	if (!variant->abstract) {
369 		unlink(self->server_addr.sock_name);
370 		unlink(self->client_addr->sock_name);
371 	}
372 }
373 
374 static void fill_sockaddr(struct sock_addr *addr, bool abstract)
375 {
376 	char *sun_path_buf = (char *)&addr->listen_addr.sun_path;
377 
378 	addr->listen_addr.sun_family = AF_UNIX;
379 	addr->addrlen = offsetof(struct sockaddr_un, sun_path);
380 	snprintf(addr->sock_name, sizeof(addr->sock_name), "scm_pidfd_%d", getpid());
381 	addr->addrlen += strlen(addr->sock_name);
382 	if (abstract) {
383 		*sun_path_buf = '\0';
384 		addr->addrlen++;
385 		sun_path_buf++;
386 	} else {
387 		unlink(addr->sock_name);
388 	}
389 	memcpy(sun_path_buf, addr->sock_name, strlen(addr->sock_name));
390 }
391 
392 static int sk_enable_cred_pass(int sk)
393 {
394 	int on = 0;
395 
396 	on = 1;
397 	if (setsockopt(sk, SOL_SOCKET, SO_PASSCRED, &on, sizeof(on))) {
398 		log_err("Failed to set SO_PASSCRED");
399 		return 1;
400 	}
401 
402 	if (setsockopt(sk, SOL_SOCKET, SO_PASSPIDFD, &on, sizeof(on))) {
403 		log_err("Failed to set SO_PASSPIDFD");
404 		return 1;
405 	}
406 
407 	return 0;
408 }
409 
410 static void client(FIXTURE_DATA(scm_pidfd) *self,
411 		   const FIXTURE_VARIANT(scm_pidfd) *variant)
412 {
413 	int cfd;
414 	socklen_t len;
415 	struct ucred peer_cred;
416 	int peer_pidfd;
417 	pid_t peer_pid;
418 
419 	cfd = socket(AF_UNIX, variant->type, 0);
420 	if (cfd < 0) {
421 		log_err("socket");
422 		child_die();
423 	}
424 
425 	if (variant->type == SOCK_DGRAM) {
426 		fill_sockaddr(self->client_addr, variant->abstract);
427 
428 		if (bind(cfd, (struct sockaddr *)&self->client_addr->listen_addr, self->client_addr->addrlen)) {
429 			log_err("bind");
430 			child_die();
431 		}
432 	}
433 
434 	if (connect(cfd, (struct sockaddr *)&self->server_addr.listen_addr,
435 		    self->server_addr.addrlen) != 0) {
436 		log_err("connect");
437 		child_die();
438 	}
439 
440 	if (sk_enable_cred_pass(cfd)) {
441 		log_err("sk_enable_cred_pass() failed");
442 		child_die();
443 	}
444 
445 	close(self->startup_pipe[1]);
446 
447 	if (cmsg_check(cfd)) {
448 		log_err("cmsg_check failed");
449 		child_die();
450 	}
451 
452 	/* send something to the parent so it can receive SCM_PIDFD too and validate it */
453 	if (send(cfd, "y", sizeof(char), 0) == -1) {
454 		log_err("Failed to send(cfd, \"y\", sizeof(char), 0)");
455 		child_die();
456 	}
457 
458 	/* skip further for SOCK_DGRAM as it's not applicable */
459 	if (variant->type == SOCK_DGRAM)
460 		return;
461 
462 	len = sizeof(peer_cred);
463 	if (getsockopt(cfd, SOL_SOCKET, SO_PEERCRED, &peer_cred, &len)) {
464 		log_err("Failed to get SO_PEERCRED");
465 		child_die();
466 	}
467 
468 	len = sizeof(peer_pidfd);
469 	if (getsockopt(cfd, SOL_SOCKET, SO_PEERPIDFD, &peer_pidfd, &len)) {
470 		log_err("Failed to get SO_PEERPIDFD");
471 		child_die();
472 	}
473 
474 	/* pid from SO_PEERCRED should point to the parent process PID */
475 	if (peer_cred.pid != getppid()) {
476 		log_err("peer_cred.pid != getppid(): %d != %d", peer_cred.pid, getppid());
477 		child_die();
478 	}
479 
480 	peer_pid = get_pid_from_fdinfo_file(peer_pidfd,
481 					    "Pid:", sizeof("Pid:") - 1);
482 	if (peer_pid != peer_cred.pid) {
483 		log_err("peer_pid != peer_cred.pid: %d != %d", peer_pid, peer_cred.pid);
484 		child_die();
485 	}
486 }
487 
488 TEST_F(scm_pidfd, test)
489 {
490 	int err;
491 	int pfd;
492 	int child_status = 0;
493 
494 	self->server = socket(AF_UNIX, variant->type, 0);
495 	ASSERT_NE(-1, self->server);
496 
497 	fill_sockaddr(&self->server_addr, variant->abstract);
498 
499 	err = bind(self->server, (struct sockaddr *)&self->server_addr.listen_addr, self->server_addr.addrlen);
500 	ASSERT_EQ(0, err);
501 
502 	if (variant->type == SOCK_STREAM) {
503 		err = listen(self->server, 1);
504 		ASSERT_EQ(0, err);
505 	}
506 
507 	err = pipe(self->startup_pipe);
508 	ASSERT_NE(-1, err);
509 
510 	self->client_pid = fork();
511 	ASSERT_NE(-1, self->client_pid);
512 	if (self->client_pid == 0) {
513 		close(self->server);
514 		close(self->startup_pipe[0]);
515 		client(self, variant);
516 
517 		/*
518 		 * It's a bit unusual, but in case of success we return non-zero
519 		 * exit code (CHILD_EXIT_CODE_OK) and then we expect to read it
520 		 * from ioctl(PIDFD_GET_INFO) in cmsg_check_dead().
521 		 */
522 		exit(CHILD_EXIT_CODE_OK);
523 	}
524 	close(self->startup_pipe[1]);
525 
526 	if (variant->type == SOCK_STREAM) {
527 		pfd = accept(self->server, NULL, NULL);
528 		ASSERT_NE(-1, pfd);
529 	} else {
530 		pfd = self->server;
531 	}
532 
533 	/* wait until the child arrives at checkpoint */
534 	read(self->startup_pipe[0], &err, sizeof(int));
535 	close(self->startup_pipe[0]);
536 
537 	if (variant->type == SOCK_DGRAM) {
538 		err = sendto(pfd, "x", sizeof(char), 0, (struct sockaddr *)&self->client_addr->listen_addr, self->client_addr->addrlen);
539 		ASSERT_NE(-1, err);
540 	} else {
541 		err = send(pfd, "x", sizeof(char), 0);
542 		ASSERT_NE(-1, err);
543 	}
544 
545 	waitpid(self->client_pid, &child_status, 0);
546 	/* see comment before exit(CHILD_EXIT_CODE_OK) */
547 	ASSERT_EQ(CHILD_EXIT_CODE_OK, WIFEXITED(child_status) ? WEXITSTATUS(child_status) : 1);
548 
549 	err = sk_enable_cred_pass(pfd);
550 	ASSERT_EQ(0, err);
551 
552 	err = cmsg_check_dead(pfd, self->client_pid);
553 	ASSERT_EQ(0, err);
554 
555 	close(pfd);
556 }
557 
558 TEST_HARNESS_MAIN
559