xref: /linux/tools/testing/selftests/liveupdate/luo_test_utils.c (revision 509d3f45847627f4c5cdce004c3ec79262b5239c)
1 // SPDX-License-Identifier: GPL-2.0-only
2 
3 /*
4  * Copyright (c) 2025, Google LLC.
5  * Pasha Tatashin <pasha.tatashin@soleen.com>
6  */
7 
8 #define _GNU_SOURCE
9 
10 #include <stdio.h>
11 #include <stdlib.h>
12 #include <string.h>
13 #include <getopt.h>
14 #include <fcntl.h>
15 #include <unistd.h>
16 #include <sys/ioctl.h>
17 #include <sys/syscall.h>
18 #include <sys/mman.h>
19 #include <sys/types.h>
20 #include <sys/stat.h>
21 #include <errno.h>
22 #include <stdarg.h>
23 
24 #include "luo_test_utils.h"
25 
luo_open_device(void)26 int luo_open_device(void)
27 {
28 	return open(LUO_DEVICE, O_RDWR);
29 }
30 
luo_create_session(int luo_fd,const char * name)31 int luo_create_session(int luo_fd, const char *name)
32 {
33 	struct liveupdate_ioctl_create_session arg = { .size = sizeof(arg) };
34 
35 	snprintf((char *)arg.name, LIVEUPDATE_SESSION_NAME_LENGTH, "%.*s",
36 		 LIVEUPDATE_SESSION_NAME_LENGTH - 1, name);
37 
38 	if (ioctl(luo_fd, LIVEUPDATE_IOCTL_CREATE_SESSION, &arg) < 0)
39 		return -errno;
40 
41 	return arg.fd;
42 }
43 
luo_retrieve_session(int luo_fd,const char * name)44 int luo_retrieve_session(int luo_fd, const char *name)
45 {
46 	struct liveupdate_ioctl_retrieve_session arg = { .size = sizeof(arg) };
47 
48 	snprintf((char *)arg.name, LIVEUPDATE_SESSION_NAME_LENGTH, "%.*s",
49 		 LIVEUPDATE_SESSION_NAME_LENGTH - 1, name);
50 
51 	if (ioctl(luo_fd, LIVEUPDATE_IOCTL_RETRIEVE_SESSION, &arg) < 0)
52 		return -errno;
53 
54 	return arg.fd;
55 }
56 
create_and_preserve_memfd(int session_fd,int token,const char * data)57 int create_and_preserve_memfd(int session_fd, int token, const char *data)
58 {
59 	struct liveupdate_session_preserve_fd arg = { .size = sizeof(arg) };
60 	long page_size = sysconf(_SC_PAGE_SIZE);
61 	void *map = MAP_FAILED;
62 	int mfd = -1, ret = -1;
63 
64 	mfd = memfd_create("test_mfd", 0);
65 	if (mfd < 0)
66 		return -errno;
67 
68 	if (ftruncate(mfd, page_size) != 0)
69 		goto out;
70 
71 	map = mmap(NULL, page_size, PROT_WRITE, MAP_SHARED, mfd, 0);
72 	if (map == MAP_FAILED)
73 		goto out;
74 
75 	snprintf(map, page_size, "%s", data);
76 	munmap(map, page_size);
77 
78 	arg.fd = mfd;
79 	arg.token = token;
80 	if (ioctl(session_fd, LIVEUPDATE_SESSION_PRESERVE_FD, &arg) < 0)
81 		goto out;
82 
83 	ret = 0;
84 out:
85 	if (ret != 0 && errno != 0)
86 		ret = -errno;
87 	if (mfd >= 0)
88 		close(mfd);
89 	return ret;
90 }
91 
restore_and_verify_memfd(int session_fd,int token,const char * expected_data)92 int restore_and_verify_memfd(int session_fd, int token,
93 			     const char *expected_data)
94 {
95 	struct liveupdate_session_retrieve_fd arg = { .size = sizeof(arg) };
96 	long page_size = sysconf(_SC_PAGE_SIZE);
97 	void *map = MAP_FAILED;
98 	int mfd = -1, ret = -1;
99 
100 	arg.token = token;
101 	if (ioctl(session_fd, LIVEUPDATE_SESSION_RETRIEVE_FD, &arg) < 0)
102 		return -errno;
103 	mfd = arg.fd;
104 
105 	map = mmap(NULL, page_size, PROT_READ, MAP_SHARED, mfd, 0);
106 	if (map == MAP_FAILED)
107 		goto out;
108 
109 	if (expected_data && strcmp(expected_data, map) != 0) {
110 		ksft_print_msg("Data mismatch! Expected '%s', Got '%s'\n",
111 			       expected_data, (char *)map);
112 		ret = -EINVAL;
113 		goto out_munmap;
114 	}
115 
116 	ret = mfd;
117 out_munmap:
118 	munmap(map, page_size);
119 out:
120 	if (ret < 0 && errno != 0)
121 		ret = -errno;
122 	if (ret < 0 && mfd >= 0)
123 		close(mfd);
124 	return ret;
125 }
126 
luo_session_finish(int session_fd)127 int luo_session_finish(int session_fd)
128 {
129 	struct liveupdate_session_finish arg = { .size = sizeof(arg) };
130 
131 	if (ioctl(session_fd, LIVEUPDATE_SESSION_FINISH, &arg) < 0)
132 		return -errno;
133 
134 	return 0;
135 }
136 
create_state_file(int luo_fd,const char * session_name,int token,int next_stage)137 void create_state_file(int luo_fd, const char *session_name, int token,
138 		       int next_stage)
139 {
140 	char buf[32];
141 	int state_session_fd;
142 
143 	state_session_fd = luo_create_session(luo_fd, session_name);
144 	if (state_session_fd < 0)
145 		fail_exit("luo_create_session for state tracking");
146 
147 	snprintf(buf, sizeof(buf), "%d", next_stage);
148 	if (create_and_preserve_memfd(state_session_fd, token, buf) < 0)
149 		fail_exit("create_and_preserve_memfd for state tracking");
150 
151 	/*
152 	 * DO NOT close session FD, otherwise it is going to be unpreserved
153 	 */
154 }
155 
restore_and_read_stage(int state_session_fd,int token,int * stage)156 void restore_and_read_stage(int state_session_fd, int token, int *stage)
157 {
158 	char buf[32] = {0};
159 	int mfd;
160 
161 	mfd = restore_and_verify_memfd(state_session_fd, token, NULL);
162 	if (mfd < 0)
163 		fail_exit("failed to restore state memfd");
164 
165 	if (read(mfd, buf, sizeof(buf) - 1) < 0)
166 		fail_exit("failed to read state mfd");
167 
168 	*stage = atoi(buf);
169 
170 	close(mfd);
171 }
172 
daemonize_and_wait(void)173 void daemonize_and_wait(void)
174 {
175 	pid_t pid;
176 
177 	ksft_print_msg("[STAGE 1] Forking persistent child to hold sessions...\n");
178 
179 	pid = fork();
180 	if (pid < 0)
181 		fail_exit("fork failed");
182 
183 	if (pid > 0) {
184 		ksft_print_msg("[STAGE 1] Child PID: %d. Resources are pinned.\n", pid);
185 		ksft_print_msg("[STAGE 1] You may now perform kexec reboot.\n");
186 		exit(EXIT_SUCCESS);
187 	}
188 
189 	/* Detach from terminal so closing the window doesn't kill us */
190 	if (setsid() < 0)
191 		fail_exit("setsid failed");
192 
193 	close(STDIN_FILENO);
194 	close(STDOUT_FILENO);
195 	close(STDERR_FILENO);
196 
197 	/* Change dir to root to avoid locking filesystems */
198 	if (chdir("/") < 0)
199 		exit(EXIT_FAILURE);
200 
201 	while (1)
202 		sleep(60);
203 }
204 
parse_stage_args(int argc,char * argv[])205 static int parse_stage_args(int argc, char *argv[])
206 {
207 	static struct option long_options[] = {
208 		{"stage", required_argument, 0, 's'},
209 		{0, 0, 0, 0}
210 	};
211 	int option_index = 0;
212 	int stage = 1;
213 	int opt;
214 
215 	optind = 1;
216 	while ((opt = getopt_long(argc, argv, "s:", long_options, &option_index)) != -1) {
217 		switch (opt) {
218 		case 's':
219 			stage = atoi(optarg);
220 			if (stage != 1 && stage != 2)
221 				fail_exit("Invalid stage argument");
222 			break;
223 		default:
224 			fail_exit("Unknown argument");
225 		}
226 	}
227 	return stage;
228 }
229 
luo_test(int argc,char * argv[],const char * state_session_name,luo_test_stage1_fn stage1,luo_test_stage2_fn stage2)230 int luo_test(int argc, char *argv[],
231 	     const char *state_session_name,
232 	     luo_test_stage1_fn stage1,
233 	     luo_test_stage2_fn stage2)
234 {
235 	int target_stage = parse_stage_args(argc, argv);
236 	int luo_fd = luo_open_device();
237 	int state_session_fd;
238 	int detected_stage;
239 
240 	if (luo_fd < 0) {
241 		ksft_exit_skip("Failed to open %s. Is the luo module loaded?\n",
242 			       LUO_DEVICE);
243 	}
244 
245 	state_session_fd = luo_retrieve_session(luo_fd, state_session_name);
246 	if (state_session_fd == -ENOENT)
247 		detected_stage = 1;
248 	else if (state_session_fd >= 0)
249 		detected_stage = 2;
250 	else
251 		fail_exit("Failed to check for state session");
252 
253 	if (target_stage != detected_stage) {
254 		ksft_exit_fail_msg("Stage mismatch Requested --stage %d, but system is in stage %d.\n"
255 				   "(State session %s: %s)\n",
256 				   target_stage, detected_stage, state_session_name,
257 				   (detected_stage == 2) ? "EXISTS" : "MISSING");
258 	}
259 
260 	if (target_stage == 1)
261 		stage1(luo_fd);
262 	else
263 		stage2(luo_fd, state_session_fd);
264 
265 	return 0;
266 }
267