xref: /linux/tools/testing/selftests/net/psock_snd.c (revision c8bfe3fad4f86a029da7157bae9699c816f0c309)
1 // SPDX-License-Identifier: GPL-2.0
2 
3 #define _GNU_SOURCE
4 
5 #include <arpa/inet.h>
6 #include <errno.h>
7 #include <error.h>
8 #include <fcntl.h>
9 #include <limits.h>
10 #include <linux/filter.h>
11 #include <linux/bpf.h>
12 #include <linux/if_packet.h>
13 #include <linux/if_vlan.h>
14 #include <linux/virtio_net.h>
15 #include <net/if.h>
16 #include <net/ethernet.h>
17 #include <netinet/ip.h>
18 #include <netinet/udp.h>
19 #include <poll.h>
20 #include <sched.h>
21 #include <stdbool.h>
22 #include <stdint.h>
23 #include <stdio.h>
24 #include <stdlib.h>
25 #include <string.h>
26 #include <sys/mman.h>
27 #include <sys/socket.h>
28 #include <sys/stat.h>
29 #include <sys/types.h>
30 #include <unistd.h>
31 
32 #include "psock_lib.h"
33 
34 static bool	cfg_use_bind;
35 static bool	cfg_use_csum_off;
36 static bool	cfg_use_csum_off_bad;
37 static bool	cfg_use_dgram;
38 static bool	cfg_use_gso;
39 static bool	cfg_use_qdisc_bypass;
40 static bool	cfg_use_vlan;
41 static bool	cfg_use_vnet;
42 
43 static char	*cfg_ifname = "lo";
44 static int	cfg_mtu	= 1500;
45 static int	cfg_payload_len = DATA_LEN;
46 static int	cfg_truncate_len = INT_MAX;
47 static uint16_t	cfg_port = 8000;
48 
49 /* test sending up to max mtu + 1 */
50 #define TEST_SZ	(sizeof(struct virtio_net_hdr) + ETH_HLEN + ETH_MAX_MTU + 1)
51 
52 static char tbuf[TEST_SZ], rbuf[TEST_SZ];
53 
54 static unsigned long add_csum_hword(const uint16_t *start, int num_u16)
55 {
56 	unsigned long sum = 0;
57 	int i;
58 
59 	for (i = 0; i < num_u16; i++)
60 		sum += start[i];
61 
62 	return sum;
63 }
64 
65 static uint16_t build_ip_csum(const uint16_t *start, int num_u16,
66 			      unsigned long sum)
67 {
68 	sum += add_csum_hword(start, num_u16);
69 
70 	while (sum >> 16)
71 		sum = (sum & 0xffff) + (sum >> 16);
72 
73 	return ~sum;
74 }
75 
76 static int build_vnet_header(void *header)
77 {
78 	struct virtio_net_hdr *vh = header;
79 
80 	vh->hdr_len = ETH_HLEN + sizeof(struct iphdr) + sizeof(struct udphdr);
81 
82 	if (cfg_use_csum_off) {
83 		vh->flags |= VIRTIO_NET_HDR_F_NEEDS_CSUM;
84 		vh->csum_start = ETH_HLEN + sizeof(struct iphdr);
85 		vh->csum_offset = __builtin_offsetof(struct udphdr, check);
86 
87 		/* position check field exactly one byte beyond end of packet */
88 		if (cfg_use_csum_off_bad)
89 			vh->csum_start += sizeof(struct udphdr) + cfg_payload_len -
90 					  vh->csum_offset - 1;
91 	}
92 
93 	if (cfg_use_gso) {
94 		vh->gso_type = VIRTIO_NET_HDR_GSO_UDP;
95 		vh->gso_size = cfg_mtu - sizeof(struct iphdr);
96 	}
97 
98 	return sizeof(*vh);
99 }
100 
101 static int build_eth_header(void *header)
102 {
103 	struct ethhdr *eth = header;
104 
105 	if (cfg_use_vlan) {
106 		uint16_t *tag = header + ETH_HLEN;
107 
108 		eth->h_proto = htons(ETH_P_8021Q);
109 		tag[1] = htons(ETH_P_IP);
110 		return ETH_HLEN + 4;
111 	}
112 
113 	eth->h_proto = htons(ETH_P_IP);
114 	return ETH_HLEN;
115 }
116 
117 static int build_ipv4_header(void *header, int payload_len)
118 {
119 	struct iphdr *iph = header;
120 
121 	iph->ihl = 5;
122 	iph->version = 4;
123 	iph->ttl = 8;
124 	iph->tot_len = htons(sizeof(*iph) + sizeof(struct udphdr) + payload_len);
125 	iph->id = htons(1337);
126 	iph->protocol = IPPROTO_UDP;
127 	iph->saddr = htonl((172 << 24) | (17 << 16) | 2);
128 	iph->daddr = htonl((172 << 24) | (17 << 16) | 1);
129 	iph->check = build_ip_csum((void *) iph, iph->ihl << 1, 0);
130 
131 	return iph->ihl << 2;
132 }
133 
134 static int build_udp_header(void *header, int payload_len)
135 {
136 	const int alen = sizeof(uint32_t);
137 	struct udphdr *udph = header;
138 	int len = sizeof(*udph) + payload_len;
139 
140 	udph->source = htons(9);
141 	udph->dest = htons(cfg_port);
142 	udph->len = htons(len);
143 
144 	if (cfg_use_csum_off)
145 		udph->check = build_ip_csum(header - (2 * alen), alen,
146 					    htons(IPPROTO_UDP) + udph->len);
147 	else
148 		udph->check = 0;
149 
150 	return sizeof(*udph);
151 }
152 
153 static int build_packet(int payload_len)
154 {
155 	int off = 0;
156 
157 	off += build_vnet_header(tbuf);
158 	off += build_eth_header(tbuf + off);
159 	off += build_ipv4_header(tbuf + off, payload_len);
160 	off += build_udp_header(tbuf + off, payload_len);
161 
162 	if (off + payload_len > sizeof(tbuf))
163 		error(1, 0, "payload length exceeds max");
164 
165 	memset(tbuf + off, DATA_CHAR, payload_len);
166 
167 	return off + payload_len;
168 }
169 
170 static void do_bind(int fd)
171 {
172 	struct sockaddr_ll laddr = {0};
173 
174 	laddr.sll_family = AF_PACKET;
175 	laddr.sll_protocol = htons(ETH_P_IP);
176 	laddr.sll_ifindex = if_nametoindex(cfg_ifname);
177 	if (!laddr.sll_ifindex)
178 		error(1, errno, "if_nametoindex");
179 
180 	if (bind(fd, (void *)&laddr, sizeof(laddr)))
181 		error(1, errno, "bind");
182 }
183 
184 static void do_send(int fd, char *buf, int len)
185 {
186 	int ret;
187 
188 	if (!cfg_use_vnet) {
189 		buf += sizeof(struct virtio_net_hdr);
190 		len -= sizeof(struct virtio_net_hdr);
191 	}
192 	if (cfg_use_dgram) {
193 		buf += ETH_HLEN;
194 		len -= ETH_HLEN;
195 	}
196 
197 	if (cfg_use_bind) {
198 		ret = write(fd, buf, len);
199 	} else {
200 		struct sockaddr_ll laddr = {0};
201 
202 		laddr.sll_protocol = htons(ETH_P_IP);
203 		laddr.sll_ifindex = if_nametoindex(cfg_ifname);
204 		if (!laddr.sll_ifindex)
205 			error(1, errno, "if_nametoindex");
206 
207 		ret = sendto(fd, buf, len, 0, (void *)&laddr, sizeof(laddr));
208 	}
209 
210 	if (ret == -1)
211 		error(1, errno, "write");
212 	if (ret != len)
213 		error(1, 0, "write: %u %u", ret, len);
214 
215 	fprintf(stderr, "tx: %u\n", ret);
216 }
217 
218 static int do_tx(void)
219 {
220 	const int one = 1;
221 	int fd, len;
222 
223 	fd = socket(PF_PACKET, cfg_use_dgram ? SOCK_DGRAM : SOCK_RAW, 0);
224 	if (fd == -1)
225 		error(1, errno, "socket t");
226 
227 	if (cfg_use_bind)
228 		do_bind(fd);
229 
230 	if (cfg_use_qdisc_bypass &&
231 	    setsockopt(fd, SOL_PACKET, PACKET_QDISC_BYPASS, &one, sizeof(one)))
232 		error(1, errno, "setsockopt qdisc bypass");
233 
234 	if (cfg_use_vnet &&
235 	    setsockopt(fd, SOL_PACKET, PACKET_VNET_HDR, &one, sizeof(one)))
236 		error(1, errno, "setsockopt vnet");
237 
238 	len = build_packet(cfg_payload_len);
239 
240 	if (cfg_truncate_len < len)
241 		len = cfg_truncate_len;
242 
243 	do_send(fd, tbuf, len);
244 
245 	if (close(fd))
246 		error(1, errno, "close t");
247 
248 	return len;
249 }
250 
251 static int setup_rx(void)
252 {
253 	struct timeval tv = { .tv_usec = 100 * 1000 };
254 	struct sockaddr_in raddr = {0};
255 	int fd;
256 
257 	fd = socket(PF_INET, SOCK_DGRAM, 0);
258 	if (fd == -1)
259 		error(1, errno, "socket r");
260 
261 	if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)))
262 		error(1, errno, "setsockopt rcv timeout");
263 
264 	raddr.sin_family = AF_INET;
265 	raddr.sin_port = htons(cfg_port);
266 	raddr.sin_addr.s_addr = htonl(INADDR_ANY);
267 
268 	if (bind(fd, (void *)&raddr, sizeof(raddr)))
269 		error(1, errno, "bind r");
270 
271 	return fd;
272 }
273 
274 static void do_rx(int fd, int expected_len, char *expected)
275 {
276 	int ret;
277 
278 	ret = recv(fd, rbuf, sizeof(rbuf), 0);
279 	if (ret == -1)
280 		error(1, errno, "recv");
281 	if (ret != expected_len)
282 		error(1, 0, "recv: %u != %u", ret, expected_len);
283 
284 	if (memcmp(rbuf, expected, ret))
285 		error(1, 0, "recv: data mismatch");
286 
287 	fprintf(stderr, "rx: %u\n", ret);
288 }
289 
290 static int setup_sniffer(void)
291 {
292 	struct timeval tv = { .tv_usec = 100 * 1000 };
293 	int fd;
294 
295 	fd = socket(PF_PACKET, SOCK_RAW, 0);
296 	if (fd == -1)
297 		error(1, errno, "socket p");
298 
299 	if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)))
300 		error(1, errno, "setsockopt rcv timeout");
301 
302 	pair_udp_setfilter(fd);
303 	do_bind(fd);
304 
305 	return fd;
306 }
307 
308 static void parse_opts(int argc, char **argv)
309 {
310 	int c;
311 
312 	while ((c = getopt(argc, argv, "bcCdgl:qt:vV")) != -1) {
313 		switch (c) {
314 		case 'b':
315 			cfg_use_bind = true;
316 			break;
317 		case 'c':
318 			cfg_use_csum_off = true;
319 			break;
320 		case 'C':
321 			cfg_use_csum_off_bad = true;
322 			break;
323 		case 'd':
324 			cfg_use_dgram = true;
325 			break;
326 		case 'g':
327 			cfg_use_gso = true;
328 			break;
329 		case 'l':
330 			cfg_payload_len = strtoul(optarg, NULL, 0);
331 			break;
332 		case 'q':
333 			cfg_use_qdisc_bypass = true;
334 			break;
335 		case 't':
336 			cfg_truncate_len = strtoul(optarg, NULL, 0);
337 			break;
338 		case 'v':
339 			cfg_use_vnet = true;
340 			break;
341 		case 'V':
342 			cfg_use_vlan = true;
343 			break;
344 		default:
345 			error(1, 0, "%s: parse error", argv[0]);
346 		}
347 	}
348 
349 	if (cfg_use_vlan && cfg_use_dgram)
350 		error(1, 0, "option vlan (-V) conflicts with dgram (-d)");
351 
352 	if (cfg_use_csum_off && !cfg_use_vnet)
353 		error(1, 0, "option csum offload (-c) requires vnet (-v)");
354 
355 	if (cfg_use_csum_off_bad && !cfg_use_csum_off)
356 		error(1, 0, "option csum bad (-C) requires csum offload (-c)");
357 
358 	if (cfg_use_gso && !cfg_use_csum_off)
359 		error(1, 0, "option gso (-g) requires csum offload (-c)");
360 }
361 
362 static void run_test(void)
363 {
364 	int fdr, fds, total_len;
365 
366 	fdr = setup_rx();
367 	fds = setup_sniffer();
368 
369 	total_len = do_tx();
370 
371 	/* BPF filter accepts only this length, vlan changes MAC */
372 	if (cfg_payload_len == DATA_LEN && !cfg_use_vlan)
373 		do_rx(fds, total_len - sizeof(struct virtio_net_hdr),
374 		      tbuf + sizeof(struct virtio_net_hdr));
375 
376 	do_rx(fdr, cfg_payload_len, tbuf + total_len - cfg_payload_len);
377 
378 	if (close(fds))
379 		error(1, errno, "close s");
380 	if (close(fdr))
381 		error(1, errno, "close r");
382 }
383 
384 int main(int argc, char **argv)
385 {
386 	parse_opts(argc, argv);
387 
388 	if (system("ip link set dev lo mtu 1500"))
389 		error(1, errno, "ip link set mtu");
390 	if (system("ip addr add dev lo 172.17.0.1/24"))
391 		error(1, errno, "ip addr add");
392 	if (system("sysctl -w net.ipv4.conf.lo.accept_local=1"))
393 		error(1, errno, "sysctl lo.accept_local");
394 
395 	run_test();
396 
397 	fprintf(stderr, "OK\n\n");
398 	return 0;
399 }
400