1 // SPDX-License-Identifier: GPL-2.0-or-later
2
3 #define _GNU_SOURCE
4
5 #include <assert.h>
6 #include <fcntl.h>
7 #include <limits.h>
8 #include <sched.h>
9 #include <stdlib.h>
10 #include <sys/mount.h>
11 #include <sys/stat.h>
12 #include <sys/wait.h>
13 #include <linux/nsfs.h>
14 #include <linux/stat.h>
15
16 #include "statmount.h"
17 #include "../../kselftest.h"
18
19 #define NSID_PASS 0
20 #define NSID_FAIL 1
21 #define NSID_SKIP 2
22 #define NSID_ERROR 3
23
handle_result(int ret,const char * testname)24 static void handle_result(int ret, const char *testname)
25 {
26 if (ret == NSID_PASS)
27 ksft_test_result_pass("%s\n", testname);
28 else if (ret == NSID_FAIL)
29 ksft_test_result_fail("%s\n", testname);
30 else if (ret == NSID_ERROR)
31 ksft_exit_fail_msg("%s\n", testname);
32 else
33 ksft_test_result_skip("%s\n", testname);
34 }
35
wait_for_pid(pid_t pid)36 static inline int wait_for_pid(pid_t pid)
37 {
38 int status, ret;
39
40 again:
41 ret = waitpid(pid, &status, 0);
42 if (ret == -1) {
43 if (errno == EINTR)
44 goto again;
45
46 ksft_print_msg("waitpid returned -1, errno=%d\n", errno);
47 return -1;
48 }
49
50 if (!WIFEXITED(status)) {
51 ksft_print_msg(
52 "waitpid !WIFEXITED, WIFSIGNALED=%d, WTERMSIG=%d\n",
53 WIFSIGNALED(status), WTERMSIG(status));
54 return -1;
55 }
56
57 ret = WEXITSTATUS(status);
58 return ret;
59 }
60
get_mnt_ns_id(const char * mnt_ns,uint64_t * mnt_ns_id)61 static int get_mnt_ns_id(const char *mnt_ns, uint64_t *mnt_ns_id)
62 {
63 int fd = open(mnt_ns, O_RDONLY);
64
65 if (fd < 0) {
66 ksft_print_msg("failed to open for ns %s: %s\n",
67 mnt_ns, strerror(errno));
68 sleep(60);
69 return NSID_ERROR;
70 }
71
72 if (ioctl(fd, NS_GET_MNTNS_ID, mnt_ns_id) < 0) {
73 ksft_print_msg("failed to get the nsid for ns %s: %s\n",
74 mnt_ns, strerror(errno));
75 return NSID_ERROR;
76 }
77 close(fd);
78 return NSID_PASS;
79 }
80
get_mnt_id(const char * path,uint64_t * mnt_id)81 static int get_mnt_id(const char *path, uint64_t *mnt_id)
82 {
83 struct statx sx;
84 int ret;
85
86 ret = statx(AT_FDCWD, path, 0, STATX_MNT_ID_UNIQUE, &sx);
87 if (ret == -1) {
88 ksft_print_msg("retrieving unique mount ID for %s: %s\n", path,
89 strerror(errno));
90 return NSID_ERROR;
91 }
92
93 if (!(sx.stx_mask & STATX_MNT_ID_UNIQUE)) {
94 ksft_print_msg("no unique mount ID available for %s\n", path);
95 return NSID_ERROR;
96 }
97
98 *mnt_id = sx.stx_mnt_id;
99 return NSID_PASS;
100 }
101
write_file(const char * path,const char * val)102 static int write_file(const char *path, const char *val)
103 {
104 int fd = open(path, O_WRONLY);
105 size_t len = strlen(val);
106 int ret;
107
108 if (fd == -1) {
109 ksft_print_msg("opening %s for write: %s\n", path, strerror(errno));
110 return NSID_ERROR;
111 }
112
113 ret = write(fd, val, len);
114 if (ret == -1) {
115 ksft_print_msg("writing to %s: %s\n", path, strerror(errno));
116 return NSID_ERROR;
117 }
118 if (ret != len) {
119 ksft_print_msg("short write to %s\n", path);
120 return NSID_ERROR;
121 }
122
123 ret = close(fd);
124 if (ret == -1) {
125 ksft_print_msg("closing %s\n", path);
126 return NSID_ERROR;
127 }
128
129 return NSID_PASS;
130 }
131
setup_namespace(void)132 static int setup_namespace(void)
133 {
134 int ret;
135 char buf[32];
136 uid_t uid = getuid();
137 gid_t gid = getgid();
138
139 ret = unshare(CLONE_NEWNS|CLONE_NEWUSER|CLONE_NEWPID);
140 if (ret == -1)
141 ksft_exit_fail_msg("unsharing mountns and userns: %s\n",
142 strerror(errno));
143
144 sprintf(buf, "0 %d 1", uid);
145 ret = write_file("/proc/self/uid_map", buf);
146 if (ret != NSID_PASS)
147 return ret;
148 ret = write_file("/proc/self/setgroups", "deny");
149 if (ret != NSID_PASS)
150 return ret;
151 sprintf(buf, "0 %d 1", gid);
152 ret = write_file("/proc/self/gid_map", buf);
153 if (ret != NSID_PASS)
154 return ret;
155
156 ret = mount("", "/", NULL, MS_REC|MS_PRIVATE, NULL);
157 if (ret == -1) {
158 ksft_print_msg("making mount tree private: %s\n",
159 strerror(errno));
160 return NSID_ERROR;
161 }
162
163 return NSID_PASS;
164 }
165
_test_statmount_mnt_ns_id(void)166 static int _test_statmount_mnt_ns_id(void)
167 {
168 struct statmount sm;
169 uint64_t mnt_ns_id;
170 uint64_t root_id;
171 int ret;
172
173 ret = get_mnt_ns_id("/proc/self/ns/mnt", &mnt_ns_id);
174 if (ret != NSID_PASS)
175 return ret;
176
177 ret = get_mnt_id("/", &root_id);
178 if (ret != NSID_PASS)
179 return ret;
180
181 ret = statmount(root_id, 0, STATMOUNT_MNT_NS_ID, &sm, sizeof(sm), 0);
182 if (ret == -1) {
183 ksft_print_msg("statmount mnt ns id: %s\n", strerror(errno));
184 return NSID_ERROR;
185 }
186
187 if (sm.size != sizeof(sm)) {
188 ksft_print_msg("unexpected size: %u != %u\n", sm.size,
189 (uint32_t)sizeof(sm));
190 return NSID_FAIL;
191 }
192 if (sm.mask != STATMOUNT_MNT_NS_ID) {
193 ksft_print_msg("statmount mnt ns id unavailable\n");
194 return NSID_SKIP;
195 }
196
197 if (sm.mnt_ns_id != mnt_ns_id) {
198 ksft_print_msg("unexpected mnt ns ID: 0x%llx != 0x%llx\n",
199 (unsigned long long)sm.mnt_ns_id,
200 (unsigned long long)mnt_ns_id);
201 return NSID_FAIL;
202 }
203
204 return NSID_PASS;
205 }
206
test_statmount_mnt_ns_id(void)207 static void test_statmount_mnt_ns_id(void)
208 {
209 pid_t pid;
210 int ret;
211
212 pid = fork();
213 if (pid < 0)
214 ksft_exit_fail_msg("failed to fork: %s\n", strerror(errno));
215
216 /* We're the original pid, wait for the result. */
217 if (pid != 0) {
218 ret = wait_for_pid(pid);
219 handle_result(ret, "test statmount ns id");
220 return;
221 }
222
223 ret = setup_namespace();
224 if (ret != NSID_PASS)
225 exit(ret);
226 ret = _test_statmount_mnt_ns_id();
227 exit(ret);
228 }
229
validate_external_listmount(pid_t pid,uint64_t child_nr_mounts)230 static int validate_external_listmount(pid_t pid, uint64_t child_nr_mounts)
231 {
232 uint64_t list[256];
233 uint64_t mnt_ns_id;
234 uint64_t nr_mounts;
235 char buf[256];
236 int ret;
237
238 /* Get the mount ns id for our child. */
239 snprintf(buf, sizeof(buf), "/proc/%lu/ns/mnt", (unsigned long)pid);
240 ret = get_mnt_ns_id(buf, &mnt_ns_id);
241
242 nr_mounts = listmount(LSMT_ROOT, mnt_ns_id, 0, list, 256, 0);
243 if (nr_mounts == (uint64_t)-1) {
244 ksft_print_msg("listmount: %s\n", strerror(errno));
245 return NSID_ERROR;
246 }
247
248 if (nr_mounts != child_nr_mounts) {
249 ksft_print_msg("listmount results is %zi != %zi\n", nr_mounts,
250 child_nr_mounts);
251 return NSID_FAIL;
252 }
253
254 /* Validate that all of our entries match our mnt_ns_id. */
255 for (int i = 0; i < nr_mounts; i++) {
256 struct statmount sm;
257
258 ret = statmount(list[i], mnt_ns_id, STATMOUNT_MNT_NS_ID, &sm,
259 sizeof(sm), 0);
260 if (ret < 0) {
261 ksft_print_msg("statmount mnt ns id: %s\n", strerror(errno));
262 return NSID_ERROR;
263 }
264
265 if (sm.mask != STATMOUNT_MNT_NS_ID) {
266 ksft_print_msg("statmount mnt ns id unavailable\n");
267 return NSID_SKIP;
268 }
269
270 if (sm.mnt_ns_id != mnt_ns_id) {
271 ksft_print_msg("listmount gave us the wrong ns id: 0x%llx != 0x%llx\n",
272 (unsigned long long)sm.mnt_ns_id,
273 (unsigned long long)mnt_ns_id);
274 return NSID_FAIL;
275 }
276 }
277
278 return NSID_PASS;
279 }
280
test_listmount_ns(void)281 static void test_listmount_ns(void)
282 {
283 uint64_t nr_mounts;
284 char pval;
285 int child_ready_pipe[2];
286 int parent_ready_pipe[2];
287 pid_t pid;
288 int ret, child_ret;
289
290 if (pipe(child_ready_pipe) < 0)
291 ksft_exit_fail_msg("failed to create the child pipe: %s\n",
292 strerror(errno));
293 if (pipe(parent_ready_pipe) < 0)
294 ksft_exit_fail_msg("failed to create the parent pipe: %s\n",
295 strerror(errno));
296
297 pid = fork();
298 if (pid < 0)
299 ksft_exit_fail_msg("failed to fork: %s\n", strerror(errno));
300
301 if (pid == 0) {
302 char cval;
303 uint64_t list[256];
304
305 close(child_ready_pipe[0]);
306 close(parent_ready_pipe[1]);
307
308 ret = setup_namespace();
309 if (ret != NSID_PASS)
310 exit(ret);
311
312 nr_mounts = listmount(LSMT_ROOT, 0, 0, list, 256, 0);
313 if (nr_mounts == (uint64_t)-1) {
314 ksft_print_msg("listmount: %s\n", strerror(errno));
315 exit(NSID_FAIL);
316 }
317
318 /*
319 * Tell our parent how many mounts we have, and then wait for it
320 * to tell us we're done.
321 */
322 if (write(child_ready_pipe[1], &nr_mounts, sizeof(nr_mounts)) !=
323 sizeof(nr_mounts))
324 ret = NSID_ERROR;
325 if (read(parent_ready_pipe[0], &cval, sizeof(cval)) != sizeof(cval))
326 ret = NSID_ERROR;
327 exit(NSID_PASS);
328 }
329
330 close(child_ready_pipe[1]);
331 close(parent_ready_pipe[0]);
332
333 /* Wait until the child has created everything. */
334 if (read(child_ready_pipe[0], &nr_mounts, sizeof(nr_mounts)) !=
335 sizeof(nr_mounts))
336 ret = NSID_ERROR;
337
338 ret = validate_external_listmount(pid, nr_mounts);
339
340 if (write(parent_ready_pipe[1], &pval, sizeof(pval)) != sizeof(pval))
341 ret = NSID_ERROR;
342
343 child_ret = wait_for_pid(pid);
344 if (child_ret != NSID_PASS)
345 ret = child_ret;
346 handle_result(ret, "test listmount ns id");
347 }
348
main(void)349 int main(void)
350 {
351 int ret;
352
353 ksft_print_header();
354 ret = statmount(0, 0, 0, NULL, 0, 0);
355 assert(ret == -1);
356 if (errno == ENOSYS)
357 ksft_exit_skip("statmount() syscall not supported\n");
358
359 ksft_set_plan(2);
360 test_statmount_mnt_ns_id();
361 test_listmount_ns();
362
363 if (ksft_get_fail_cnt() + ksft_get_error_cnt() > 0)
364 ksft_exit_fail();
365 else
366 ksft_exit_pass();
367 }
368