1 // SPDX-License-Identifier: GPL-2.0 OR MIT
2 #define _GNU_SOURCE
3 #include <error.h>
4 #include <limits.h>
5 #include <stddef.h>
6 #include <stdio.h>
7 #include <stdlib.h>
8 #include <sys/socket.h>
9 #include <linux/socket.h>
10 #include <unistd.h>
11 #include <string.h>
12 #include <errno.h>
13 #include <sys/un.h>
14 #include <sys/signal.h>
15 #include <sys/types.h>
16 #include <sys/wait.h>
17
18 #include "../../pidfd/pidfd.h"
19 #include "kselftest_harness.h"
20
21 #define clean_errno() (errno == 0 ? "None" : strerror(errno))
22 #define log_err(MSG, ...) \
23 fprintf(stderr, "(%s:%d: errno: %s) " MSG "\n", __FILE__, __LINE__, \
24 clean_errno(), ##__VA_ARGS__)
25
26 #ifndef SCM_PIDFD
27 #define SCM_PIDFD 0x04
28 #endif
29
30 #define CHILD_EXIT_CODE_OK 123
31
child_die()32 static void child_die()
33 {
34 exit(1);
35 }
36
safe_int(const char * numstr,int * converted)37 static int safe_int(const char *numstr, int *converted)
38 {
39 char *err = NULL;
40 long sli;
41
42 errno = 0;
43 sli = strtol(numstr, &err, 0);
44 if (errno == ERANGE && (sli == LONG_MAX || sli == LONG_MIN))
45 return -ERANGE;
46
47 if (errno != 0 && sli == 0)
48 return -EINVAL;
49
50 if (err == numstr || *err != '\0')
51 return -EINVAL;
52
53 if (sli > INT_MAX || sli < INT_MIN)
54 return -ERANGE;
55
56 *converted = (int)sli;
57 return 0;
58 }
59
char_left_gc(const char * buffer,size_t len)60 static int char_left_gc(const char *buffer, size_t len)
61 {
62 size_t i;
63
64 for (i = 0; i < len; i++) {
65 if (buffer[i] == ' ' || buffer[i] == '\t')
66 continue;
67
68 return i;
69 }
70
71 return 0;
72 }
73
char_right_gc(const char * buffer,size_t len)74 static int char_right_gc(const char *buffer, size_t len)
75 {
76 int i;
77
78 for (i = len - 1; i >= 0; i--) {
79 if (buffer[i] == ' ' || buffer[i] == '\t' ||
80 buffer[i] == '\n' || buffer[i] == '\0')
81 continue;
82
83 return i + 1;
84 }
85
86 return 0;
87 }
88
trim_whitespace_in_place(char * buffer)89 static char *trim_whitespace_in_place(char *buffer)
90 {
91 buffer += char_left_gc(buffer, strlen(buffer));
92 buffer[char_right_gc(buffer, strlen(buffer))] = '\0';
93 return buffer;
94 }
95
96 /* borrowed (with all helpers) from pidfd/pidfd_open_test.c */
get_pid_from_fdinfo_file(int pidfd,const char * key,size_t keylen)97 static pid_t get_pid_from_fdinfo_file(int pidfd, const char *key, size_t keylen)
98 {
99 int ret;
100 char path[512];
101 FILE *f;
102 size_t n = 0;
103 pid_t result = -1;
104 char *line = NULL;
105
106 snprintf(path, sizeof(path), "/proc/self/fdinfo/%d", pidfd);
107
108 f = fopen(path, "re");
109 if (!f)
110 return -1;
111
112 while (getline(&line, &n, f) != -1) {
113 char *numstr;
114
115 if (strncmp(line, key, keylen))
116 continue;
117
118 numstr = trim_whitespace_in_place(line + 4);
119 ret = safe_int(numstr, &result);
120 if (ret < 0)
121 goto out;
122
123 break;
124 }
125
126 out:
127 free(line);
128 fclose(f);
129 return result;
130 }
131
132 struct cmsg_data {
133 struct ucred *ucred;
134 int *pidfd;
135 };
136
parse_cmsg(struct msghdr * msg,struct cmsg_data * res)137 static int parse_cmsg(struct msghdr *msg, struct cmsg_data *res)
138 {
139 struct cmsghdr *cmsg;
140
141 if (msg->msg_flags & (MSG_TRUNC | MSG_CTRUNC)) {
142 log_err("recvmsg: truncated");
143 return 1;
144 }
145
146 for (cmsg = CMSG_FIRSTHDR(msg); cmsg != NULL;
147 cmsg = CMSG_NXTHDR(msg, cmsg)) {
148 if (cmsg->cmsg_level == SOL_SOCKET &&
149 cmsg->cmsg_type == SCM_PIDFD) {
150 if (cmsg->cmsg_len < sizeof(*res->pidfd)) {
151 log_err("CMSG parse: SCM_PIDFD wrong len");
152 return 1;
153 }
154
155 res->pidfd = (void *)CMSG_DATA(cmsg);
156 }
157
158 if (cmsg->cmsg_level == SOL_SOCKET &&
159 cmsg->cmsg_type == SCM_CREDENTIALS) {
160 if (cmsg->cmsg_len < sizeof(*res->ucred)) {
161 log_err("CMSG parse: SCM_CREDENTIALS wrong len");
162 return 1;
163 }
164
165 res->ucred = (void *)CMSG_DATA(cmsg);
166 }
167 }
168
169 if (!res->pidfd) {
170 log_err("CMSG parse: SCM_PIDFD not found");
171 return 1;
172 }
173
174 if (!res->ucred) {
175 log_err("CMSG parse: SCM_CREDENTIALS not found");
176 return 1;
177 }
178
179 return 0;
180 }
181
cmsg_check(int fd)182 static int cmsg_check(int fd)
183 {
184 struct msghdr msg = { 0 };
185 struct cmsg_data res;
186 struct iovec iov;
187 int data = 0;
188 char control[CMSG_SPACE(sizeof(struct ucred)) +
189 CMSG_SPACE(sizeof(int))] = { 0 };
190 pid_t parent_pid;
191 int err;
192
193 iov.iov_base = &data;
194 iov.iov_len = sizeof(data);
195
196 msg.msg_iov = &iov;
197 msg.msg_iovlen = 1;
198 msg.msg_control = control;
199 msg.msg_controllen = sizeof(control);
200
201 err = recvmsg(fd, &msg, 0);
202 if (err < 0) {
203 log_err("recvmsg");
204 return 1;
205 }
206
207 if (msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC)) {
208 log_err("recvmsg: truncated");
209 return 1;
210 }
211
212 /* send(pfd, "x", sizeof(char), 0) */
213 if (data != 'x') {
214 log_err("recvmsg: data corruption");
215 return 1;
216 }
217
218 if (parse_cmsg(&msg, &res)) {
219 log_err("CMSG parse: parse_cmsg() failed");
220 return 1;
221 }
222
223 /* pidfd from SCM_PIDFD should point to the parent process PID */
224 parent_pid =
225 get_pid_from_fdinfo_file(*res.pidfd, "Pid:", sizeof("Pid:") - 1);
226 if (parent_pid != getppid()) {
227 log_err("wrong SCM_PIDFD %d != %d", parent_pid, getppid());
228 close(*res.pidfd);
229 return 1;
230 }
231
232 close(*res.pidfd);
233 return 0;
234 }
235
cmsg_check_dead(int fd,int expected_pid)236 static int cmsg_check_dead(int fd, int expected_pid)
237 {
238 int err;
239 struct msghdr msg = { 0 };
240 struct cmsg_data res;
241 struct iovec iov;
242 int data = 0;
243 char control[CMSG_SPACE(sizeof(struct ucred)) +
244 CMSG_SPACE(sizeof(int))] = { 0 };
245 struct pidfd_info info = {
246 .mask = PIDFD_INFO_EXIT,
247 };
248
249 iov.iov_base = &data;
250 iov.iov_len = sizeof(data);
251
252 msg.msg_iov = &iov;
253 msg.msg_iovlen = 1;
254 msg.msg_control = control;
255 msg.msg_controllen = sizeof(control);
256
257 err = recvmsg(fd, &msg, 0);
258 if (err < 0) {
259 log_err("recvmsg");
260 return 1;
261 }
262
263 if (msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC)) {
264 log_err("recvmsg: truncated");
265 return 1;
266 }
267
268 /* send(cfd, "y", sizeof(char), 0) */
269 if (data != 'y') {
270 log_err("recvmsg: data corruption");
271 return 1;
272 }
273
274 if (parse_cmsg(&msg, &res)) {
275 log_err("CMSG parse: parse_cmsg() failed");
276 return 1;
277 }
278
279 /*
280 * pidfd from SCM_PIDFD should point to the client_pid.
281 * Let's read exit information and check if it's what
282 * we expect to see.
283 */
284 if (ioctl(*res.pidfd, PIDFD_GET_INFO, &info)) {
285 log_err("%s: ioctl(PIDFD_GET_INFO) failed", __func__);
286 close(*res.pidfd);
287 return 1;
288 }
289
290 if (!(info.mask & PIDFD_INFO_EXIT)) {
291 log_err("%s: No exit information from ioctl(PIDFD_GET_INFO)", __func__);
292 close(*res.pidfd);
293 return 1;
294 }
295
296 err = WIFEXITED(info.exit_code) ? WEXITSTATUS(info.exit_code) : 1;
297 if (err != CHILD_EXIT_CODE_OK) {
298 log_err("%s: wrong exit_code %d != %d", __func__, err, CHILD_EXIT_CODE_OK);
299 close(*res.pidfd);
300 return 1;
301 }
302
303 close(*res.pidfd);
304 return 0;
305 }
306
307 struct sock_addr {
308 char sock_name[32];
309 struct sockaddr_un listen_addr;
310 socklen_t addrlen;
311 };
312
FIXTURE(scm_pidfd)313 FIXTURE(scm_pidfd)
314 {
315 int server;
316 pid_t client_pid;
317 int startup_pipe[2];
318 struct sock_addr server_addr;
319 struct sock_addr *client_addr;
320 };
321
FIXTURE_VARIANT(scm_pidfd)322 FIXTURE_VARIANT(scm_pidfd)
323 {
324 int type;
325 bool abstract;
326 };
327
FIXTURE_VARIANT_ADD(scm_pidfd,stream_pathname)328 FIXTURE_VARIANT_ADD(scm_pidfd, stream_pathname)
329 {
330 .type = SOCK_STREAM,
331 .abstract = 0,
332 };
333
FIXTURE_VARIANT_ADD(scm_pidfd,stream_abstract)334 FIXTURE_VARIANT_ADD(scm_pidfd, stream_abstract)
335 {
336 .type = SOCK_STREAM,
337 .abstract = 1,
338 };
339
FIXTURE_VARIANT_ADD(scm_pidfd,dgram_pathname)340 FIXTURE_VARIANT_ADD(scm_pidfd, dgram_pathname)
341 {
342 .type = SOCK_DGRAM,
343 .abstract = 0,
344 };
345
FIXTURE_VARIANT_ADD(scm_pidfd,dgram_abstract)346 FIXTURE_VARIANT_ADD(scm_pidfd, dgram_abstract)
347 {
348 .type = SOCK_DGRAM,
349 .abstract = 1,
350 };
351
FIXTURE_SETUP(scm_pidfd)352 FIXTURE_SETUP(scm_pidfd)
353 {
354 self->client_addr = mmap(NULL, sizeof(*self->client_addr), PROT_READ | PROT_WRITE,
355 MAP_SHARED | MAP_ANONYMOUS, -1, 0);
356 ASSERT_NE(MAP_FAILED, self->client_addr);
357 }
358
FIXTURE_TEARDOWN(scm_pidfd)359 FIXTURE_TEARDOWN(scm_pidfd)
360 {
361 close(self->server);
362
363 kill(self->client_pid, SIGKILL);
364 waitpid(self->client_pid, NULL, 0);
365
366 if (!variant->abstract) {
367 unlink(self->server_addr.sock_name);
368 unlink(self->client_addr->sock_name);
369 }
370 }
371
fill_sockaddr(struct sock_addr * addr,bool abstract)372 static void fill_sockaddr(struct sock_addr *addr, bool abstract)
373 {
374 char *sun_path_buf = (char *)&addr->listen_addr.sun_path;
375
376 addr->listen_addr.sun_family = AF_UNIX;
377 addr->addrlen = offsetof(struct sockaddr_un, sun_path);
378 snprintf(addr->sock_name, sizeof(addr->sock_name), "scm_pidfd_%d", getpid());
379 addr->addrlen += strlen(addr->sock_name);
380 if (abstract) {
381 *sun_path_buf = '\0';
382 addr->addrlen++;
383 sun_path_buf++;
384 } else {
385 unlink(addr->sock_name);
386 }
387 memcpy(sun_path_buf, addr->sock_name, strlen(addr->sock_name));
388 }
389
sk_enable_cred_pass(int sk)390 static int sk_enable_cred_pass(int sk)
391 {
392 int on = 0;
393
394 on = 1;
395 if (setsockopt(sk, SOL_SOCKET, SO_PASSCRED, &on, sizeof(on))) {
396 log_err("Failed to set SO_PASSCRED");
397 return 1;
398 }
399
400 if (setsockopt(sk, SOL_SOCKET, SO_PASSPIDFD, &on, sizeof(on))) {
401 log_err("Failed to set SO_PASSPIDFD");
402 return 1;
403 }
404
405 return 0;
406 }
407
client(FIXTURE_DATA (scm_pidfd)* self,const FIXTURE_VARIANT (scm_pidfd)* variant)408 static void client(FIXTURE_DATA(scm_pidfd) *self,
409 const FIXTURE_VARIANT(scm_pidfd) *variant)
410 {
411 int cfd;
412 socklen_t len;
413 struct ucred peer_cred;
414 int peer_pidfd;
415 pid_t peer_pid;
416
417 cfd = socket(AF_UNIX, variant->type, 0);
418 if (cfd < 0) {
419 log_err("socket");
420 child_die();
421 }
422
423 if (variant->type == SOCK_DGRAM) {
424 fill_sockaddr(self->client_addr, variant->abstract);
425
426 if (bind(cfd, (struct sockaddr *)&self->client_addr->listen_addr, self->client_addr->addrlen)) {
427 log_err("bind");
428 child_die();
429 }
430 }
431
432 if (connect(cfd, (struct sockaddr *)&self->server_addr.listen_addr,
433 self->server_addr.addrlen) != 0) {
434 log_err("connect");
435 child_die();
436 }
437
438 if (sk_enable_cred_pass(cfd)) {
439 log_err("sk_enable_cred_pass() failed");
440 child_die();
441 }
442
443 close(self->startup_pipe[1]);
444
445 if (cmsg_check(cfd)) {
446 log_err("cmsg_check failed");
447 child_die();
448 }
449
450 /* send something to the parent so it can receive SCM_PIDFD too and validate it */
451 if (send(cfd, "y", sizeof(char), 0) == -1) {
452 log_err("Failed to send(cfd, \"y\", sizeof(char), 0)");
453 child_die();
454 }
455
456 /* skip further for SOCK_DGRAM as it's not applicable */
457 if (variant->type == SOCK_DGRAM)
458 return;
459
460 len = sizeof(peer_cred);
461 if (getsockopt(cfd, SOL_SOCKET, SO_PEERCRED, &peer_cred, &len)) {
462 log_err("Failed to get SO_PEERCRED");
463 child_die();
464 }
465
466 len = sizeof(peer_pidfd);
467 if (getsockopt(cfd, SOL_SOCKET, SO_PEERPIDFD, &peer_pidfd, &len)) {
468 log_err("Failed to get SO_PEERPIDFD");
469 child_die();
470 }
471
472 /* pid from SO_PEERCRED should point to the parent process PID */
473 if (peer_cred.pid != getppid()) {
474 log_err("peer_cred.pid != getppid(): %d != %d", peer_cred.pid, getppid());
475 child_die();
476 }
477
478 peer_pid = get_pid_from_fdinfo_file(peer_pidfd,
479 "Pid:", sizeof("Pid:") - 1);
480 if (peer_pid != peer_cred.pid) {
481 log_err("peer_pid != peer_cred.pid: %d != %d", peer_pid, peer_cred.pid);
482 child_die();
483 }
484 }
485
TEST_F(scm_pidfd,test)486 TEST_F(scm_pidfd, test)
487 {
488 int err;
489 int pfd;
490 int child_status = 0;
491
492 self->server = socket(AF_UNIX, variant->type, 0);
493 ASSERT_NE(-1, self->server);
494
495 fill_sockaddr(&self->server_addr, variant->abstract);
496
497 err = bind(self->server, (struct sockaddr *)&self->server_addr.listen_addr, self->server_addr.addrlen);
498 ASSERT_EQ(0, err);
499
500 if (variant->type == SOCK_STREAM) {
501 err = listen(self->server, 1);
502 ASSERT_EQ(0, err);
503 }
504
505 err = pipe(self->startup_pipe);
506 ASSERT_NE(-1, err);
507
508 self->client_pid = fork();
509 ASSERT_NE(-1, self->client_pid);
510 if (self->client_pid == 0) {
511 close(self->server);
512 close(self->startup_pipe[0]);
513 client(self, variant);
514
515 /*
516 * It's a bit unusual, but in case of success we return non-zero
517 * exit code (CHILD_EXIT_CODE_OK) and then we expect to read it
518 * from ioctl(PIDFD_GET_INFO) in cmsg_check_dead().
519 */
520 exit(CHILD_EXIT_CODE_OK);
521 }
522 close(self->startup_pipe[1]);
523
524 if (variant->type == SOCK_STREAM) {
525 pfd = accept(self->server, NULL, NULL);
526 ASSERT_NE(-1, pfd);
527 } else {
528 pfd = self->server;
529 }
530
531 /* wait until the child arrives at checkpoint */
532 read(self->startup_pipe[0], &err, sizeof(int));
533 close(self->startup_pipe[0]);
534
535 if (variant->type == SOCK_DGRAM) {
536 err = sendto(pfd, "x", sizeof(char), 0, (struct sockaddr *)&self->client_addr->listen_addr, self->client_addr->addrlen);
537 ASSERT_NE(-1, err);
538 } else {
539 err = send(pfd, "x", sizeof(char), 0);
540 ASSERT_NE(-1, err);
541 }
542
543 waitpid(self->client_pid, &child_status, 0);
544 /* see comment before exit(CHILD_EXIT_CODE_OK) */
545 ASSERT_EQ(CHILD_EXIT_CODE_OK, WIFEXITED(child_status) ? WEXITSTATUS(child_status) : 1);
546
547 err = sk_enable_cred_pass(pfd);
548 ASSERT_EQ(0, err);
549
550 err = cmsg_check_dead(pfd, self->client_pid);
551 ASSERT_EQ(0, err);
552
553 close(pfd);
554 }
555
556 TEST_HARNESS_MAIN
557