xref: /linux/net/netfilter/ipvs/ip_vs_proto_udp.c (revision 24f171c7e145f43b9f187578e89b0982ce87e54c)
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  * ip_vs_proto_udp.c:	UDP load balancing support for IPVS
4  *
5  * Authors:     Wensong Zhang <wensong@linuxvirtualserver.org>
6  *              Julian Anastasov <ja@ssi.bg>
7  *
8  * Changes:     Hans Schillstrom <hans.schillstrom@ericsson.com>
9  *              Network name space (netns) aware.
10  */
11 
12 #define pr_fmt(fmt) "IPVS: " fmt
13 
14 #include <linux/in.h>
15 #include <linux/ip.h>
16 #include <linux/kernel.h>
17 #include <linux/netfilter.h>
18 #include <linux/netfilter_ipv4.h>
19 #include <linux/udp.h>
20 #include <linux/indirect_call_wrapper.h>
21 
22 #include <net/ip_vs.h>
23 #include <net/ip.h>
24 #include <net/ip6_checksum.h>
25 
26 static int
27 udp_csum_check(int af, struct sk_buff *skb, struct ip_vs_protocol *pp);
28 
29 static int
30 udp_conn_schedule(struct netns_ipvs *ipvs, int af, struct sk_buff *skb,
31 		  struct ip_vs_proto_data *pd,
32 		  int *verdict, struct ip_vs_conn **cpp,
33 		  struct ip_vs_iphdr *iph)
34 {
35 	struct ip_vs_service *svc;
36 	struct udphdr _udph, *uh;
37 	__be16 _ports[2], *ports = NULL;
38 
39 	if (likely(!ip_vs_iph_icmp(iph))) {
40 		/* IPv6 fragments, only first fragment will hit this */
41 		uh = skb_header_pointer(skb, iph->len, sizeof(_udph), &_udph);
42 		if (uh)
43 			ports = &uh->source;
44 	} else {
45 		ports = skb_header_pointer(
46 			skb, iph->len, sizeof(_ports), &_ports);
47 	}
48 
49 	if (!ports) {
50 		*verdict = NF_DROP;
51 		return 0;
52 	}
53 
54 	if (likely(!ip_vs_iph_inverse(iph)))
55 		svc = ip_vs_service_find(ipvs, af, skb->mark, iph->protocol,
56 					 &iph->daddr, ports[1]);
57 	else
58 		svc = ip_vs_service_find(ipvs, af, skb->mark, iph->protocol,
59 					 &iph->saddr, ports[0]);
60 
61 	if (svc) {
62 		int ignored;
63 
64 		if (ip_vs_todrop(ipvs)) {
65 			/*
66 			 * It seems that we are very loaded.
67 			 * We have to drop this packet :(
68 			 */
69 			*verdict = NF_DROP;
70 			return 0;
71 		}
72 
73 		/*
74 		 * Let the virtual server select a real server for the
75 		 * incoming connection, and create a connection entry.
76 		 */
77 		*cpp = ip_vs_schedule(svc, skb, pd, &ignored, iph);
78 		if (!*cpp && ignored <= 0) {
79 			if (!ignored)
80 				*verdict = ip_vs_leave(svc, skb, pd, iph);
81 			else
82 				*verdict = NF_DROP;
83 			return 0;
84 		}
85 	}
86 	/* NF_ACCEPT */
87 	return 1;
88 }
89 
90 
91 static inline void
92 udp_fast_csum_update(int af, struct udphdr *uhdr,
93 		     const union nf_inet_addr *oldip,
94 		     const union nf_inet_addr *newip,
95 		     __be16 oldport, __be16 newport)
96 {
97 #ifdef CONFIG_IP_VS_IPV6
98 	if (af == AF_INET6)
99 		uhdr->check =
100 			csum_fold(ip_vs_check_diff16(oldip->ip6, newip->ip6,
101 					 ip_vs_check_diff2(oldport, newport,
102 						~csum_unfold(uhdr->check))));
103 	else
104 #endif
105 		uhdr->check =
106 			csum_fold(ip_vs_check_diff4(oldip->ip, newip->ip,
107 					 ip_vs_check_diff2(oldport, newport,
108 						~csum_unfold(uhdr->check))));
109 	if (!uhdr->check)
110 		uhdr->check = CSUM_MANGLED_0;
111 }
112 
113 static inline void
114 udp_partial_csum_update(int af, struct udphdr *uhdr,
115 		     const union nf_inet_addr *oldip,
116 		     const union nf_inet_addr *newip,
117 		     __be16 oldlen, __be16 newlen)
118 {
119 #ifdef CONFIG_IP_VS_IPV6
120 	if (af == AF_INET6)
121 		uhdr->check =
122 			~csum_fold(ip_vs_check_diff16(oldip->ip6, newip->ip6,
123 					 ip_vs_check_diff2(oldlen, newlen,
124 						csum_unfold(uhdr->check))));
125 	else
126 #endif
127 	uhdr->check =
128 		~csum_fold(ip_vs_check_diff4(oldip->ip, newip->ip,
129 				ip_vs_check_diff2(oldlen, newlen,
130 						csum_unfold(uhdr->check))));
131 }
132 
133 
134 INDIRECT_CALLABLE_SCOPE int
135 udp_snat_handler(struct sk_buff *skb, struct ip_vs_protocol *pp,
136 		 struct ip_vs_conn *cp, struct ip_vs_iphdr *iph)
137 {
138 	struct udphdr *udph;
139 	unsigned int udphoff = iph->len;
140 	bool payload_csum = false;
141 	int oldlen;
142 
143 #ifdef CONFIG_IP_VS_IPV6
144 	if (cp->af == AF_INET6 && iph->fragoffs)
145 		return 1;
146 #endif
147 	oldlen = skb->len - udphoff;
148 
149 	/* csum_check requires unshared skb */
150 	if (skb_ensure_writable(skb, udphoff + sizeof(*udph)))
151 		return 0;
152 
153 	if (unlikely(cp->app != NULL)) {
154 		int ret;
155 
156 		/* Some checks before mangling */
157 		if (!udp_csum_check(cp->af, skb, pp))
158 			return 0;
159 
160 		/*
161 		 *	Call application helper if needed
162 		 */
163 		if (!(ret = ip_vs_app_pkt_out(cp, skb, iph)))
164 			return 0;
165 		/* ret=2: csum update is needed after payload mangling */
166 		if (ret == 1)
167 			oldlen = skb->len - udphoff;
168 		else
169 			payload_csum = true;
170 	}
171 
172 	udph = (void *)skb_network_header(skb) + udphoff;
173 	udph->source = cp->vport;
174 
175 	/*
176 	 *	Adjust UDP checksums
177 	 */
178 	if (skb->ip_summed == CHECKSUM_PARTIAL) {
179 		udp_partial_csum_update(cp->af, udph, &cp->daddr, &cp->vaddr,
180 					htons(oldlen),
181 					htons(skb->len - udphoff));
182 	} else if (!payload_csum && (udph->check != 0)) {
183 		/* Only port and addr are changed, do fast csum update */
184 		udp_fast_csum_update(cp->af, udph, &cp->daddr, &cp->vaddr,
185 				     cp->dport, cp->vport);
186 		if (skb->ip_summed == CHECKSUM_COMPLETE)
187 			skb->ip_summed = cp->app ?
188 					 CHECKSUM_UNNECESSARY : CHECKSUM_NONE;
189 	} else {
190 		/* full checksum calculation */
191 		udph->check = 0;
192 		skb->csum = skb_checksum(skb, udphoff, skb->len - udphoff, 0);
193 #ifdef CONFIG_IP_VS_IPV6
194 		if (cp->af == AF_INET6)
195 			udph->check = csum_ipv6_magic(&cp->vaddr.in6,
196 						      &cp->caddr.in6,
197 						      skb->len - udphoff,
198 						      cp->protocol, skb->csum);
199 		else
200 #endif
201 			udph->check = csum_tcpudp_magic(cp->vaddr.ip,
202 							cp->caddr.ip,
203 							skb->len - udphoff,
204 							cp->protocol,
205 							skb->csum);
206 		if (udph->check == 0)
207 			udph->check = CSUM_MANGLED_0;
208 		skb->ip_summed = CHECKSUM_UNNECESSARY;
209 		IP_VS_DBG(11, "O-pkt: %s O-csum=%d (+%zd)\n",
210 			  pp->name, udph->check,
211 			  (char*)&(udph->check) - (char*)udph);
212 	}
213 	return 1;
214 }
215 
216 
217 static int
218 udp_dnat_handler(struct sk_buff *skb, struct ip_vs_protocol *pp,
219 		 struct ip_vs_conn *cp, struct ip_vs_iphdr *iph)
220 {
221 	struct udphdr *udph;
222 	unsigned int udphoff = iph->len;
223 	bool payload_csum = false;
224 	int oldlen;
225 
226 #ifdef CONFIG_IP_VS_IPV6
227 	if (cp->af == AF_INET6 && iph->fragoffs)
228 		return 1;
229 #endif
230 	oldlen = skb->len - udphoff;
231 
232 	/* csum_check requires unshared skb */
233 	if (skb_ensure_writable(skb, udphoff + sizeof(*udph)))
234 		return 0;
235 
236 	if (unlikely(cp->app != NULL)) {
237 		int ret;
238 
239 		/* Some checks before mangling */
240 		if (!udp_csum_check(cp->af, skb, pp))
241 			return 0;
242 
243 		/*
244 		 *	Attempt ip_vs_app call.
245 		 *	It will fix ip_vs_conn
246 		 */
247 		if (!(ret = ip_vs_app_pkt_in(cp, skb, iph)))
248 			return 0;
249 		/* ret=2: csum update is needed after payload mangling */
250 		if (ret == 1)
251 			oldlen = skb->len - udphoff;
252 		else
253 			payload_csum = true;
254 	}
255 
256 	udph = (void *)skb_network_header(skb) + udphoff;
257 	udph->dest = cp->dport;
258 
259 	/*
260 	 *	Adjust UDP checksums
261 	 */
262 	if (skb->ip_summed == CHECKSUM_PARTIAL) {
263 		udp_partial_csum_update(cp->af, udph, &cp->vaddr, &cp->daddr,
264 					htons(oldlen),
265 					htons(skb->len - udphoff));
266 	} else if (!payload_csum && (udph->check != 0)) {
267 		/* Only port and addr are changed, do fast csum update */
268 		udp_fast_csum_update(cp->af, udph, &cp->vaddr, &cp->daddr,
269 				     cp->vport, cp->dport);
270 		if (skb->ip_summed == CHECKSUM_COMPLETE)
271 			skb->ip_summed = cp->app ?
272 					 CHECKSUM_UNNECESSARY : CHECKSUM_NONE;
273 	} else {
274 		/* full checksum calculation */
275 		udph->check = 0;
276 		skb->csum = skb_checksum(skb, udphoff, skb->len - udphoff, 0);
277 #ifdef CONFIG_IP_VS_IPV6
278 		if (cp->af == AF_INET6)
279 			udph->check = csum_ipv6_magic(&cp->caddr.in6,
280 						      &cp->daddr.in6,
281 						      skb->len - udphoff,
282 						      cp->protocol, skb->csum);
283 		else
284 #endif
285 			udph->check = csum_tcpudp_magic(cp->caddr.ip,
286 							cp->daddr.ip,
287 							skb->len - udphoff,
288 							cp->protocol,
289 							skb->csum);
290 		if (udph->check == 0)
291 			udph->check = CSUM_MANGLED_0;
292 		skb->ip_summed = CHECKSUM_UNNECESSARY;
293 	}
294 	return 1;
295 }
296 
297 
298 static int
299 udp_csum_check(int af, struct sk_buff *skb, struct ip_vs_protocol *pp)
300 {
301 	struct udphdr _udph, *uh;
302 	unsigned int udphoff;
303 
304 #ifdef CONFIG_IP_VS_IPV6
305 	if (af == AF_INET6)
306 		udphoff = sizeof(struct ipv6hdr);
307 	else
308 #endif
309 		udphoff = ip_hdrlen(skb);
310 
311 	uh = skb_header_pointer(skb, udphoff, sizeof(_udph), &_udph);
312 	if (uh == NULL)
313 		return 0;
314 
315 	if (uh->check != 0) {
316 		switch (skb->ip_summed) {
317 		case CHECKSUM_NONE:
318 			skb->csum = skb_checksum(skb, udphoff,
319 						 skb->len - udphoff, 0);
320 			fallthrough;
321 		case CHECKSUM_COMPLETE:
322 #ifdef CONFIG_IP_VS_IPV6
323 			if (af == AF_INET6) {
324 				if (csum_ipv6_magic(&ipv6_hdr(skb)->saddr,
325 						    &ipv6_hdr(skb)->daddr,
326 						    skb->len - udphoff,
327 						    ipv6_hdr(skb)->nexthdr,
328 						    skb->csum)) {
329 					IP_VS_DBG_RL_PKT(0, af, pp, skb, 0,
330 							 "Failed checksum for");
331 					return 0;
332 				}
333 			} else
334 #endif
335 				if (csum_tcpudp_magic(ip_hdr(skb)->saddr,
336 						      ip_hdr(skb)->daddr,
337 						      skb->len - udphoff,
338 						      ip_hdr(skb)->protocol,
339 						      skb->csum)) {
340 					IP_VS_DBG_RL_PKT(0, af, pp, skb, 0,
341 							 "Failed checksum for");
342 					return 0;
343 				}
344 			break;
345 		default:
346 			/* No need to checksum. */
347 			break;
348 		}
349 	}
350 	return 1;
351 }
352 
353 static inline __u16 udp_app_hashkey(__be16 port)
354 {
355 	return (((__force u16)port >> UDP_APP_TAB_BITS) ^ (__force u16)port)
356 		& UDP_APP_TAB_MASK;
357 }
358 
359 
360 static int udp_register_app(struct netns_ipvs *ipvs, struct ip_vs_app *inc)
361 {
362 	struct ip_vs_app *i;
363 	__u16 hash;
364 	__be16 port = inc->port;
365 	int ret = 0;
366 	struct ip_vs_proto_data *pd = ip_vs_proto_data_get(ipvs, IPPROTO_UDP);
367 
368 	hash = udp_app_hashkey(port);
369 
370 	list_for_each_entry(i, &ipvs->udp_apps[hash], p_list) {
371 		if (i->port == port) {
372 			ret = -EEXIST;
373 			goto out;
374 		}
375 	}
376 	list_add_rcu(&inc->p_list, &ipvs->udp_apps[hash]);
377 	atomic_inc(&pd->appcnt);
378 
379   out:
380 	return ret;
381 }
382 
383 
384 static void
385 udp_unregister_app(struct netns_ipvs *ipvs, struct ip_vs_app *inc)
386 {
387 	struct ip_vs_proto_data *pd = ip_vs_proto_data_get(ipvs, IPPROTO_UDP);
388 
389 	atomic_dec(&pd->appcnt);
390 	list_del_rcu(&inc->p_list);
391 }
392 
393 
394 static int udp_app_conn_bind(struct ip_vs_conn *cp)
395 {
396 	struct netns_ipvs *ipvs = cp->ipvs;
397 	int hash;
398 	struct ip_vs_app *inc;
399 	int result = 0;
400 
401 	/* Default binding: bind app only for NAT */
402 	if (IP_VS_FWD_METHOD(cp) != IP_VS_CONN_F_MASQ)
403 		return 0;
404 
405 	/* Lookup application incarnations and bind the right one */
406 	hash = udp_app_hashkey(cp->vport);
407 
408 	list_for_each_entry_rcu(inc, &ipvs->udp_apps[hash], p_list) {
409 		if (inc->port == cp->vport) {
410 			if (unlikely(!ip_vs_app_inc_get(inc)))
411 				break;
412 
413 			IP_VS_DBG_BUF(9, "%s(): Binding conn %s:%u->"
414 				      "%s:%u to app %s on port %u\n",
415 				      __func__,
416 				      IP_VS_DBG_ADDR(cp->af, &cp->caddr),
417 				      ntohs(cp->cport),
418 				      IP_VS_DBG_ADDR(cp->af, &cp->vaddr),
419 				      ntohs(cp->vport),
420 				      inc->name, ntohs(inc->port));
421 
422 			cp->app = inc;
423 			if (inc->init_conn)
424 				result = inc->init_conn(inc, cp);
425 			break;
426 		}
427 	}
428 
429 	return result;
430 }
431 
432 
433 static const int udp_timeouts[IP_VS_UDP_S_LAST+1] = {
434 	[IP_VS_UDP_S_NORMAL]		=	5*60*HZ,
435 	[IP_VS_UDP_S_LAST]		=	2*HZ,
436 };
437 
438 static const char *const udp_state_name_table[IP_VS_UDP_S_LAST+1] = {
439 	[IP_VS_UDP_S_NORMAL]		=	"UDP",
440 	[IP_VS_UDP_S_LAST]		=	"BUG!",
441 };
442 
443 static const char * udp_state_name(int state)
444 {
445 	if (state >= IP_VS_UDP_S_LAST)
446 		return "ERR!";
447 	return udp_state_name_table[state] ? udp_state_name_table[state] : "?";
448 }
449 
450 static void
451 udp_state_transition(struct ip_vs_conn *cp, int direction,
452 		     const struct sk_buff *skb,
453 		     struct ip_vs_proto_data *pd)
454 {
455 	if (unlikely(!pd)) {
456 		pr_err("UDP no ns data\n");
457 		return;
458 	}
459 
460 	cp->timeout = pd->timeout_table[IP_VS_UDP_S_NORMAL];
461 	if (direction == IP_VS_DIR_OUTPUT)
462 		ip_vs_control_assure_ct(cp);
463 }
464 
465 static int __udp_init(struct netns_ipvs *ipvs, struct ip_vs_proto_data *pd)
466 {
467 	ip_vs_init_hash_table(ipvs->udp_apps, UDP_APP_TAB_SIZE);
468 	pd->timeout_table = ip_vs_create_timeout_table((int *)udp_timeouts,
469 							sizeof(udp_timeouts));
470 	if (!pd->timeout_table)
471 		return -ENOMEM;
472 	return 0;
473 }
474 
475 static void __udp_exit(struct netns_ipvs *ipvs, struct ip_vs_proto_data *pd)
476 {
477 	kfree(pd->timeout_table);
478 }
479 
480 
481 struct ip_vs_protocol ip_vs_protocol_udp = {
482 	.name =			"UDP",
483 	.protocol =		IPPROTO_UDP,
484 	.num_states =		IP_VS_UDP_S_LAST,
485 	.dont_defrag =		0,
486 	.init =			NULL,
487 	.exit =			NULL,
488 	.init_netns =		__udp_init,
489 	.exit_netns =		__udp_exit,
490 	.conn_schedule =	udp_conn_schedule,
491 	.conn_in_get =		ip_vs_conn_in_get_proto,
492 	.conn_out_get =		ip_vs_conn_out_get_proto,
493 	.snat_handler =		udp_snat_handler,
494 	.dnat_handler =		udp_dnat_handler,
495 	.state_transition =	udp_state_transition,
496 	.state_name =		udp_state_name,
497 	.register_app =		udp_register_app,
498 	.unregister_app =	udp_unregister_app,
499 	.app_conn_bind =	udp_app_conn_bind,
500 	.debug_packet =		ip_vs_tcpudp_debug_packet,
501 	.timeout_change =	NULL,
502 };
503