xref: /freebsd/usr.bin/mdo/mdo.c (revision dd21556857e8d40f66bf5ad54754d9d52669ebf7)
1 /*-
2  * Copyright(c) 2024 Baptiste Daroussin <bapt@FreeBSD.org>
3  *
4  * SPDX-License-Identifier: BSD-2-Clause
5  */
6 
7 #include <sys/limits.h>
8 #include <sys/ucred.h>
9 
10 #include <err.h>
11 #include <paths.h>
12 #include <pwd.h>
13 #include <stdbool.h>
14 #include <stdio.h>
15 #include <stdlib.h>
16 #include <string.h>
17 #include <unistd.h>
18 
19 static void
20 usage(void)
21 {
22 	fprintf(stderr, "usage: mdo [-u username] [-i] [--] [command [args]]\n");
23 	exit(EXIT_FAILURE);
24 }
25 
26 int
27 main(int argc, char **argv)
28 {
29 	struct passwd *pw;
30 	const char *username = "root";
31 	struct setcred wcred = SETCRED_INITIALIZER;
32 	u_int setcred_flags = 0;
33 	bool uidonly = false;
34 	int ch;
35 
36 	while ((ch = getopt(argc, argv, "u:i")) != -1) {
37 		switch (ch) {
38 		case 'u':
39 			username = optarg;
40 			break;
41 		case 'i':
42 			uidonly = true;
43 			break;
44 		default:
45 			usage();
46 		}
47 	}
48 	argc -= optind;
49 	argv += optind;
50 
51 	if ((pw = getpwnam(username)) == NULL) {
52 		if (strspn(username, "0123456789") == strlen(username)) {
53 			const char *errp = NULL;
54 			uid_t uid = strtonum(username, 0, UID_MAX, &errp);
55 			if (errp != NULL)
56 				err(EXIT_FAILURE, "invalid user ID '%s'",
57 				    username);
58 			pw = getpwuid(uid);
59 		}
60 		if (pw == NULL)
61 			err(EXIT_FAILURE, "invalid username '%s'", username);
62 	}
63 
64 	wcred.sc_uid = wcred.sc_ruid = wcred.sc_svuid = pw->pw_uid;
65 	setcred_flags |= SETCREDF_UID | SETCREDF_RUID | SETCREDF_SVUID;
66 
67 	if (!uidonly) {
68 		/*
69 		 * If there are too many groups specified for some UID, setting
70 		 * the groups will fail.  We preserve this condition by
71 		 * allocating one more group slot than allowed, as
72 		 * getgrouplist() itself is just some getter function and thus
73 		 * doesn't (and shouldn't) check the limit, and to allow
74 		 * setcred() to actually check for overflow.
75 		 */
76 		const long ngroups_alloc = sysconf(_SC_NGROUPS_MAX) + 2;
77 		gid_t *const groups = malloc(sizeof(*groups) * ngroups_alloc);
78 		int ngroups = ngroups_alloc;
79 
80 		if (groups == NULL)
81 			err(EXIT_FAILURE, "cannot allocate memory for groups");
82 
83 		getgrouplist(pw->pw_name, pw->pw_gid, groups, &ngroups);
84 
85 		wcred.sc_gid = wcred.sc_rgid = wcred.sc_svgid = pw->pw_gid;
86 		wcred.sc_supp_groups = groups + 1;
87 		wcred.sc_supp_groups_nb = ngroups - 1;
88 		setcred_flags |= SETCREDF_GID | SETCREDF_RGID | SETCREDF_SVGID |
89 		    SETCREDF_SUPP_GROUPS;
90 	}
91 
92 	if (setcred(setcred_flags, &wcred, sizeof(wcred)) != 0)
93 		err(EXIT_FAILURE, "calling setcred() failed");
94 
95 	if (*argv == NULL) {
96 		const char *sh = getenv("SHELL");
97 		if (sh == NULL)
98 			sh = _PATH_BSHELL;
99 		execlp(sh, sh, "-i", NULL);
100 	} else {
101 		execvp(argv[0], argv);
102 	}
103 	err(EXIT_FAILURE, "exec failed");
104 }
105