xref: /linux/tools/testing/selftests/filesystems/statmount/statmount_test_ns.c (revision 23b0f90ba871f096474e1c27c3d14f455189d2d9)
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 
3 #define _GNU_SOURCE
4 
5 #include <assert.h>
6 #include <fcntl.h>
7 #include <limits.h>
8 #include <sched.h>
9 #include <stdlib.h>
10 #include <sys/mount.h>
11 #include <sys/stat.h>
12 #include <sys/wait.h>
13 #include <linux/nsfs.h>
14 #include <linux/stat.h>
15 
16 #include "statmount.h"
17 #include "../utils.h"
18 #include "kselftest.h"
19 
20 #define NSID_PASS 0
21 #define NSID_FAIL 1
22 #define NSID_SKIP 2
23 #define NSID_ERROR 3
24 
25 static void handle_result(int ret, const char *testname)
26 {
27 	if (ret == NSID_PASS)
28 		ksft_test_result_pass("%s\n", testname);
29 	else if (ret == NSID_FAIL)
30 		ksft_test_result_fail("%s\n", testname);
31 	else if (ret == NSID_ERROR)
32 		ksft_exit_fail_msg("%s\n", testname);
33 	else
34 		ksft_test_result_skip("%s\n", testname);
35 }
36 
37 static inline int wait_for_pid(pid_t pid)
38 {
39 	int status, ret;
40 
41 again:
42 	ret = waitpid(pid, &status, 0);
43 	if (ret == -1) {
44 		if (errno == EINTR)
45 			goto again;
46 
47 		ksft_print_msg("waitpid returned -1, errno=%d\n", errno);
48 		return -1;
49 	}
50 
51 	if (!WIFEXITED(status)) {
52 		ksft_print_msg(
53 		       "waitpid !WIFEXITED, WIFSIGNALED=%d, WTERMSIG=%d\n",
54 		       WIFSIGNALED(status), WTERMSIG(status));
55 		return -1;
56 	}
57 
58 	ret = WEXITSTATUS(status);
59 	return ret;
60 }
61 
62 static int get_mnt_ns_id(const char *mnt_ns, uint64_t *mnt_ns_id)
63 {
64 	int fd = open(mnt_ns, O_RDONLY);
65 
66 	if (fd < 0) {
67 		ksft_print_msg("failed to open for ns %s: %s\n",
68 			       mnt_ns, strerror(errno));
69 		sleep(60);
70 		return NSID_ERROR;
71 	}
72 
73 	if (ioctl(fd, NS_GET_MNTNS_ID, mnt_ns_id) < 0) {
74 		ksft_print_msg("failed to get the nsid for ns %s: %s\n",
75 			       mnt_ns, strerror(errno));
76 		return NSID_ERROR;
77 	}
78 	close(fd);
79 	return NSID_PASS;
80 }
81 
82 static int setup_namespace(void)
83 {
84 	if (setup_userns() != 0)
85 		return NSID_ERROR;
86 
87 	return NSID_PASS;
88 }
89 
90 static int _test_statmount_mnt_ns_id(void)
91 {
92 	struct statmount sm;
93 	uint64_t mnt_ns_id;
94 	uint64_t root_id;
95 	int ret;
96 
97 	ret = get_mnt_ns_id("/proc/self/ns/mnt", &mnt_ns_id);
98 	if (ret != NSID_PASS)
99 		return ret;
100 
101 	root_id = get_unique_mnt_id("/");
102 	if (!root_id)
103 		return NSID_ERROR;
104 
105 	ret = statmount(root_id, 0, 0, STATMOUNT_MNT_NS_ID, &sm, sizeof(sm), 0);
106 	if (ret == -1) {
107 		ksft_print_msg("statmount mnt ns id: %s\n", strerror(errno));
108 		return NSID_ERROR;
109 	}
110 
111 	if (sm.size != sizeof(sm)) {
112 		ksft_print_msg("unexpected size: %u != %u\n", sm.size,
113 			       (uint32_t)sizeof(sm));
114 		return NSID_FAIL;
115 	}
116 	if (sm.mask != STATMOUNT_MNT_NS_ID) {
117 		ksft_print_msg("statmount mnt ns id unavailable\n");
118 		return NSID_SKIP;
119 	}
120 
121 	if (sm.mnt_ns_id != mnt_ns_id) {
122 		ksft_print_msg("unexpected mnt ns ID: 0x%llx != 0x%llx\n",
123 			       (unsigned long long)sm.mnt_ns_id,
124 			       (unsigned long long)mnt_ns_id);
125 		return NSID_FAIL;
126 	}
127 
128 	return NSID_PASS;
129 }
130 
131 static int _test_statmount_mnt_ns_id_by_fd(void)
132 {
133 	struct statmount sm;
134 	uint64_t mnt_ns_id;
135 	int ret, fd, mounted = 1, status = NSID_ERROR;
136 	char mnt[] = "/statmount.fd.XXXXXX";
137 
138 	ret = get_mnt_ns_id("/proc/self/ns/mnt", &mnt_ns_id);
139 	if (ret != NSID_PASS)
140 		return ret;
141 
142 	if (!mkdtemp(mnt)) {
143 		ksft_print_msg("statmount by fd mnt ns id mkdtemp: %s\n", strerror(errno));
144 		return NSID_ERROR;
145 	}
146 
147 	if (mount(mnt, mnt, NULL, MS_BIND, 0)) {
148 		ksft_print_msg("statmount by fd mnt ns id mount: %s\n", strerror(errno));
149 		status = NSID_ERROR;
150 		goto err;
151 	}
152 
153 	fd = open(mnt, O_PATH);
154 	if (fd < 0) {
155 		ksft_print_msg("statmount by fd mnt ns id open: %s\n", strerror(errno));
156 		goto err;
157 	}
158 
159 	ret = statmount(0, 0, fd, STATMOUNT_MNT_NS_ID, &sm, sizeof(sm), STATMOUNT_BY_FD);
160 	if (ret == -1) {
161 		ksft_print_msg("statmount mnt ns id statmount: %s\n", strerror(errno));
162 		status = NSID_ERROR;
163 		goto out;
164 	}
165 
166 	if (sm.size != sizeof(sm)) {
167 		ksft_print_msg("unexpected size: %u != %u\n", sm.size,
168 			       (uint32_t)sizeof(sm));
169 		status = NSID_FAIL;
170 		goto out;
171 	}
172 	if (sm.mask != STATMOUNT_MNT_NS_ID) {
173 		ksft_print_msg("statmount mnt ns id unavailable\n");
174 		status = NSID_SKIP;
175 		goto out;
176 	}
177 
178 	if (sm.mnt_ns_id != mnt_ns_id) {
179 		ksft_print_msg("unexpected mnt ns ID: 0x%llx != 0x%llx\n",
180 			       (unsigned long long)sm.mnt_ns_id,
181 			       (unsigned long long)mnt_ns_id);
182 		status = NSID_FAIL;
183 		goto out;
184 	}
185 
186 	mounted = 0;
187 	if (umount2(mnt, MNT_DETACH)) {
188 		ksft_print_msg("statmount by fd mnt ns id umount2: %s\n", strerror(errno));
189 		goto out;
190 	}
191 
192 	ret = statmount(0, 0, fd, STATMOUNT_MNT_NS_ID, &sm, sizeof(sm), STATMOUNT_BY_FD);
193 	if (ret == -1) {
194 		ksft_print_msg("statmount mnt ns id statmount: %s\n", strerror(errno));
195 		status = NSID_ERROR;
196 		goto out;
197 	}
198 
199 	if (sm.size != sizeof(sm)) {
200 		ksft_print_msg("unexpected size: %u != %u\n", sm.size,
201 			       (uint32_t)sizeof(sm));
202 		status = NSID_FAIL;
203 		goto out;
204 	}
205 
206 	if (sm.mask == STATMOUNT_MNT_NS_ID) {
207 		ksft_print_msg("unexpected STATMOUNT_MNT_NS_ID in mask\n");
208 		status = NSID_FAIL;
209 		goto out;
210 	}
211 
212 	status = NSID_PASS;
213 out:
214 	close(fd);
215 	if (mounted)
216 		umount2(mnt, MNT_DETACH);
217 err:
218 	rmdir(mnt);
219 	return status;
220 }
221 
222 
223 static void test_statmount_mnt_ns_id(void)
224 {
225 	pid_t pid;
226 	int ret;
227 
228 	pid = fork();
229 	if (pid < 0)
230 		ksft_exit_fail_msg("failed to fork: %s\n", strerror(errno));
231 
232 	/* We're the original pid, wait for the result. */
233 	if (pid != 0) {
234 		ret = wait_for_pid(pid);
235 		handle_result(ret, "test statmount ns id");
236 		return;
237 	}
238 
239 	ret = setup_namespace();
240 	if (ret != NSID_PASS)
241 		exit(ret);
242 	ret = _test_statmount_mnt_ns_id();
243 	if (ret != NSID_PASS)
244 		exit(ret);
245 	ret = _test_statmount_mnt_ns_id_by_fd();
246 	exit(ret);
247 }
248 
249 static int validate_external_listmount(pid_t pid, uint64_t child_nr_mounts)
250 {
251 	uint64_t list[256];
252 	uint64_t mnt_ns_id;
253 	uint64_t nr_mounts;
254 	char buf[256];
255 	int ret;
256 
257 	/* Get the mount ns id for our child. */
258 	snprintf(buf, sizeof(buf), "/proc/%lu/ns/mnt", (unsigned long)pid);
259 	ret = get_mnt_ns_id(buf, &mnt_ns_id);
260 
261 	nr_mounts = listmount(LSMT_ROOT, mnt_ns_id, 0, list, 256, 0);
262 	if (nr_mounts == (uint64_t)-1) {
263 		ksft_print_msg("listmount: %s\n", strerror(errno));
264 		return NSID_ERROR;
265 	}
266 
267 	if (nr_mounts != child_nr_mounts) {
268 		ksft_print_msg("listmount results is %zi != %zi\n", nr_mounts,
269 			       child_nr_mounts);
270 		return NSID_FAIL;
271 	}
272 
273 	/* Validate that all of our entries match our mnt_ns_id. */
274 	for (int i = 0; i < nr_mounts; i++) {
275 		struct statmount sm;
276 
277 		ret = statmount(list[i], mnt_ns_id, 0, STATMOUNT_MNT_NS_ID, &sm,
278 				sizeof(sm), 0);
279 		if (ret < 0) {
280 			ksft_print_msg("statmount mnt ns id: %s\n", strerror(errno));
281 			return NSID_ERROR;
282 		}
283 
284 		if (sm.mask != STATMOUNT_MNT_NS_ID) {
285 			ksft_print_msg("statmount mnt ns id unavailable\n");
286 			return NSID_SKIP;
287 		}
288 
289 		if (sm.mnt_ns_id != mnt_ns_id) {
290 			ksft_print_msg("listmount gave us the wrong ns id: 0x%llx != 0x%llx\n",
291 				       (unsigned long long)sm.mnt_ns_id,
292 				       (unsigned long long)mnt_ns_id);
293 			return NSID_FAIL;
294 		}
295 	}
296 
297 	return NSID_PASS;
298 }
299 
300 static void test_listmount_ns(void)
301 {
302 	uint64_t nr_mounts;
303 	char pval;
304 	int child_ready_pipe[2];
305 	int parent_ready_pipe[2];
306 	pid_t pid;
307 	int ret, child_ret;
308 
309 	if (pipe(child_ready_pipe) < 0)
310 		ksft_exit_fail_msg("failed to create the child pipe: %s\n",
311 				   strerror(errno));
312 	if (pipe(parent_ready_pipe) < 0)
313 		ksft_exit_fail_msg("failed to create the parent pipe: %s\n",
314 				   strerror(errno));
315 
316 	pid = fork();
317 	if (pid < 0)
318 		ksft_exit_fail_msg("failed to fork: %s\n", strerror(errno));
319 
320 	if (pid == 0) {
321 		char cval;
322 		uint64_t list[256];
323 
324 		close(child_ready_pipe[0]);
325 		close(parent_ready_pipe[1]);
326 
327 		ret = setup_namespace();
328 		if (ret != NSID_PASS)
329 			exit(ret);
330 
331 		nr_mounts = listmount(LSMT_ROOT, 0, 0, list, 256, 0);
332 		if (nr_mounts == (uint64_t)-1) {
333 			ksft_print_msg("listmount: %s\n", strerror(errno));
334 			exit(NSID_FAIL);
335 		}
336 
337 		/*
338 		 * Tell our parent how many mounts we have, and then wait for it
339 		 * to tell us we're done.
340 		 */
341 		if (write(child_ready_pipe[1], &nr_mounts, sizeof(nr_mounts)) !=
342 					sizeof(nr_mounts))
343 			ret = NSID_ERROR;
344 		if (read(parent_ready_pipe[0], &cval, sizeof(cval)) != sizeof(cval))
345 			ret = NSID_ERROR;
346 		exit(NSID_PASS);
347 	}
348 
349 	close(child_ready_pipe[1]);
350 	close(parent_ready_pipe[0]);
351 
352 	/* Wait until the child has created everything. */
353 	if (read(child_ready_pipe[0], &nr_mounts, sizeof(nr_mounts)) !=
354 	    sizeof(nr_mounts))
355 		ret = NSID_ERROR;
356 
357 	ret = validate_external_listmount(pid, nr_mounts);
358 
359 	if (write(parent_ready_pipe[1], &pval, sizeof(pval)) != sizeof(pval))
360 		ret = NSID_ERROR;
361 
362 	child_ret = wait_for_pid(pid);
363 	if (child_ret != NSID_PASS)
364 		ret = child_ret;
365 	handle_result(ret, "test listmount ns id");
366 }
367 
368 int main(void)
369 {
370 	int ret;
371 
372 	ksft_print_header();
373 	ret = statmount(0, 0, 0, 0, NULL, 0, 0);
374 	assert(ret == -1);
375 	if (errno == ENOSYS)
376 		ksft_exit_skip("statmount() syscall not supported\n");
377 
378 	ksft_set_plan(2);
379 	test_statmount_mnt_ns_id();
380 	test_listmount_ns();
381 
382 	if (ksft_get_fail_cnt() + ksft_get_error_cnt() > 0)
383 		ksft_exit_fail();
384 	else
385 		ksft_exit_pass();
386 }
387