1 // SPDX-License-Identifier: GPL-2.0-only
2
3 /*
4 * Copyright (c) 2025, Google LLC.
5 * Pasha Tatashin <pasha.tatashin@soleen.com>
6 */
7
8 #define _GNU_SOURCE
9
10 #include <stdio.h>
11 #include <stdlib.h>
12 #include <string.h>
13 #include <getopt.h>
14 #include <fcntl.h>
15 #include <unistd.h>
16 #include <sys/ioctl.h>
17 #include <sys/syscall.h>
18 #include <sys/mman.h>
19 #include <sys/types.h>
20 #include <sys/stat.h>
21 #include <errno.h>
22 #include <stdarg.h>
23
24 #include "luo_test_utils.h"
25
luo_open_device(void)26 int luo_open_device(void)
27 {
28 return open(LUO_DEVICE, O_RDWR);
29 }
30
luo_create_session(int luo_fd,const char * name)31 int luo_create_session(int luo_fd, const char *name)
32 {
33 struct liveupdate_ioctl_create_session arg = { .size = sizeof(arg) };
34
35 snprintf((char *)arg.name, LIVEUPDATE_SESSION_NAME_LENGTH, "%.*s",
36 LIVEUPDATE_SESSION_NAME_LENGTH - 1, name);
37
38 if (ioctl(luo_fd, LIVEUPDATE_IOCTL_CREATE_SESSION, &arg) < 0)
39 return -errno;
40
41 return arg.fd;
42 }
43
luo_retrieve_session(int luo_fd,const char * name)44 int luo_retrieve_session(int luo_fd, const char *name)
45 {
46 struct liveupdate_ioctl_retrieve_session arg = { .size = sizeof(arg) };
47
48 snprintf((char *)arg.name, LIVEUPDATE_SESSION_NAME_LENGTH, "%.*s",
49 LIVEUPDATE_SESSION_NAME_LENGTH - 1, name);
50
51 if (ioctl(luo_fd, LIVEUPDATE_IOCTL_RETRIEVE_SESSION, &arg) < 0)
52 return -errno;
53
54 return arg.fd;
55 }
56
create_and_preserve_memfd(int session_fd,int token,const char * data)57 int create_and_preserve_memfd(int session_fd, int token, const char *data)
58 {
59 struct liveupdate_session_preserve_fd arg = { .size = sizeof(arg) };
60 long page_size = sysconf(_SC_PAGE_SIZE);
61 void *map = MAP_FAILED;
62 int mfd = -1, ret = -1;
63
64 mfd = memfd_create("test_mfd", 0);
65 if (mfd < 0)
66 return -errno;
67
68 if (ftruncate(mfd, page_size) != 0)
69 goto out;
70
71 map = mmap(NULL, page_size, PROT_WRITE, MAP_SHARED, mfd, 0);
72 if (map == MAP_FAILED)
73 goto out;
74
75 snprintf(map, page_size, "%s", data);
76 munmap(map, page_size);
77
78 arg.fd = mfd;
79 arg.token = token;
80 if (ioctl(session_fd, LIVEUPDATE_SESSION_PRESERVE_FD, &arg) < 0)
81 goto out;
82
83 ret = 0;
84 out:
85 if (ret != 0 && errno != 0)
86 ret = -errno;
87 if (mfd >= 0)
88 close(mfd);
89 return ret;
90 }
91
restore_and_verify_memfd(int session_fd,int token,const char * expected_data)92 int restore_and_verify_memfd(int session_fd, int token,
93 const char *expected_data)
94 {
95 struct liveupdate_session_retrieve_fd arg = { .size = sizeof(arg) };
96 long page_size = sysconf(_SC_PAGE_SIZE);
97 void *map = MAP_FAILED;
98 int mfd = -1, ret = -1;
99
100 arg.token = token;
101 if (ioctl(session_fd, LIVEUPDATE_SESSION_RETRIEVE_FD, &arg) < 0)
102 return -errno;
103 mfd = arg.fd;
104
105 map = mmap(NULL, page_size, PROT_READ, MAP_SHARED, mfd, 0);
106 if (map == MAP_FAILED)
107 goto out;
108
109 if (expected_data && strcmp(expected_data, map) != 0) {
110 ksft_print_msg("Data mismatch! Expected '%s', Got '%s'\n",
111 expected_data, (char *)map);
112 ret = -EINVAL;
113 goto out_munmap;
114 }
115
116 ret = mfd;
117 out_munmap:
118 munmap(map, page_size);
119 out:
120 if (ret < 0 && errno != 0)
121 ret = -errno;
122 if (ret < 0 && mfd >= 0)
123 close(mfd);
124 return ret;
125 }
126
luo_session_finish(int session_fd)127 int luo_session_finish(int session_fd)
128 {
129 struct liveupdate_session_finish arg = { .size = sizeof(arg) };
130
131 if (ioctl(session_fd, LIVEUPDATE_SESSION_FINISH, &arg) < 0)
132 return -errno;
133
134 return 0;
135 }
136
create_state_file(int luo_fd,const char * session_name,int token,int next_stage)137 void create_state_file(int luo_fd, const char *session_name, int token,
138 int next_stage)
139 {
140 char buf[32];
141 int state_session_fd;
142
143 state_session_fd = luo_create_session(luo_fd, session_name);
144 if (state_session_fd < 0)
145 fail_exit("luo_create_session for state tracking");
146
147 snprintf(buf, sizeof(buf), "%d", next_stage);
148 if (create_and_preserve_memfd(state_session_fd, token, buf) < 0)
149 fail_exit("create_and_preserve_memfd for state tracking");
150
151 /*
152 * DO NOT close session FD, otherwise it is going to be unpreserved
153 */
154 }
155
restore_and_read_stage(int state_session_fd,int token,int * stage)156 void restore_and_read_stage(int state_session_fd, int token, int *stage)
157 {
158 char buf[32] = {0};
159 int mfd;
160
161 mfd = restore_and_verify_memfd(state_session_fd, token, NULL);
162 if (mfd < 0)
163 fail_exit("failed to restore state memfd");
164
165 if (read(mfd, buf, sizeof(buf) - 1) < 0)
166 fail_exit("failed to read state mfd");
167
168 *stage = atoi(buf);
169
170 close(mfd);
171 }
172
daemonize_and_wait(void)173 void daemonize_and_wait(void)
174 {
175 pid_t pid;
176
177 ksft_print_msg("[STAGE 1] Forking persistent child to hold sessions...\n");
178
179 pid = fork();
180 if (pid < 0)
181 fail_exit("fork failed");
182
183 if (pid > 0) {
184 ksft_print_msg("[STAGE 1] Child PID: %d. Resources are pinned.\n", pid);
185 ksft_print_msg("[STAGE 1] You may now perform kexec reboot.\n");
186 exit(EXIT_SUCCESS);
187 }
188
189 /* Detach from terminal so closing the window doesn't kill us */
190 if (setsid() < 0)
191 fail_exit("setsid failed");
192
193 close(STDIN_FILENO);
194 close(STDOUT_FILENO);
195 close(STDERR_FILENO);
196
197 /* Change dir to root to avoid locking filesystems */
198 if (chdir("/") < 0)
199 exit(EXIT_FAILURE);
200
201 while (1)
202 sleep(60);
203 }
204
parse_stage_args(int argc,char * argv[])205 static int parse_stage_args(int argc, char *argv[])
206 {
207 static struct option long_options[] = {
208 {"stage", required_argument, 0, 's'},
209 {0, 0, 0, 0}
210 };
211 int option_index = 0;
212 int stage = 1;
213 int opt;
214
215 optind = 1;
216 while ((opt = getopt_long(argc, argv, "s:", long_options, &option_index)) != -1) {
217 switch (opt) {
218 case 's':
219 stage = atoi(optarg);
220 if (stage != 1 && stage != 2)
221 fail_exit("Invalid stage argument");
222 break;
223 default:
224 fail_exit("Unknown argument");
225 }
226 }
227 return stage;
228 }
229
luo_test(int argc,char * argv[],const char * state_session_name,luo_test_stage1_fn stage1,luo_test_stage2_fn stage2)230 int luo_test(int argc, char *argv[],
231 const char *state_session_name,
232 luo_test_stage1_fn stage1,
233 luo_test_stage2_fn stage2)
234 {
235 int target_stage = parse_stage_args(argc, argv);
236 int luo_fd = luo_open_device();
237 int state_session_fd;
238 int detected_stage;
239
240 if (luo_fd < 0) {
241 ksft_exit_skip("Failed to open %s. Is the luo module loaded?\n",
242 LUO_DEVICE);
243 }
244
245 state_session_fd = luo_retrieve_session(luo_fd, state_session_name);
246 if (state_session_fd == -ENOENT)
247 detected_stage = 1;
248 else if (state_session_fd >= 0)
249 detected_stage = 2;
250 else
251 fail_exit("Failed to check for state session");
252
253 if (target_stage != detected_stage) {
254 ksft_exit_fail_msg("Stage mismatch Requested --stage %d, but system is in stage %d.\n"
255 "(State session %s: %s)\n",
256 target_stage, detected_stage, state_session_name,
257 (detected_stage == 2) ? "EXISTS" : "MISSING");
258 }
259
260 if (target_stage == 1)
261 stage1(luo_fd);
262 else
263 stage2(luo_fd, state_session_fd);
264
265 return 0;
266 }
267