xref: /linux/tools/testing/selftests/filesystems/statmount/statmount_test_ns.c (revision d53b8e36925256097a08d7cb749198d85cbf9b2b)
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 
3 #define _GNU_SOURCE
4 
5 #include <assert.h>
6 #include <fcntl.h>
7 #include <limits.h>
8 #include <sched.h>
9 #include <stdlib.h>
10 #include <sys/mount.h>
11 #include <sys/stat.h>
12 #include <sys/wait.h>
13 #include <linux/nsfs.h>
14 #include <linux/stat.h>
15 
16 #include "statmount.h"
17 #include "../../kselftest.h"
18 
19 #define NSID_PASS 0
20 #define NSID_FAIL 1
21 #define NSID_SKIP 2
22 #define NSID_ERROR 3
23 
24 static void handle_result(int ret, const char *testname)
25 {
26 	if (ret == NSID_PASS)
27 		ksft_test_result_pass("%s\n", testname);
28 	else if (ret == NSID_FAIL)
29 		ksft_test_result_fail("%s\n", testname);
30 	else if (ret == NSID_ERROR)
31 		ksft_exit_fail_msg("%s\n", testname);
32 	else
33 		ksft_test_result_skip("%s\n", testname);
34 }
35 
36 static inline int wait_for_pid(pid_t pid)
37 {
38 	int status, ret;
39 
40 again:
41 	ret = waitpid(pid, &status, 0);
42 	if (ret == -1) {
43 		if (errno == EINTR)
44 			goto again;
45 
46 		ksft_print_msg("waitpid returned -1, errno=%d\n", errno);
47 		return -1;
48 	}
49 
50 	if (!WIFEXITED(status)) {
51 		ksft_print_msg(
52 		       "waitpid !WIFEXITED, WIFSIGNALED=%d, WTERMSIG=%d\n",
53 		       WIFSIGNALED(status), WTERMSIG(status));
54 		return -1;
55 	}
56 
57 	ret = WEXITSTATUS(status);
58 	return ret;
59 }
60 
61 static int get_mnt_ns_id(const char *mnt_ns, uint64_t *mnt_ns_id)
62 {
63 	int fd = open(mnt_ns, O_RDONLY);
64 
65 	if (fd < 0) {
66 		ksft_print_msg("failed to open for ns %s: %s\n",
67 			       mnt_ns, strerror(errno));
68 		sleep(60);
69 		return NSID_ERROR;
70 	}
71 
72 	if (ioctl(fd, NS_GET_MNTNS_ID, mnt_ns_id) < 0) {
73 		ksft_print_msg("failed to get the nsid for ns %s: %s\n",
74 			       mnt_ns, strerror(errno));
75 		return NSID_ERROR;
76 	}
77 	close(fd);
78 	return NSID_PASS;
79 }
80 
81 static int get_mnt_id(const char *path, uint64_t *mnt_id)
82 {
83 	struct statx sx;
84 	int ret;
85 
86 	ret = statx(AT_FDCWD, path, 0, STATX_MNT_ID_UNIQUE, &sx);
87 	if (ret == -1) {
88 		ksft_print_msg("retrieving unique mount ID for %s: %s\n", path,
89 			       strerror(errno));
90 		return NSID_ERROR;
91 	}
92 
93 	if (!(sx.stx_mask & STATX_MNT_ID_UNIQUE)) {
94 		ksft_print_msg("no unique mount ID available for %s\n", path);
95 		return NSID_ERROR;
96 	}
97 
98 	*mnt_id = sx.stx_mnt_id;
99 	return NSID_PASS;
100 }
101 
102 static int write_file(const char *path, const char *val)
103 {
104 	int fd = open(path, O_WRONLY);
105 	size_t len = strlen(val);
106 	int ret;
107 
108 	if (fd == -1) {
109 		ksft_print_msg("opening %s for write: %s\n", path, strerror(errno));
110 		return NSID_ERROR;
111 	}
112 
113 	ret = write(fd, val, len);
114 	if (ret == -1) {
115 		ksft_print_msg("writing to %s: %s\n", path, strerror(errno));
116 		return NSID_ERROR;
117 	}
118 	if (ret != len) {
119 		ksft_print_msg("short write to %s\n", path);
120 		return NSID_ERROR;
121 	}
122 
123 	ret = close(fd);
124 	if (ret == -1) {
125 		ksft_print_msg("closing %s\n", path);
126 		return NSID_ERROR;
127 	}
128 
129 	return NSID_PASS;
130 }
131 
132 static int setup_namespace(void)
133 {
134 	int ret;
135 	char buf[32];
136 	uid_t uid = getuid();
137 	gid_t gid = getgid();
138 
139 	ret = unshare(CLONE_NEWNS|CLONE_NEWUSER|CLONE_NEWPID);
140 	if (ret == -1)
141 		ksft_exit_fail_msg("unsharing mountns and userns: %s\n",
142 				   strerror(errno));
143 
144 	sprintf(buf, "0 %d 1", uid);
145 	ret = write_file("/proc/self/uid_map", buf);
146 	if (ret != NSID_PASS)
147 		return ret;
148 	ret = write_file("/proc/self/setgroups", "deny");
149 	if (ret != NSID_PASS)
150 		return ret;
151 	sprintf(buf, "0 %d 1", gid);
152 	ret = write_file("/proc/self/gid_map", buf);
153 	if (ret != NSID_PASS)
154 		return ret;
155 
156 	ret = mount("", "/", NULL, MS_REC|MS_PRIVATE, NULL);
157 	if (ret == -1) {
158 		ksft_print_msg("making mount tree private: %s\n",
159 			       strerror(errno));
160 		return NSID_ERROR;
161 	}
162 
163 	return NSID_PASS;
164 }
165 
166 static int _test_statmount_mnt_ns_id(void)
167 {
168 	struct statmount sm;
169 	uint64_t mnt_ns_id;
170 	uint64_t root_id;
171 	int ret;
172 
173 	ret = get_mnt_ns_id("/proc/self/ns/mnt", &mnt_ns_id);
174 	if (ret != NSID_PASS)
175 		return ret;
176 
177 	ret = get_mnt_id("/", &root_id);
178 	if (ret != NSID_PASS)
179 		return ret;
180 
181 	ret = statmount(root_id, 0, STATMOUNT_MNT_NS_ID, &sm, sizeof(sm), 0);
182 	if (ret == -1) {
183 		ksft_print_msg("statmount mnt ns id: %s\n", strerror(errno));
184 		return NSID_ERROR;
185 	}
186 
187 	if (sm.size != sizeof(sm)) {
188 		ksft_print_msg("unexpected size: %u != %u\n", sm.size,
189 			       (uint32_t)sizeof(sm));
190 		return NSID_FAIL;
191 	}
192 	if (sm.mask != STATMOUNT_MNT_NS_ID) {
193 		ksft_print_msg("statmount mnt ns id unavailable\n");
194 		return NSID_SKIP;
195 	}
196 
197 	if (sm.mnt_ns_id != mnt_ns_id) {
198 		ksft_print_msg("unexpected mnt ns ID: 0x%llx != 0x%llx\n",
199 			       (unsigned long long)sm.mnt_ns_id,
200 			       (unsigned long long)mnt_ns_id);
201 		return NSID_FAIL;
202 	}
203 
204 	return NSID_PASS;
205 }
206 
207 static void test_statmount_mnt_ns_id(void)
208 {
209 	pid_t pid;
210 	int ret;
211 
212 	pid = fork();
213 	if (pid < 0)
214 		ksft_exit_fail_msg("failed to fork: %s\n", strerror(errno));
215 
216 	/* We're the original pid, wait for the result. */
217 	if (pid != 0) {
218 		ret = wait_for_pid(pid);
219 		handle_result(ret, "test statmount ns id");
220 		return;
221 	}
222 
223 	ret = setup_namespace();
224 	if (ret != NSID_PASS)
225 		exit(ret);
226 	ret = _test_statmount_mnt_ns_id();
227 	exit(ret);
228 }
229 
230 static int validate_external_listmount(pid_t pid, uint64_t child_nr_mounts)
231 {
232 	uint64_t list[256];
233 	uint64_t mnt_ns_id;
234 	uint64_t nr_mounts;
235 	char buf[256];
236 	int ret;
237 
238 	/* Get the mount ns id for our child. */
239 	snprintf(buf, sizeof(buf), "/proc/%lu/ns/mnt", (unsigned long)pid);
240 	ret = get_mnt_ns_id(buf, &mnt_ns_id);
241 
242 	nr_mounts = listmount(LSMT_ROOT, mnt_ns_id, 0, list, 256, 0);
243 	if (nr_mounts == (uint64_t)-1) {
244 		ksft_print_msg("listmount: %s\n", strerror(errno));
245 		return NSID_ERROR;
246 	}
247 
248 	if (nr_mounts != child_nr_mounts) {
249 		ksft_print_msg("listmount results is %zi != %zi\n", nr_mounts,
250 			       child_nr_mounts);
251 		return NSID_FAIL;
252 	}
253 
254 	/* Validate that all of our entries match our mnt_ns_id. */
255 	for (int i = 0; i < nr_mounts; i++) {
256 		struct statmount sm;
257 
258 		ret = statmount(list[i], mnt_ns_id, STATMOUNT_MNT_NS_ID, &sm,
259 				sizeof(sm), 0);
260 		if (ret < 0) {
261 			ksft_print_msg("statmount mnt ns id: %s\n", strerror(errno));
262 			return NSID_ERROR;
263 		}
264 
265 		if (sm.mask != STATMOUNT_MNT_NS_ID) {
266 			ksft_print_msg("statmount mnt ns id unavailable\n");
267 			return NSID_SKIP;
268 		}
269 
270 		if (sm.mnt_ns_id != mnt_ns_id) {
271 			ksft_print_msg("listmount gave us the wrong ns id: 0x%llx != 0x%llx\n",
272 				       (unsigned long long)sm.mnt_ns_id,
273 				       (unsigned long long)mnt_ns_id);
274 			return NSID_FAIL;
275 		}
276 	}
277 
278 	return NSID_PASS;
279 }
280 
281 static void test_listmount_ns(void)
282 {
283 	uint64_t nr_mounts;
284 	char pval;
285 	int child_ready_pipe[2];
286 	int parent_ready_pipe[2];
287 	pid_t pid;
288 	int ret, child_ret;
289 
290 	if (pipe(child_ready_pipe) < 0)
291 		ksft_exit_fail_msg("failed to create the child pipe: %s\n",
292 				   strerror(errno));
293 	if (pipe(parent_ready_pipe) < 0)
294 		ksft_exit_fail_msg("failed to create the parent pipe: %s\n",
295 				   strerror(errno));
296 
297 	pid = fork();
298 	if (pid < 0)
299 		ksft_exit_fail_msg("failed to fork: %s\n", strerror(errno));
300 
301 	if (pid == 0) {
302 		char cval;
303 		uint64_t list[256];
304 
305 		close(child_ready_pipe[0]);
306 		close(parent_ready_pipe[1]);
307 
308 		ret = setup_namespace();
309 		if (ret != NSID_PASS)
310 			exit(ret);
311 
312 		nr_mounts = listmount(LSMT_ROOT, 0, 0, list, 256, 0);
313 		if (nr_mounts == (uint64_t)-1) {
314 			ksft_print_msg("listmount: %s\n", strerror(errno));
315 			exit(NSID_FAIL);
316 		}
317 
318 		/*
319 		 * Tell our parent how many mounts we have, and then wait for it
320 		 * to tell us we're done.
321 		 */
322 		write(child_ready_pipe[1], &nr_mounts, sizeof(nr_mounts));
323 		read(parent_ready_pipe[0], &cval, sizeof(cval));
324 		exit(NSID_PASS);
325 	}
326 
327 	close(child_ready_pipe[1]);
328 	close(parent_ready_pipe[0]);
329 
330 	/* Wait until the child has created everything. */
331 	if (read(child_ready_pipe[0], &nr_mounts, sizeof(nr_mounts)) !=
332 	    sizeof(nr_mounts))
333 		ret = NSID_ERROR;
334 
335 	ret = validate_external_listmount(pid, nr_mounts);
336 
337 	if (write(parent_ready_pipe[1], &pval, sizeof(pval)) != sizeof(pval))
338 		ret = NSID_ERROR;
339 
340 	child_ret = wait_for_pid(pid);
341 	if (child_ret != NSID_PASS)
342 		ret = child_ret;
343 	handle_result(ret, "test listmount ns id");
344 }
345 
346 int main(void)
347 {
348 	int ret;
349 
350 	ksft_print_header();
351 	ret = statmount(0, 0, 0, NULL, 0, 0);
352 	assert(ret == -1);
353 	if (errno == ENOSYS)
354 		ksft_exit_skip("statmount() syscall not supported\n");
355 
356 	ksft_set_plan(2);
357 	test_statmount_mnt_ns_id();
358 	test_listmount_ns();
359 
360 	if (ksft_get_fail_cnt() + ksft_get_error_cnt() > 0)
361 		ksft_exit_fail();
362 	else
363 		ksft_exit_pass();
364 }
365