1 // SPDX-License-Identifier: GPL-2.0-or-later 2 3 #define _GNU_SOURCE 4 5 #include <assert.h> 6 #include <fcntl.h> 7 #include <limits.h> 8 #include <sched.h> 9 #include <stdlib.h> 10 #include <sys/mount.h> 11 #include <sys/stat.h> 12 #include <sys/wait.h> 13 #include <linux/nsfs.h> 14 #include <linux/stat.h> 15 16 #include "statmount.h" 17 #include "../../kselftest.h" 18 19 #define NSID_PASS 0 20 #define NSID_FAIL 1 21 #define NSID_SKIP 2 22 #define NSID_ERROR 3 23 24 static void handle_result(int ret, const char *testname) 25 { 26 if (ret == NSID_PASS) 27 ksft_test_result_pass("%s\n", testname); 28 else if (ret == NSID_FAIL) 29 ksft_test_result_fail("%s\n", testname); 30 else if (ret == NSID_ERROR) 31 ksft_exit_fail_msg("%s\n", testname); 32 else 33 ksft_test_result_skip("%s\n", testname); 34 } 35 36 static inline int wait_for_pid(pid_t pid) 37 { 38 int status, ret; 39 40 again: 41 ret = waitpid(pid, &status, 0); 42 if (ret == -1) { 43 if (errno == EINTR) 44 goto again; 45 46 ksft_print_msg("waitpid returned -1, errno=%d\n", errno); 47 return -1; 48 } 49 50 if (!WIFEXITED(status)) { 51 ksft_print_msg( 52 "waitpid !WIFEXITED, WIFSIGNALED=%d, WTERMSIG=%d\n", 53 WIFSIGNALED(status), WTERMSIG(status)); 54 return -1; 55 } 56 57 ret = WEXITSTATUS(status); 58 return ret; 59 } 60 61 static int get_mnt_ns_id(const char *mnt_ns, uint64_t *mnt_ns_id) 62 { 63 int fd = open(mnt_ns, O_RDONLY); 64 65 if (fd < 0) { 66 ksft_print_msg("failed to open for ns %s: %s\n", 67 mnt_ns, strerror(errno)); 68 sleep(60); 69 return NSID_ERROR; 70 } 71 72 if (ioctl(fd, NS_GET_MNTNS_ID, mnt_ns_id) < 0) { 73 ksft_print_msg("failed to get the nsid for ns %s: %s\n", 74 mnt_ns, strerror(errno)); 75 return NSID_ERROR; 76 } 77 close(fd); 78 return NSID_PASS; 79 } 80 81 static int get_mnt_id(const char *path, uint64_t *mnt_id) 82 { 83 struct statx sx; 84 int ret; 85 86 ret = statx(AT_FDCWD, path, 0, STATX_MNT_ID_UNIQUE, &sx); 87 if (ret == -1) { 88 ksft_print_msg("retrieving unique mount ID for %s: %s\n", path, 89 strerror(errno)); 90 return NSID_ERROR; 91 } 92 93 if (!(sx.stx_mask & STATX_MNT_ID_UNIQUE)) { 94 ksft_print_msg("no unique mount ID available for %s\n", path); 95 return NSID_ERROR; 96 } 97 98 *mnt_id = sx.stx_mnt_id; 99 return NSID_PASS; 100 } 101 102 static int write_file(const char *path, const char *val) 103 { 104 int fd = open(path, O_WRONLY); 105 size_t len = strlen(val); 106 int ret; 107 108 if (fd == -1) { 109 ksft_print_msg("opening %s for write: %s\n", path, strerror(errno)); 110 return NSID_ERROR; 111 } 112 113 ret = write(fd, val, len); 114 if (ret == -1) { 115 ksft_print_msg("writing to %s: %s\n", path, strerror(errno)); 116 return NSID_ERROR; 117 } 118 if (ret != len) { 119 ksft_print_msg("short write to %s\n", path); 120 return NSID_ERROR; 121 } 122 123 ret = close(fd); 124 if (ret == -1) { 125 ksft_print_msg("closing %s\n", path); 126 return NSID_ERROR; 127 } 128 129 return NSID_PASS; 130 } 131 132 static int setup_namespace(void) 133 { 134 int ret; 135 char buf[32]; 136 uid_t uid = getuid(); 137 gid_t gid = getgid(); 138 139 ret = unshare(CLONE_NEWNS|CLONE_NEWUSER|CLONE_NEWPID); 140 if (ret == -1) 141 ksft_exit_fail_msg("unsharing mountns and userns: %s\n", 142 strerror(errno)); 143 144 sprintf(buf, "0 %d 1", uid); 145 ret = write_file("/proc/self/uid_map", buf); 146 if (ret != NSID_PASS) 147 return ret; 148 ret = write_file("/proc/self/setgroups", "deny"); 149 if (ret != NSID_PASS) 150 return ret; 151 sprintf(buf, "0 %d 1", gid); 152 ret = write_file("/proc/self/gid_map", buf); 153 if (ret != NSID_PASS) 154 return ret; 155 156 ret = mount("", "/", NULL, MS_REC|MS_PRIVATE, NULL); 157 if (ret == -1) { 158 ksft_print_msg("making mount tree private: %s\n", 159 strerror(errno)); 160 return NSID_ERROR; 161 } 162 163 return NSID_PASS; 164 } 165 166 static int _test_statmount_mnt_ns_id(void) 167 { 168 struct statmount sm; 169 uint64_t mnt_ns_id; 170 uint64_t root_id; 171 int ret; 172 173 ret = get_mnt_ns_id("/proc/self/ns/mnt", &mnt_ns_id); 174 if (ret != NSID_PASS) 175 return ret; 176 177 ret = get_mnt_id("/", &root_id); 178 if (ret != NSID_PASS) 179 return ret; 180 181 ret = statmount(root_id, 0, STATMOUNT_MNT_NS_ID, &sm, sizeof(sm), 0); 182 if (ret == -1) { 183 ksft_print_msg("statmount mnt ns id: %s\n", strerror(errno)); 184 return NSID_ERROR; 185 } 186 187 if (sm.size != sizeof(sm)) { 188 ksft_print_msg("unexpected size: %u != %u\n", sm.size, 189 (uint32_t)sizeof(sm)); 190 return NSID_FAIL; 191 } 192 if (sm.mask != STATMOUNT_MNT_NS_ID) { 193 ksft_print_msg("statmount mnt ns id unavailable\n"); 194 return NSID_SKIP; 195 } 196 197 if (sm.mnt_ns_id != mnt_ns_id) { 198 ksft_print_msg("unexpected mnt ns ID: 0x%llx != 0x%llx\n", 199 (unsigned long long)sm.mnt_ns_id, 200 (unsigned long long)mnt_ns_id); 201 return NSID_FAIL; 202 } 203 204 return NSID_PASS; 205 } 206 207 static void test_statmount_mnt_ns_id(void) 208 { 209 pid_t pid; 210 int ret; 211 212 pid = fork(); 213 if (pid < 0) 214 ksft_exit_fail_msg("failed to fork: %s\n", strerror(errno)); 215 216 /* We're the original pid, wait for the result. */ 217 if (pid != 0) { 218 ret = wait_for_pid(pid); 219 handle_result(ret, "test statmount ns id"); 220 return; 221 } 222 223 ret = setup_namespace(); 224 if (ret != NSID_PASS) 225 exit(ret); 226 ret = _test_statmount_mnt_ns_id(); 227 exit(ret); 228 } 229 230 static int validate_external_listmount(pid_t pid, uint64_t child_nr_mounts) 231 { 232 uint64_t list[256]; 233 uint64_t mnt_ns_id; 234 uint64_t nr_mounts; 235 char buf[256]; 236 int ret; 237 238 /* Get the mount ns id for our child. */ 239 snprintf(buf, sizeof(buf), "/proc/%lu/ns/mnt", (unsigned long)pid); 240 ret = get_mnt_ns_id(buf, &mnt_ns_id); 241 242 nr_mounts = listmount(LSMT_ROOT, mnt_ns_id, 0, list, 256, 0); 243 if (nr_mounts == (uint64_t)-1) { 244 ksft_print_msg("listmount: %s\n", strerror(errno)); 245 return NSID_ERROR; 246 } 247 248 if (nr_mounts != child_nr_mounts) { 249 ksft_print_msg("listmount results is %zi != %zi\n", nr_mounts, 250 child_nr_mounts); 251 return NSID_FAIL; 252 } 253 254 /* Validate that all of our entries match our mnt_ns_id. */ 255 for (int i = 0; i < nr_mounts; i++) { 256 struct statmount sm; 257 258 ret = statmount(list[i], mnt_ns_id, STATMOUNT_MNT_NS_ID, &sm, 259 sizeof(sm), 0); 260 if (ret < 0) { 261 ksft_print_msg("statmount mnt ns id: %s\n", strerror(errno)); 262 return NSID_ERROR; 263 } 264 265 if (sm.mask != STATMOUNT_MNT_NS_ID) { 266 ksft_print_msg("statmount mnt ns id unavailable\n"); 267 return NSID_SKIP; 268 } 269 270 if (sm.mnt_ns_id != mnt_ns_id) { 271 ksft_print_msg("listmount gave us the wrong ns id: 0x%llx != 0x%llx\n", 272 (unsigned long long)sm.mnt_ns_id, 273 (unsigned long long)mnt_ns_id); 274 return NSID_FAIL; 275 } 276 } 277 278 return NSID_PASS; 279 } 280 281 static void test_listmount_ns(void) 282 { 283 uint64_t nr_mounts; 284 char pval; 285 int child_ready_pipe[2]; 286 int parent_ready_pipe[2]; 287 pid_t pid; 288 int ret, child_ret; 289 290 if (pipe(child_ready_pipe) < 0) 291 ksft_exit_fail_msg("failed to create the child pipe: %s\n", 292 strerror(errno)); 293 if (pipe(parent_ready_pipe) < 0) 294 ksft_exit_fail_msg("failed to create the parent pipe: %s\n", 295 strerror(errno)); 296 297 pid = fork(); 298 if (pid < 0) 299 ksft_exit_fail_msg("failed to fork: %s\n", strerror(errno)); 300 301 if (pid == 0) { 302 char cval; 303 uint64_t list[256]; 304 305 close(child_ready_pipe[0]); 306 close(parent_ready_pipe[1]); 307 308 ret = setup_namespace(); 309 if (ret != NSID_PASS) 310 exit(ret); 311 312 nr_mounts = listmount(LSMT_ROOT, 0, 0, list, 256, 0); 313 if (nr_mounts == (uint64_t)-1) { 314 ksft_print_msg("listmount: %s\n", strerror(errno)); 315 exit(NSID_FAIL); 316 } 317 318 /* 319 * Tell our parent how many mounts we have, and then wait for it 320 * to tell us we're done. 321 */ 322 if (write(child_ready_pipe[1], &nr_mounts, sizeof(nr_mounts)) != 323 sizeof(nr_mounts)) 324 ret = NSID_ERROR; 325 if (read(parent_ready_pipe[0], &cval, sizeof(cval)) != sizeof(cval)) 326 ret = NSID_ERROR; 327 exit(NSID_PASS); 328 } 329 330 close(child_ready_pipe[1]); 331 close(parent_ready_pipe[0]); 332 333 /* Wait until the child has created everything. */ 334 if (read(child_ready_pipe[0], &nr_mounts, sizeof(nr_mounts)) != 335 sizeof(nr_mounts)) 336 ret = NSID_ERROR; 337 338 ret = validate_external_listmount(pid, nr_mounts); 339 340 if (write(parent_ready_pipe[1], &pval, sizeof(pval)) != sizeof(pval)) 341 ret = NSID_ERROR; 342 343 child_ret = wait_for_pid(pid); 344 if (child_ret != NSID_PASS) 345 ret = child_ret; 346 handle_result(ret, "test listmount ns id"); 347 } 348 349 int main(void) 350 { 351 int ret; 352 353 ksft_print_header(); 354 ret = statmount(0, 0, 0, NULL, 0, 0); 355 assert(ret == -1); 356 if (errno == ENOSYS) 357 ksft_exit_skip("statmount() syscall not supported\n"); 358 359 ksft_set_plan(2); 360 test_statmount_mnt_ns_id(); 361 test_listmount_ns(); 362 363 if (ksft_get_fail_cnt() + ksft_get_error_cnt() > 0) 364 ksft_exit_fail(); 365 else 366 ksft_exit_pass(); 367 } 368