xref: /linux/tools/testing/selftests/net/af_unix/scm_pidfd.c (revision 07fdad3a93756b872da7b53647715c48d0f4a2d0)
1 // SPDX-License-Identifier: GPL-2.0 OR MIT
2 #define _GNU_SOURCE
3 #include <error.h>
4 #include <limits.h>
5 #include <stddef.h>
6 #include <stdio.h>
7 #include <stdlib.h>
8 #include <sys/socket.h>
9 #include <linux/socket.h>
10 #include <unistd.h>
11 #include <string.h>
12 #include <errno.h>
13 #include <sys/un.h>
14 #include <sys/signal.h>
15 #include <sys/types.h>
16 #include <sys/wait.h>
17 
18 #include "../../pidfd/pidfd.h"
19 #include "../../kselftest_harness.h"
20 
21 #define clean_errno() (errno == 0 ? "None" : strerror(errno))
22 #define log_err(MSG, ...)                                                   \
23 	fprintf(stderr, "(%s:%d: errno: %s) " MSG "\n", __FILE__, __LINE__, \
24 		clean_errno(), ##__VA_ARGS__)
25 
26 #ifndef SCM_PIDFD
27 #define SCM_PIDFD 0x04
28 #endif
29 
30 #define CHILD_EXIT_CODE_OK 123
31 
32 static void child_die()
33 {
34 	exit(1);
35 }
36 
37 static int safe_int(const char *numstr, int *converted)
38 {
39 	char *err = NULL;
40 	long sli;
41 
42 	errno = 0;
43 	sli = strtol(numstr, &err, 0);
44 	if (errno == ERANGE && (sli == LONG_MAX || sli == LONG_MIN))
45 		return -ERANGE;
46 
47 	if (errno != 0 && sli == 0)
48 		return -EINVAL;
49 
50 	if (err == numstr || *err != '\0')
51 		return -EINVAL;
52 
53 	if (sli > INT_MAX || sli < INT_MIN)
54 		return -ERANGE;
55 
56 	*converted = (int)sli;
57 	return 0;
58 }
59 
60 static int char_left_gc(const char *buffer, size_t len)
61 {
62 	size_t i;
63 
64 	for (i = 0; i < len; i++) {
65 		if (buffer[i] == ' ' || buffer[i] == '\t')
66 			continue;
67 
68 		return i;
69 	}
70 
71 	return 0;
72 }
73 
74 static int char_right_gc(const char *buffer, size_t len)
75 {
76 	int i;
77 
78 	for (i = len - 1; i >= 0; i--) {
79 		if (buffer[i] == ' ' || buffer[i] == '\t' ||
80 		    buffer[i] == '\n' || buffer[i] == '\0')
81 			continue;
82 
83 		return i + 1;
84 	}
85 
86 	return 0;
87 }
88 
89 static char *trim_whitespace_in_place(char *buffer)
90 {
91 	buffer += char_left_gc(buffer, strlen(buffer));
92 	buffer[char_right_gc(buffer, strlen(buffer))] = '\0';
93 	return buffer;
94 }
95 
96 /* borrowed (with all helpers) from pidfd/pidfd_open_test.c */
97 static pid_t get_pid_from_fdinfo_file(int pidfd, const char *key, size_t keylen)
98 {
99 	int ret;
100 	char path[512];
101 	FILE *f;
102 	size_t n = 0;
103 	pid_t result = -1;
104 	char *line = NULL;
105 
106 	snprintf(path, sizeof(path), "/proc/self/fdinfo/%d", pidfd);
107 
108 	f = fopen(path, "re");
109 	if (!f)
110 		return -1;
111 
112 	while (getline(&line, &n, f) != -1) {
113 		char *numstr;
114 
115 		if (strncmp(line, key, keylen))
116 			continue;
117 
118 		numstr = trim_whitespace_in_place(line + 4);
119 		ret = safe_int(numstr, &result);
120 		if (ret < 0)
121 			goto out;
122 
123 		break;
124 	}
125 
126 out:
127 	free(line);
128 	fclose(f);
129 	return result;
130 }
131 
132 struct cmsg_data {
133 	struct ucred *ucred;
134 	int *pidfd;
135 };
136 
137 static int parse_cmsg(struct msghdr *msg, struct cmsg_data *res)
138 {
139 	struct cmsghdr *cmsg;
140 
141 	if (msg->msg_flags & (MSG_TRUNC | MSG_CTRUNC)) {
142 		log_err("recvmsg: truncated");
143 		return 1;
144 	}
145 
146 	for (cmsg = CMSG_FIRSTHDR(msg); cmsg != NULL;
147 	     cmsg = CMSG_NXTHDR(msg, cmsg)) {
148 		if (cmsg->cmsg_level == SOL_SOCKET &&
149 		    cmsg->cmsg_type == SCM_PIDFD) {
150 			if (cmsg->cmsg_len < sizeof(*res->pidfd)) {
151 				log_err("CMSG parse: SCM_PIDFD wrong len");
152 				return 1;
153 			}
154 
155 			res->pidfd = (void *)CMSG_DATA(cmsg);
156 		}
157 
158 		if (cmsg->cmsg_level == SOL_SOCKET &&
159 		    cmsg->cmsg_type == SCM_CREDENTIALS) {
160 			if (cmsg->cmsg_len < sizeof(*res->ucred)) {
161 				log_err("CMSG parse: SCM_CREDENTIALS wrong len");
162 				return 1;
163 			}
164 
165 			res->ucred = (void *)CMSG_DATA(cmsg);
166 		}
167 	}
168 
169 	if (!res->pidfd) {
170 		log_err("CMSG parse: SCM_PIDFD not found");
171 		return 1;
172 	}
173 
174 	if (!res->ucred) {
175 		log_err("CMSG parse: SCM_CREDENTIALS not found");
176 		return 1;
177 	}
178 
179 	return 0;
180 }
181 
182 static int cmsg_check(int fd)
183 {
184 	struct msghdr msg = { 0 };
185 	struct cmsg_data res;
186 	struct iovec iov;
187 	int data = 0;
188 	char control[CMSG_SPACE(sizeof(struct ucred)) +
189 		     CMSG_SPACE(sizeof(int))] = { 0 };
190 	pid_t parent_pid;
191 	int err;
192 
193 	iov.iov_base = &data;
194 	iov.iov_len = sizeof(data);
195 
196 	msg.msg_iov = &iov;
197 	msg.msg_iovlen = 1;
198 	msg.msg_control = control;
199 	msg.msg_controllen = sizeof(control);
200 
201 	err = recvmsg(fd, &msg, 0);
202 	if (err < 0) {
203 		log_err("recvmsg");
204 		return 1;
205 	}
206 
207 	if (msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC)) {
208 		log_err("recvmsg: truncated");
209 		return 1;
210 	}
211 
212 	/* send(pfd, "x", sizeof(char), 0) */
213 	if (data != 'x') {
214 		log_err("recvmsg: data corruption");
215 		return 1;
216 	}
217 
218 	if (parse_cmsg(&msg, &res)) {
219 		log_err("CMSG parse: parse_cmsg() failed");
220 		return 1;
221 	}
222 
223 	/* pidfd from SCM_PIDFD should point to the parent process PID */
224 	parent_pid =
225 		get_pid_from_fdinfo_file(*res.pidfd, "Pid:", sizeof("Pid:") - 1);
226 	if (parent_pid != getppid()) {
227 		log_err("wrong SCM_PIDFD %d != %d", parent_pid, getppid());
228 		close(*res.pidfd);
229 		return 1;
230 	}
231 
232 	close(*res.pidfd);
233 	return 0;
234 }
235 
236 static int cmsg_check_dead(int fd, int expected_pid)
237 {
238 	int err;
239 	struct msghdr msg = { 0 };
240 	struct cmsg_data res;
241 	struct iovec iov;
242 	int data = 0;
243 	char control[CMSG_SPACE(sizeof(struct ucred)) +
244 		     CMSG_SPACE(sizeof(int))] = { 0 };
245 	struct pidfd_info info = {
246 		.mask = PIDFD_INFO_EXIT,
247 	};
248 
249 	iov.iov_base = &data;
250 	iov.iov_len = sizeof(data);
251 
252 	msg.msg_iov = &iov;
253 	msg.msg_iovlen = 1;
254 	msg.msg_control = control;
255 	msg.msg_controllen = sizeof(control);
256 
257 	err = recvmsg(fd, &msg, 0);
258 	if (err < 0) {
259 		log_err("recvmsg");
260 		return 1;
261 	}
262 
263 	if (msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC)) {
264 		log_err("recvmsg: truncated");
265 		return 1;
266 	}
267 
268 	/* send(cfd, "y", sizeof(char), 0) */
269 	if (data != 'y') {
270 		log_err("recvmsg: data corruption");
271 		return 1;
272 	}
273 
274 	if (parse_cmsg(&msg, &res)) {
275 		log_err("CMSG parse: parse_cmsg() failed");
276 		return 1;
277 	}
278 
279 	/*
280 	 * pidfd from SCM_PIDFD should point to the client_pid.
281 	 * Let's read exit information and check if it's what
282 	 * we expect to see.
283 	 */
284 	if (ioctl(*res.pidfd, PIDFD_GET_INFO, &info)) {
285 		log_err("%s: ioctl(PIDFD_GET_INFO) failed", __func__);
286 		close(*res.pidfd);
287 		return 1;
288 	}
289 
290 	if (!(info.mask & PIDFD_INFO_EXIT)) {
291 		log_err("%s: No exit information from ioctl(PIDFD_GET_INFO)", __func__);
292 		close(*res.pidfd);
293 		return 1;
294 	}
295 
296 	err = WIFEXITED(info.exit_code) ? WEXITSTATUS(info.exit_code) : 1;
297 	if (err != CHILD_EXIT_CODE_OK) {
298 		log_err("%s: wrong exit_code %d != %d", __func__, err, CHILD_EXIT_CODE_OK);
299 		close(*res.pidfd);
300 		return 1;
301 	}
302 
303 	close(*res.pidfd);
304 	return 0;
305 }
306 
307 struct sock_addr {
308 	char sock_name[32];
309 	struct sockaddr_un listen_addr;
310 	socklen_t addrlen;
311 };
312 
313 FIXTURE(scm_pidfd)
314 {
315 	int server;
316 	pid_t client_pid;
317 	int startup_pipe[2];
318 	struct sock_addr server_addr;
319 	struct sock_addr *client_addr;
320 };
321 
322 FIXTURE_VARIANT(scm_pidfd)
323 {
324 	int type;
325 	bool abstract;
326 };
327 
328 FIXTURE_VARIANT_ADD(scm_pidfd, stream_pathname)
329 {
330 	.type = SOCK_STREAM,
331 	.abstract = 0,
332 };
333 
334 FIXTURE_VARIANT_ADD(scm_pidfd, stream_abstract)
335 {
336 	.type = SOCK_STREAM,
337 	.abstract = 1,
338 };
339 
340 FIXTURE_VARIANT_ADD(scm_pidfd, dgram_pathname)
341 {
342 	.type = SOCK_DGRAM,
343 	.abstract = 0,
344 };
345 
346 FIXTURE_VARIANT_ADD(scm_pidfd, dgram_abstract)
347 {
348 	.type = SOCK_DGRAM,
349 	.abstract = 1,
350 };
351 
352 FIXTURE_SETUP(scm_pidfd)
353 {
354 	self->client_addr = mmap(NULL, sizeof(*self->client_addr), PROT_READ | PROT_WRITE,
355 				 MAP_SHARED | MAP_ANONYMOUS, -1, 0);
356 	ASSERT_NE(MAP_FAILED, self->client_addr);
357 }
358 
359 FIXTURE_TEARDOWN(scm_pidfd)
360 {
361 	close(self->server);
362 
363 	kill(self->client_pid, SIGKILL);
364 	waitpid(self->client_pid, NULL, 0);
365 
366 	if (!variant->abstract) {
367 		unlink(self->server_addr.sock_name);
368 		unlink(self->client_addr->sock_name);
369 	}
370 }
371 
372 static void fill_sockaddr(struct sock_addr *addr, bool abstract)
373 {
374 	char *sun_path_buf = (char *)&addr->listen_addr.sun_path;
375 
376 	addr->listen_addr.sun_family = AF_UNIX;
377 	addr->addrlen = offsetof(struct sockaddr_un, sun_path);
378 	snprintf(addr->sock_name, sizeof(addr->sock_name), "scm_pidfd_%d", getpid());
379 	addr->addrlen += strlen(addr->sock_name);
380 	if (abstract) {
381 		*sun_path_buf = '\0';
382 		addr->addrlen++;
383 		sun_path_buf++;
384 	} else {
385 		unlink(addr->sock_name);
386 	}
387 	memcpy(sun_path_buf, addr->sock_name, strlen(addr->sock_name));
388 }
389 
390 static int sk_enable_cred_pass(int sk)
391 {
392 	int on = 0;
393 
394 	on = 1;
395 	if (setsockopt(sk, SOL_SOCKET, SO_PASSCRED, &on, sizeof(on))) {
396 		log_err("Failed to set SO_PASSCRED");
397 		return 1;
398 	}
399 
400 	if (setsockopt(sk, SOL_SOCKET, SO_PASSPIDFD, &on, sizeof(on))) {
401 		log_err("Failed to set SO_PASSPIDFD");
402 		return 1;
403 	}
404 
405 	return 0;
406 }
407 
408 static void client(FIXTURE_DATA(scm_pidfd) *self,
409 		   const FIXTURE_VARIANT(scm_pidfd) *variant)
410 {
411 	int cfd;
412 	socklen_t len;
413 	struct ucred peer_cred;
414 	int peer_pidfd;
415 	pid_t peer_pid;
416 
417 	cfd = socket(AF_UNIX, variant->type, 0);
418 	if (cfd < 0) {
419 		log_err("socket");
420 		child_die();
421 	}
422 
423 	if (variant->type == SOCK_DGRAM) {
424 		fill_sockaddr(self->client_addr, variant->abstract);
425 
426 		if (bind(cfd, (struct sockaddr *)&self->client_addr->listen_addr, self->client_addr->addrlen)) {
427 			log_err("bind");
428 			child_die();
429 		}
430 	}
431 
432 	if (connect(cfd, (struct sockaddr *)&self->server_addr.listen_addr,
433 		    self->server_addr.addrlen) != 0) {
434 		log_err("connect");
435 		child_die();
436 	}
437 
438 	if (sk_enable_cred_pass(cfd)) {
439 		log_err("sk_enable_cred_pass() failed");
440 		child_die();
441 	}
442 
443 	close(self->startup_pipe[1]);
444 
445 	if (cmsg_check(cfd)) {
446 		log_err("cmsg_check failed");
447 		child_die();
448 	}
449 
450 	/* send something to the parent so it can receive SCM_PIDFD too and validate it */
451 	if (send(cfd, "y", sizeof(char), 0) == -1) {
452 		log_err("Failed to send(cfd, \"y\", sizeof(char), 0)");
453 		child_die();
454 	}
455 
456 	/* skip further for SOCK_DGRAM as it's not applicable */
457 	if (variant->type == SOCK_DGRAM)
458 		return;
459 
460 	len = sizeof(peer_cred);
461 	if (getsockopt(cfd, SOL_SOCKET, SO_PEERCRED, &peer_cred, &len)) {
462 		log_err("Failed to get SO_PEERCRED");
463 		child_die();
464 	}
465 
466 	len = sizeof(peer_pidfd);
467 	if (getsockopt(cfd, SOL_SOCKET, SO_PEERPIDFD, &peer_pidfd, &len)) {
468 		log_err("Failed to get SO_PEERPIDFD");
469 		child_die();
470 	}
471 
472 	/* pid from SO_PEERCRED should point to the parent process PID */
473 	if (peer_cred.pid != getppid()) {
474 		log_err("peer_cred.pid != getppid(): %d != %d", peer_cred.pid, getppid());
475 		child_die();
476 	}
477 
478 	peer_pid = get_pid_from_fdinfo_file(peer_pidfd,
479 					    "Pid:", sizeof("Pid:") - 1);
480 	if (peer_pid != peer_cred.pid) {
481 		log_err("peer_pid != peer_cred.pid: %d != %d", peer_pid, peer_cred.pid);
482 		child_die();
483 	}
484 }
485 
486 TEST_F(scm_pidfd, test)
487 {
488 	int err;
489 	int pfd;
490 	int child_status = 0;
491 
492 	self->server = socket(AF_UNIX, variant->type, 0);
493 	ASSERT_NE(-1, self->server);
494 
495 	fill_sockaddr(&self->server_addr, variant->abstract);
496 
497 	err = bind(self->server, (struct sockaddr *)&self->server_addr.listen_addr, self->server_addr.addrlen);
498 	ASSERT_EQ(0, err);
499 
500 	if (variant->type == SOCK_STREAM) {
501 		err = listen(self->server, 1);
502 		ASSERT_EQ(0, err);
503 	}
504 
505 	err = pipe(self->startup_pipe);
506 	ASSERT_NE(-1, err);
507 
508 	self->client_pid = fork();
509 	ASSERT_NE(-1, self->client_pid);
510 	if (self->client_pid == 0) {
511 		close(self->server);
512 		close(self->startup_pipe[0]);
513 		client(self, variant);
514 
515 		/*
516 		 * It's a bit unusual, but in case of success we return non-zero
517 		 * exit code (CHILD_EXIT_CODE_OK) and then we expect to read it
518 		 * from ioctl(PIDFD_GET_INFO) in cmsg_check_dead().
519 		 */
520 		exit(CHILD_EXIT_CODE_OK);
521 	}
522 	close(self->startup_pipe[1]);
523 
524 	if (variant->type == SOCK_STREAM) {
525 		pfd = accept(self->server, NULL, NULL);
526 		ASSERT_NE(-1, pfd);
527 	} else {
528 		pfd = self->server;
529 	}
530 
531 	/* wait until the child arrives at checkpoint */
532 	read(self->startup_pipe[0], &err, sizeof(int));
533 	close(self->startup_pipe[0]);
534 
535 	if (variant->type == SOCK_DGRAM) {
536 		err = sendto(pfd, "x", sizeof(char), 0, (struct sockaddr *)&self->client_addr->listen_addr, self->client_addr->addrlen);
537 		ASSERT_NE(-1, err);
538 	} else {
539 		err = send(pfd, "x", sizeof(char), 0);
540 		ASSERT_NE(-1, err);
541 	}
542 
543 	waitpid(self->client_pid, &child_status, 0);
544 	/* see comment before exit(CHILD_EXIT_CODE_OK) */
545 	ASSERT_EQ(CHILD_EXIT_CODE_OK, WIFEXITED(child_status) ? WEXITSTATUS(child_status) : 1);
546 
547 	err = sk_enable_cred_pass(pfd);
548 	ASSERT_EQ(0, err);
549 
550 	err = cmsg_check_dead(pfd, self->client_pid);
551 	ASSERT_EQ(0, err);
552 
553 	close(pfd);
554 }
555 
556 TEST_HARNESS_MAIN
557