xref: /linux/tools/testing/selftests/filesystems/statmount/statmount_test_ns.c (revision 7a5f1cd22d47f8ca4b760b6334378ae42c1bd24b)
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 
3 #define _GNU_SOURCE
4 
5 #include <assert.h>
6 #include <fcntl.h>
7 #include <limits.h>
8 #include <sched.h>
9 #include <stdlib.h>
10 #include <sys/mount.h>
11 #include <sys/stat.h>
12 #include <sys/wait.h>
13 #include <linux/nsfs.h>
14 #include <linux/stat.h>
15 
16 #include "statmount.h"
17 #include "../utils.h"
18 #include "kselftest.h"
19 
20 #define NSID_PASS 0
21 #define NSID_FAIL 1
22 #define NSID_SKIP 2
23 #define NSID_ERROR 3
24 
25 static void handle_result(int ret, const char *testname)
26 {
27 	if (ret == NSID_PASS)
28 		ksft_test_result_pass("%s\n", testname);
29 	else if (ret == NSID_FAIL)
30 		ksft_test_result_fail("%s\n", testname);
31 	else if (ret == NSID_ERROR)
32 		ksft_exit_fail_msg("%s\n", testname);
33 	else
34 		ksft_test_result_skip("%s\n", testname);
35 }
36 
37 static int get_mnt_ns_id(const char *mnt_ns, uint64_t *mnt_ns_id)
38 {
39 	int fd = open(mnt_ns, O_RDONLY);
40 
41 	if (fd < 0) {
42 		ksft_print_msg("failed to open for ns %s: %s\n",
43 			       mnt_ns, strerror(errno));
44 		sleep(60);
45 		return NSID_ERROR;
46 	}
47 
48 	if (ioctl(fd, NS_GET_MNTNS_ID, mnt_ns_id) < 0) {
49 		ksft_print_msg("failed to get the nsid for ns %s: %s\n",
50 			       mnt_ns, strerror(errno));
51 		return NSID_ERROR;
52 	}
53 	close(fd);
54 	return NSID_PASS;
55 }
56 
57 static int setup_namespace(void)
58 {
59 	if (setup_userns() != 0)
60 		return NSID_ERROR;
61 
62 	return NSID_PASS;
63 }
64 
65 static int _test_statmount_mnt_ns_id(void)
66 {
67 	struct statmount sm;
68 	uint64_t mnt_ns_id;
69 	uint64_t root_id;
70 	int ret;
71 
72 	ret = get_mnt_ns_id("/proc/self/ns/mnt", &mnt_ns_id);
73 	if (ret != NSID_PASS)
74 		return ret;
75 
76 	root_id = get_unique_mnt_id("/");
77 	if (!root_id)
78 		return NSID_ERROR;
79 
80 	ret = statmount(root_id, 0, 0, STATMOUNT_MNT_NS_ID, &sm, sizeof(sm), 0);
81 	if (ret == -1) {
82 		ksft_print_msg("statmount mnt ns id: %s\n", strerror(errno));
83 		return NSID_ERROR;
84 	}
85 
86 	if (sm.size != sizeof(sm)) {
87 		ksft_print_msg("unexpected size: %u != %u\n", sm.size,
88 			       (uint32_t)sizeof(sm));
89 		return NSID_FAIL;
90 	}
91 	if (sm.mask != STATMOUNT_MNT_NS_ID) {
92 		ksft_print_msg("statmount mnt ns id unavailable\n");
93 		return NSID_SKIP;
94 	}
95 
96 	if (sm.mnt_ns_id != mnt_ns_id) {
97 		ksft_print_msg("unexpected mnt ns ID: 0x%llx != 0x%llx\n",
98 			       (unsigned long long)sm.mnt_ns_id,
99 			       (unsigned long long)mnt_ns_id);
100 		return NSID_FAIL;
101 	}
102 
103 	return NSID_PASS;
104 }
105 
106 static int _test_statmount_mnt_ns_id_by_fd(void)
107 {
108 	struct statmount sm;
109 	uint64_t mnt_ns_id;
110 	int ret, fd, mounted = 1, status = NSID_ERROR;
111 	char mnt[] = "/statmount.fd.XXXXXX";
112 
113 	ret = get_mnt_ns_id("/proc/self/ns/mnt", &mnt_ns_id);
114 	if (ret != NSID_PASS)
115 		return ret;
116 
117 	if (!mkdtemp(mnt)) {
118 		ksft_print_msg("statmount by fd mnt ns id mkdtemp: %s\n", strerror(errno));
119 		return NSID_ERROR;
120 	}
121 
122 	if (mount(mnt, mnt, NULL, MS_BIND, 0)) {
123 		ksft_print_msg("statmount by fd mnt ns id mount: %s\n", strerror(errno));
124 		status = NSID_ERROR;
125 		goto err;
126 	}
127 
128 	fd = open(mnt, O_PATH);
129 	if (fd < 0) {
130 		ksft_print_msg("statmount by fd mnt ns id open: %s\n", strerror(errno));
131 		goto err;
132 	}
133 
134 	ret = statmount(0, 0, fd, STATMOUNT_MNT_NS_ID, &sm, sizeof(sm), STATMOUNT_BY_FD);
135 	if (ret == -1) {
136 		ksft_print_msg("statmount mnt ns id statmount: %s\n", strerror(errno));
137 		status = NSID_ERROR;
138 		goto out;
139 	}
140 
141 	if (sm.size != sizeof(sm)) {
142 		ksft_print_msg("unexpected size: %u != %u\n", sm.size,
143 			       (uint32_t)sizeof(sm));
144 		status = NSID_FAIL;
145 		goto out;
146 	}
147 	if (sm.mask != STATMOUNT_MNT_NS_ID) {
148 		ksft_print_msg("statmount mnt ns id unavailable\n");
149 		status = NSID_SKIP;
150 		goto out;
151 	}
152 
153 	if (sm.mnt_ns_id != mnt_ns_id) {
154 		ksft_print_msg("unexpected mnt ns ID: 0x%llx != 0x%llx\n",
155 			       (unsigned long long)sm.mnt_ns_id,
156 			       (unsigned long long)mnt_ns_id);
157 		status = NSID_FAIL;
158 		goto out;
159 	}
160 
161 	mounted = 0;
162 	if (umount2(mnt, MNT_DETACH)) {
163 		ksft_print_msg("statmount by fd mnt ns id umount2: %s\n", strerror(errno));
164 		goto out;
165 	}
166 
167 	ret = statmount(0, 0, fd, STATMOUNT_MNT_NS_ID, &sm, sizeof(sm), STATMOUNT_BY_FD);
168 	if (ret == -1) {
169 		ksft_print_msg("statmount mnt ns id statmount: %s\n", strerror(errno));
170 		status = NSID_ERROR;
171 		goto out;
172 	}
173 
174 	if (sm.size != sizeof(sm)) {
175 		ksft_print_msg("unexpected size: %u != %u\n", sm.size,
176 			       (uint32_t)sizeof(sm));
177 		status = NSID_FAIL;
178 		goto out;
179 	}
180 
181 	if (sm.mask == STATMOUNT_MNT_NS_ID) {
182 		ksft_print_msg("unexpected STATMOUNT_MNT_NS_ID in mask\n");
183 		status = NSID_FAIL;
184 		goto out;
185 	}
186 
187 	status = NSID_PASS;
188 out:
189 	close(fd);
190 	if (mounted)
191 		umount2(mnt, MNT_DETACH);
192 err:
193 	rmdir(mnt);
194 	return status;
195 }
196 
197 
198 static void test_statmount_mnt_ns_id(void)
199 {
200 	pid_t pid;
201 	int ret;
202 
203 	pid = fork();
204 	if (pid < 0)
205 		ksft_exit_fail_msg("failed to fork: %s\n", strerror(errno));
206 
207 	/* We're the original pid, wait for the result. */
208 	if (pid != 0) {
209 		ret = wait_for_pid(pid);
210 		handle_result(ret, "test statmount ns id");
211 		return;
212 	}
213 
214 	ret = setup_namespace();
215 	if (ret != NSID_PASS)
216 		exit(ret);
217 	ret = _test_statmount_mnt_ns_id();
218 	if (ret != NSID_PASS)
219 		exit(ret);
220 	ret = _test_statmount_mnt_ns_id_by_fd();
221 	exit(ret);
222 }
223 
224 static int validate_external_listmount(pid_t pid, uint64_t child_nr_mounts)
225 {
226 	uint64_t list[256];
227 	uint64_t mnt_ns_id;
228 	uint64_t nr_mounts;
229 	char buf[256];
230 	int ret;
231 
232 	/* Get the mount ns id for our child. */
233 	snprintf(buf, sizeof(buf), "/proc/%lu/ns/mnt", (unsigned long)pid);
234 	ret = get_mnt_ns_id(buf, &mnt_ns_id);
235 
236 	nr_mounts = listmount(LSMT_ROOT, mnt_ns_id, 0, list, 256, 0);
237 	if (nr_mounts == (uint64_t)-1) {
238 		ksft_print_msg("listmount: %s\n", strerror(errno));
239 		return NSID_ERROR;
240 	}
241 
242 	if (nr_mounts != child_nr_mounts) {
243 		ksft_print_msg("listmount results is %zi != %zi\n", nr_mounts,
244 			       child_nr_mounts);
245 		return NSID_FAIL;
246 	}
247 
248 	/* Validate that all of our entries match our mnt_ns_id. */
249 	for (int i = 0; i < nr_mounts; i++) {
250 		struct statmount sm;
251 
252 		ret = statmount(list[i], mnt_ns_id, 0, STATMOUNT_MNT_NS_ID, &sm,
253 				sizeof(sm), 0);
254 		if (ret < 0) {
255 			ksft_print_msg("statmount mnt ns id: %s\n", strerror(errno));
256 			return NSID_ERROR;
257 		}
258 
259 		if (sm.mask != STATMOUNT_MNT_NS_ID) {
260 			ksft_print_msg("statmount mnt ns id unavailable\n");
261 			return NSID_SKIP;
262 		}
263 
264 		if (sm.mnt_ns_id != mnt_ns_id) {
265 			ksft_print_msg("listmount gave us the wrong ns id: 0x%llx != 0x%llx\n",
266 				       (unsigned long long)sm.mnt_ns_id,
267 				       (unsigned long long)mnt_ns_id);
268 			return NSID_FAIL;
269 		}
270 	}
271 
272 	return NSID_PASS;
273 }
274 
275 static void test_listmount_ns(void)
276 {
277 	uint64_t nr_mounts;
278 	char pval;
279 	int child_ready_pipe[2];
280 	int parent_ready_pipe[2];
281 	pid_t pid;
282 	int ret, child_ret;
283 
284 	if (pipe(child_ready_pipe) < 0)
285 		ksft_exit_fail_msg("failed to create the child pipe: %s\n",
286 				   strerror(errno));
287 	if (pipe(parent_ready_pipe) < 0)
288 		ksft_exit_fail_msg("failed to create the parent pipe: %s\n",
289 				   strerror(errno));
290 
291 	pid = fork();
292 	if (pid < 0)
293 		ksft_exit_fail_msg("failed to fork: %s\n", strerror(errno));
294 
295 	if (pid == 0) {
296 		char cval;
297 		uint64_t list[256];
298 
299 		close(child_ready_pipe[0]);
300 		close(parent_ready_pipe[1]);
301 
302 		ret = setup_namespace();
303 		if (ret != NSID_PASS)
304 			exit(ret);
305 
306 		nr_mounts = listmount(LSMT_ROOT, 0, 0, list, 256, 0);
307 		if (nr_mounts == (uint64_t)-1) {
308 			ksft_print_msg("listmount: %s\n", strerror(errno));
309 			exit(NSID_FAIL);
310 		}
311 
312 		/*
313 		 * Tell our parent how many mounts we have, and then wait for it
314 		 * to tell us we're done.
315 		 */
316 		if (write(child_ready_pipe[1], &nr_mounts, sizeof(nr_mounts)) !=
317 					sizeof(nr_mounts))
318 			ret = NSID_ERROR;
319 		if (read(parent_ready_pipe[0], &cval, sizeof(cval)) != sizeof(cval))
320 			ret = NSID_ERROR;
321 		exit(NSID_PASS);
322 	}
323 
324 	close(child_ready_pipe[1]);
325 	close(parent_ready_pipe[0]);
326 
327 	/* Wait until the child has created everything. */
328 	if (read(child_ready_pipe[0], &nr_mounts, sizeof(nr_mounts)) !=
329 	    sizeof(nr_mounts))
330 		ret = NSID_ERROR;
331 
332 	ret = validate_external_listmount(pid, nr_mounts);
333 
334 	if (write(parent_ready_pipe[1], &pval, sizeof(pval)) != sizeof(pval))
335 		ret = NSID_ERROR;
336 
337 	child_ret = wait_for_pid(pid);
338 	if (child_ret != NSID_PASS)
339 		ret = child_ret;
340 	handle_result(ret, "test listmount ns id");
341 }
342 
343 int main(void)
344 {
345 	int ret;
346 
347 	ksft_print_header();
348 	ret = statmount(0, 0, 0, 0, NULL, 0, 0);
349 	assert(ret == -1);
350 	if (errno == ENOSYS)
351 		ksft_exit_skip("statmount() syscall not supported\n");
352 
353 	ksft_set_plan(2);
354 	test_statmount_mnt_ns_id();
355 	test_listmount_ns();
356 
357 	if (ksft_get_fail_cnt() + ksft_get_error_cnt() > 0)
358 		ksft_exit_fail();
359 	else
360 		ksft_exit_pass();
361 }
362