1 // SPDX-License-Identifier: GPL-2.0-only 2 /* MSG_ZEROCOPY feature tests for vsock 3 * 4 * Copyright (C) 2023 SberDevices. 5 * 6 * Author: Arseniy Krasnov <avkrasnov@salutedevices.com> 7 */ 8 9 #include <stdio.h> 10 #include <stdlib.h> 11 #include <string.h> 12 #include <sys/ioctl.h> 13 #include <sys/mman.h> 14 #include <unistd.h> 15 #include <poll.h> 16 #include <linux/errqueue.h> 17 #include <linux/kernel.h> 18 #include <linux/sockios.h> 19 #include <linux/time64.h> 20 #include <errno.h> 21 22 #include "control.h" 23 #include "timeout.h" 24 #include "vsock_test_zerocopy.h" 25 #include "msg_zerocopy_common.h" 26 27 #ifndef PAGE_SIZE 28 #define PAGE_SIZE 4096 29 #endif 30 31 #define VSOCK_TEST_DATA_MAX_IOV 3 32 33 struct vsock_test_data { 34 /* This test case if for SOCK_STREAM only. */ 35 bool stream_only; 36 /* Data must be zerocopied. This field is checked against 37 * field 'ee_code' of the 'struct sock_extended_err', which 38 * contains bit to detect that zerocopy transmission was 39 * fallbacked to copy mode. 40 */ 41 bool zerocopied; 42 /* Enable SO_ZEROCOPY option on the socket. Without enabled 43 * SO_ZEROCOPY, every MSG_ZEROCOPY transmission will behave 44 * like without MSG_ZEROCOPY flag. 45 */ 46 bool so_zerocopy; 47 /* 'errno' after 'sendmsg()' call. */ 48 int sendmsg_errno; 49 /* Number of valid elements in 'vecs'. */ 50 int vecs_cnt; 51 struct iovec vecs[VSOCK_TEST_DATA_MAX_IOV]; 52 }; 53 54 static struct vsock_test_data test_data_array[] = { 55 /* Last element has non-page aligned size. */ 56 { 57 .zerocopied = true, 58 .so_zerocopy = true, 59 .sendmsg_errno = 0, 60 .vecs_cnt = 3, 61 { 62 { NULL, PAGE_SIZE }, 63 { NULL, PAGE_SIZE }, 64 { NULL, 200 } 65 } 66 }, 67 /* All elements have page aligned base and size. */ 68 { 69 .zerocopied = true, 70 .so_zerocopy = true, 71 .sendmsg_errno = 0, 72 .vecs_cnt = 3, 73 { 74 { NULL, PAGE_SIZE }, 75 { NULL, PAGE_SIZE * 2 }, 76 { NULL, PAGE_SIZE * 3 } 77 } 78 }, 79 /* All elements have page aligned base and size. But 80 * data length is bigger than 64Kb. 81 */ 82 { 83 .zerocopied = true, 84 .so_zerocopy = true, 85 .sendmsg_errno = 0, 86 .vecs_cnt = 3, 87 { 88 { NULL, PAGE_SIZE * 16 }, 89 { NULL, PAGE_SIZE * 16 }, 90 { NULL, PAGE_SIZE * 16 } 91 } 92 }, 93 /* Middle element has both non-page aligned base and size. */ 94 { 95 .zerocopied = true, 96 .so_zerocopy = true, 97 .sendmsg_errno = 0, 98 .vecs_cnt = 3, 99 { 100 { NULL, PAGE_SIZE }, 101 { (void *)1, 100 }, 102 { NULL, PAGE_SIZE } 103 } 104 }, 105 /* Middle element is unmapped. */ 106 { 107 .zerocopied = false, 108 .so_zerocopy = true, 109 .sendmsg_errno = ENOMEM, 110 .vecs_cnt = 3, 111 { 112 { NULL, PAGE_SIZE }, 113 { MAP_FAILED, PAGE_SIZE }, 114 { NULL, PAGE_SIZE } 115 } 116 }, 117 /* Valid data, but SO_ZEROCOPY is off. This 118 * will trigger fallback to copy. 119 */ 120 { 121 .zerocopied = false, 122 .so_zerocopy = false, 123 .sendmsg_errno = 0, 124 .vecs_cnt = 1, 125 { 126 { NULL, PAGE_SIZE } 127 } 128 }, 129 /* Valid data, but message is bigger than peer's 130 * buffer, so this will trigger fallback to copy. 131 * This test is for SOCK_STREAM only, because 132 * for SOCK_SEQPACKET, 'sendmsg()' returns EMSGSIZE. 133 */ 134 { 135 .stream_only = true, 136 .zerocopied = false, 137 .so_zerocopy = true, 138 .sendmsg_errno = 0, 139 .vecs_cnt = 1, 140 { 141 { NULL, 100 * PAGE_SIZE } 142 } 143 }, 144 }; 145 146 #define POLL_TIMEOUT_MS 100 147 148 static void test_client(const struct test_opts *opts, 149 const struct vsock_test_data *test_data, 150 bool sock_seqpacket) 151 { 152 struct pollfd fds = { 0 }; 153 struct msghdr msg = { 0 }; 154 ssize_t sendmsg_res; 155 struct iovec *iovec; 156 int fd; 157 158 if (sock_seqpacket) 159 fd = vsock_seqpacket_connect(opts->peer_cid, opts->peer_port); 160 else 161 fd = vsock_stream_connect(opts->peer_cid, opts->peer_port); 162 163 if (fd < 0) { 164 perror("connect"); 165 exit(EXIT_FAILURE); 166 } 167 168 if (test_data->so_zerocopy) 169 enable_so_zerocopy_check(fd); 170 171 iovec = alloc_test_iovec(test_data->vecs, test_data->vecs_cnt); 172 173 msg.msg_iov = iovec; 174 msg.msg_iovlen = test_data->vecs_cnt; 175 176 errno = 0; 177 178 sendmsg_res = sendmsg(fd, &msg, MSG_ZEROCOPY); 179 if (errno != test_data->sendmsg_errno) { 180 fprintf(stderr, "expected 'errno' == %i, got %i\n", 181 test_data->sendmsg_errno, errno); 182 exit(EXIT_FAILURE); 183 } 184 185 if (!errno) { 186 if (sendmsg_res != iovec_bytes(iovec, test_data->vecs_cnt)) { 187 fprintf(stderr, "expected 'sendmsg()' == %li, got %li\n", 188 iovec_bytes(iovec, test_data->vecs_cnt), 189 sendmsg_res); 190 exit(EXIT_FAILURE); 191 } 192 } 193 194 fds.fd = fd; 195 fds.events = 0; 196 197 if (poll(&fds, 1, POLL_TIMEOUT_MS) < 0) { 198 perror("poll"); 199 exit(EXIT_FAILURE); 200 } 201 202 if (fds.revents & POLLERR) { 203 vsock_recv_completion(fd, &test_data->zerocopied); 204 } else if (test_data->so_zerocopy && !test_data->sendmsg_errno) { 205 /* If we don't have data in the error queue, but 206 * SO_ZEROCOPY was enabled and 'sendmsg()' was 207 * successful - this is an error. 208 */ 209 fprintf(stderr, "POLLERR expected\n"); 210 exit(EXIT_FAILURE); 211 } 212 213 if (!test_data->sendmsg_errno) 214 control_writeulong(iovec_hash_djb2(iovec, test_data->vecs_cnt)); 215 else 216 control_writeulong(0); 217 218 control_writeln("DONE"); 219 free_test_iovec(test_data->vecs, iovec, test_data->vecs_cnt); 220 close(fd); 221 } 222 223 void test_stream_msgzcopy_client(const struct test_opts *opts) 224 { 225 int i; 226 227 for (i = 0; i < ARRAY_SIZE(test_data_array); i++) 228 test_client(opts, &test_data_array[i], false); 229 } 230 231 void test_seqpacket_msgzcopy_client(const struct test_opts *opts) 232 { 233 int i; 234 235 for (i = 0; i < ARRAY_SIZE(test_data_array); i++) { 236 if (test_data_array[i].stream_only) 237 continue; 238 239 test_client(opts, &test_data_array[i], true); 240 } 241 } 242 243 static void test_server(const struct test_opts *opts, 244 const struct vsock_test_data *test_data, 245 bool sock_seqpacket) 246 { 247 unsigned long remote_hash; 248 unsigned long local_hash; 249 ssize_t total_bytes_rec; 250 unsigned char *data; 251 size_t data_len; 252 int fd; 253 254 if (sock_seqpacket) 255 fd = vsock_seqpacket_accept(VMADDR_CID_ANY, opts->peer_port, NULL); 256 else 257 fd = vsock_stream_accept(VMADDR_CID_ANY, opts->peer_port, NULL); 258 259 if (fd < 0) { 260 perror("accept"); 261 exit(EXIT_FAILURE); 262 } 263 264 data_len = iovec_bytes(test_data->vecs, test_data->vecs_cnt); 265 266 data = malloc(data_len); 267 if (!data) { 268 perror("malloc"); 269 exit(EXIT_FAILURE); 270 } 271 272 total_bytes_rec = 0; 273 274 while (total_bytes_rec != data_len) { 275 ssize_t bytes_rec; 276 277 bytes_rec = read(fd, data + total_bytes_rec, 278 data_len - total_bytes_rec); 279 if (bytes_rec <= 0) 280 break; 281 282 total_bytes_rec += bytes_rec; 283 } 284 285 if (test_data->sendmsg_errno == 0) 286 local_hash = hash_djb2(data, data_len); 287 else 288 local_hash = 0; 289 290 free(data); 291 292 /* Waiting for some result. */ 293 remote_hash = control_readulong(); 294 if (remote_hash != local_hash) { 295 fprintf(stderr, "hash mismatch\n"); 296 exit(EXIT_FAILURE); 297 } 298 299 control_expectln("DONE"); 300 close(fd); 301 } 302 303 void test_stream_msgzcopy_server(const struct test_opts *opts) 304 { 305 int i; 306 307 for (i = 0; i < ARRAY_SIZE(test_data_array); i++) 308 test_server(opts, &test_data_array[i], false); 309 } 310 311 void test_seqpacket_msgzcopy_server(const struct test_opts *opts) 312 { 313 int i; 314 315 for (i = 0; i < ARRAY_SIZE(test_data_array); i++) { 316 if (test_data_array[i].stream_only) 317 continue; 318 319 test_server(opts, &test_data_array[i], true); 320 } 321 } 322 323 void test_stream_msgzcopy_empty_errq_client(const struct test_opts *opts) 324 { 325 struct msghdr msg = { 0 }; 326 char cmsg_data[128]; 327 ssize_t res; 328 int fd; 329 330 fd = vsock_stream_connect(opts->peer_cid, opts->peer_port); 331 if (fd < 0) { 332 perror("connect"); 333 exit(EXIT_FAILURE); 334 } 335 336 msg.msg_control = cmsg_data; 337 msg.msg_controllen = sizeof(cmsg_data); 338 339 res = recvmsg(fd, &msg, MSG_ERRQUEUE); 340 if (res != -1) { 341 fprintf(stderr, "expected 'recvmsg(2)' failure, got %zi\n", 342 res); 343 exit(EXIT_FAILURE); 344 } 345 346 control_writeln("DONE"); 347 close(fd); 348 } 349 350 void test_stream_msgzcopy_empty_errq_server(const struct test_opts *opts) 351 { 352 int fd; 353 354 fd = vsock_stream_accept(VMADDR_CID_ANY, opts->peer_port, NULL); 355 if (fd < 0) { 356 perror("accept"); 357 exit(EXIT_FAILURE); 358 } 359 360 control_expectln("DONE"); 361 close(fd); 362 } 363 364 #define GOOD_COPY_LEN 128 /* net/vmw_vsock/virtio_transport_common.c */ 365 366 void test_stream_msgzcopy_mangle_client(const struct test_opts *opts) 367 { 368 char sbuf1[PAGE_SIZE + 1], sbuf2[GOOD_COPY_LEN]; 369 unsigned long hash; 370 struct pollfd fds; 371 int fd, i; 372 373 fd = vsock_stream_connect(opts->peer_cid, opts->peer_port); 374 if (fd < 0) { 375 perror("connect"); 376 exit(EXIT_FAILURE); 377 } 378 379 enable_so_zerocopy_check(fd); 380 381 memset(sbuf1, 'x', sizeof(sbuf1)); 382 send_buf(fd, sbuf1, sizeof(sbuf1), 0, sizeof(sbuf1)); 383 384 for (i = 0; i < sizeof(sbuf2); i++) 385 sbuf2[i] = rand() & 0xff; 386 387 send_buf(fd, sbuf2, sizeof(sbuf2), MSG_ZEROCOPY, sizeof(sbuf2)); 388 389 hash = hash_djb2(sbuf2, sizeof(sbuf2)); 390 control_writeulong(hash); 391 392 fds.fd = fd; 393 fds.events = 0; 394 395 if (poll(&fds, 1, TIMEOUT * MSEC_PER_SEC) != 1 || 396 !(fds.revents & POLLERR)) { 397 perror("poll"); 398 exit(EXIT_FAILURE); 399 } 400 401 close(fd); 402 } 403 404 void test_stream_msgzcopy_mangle_server(const struct test_opts *opts) 405 { 406 unsigned long local_hash, remote_hash; 407 char rbuf[PAGE_SIZE + 1]; 408 int fd; 409 410 fd = vsock_stream_accept(VMADDR_CID_ANY, opts->peer_port, NULL); 411 if (fd < 0) { 412 perror("accept"); 413 exit(EXIT_FAILURE); 414 } 415 416 /* Wait, don't race the (buggy) skbs coalescence. */ 417 vsock_ioctl_int(fd, SIOCINQ, PAGE_SIZE + 1 + GOOD_COPY_LEN); 418 419 /* Discard the first packet. */ 420 recv_buf(fd, rbuf, PAGE_SIZE + 1, 0, PAGE_SIZE + 1); 421 422 recv_buf(fd, rbuf, GOOD_COPY_LEN, 0, GOOD_COPY_LEN); 423 remote_hash = control_readulong(); 424 local_hash = hash_djb2(rbuf, GOOD_COPY_LEN); 425 426 if (local_hash != remote_hash) { 427 fprintf(stderr, "Data received corrupted\n"); 428 exit(EXIT_FAILURE); 429 } 430 431 close(fd); 432 } 433