xref: /linux/tools/testing/selftests/filesystems/utils.c (revision 804382d59b81b331735d37a18149ea0d36d5936a)
1 // SPDX-License-Identifier: GPL-2.0
2 #ifndef _GNU_SOURCE
3 #define _GNU_SOURCE
4 #endif
5 #include <fcntl.h>
6 #include <sys/types.h>
7 #include <dirent.h>
8 #include <grp.h>
9 #include <linux/limits.h>
10 #include <sched.h>
11 #include <stdio.h>
12 #include <stdlib.h>
13 #include <sys/eventfd.h>
14 #include <sys/fsuid.h>
15 #include <sys/prctl.h>
16 #include <sys/socket.h>
17 #include <sys/stat.h>
18 #include <sys/types.h>
19 #include <sys/wait.h>
20 #include <sys/xattr.h>
21 
22 #include "utils.h"
23 
24 #define MAX_USERNS_LEVEL 32
25 
26 #define syserror(format, ...)                           \
27 	({                                              \
28 		fprintf(stderr, "%m - " format "\n", ##__VA_ARGS__); \
29 		(-errno);                               \
30 	})
31 
32 #define syserror_set(__ret__, format, ...)                    \
33 	({                                                    \
34 		typeof(__ret__) __internal_ret__ = (__ret__); \
35 		errno = labs(__ret__);                        \
36 		fprintf(stderr, "%m - " format "\n", ##__VA_ARGS__);       \
37 		__internal_ret__;                             \
38 	})
39 
40 #define STRLITERALLEN(x) (sizeof(""x"") - 1)
41 
42 #define INTTYPE_TO_STRLEN(type)             \
43 	(2 + (sizeof(type) <= 1             \
44 		  ? 3                       \
45 		  : sizeof(type) <= 2       \
46 			? 5                 \
47 			: sizeof(type) <= 4 \
48 			      ? 10          \
49 			      : sizeof(type) <= 8 ? 20 : sizeof(int[-2 * (sizeof(type) > 8)])))
50 
51 #define list_for_each(__iterator, __list) \
52 	for (__iterator = (__list)->next; __iterator != __list; __iterator = __iterator->next)
53 
54 typedef enum idmap_type_t {
55 	ID_TYPE_UID,
56 	ID_TYPE_GID
57 } idmap_type_t;
58 
59 struct id_map {
60 	idmap_type_t map_type;
61 	__u32 nsid;
62 	__u32 hostid;
63 	__u32 range;
64 };
65 
66 struct list {
67 	void *elem;
68 	struct list *next;
69 	struct list *prev;
70 };
71 
72 struct userns_hierarchy {
73 	int fd_userns;
74 	int fd_event;
75 	unsigned int level;
76 	struct list id_map;
77 };
78 
79 static inline void list_init(struct list *list)
80 {
81 	list->elem = NULL;
82 	list->next = list->prev = list;
83 }
84 
85 static inline int list_empty(const struct list *list)
86 {
87 	return list == list->next;
88 }
89 
90 static inline void __list_add(struct list *new, struct list *prev, struct list *next)
91 {
92 	next->prev = new;
93 	new->next = next;
94 	new->prev = prev;
95 	prev->next = new;
96 }
97 
98 static inline void list_add_tail(struct list *head, struct list *list)
99 {
100 	__list_add(list, head->prev, head);
101 }
102 
103 static inline void list_del(struct list *list)
104 {
105 	struct list *next, *prev;
106 
107 	next = list->next;
108 	prev = list->prev;
109 	next->prev = prev;
110 	prev->next = next;
111 }
112 
113 static ssize_t read_nointr(int fd, void *buf, size_t count)
114 {
115 	ssize_t ret;
116 
117 	do {
118 		ret = read(fd, buf, count);
119 	} while (ret < 0 && errno == EINTR);
120 
121 	return ret;
122 }
123 
124 static ssize_t write_nointr(int fd, const void *buf, size_t count)
125 {
126 	ssize_t ret;
127 
128 	do {
129 		ret = write(fd, buf, count);
130 	} while (ret < 0 && errno == EINTR);
131 
132 	return ret;
133 }
134 
135 #define __STACK_SIZE (8 * 1024 * 1024)
136 static pid_t do_clone(int (*fn)(void *), void *arg, int flags)
137 {
138 	void *stack;
139 
140 	stack = malloc(__STACK_SIZE);
141 	if (!stack)
142 		return -ENOMEM;
143 
144 #ifdef __ia64__
145 	return __clone2(fn, stack, __STACK_SIZE, flags | SIGCHLD, arg, NULL);
146 #else
147 	return clone(fn, stack + __STACK_SIZE, flags | SIGCHLD, arg, NULL);
148 #endif
149 }
150 
151 static int get_userns_fd_cb(void *data)
152 {
153 	for (;;)
154 		pause();
155 	_exit(0);
156 }
157 
158 static int wait_for_pid(pid_t pid)
159 {
160 	int status, ret;
161 
162 again:
163 	ret = waitpid(pid, &status, 0);
164 	if (ret == -1) {
165 		if (errno == EINTR)
166 			goto again;
167 
168 		return -1;
169 	}
170 
171 	if (!WIFEXITED(status))
172 		return -1;
173 
174 	return WEXITSTATUS(status);
175 }
176 
177 static int write_id_mapping(idmap_type_t map_type, pid_t pid, const char *buf, size_t buf_size)
178 {
179 	int fd = -EBADF, setgroups_fd = -EBADF;
180 	int fret = -1;
181 	int ret;
182 	char path[STRLITERALLEN("/proc/") + INTTYPE_TO_STRLEN(pid_t) +
183 		  STRLITERALLEN("/setgroups") + 1];
184 
185 	if (geteuid() != 0 && map_type == ID_TYPE_GID) {
186 		ret = snprintf(path, sizeof(path), "/proc/%d/setgroups", pid);
187 		if (ret < 0 || ret >= sizeof(path))
188 			goto out;
189 
190 		setgroups_fd = open(path, O_WRONLY | O_CLOEXEC);
191 		if (setgroups_fd < 0 && errno != ENOENT) {
192 			syserror("Failed to open \"%s\"", path);
193 			goto out;
194 		}
195 
196 		if (setgroups_fd >= 0) {
197 			ret = write_nointr(setgroups_fd, "deny\n", STRLITERALLEN("deny\n"));
198 			if (ret != STRLITERALLEN("deny\n")) {
199 				syserror("Failed to write \"deny\" to \"/proc/%d/setgroups\"", pid);
200 				goto out;
201 			}
202 		}
203 	}
204 
205 	ret = snprintf(path, sizeof(path), "/proc/%d/%cid_map", pid, map_type == ID_TYPE_UID ? 'u' : 'g');
206 	if (ret < 0 || ret >= sizeof(path))
207 		goto out;
208 
209 	fd = open(path, O_WRONLY | O_CLOEXEC);
210 	if (fd < 0) {
211 		syserror("Failed to open \"%s\"", path);
212 		goto out;
213 	}
214 
215 	ret = write_nointr(fd, buf, buf_size);
216 	if (ret != buf_size) {
217 		syserror("Failed to write %cid mapping to \"%s\"",
218 			 map_type == ID_TYPE_UID ? 'u' : 'g', path);
219 		goto out;
220 	}
221 
222 	fret = 0;
223 out:
224 	close(fd);
225 	close(setgroups_fd);
226 
227 	return fret;
228 }
229 
230 static int map_ids_from_idmap(struct list *idmap, pid_t pid)
231 {
232 	int fill, left;
233 	char mapbuf[4096] = {};
234 	bool had_entry = false;
235 	idmap_type_t map_type, u_or_g;
236 
237 	if (list_empty(idmap))
238 		return 0;
239 
240 	for (map_type = ID_TYPE_UID, u_or_g = 'u';
241 	     map_type <= ID_TYPE_GID; map_type++, u_or_g = 'g') {
242 		char *pos = mapbuf;
243 		int ret;
244 		struct list *iterator;
245 
246 
247 		list_for_each(iterator, idmap) {
248 			struct id_map *map = iterator->elem;
249 			if (map->map_type != map_type)
250 				continue;
251 
252 			had_entry = true;
253 
254 			left = 4096 - (pos - mapbuf);
255 			fill = snprintf(pos, left, "%u %u %u\n", map->nsid, map->hostid, map->range);
256 			/*
257 			 * The kernel only takes <= 4k for writes to
258 			 * /proc/<pid>/{g,u}id_map
259 			 */
260 			if (fill <= 0 || fill >= left)
261 				return syserror_set(-E2BIG, "Too many %cid mappings defined", u_or_g);
262 
263 			pos += fill;
264 		}
265 		if (!had_entry)
266 			continue;
267 
268 		ret = write_id_mapping(map_type, pid, mapbuf, pos - mapbuf);
269 		if (ret < 0)
270 			return syserror("Failed to write mapping: %s", mapbuf);
271 
272 		memset(mapbuf, 0, sizeof(mapbuf));
273 	}
274 
275 	return 0;
276 }
277 
278 static int get_userns_fd_from_idmap(struct list *idmap)
279 {
280 	int ret;
281 	pid_t pid;
282 	char path_ns[STRLITERALLEN("/proc/") + INTTYPE_TO_STRLEN(pid_t) +
283 		     STRLITERALLEN("/ns/user") + 1];
284 
285 	pid = do_clone(get_userns_fd_cb, NULL, CLONE_NEWUSER | CLONE_NEWNS);
286 	if (pid < 0)
287 		return -errno;
288 
289 	ret = map_ids_from_idmap(idmap, pid);
290 	if (ret < 0)
291 		return ret;
292 
293 	ret = snprintf(path_ns, sizeof(path_ns), "/proc/%d/ns/user", pid);
294 	if (ret < 0 || (size_t)ret >= sizeof(path_ns))
295 		ret = -EIO;
296 	else
297 		ret = open(path_ns, O_RDONLY | O_CLOEXEC | O_NOCTTY);
298 
299 	(void)kill(pid, SIGKILL);
300 	(void)wait_for_pid(pid);
301 	return ret;
302 }
303 
304 int get_userns_fd(unsigned long nsid, unsigned long hostid, unsigned long range)
305 {
306 	struct list head, uid_mapl, gid_mapl;
307 	struct id_map uid_map = {
308 		.map_type	= ID_TYPE_UID,
309 		.nsid		= nsid,
310 		.hostid		= hostid,
311 		.range		= range,
312 	};
313 	struct id_map gid_map = {
314 		.map_type	= ID_TYPE_GID,
315 		.nsid		= nsid,
316 		.hostid		= hostid,
317 		.range		= range,
318 	};
319 
320 	list_init(&head);
321 	uid_mapl.elem = &uid_map;
322 	gid_mapl.elem = &gid_map;
323 	list_add_tail(&head, &uid_mapl);
324 	list_add_tail(&head, &gid_mapl);
325 
326 	return get_userns_fd_from_idmap(&head);
327 }
328 
329 bool switch_ids(uid_t uid, gid_t gid)
330 {
331 	if (setgroups(0, NULL))
332 		return syserror("failure: setgroups");
333 
334 	if (setresgid(gid, gid, gid))
335 		return syserror("failure: setresgid");
336 
337 	if (setresuid(uid, uid, uid))
338 		return syserror("failure: setresuid");
339 
340 	/* Ensure we can access proc files from processes we can ptrace. */
341 	if (prctl(PR_SET_DUMPABLE, 1, 0, 0, 0))
342 		return syserror("failure: make dumpable");
343 
344 	return true;
345 }
346 
347 static int create_userns_hierarchy(struct userns_hierarchy *h);
348 
349 static int userns_fd_cb(void *data)
350 {
351 	struct userns_hierarchy *h = data;
352 	char c;
353 	int ret;
354 
355 	ret = read_nointr(h->fd_event, &c, 1);
356 	if (ret < 0)
357 		return syserror("failure: read from socketpair");
358 
359 	/* Only switch ids if someone actually wrote a mapping for us. */
360 	if (c == '1') {
361 		if (!switch_ids(0, 0))
362 			return syserror("failure: switch ids to 0");
363 	}
364 
365 	ret = write_nointr(h->fd_event, "1", 1);
366 	if (ret < 0)
367 		return syserror("failure: write to socketpair");
368 
369 	ret = create_userns_hierarchy(++h);
370 	if (ret < 0)
371 		return syserror("failure: userns level %d", h->level);
372 
373 	return 0;
374 }
375 
376 static int create_userns_hierarchy(struct userns_hierarchy *h)
377 {
378 	int fret = -1;
379 	char c;
380 	int fd_socket[2];
381 	int fd_userns = -EBADF, ret = -1;
382 	ssize_t bytes;
383 	pid_t pid;
384 	char path[256];
385 
386 	if (h->level == MAX_USERNS_LEVEL)
387 		return 0;
388 
389 	ret = socketpair(AF_LOCAL, SOCK_STREAM | SOCK_CLOEXEC, 0, fd_socket);
390 	if (ret < 0)
391 		return syserror("failure: create socketpair");
392 
393 	/* Note the CLONE_FILES | CLONE_VM when mucking with fds and memory. */
394 	h->fd_event = fd_socket[1];
395 	pid = do_clone(userns_fd_cb, h, CLONE_NEWUSER | CLONE_FILES | CLONE_VM);
396 	if (pid < 0) {
397 		syserror("failure: userns level %d", h->level);
398 		goto out_close;
399 	}
400 
401 	ret = map_ids_from_idmap(&h->id_map, pid);
402 	if (ret < 0) {
403 		kill(pid, SIGKILL);
404 		syserror("failure: writing id mapping for userns level %d for %d", h->level, pid);
405 		goto out_wait;
406 	}
407 
408 	if (!list_empty(&h->id_map))
409 		bytes = write_nointr(fd_socket[0], "1", 1); /* Inform the child we wrote a mapping. */
410 	else
411 		bytes = write_nointr(fd_socket[0], "0", 1); /* Inform the child we didn't write a mapping. */
412 	if (bytes < 0) {
413 		kill(pid, SIGKILL);
414 		syserror("failure: write to socketpair");
415 		goto out_wait;
416 	}
417 
418 	/* Wait for child to set*id() and become dumpable. */
419 	bytes = read_nointr(fd_socket[0], &c, 1);
420 	if (bytes < 0) {
421 		kill(pid, SIGKILL);
422 		syserror("failure: read from socketpair");
423 		goto out_wait;
424 	}
425 
426 	snprintf(path, sizeof(path), "/proc/%d/ns/user", pid);
427 	fd_userns = open(path, O_RDONLY | O_CLOEXEC);
428 	if (fd_userns < 0) {
429 		kill(pid, SIGKILL);
430 		syserror("failure: open userns level %d for %d", h->level, pid);
431 		goto out_wait;
432 	}
433 
434 	fret = 0;
435 
436 out_wait:
437 	if (!wait_for_pid(pid) && !fret) {
438 		h->fd_userns = fd_userns;
439 		fd_userns = -EBADF;
440 	}
441 
442 out_close:
443 	if (fd_userns >= 0)
444 		close(fd_userns);
445 	close(fd_socket[0]);
446 	close(fd_socket[1]);
447 	return fret;
448 }
449 
450 /* caps_down - lower all effective caps */
451 int caps_down(void)
452 {
453 	bool fret = false;
454 	cap_t caps = NULL;
455 	int ret = -1;
456 
457 	caps = cap_get_proc();
458 	if (!caps)
459 		goto out;
460 
461 	ret = cap_clear_flag(caps, CAP_EFFECTIVE);
462 	if (ret)
463 		goto out;
464 
465 	ret = cap_set_proc(caps);
466 	if (ret)
467 		goto out;
468 
469 	fret = true;
470 
471 out:
472 	cap_free(caps);
473 	return fret;
474 }
475 
476 /* cap_down - lower an effective cap */
477 int cap_down(cap_value_t down)
478 {
479 	bool fret = false;
480 	cap_t caps = NULL;
481 	cap_value_t cap = down;
482 	int ret = -1;
483 
484 	caps = cap_get_proc();
485 	if (!caps)
486 		goto out;
487 
488 	ret = cap_set_flag(caps, CAP_EFFECTIVE, 1, &cap, 0);
489 	if (ret)
490 		goto out;
491 
492 	ret = cap_set_proc(caps);
493 	if (ret)
494 		goto out;
495 
496 	fret = true;
497 
498 out:
499 	cap_free(caps);
500 	return fret;
501 }
502